mirror of
				https://github.com/aykhans/slash-e.git
				synced 2025-10-24 22:10:58 +00:00 
			
		
		
		
	refactor: migration auth api to v1
This commit is contained in:
		| @@ -1,4 +1,4 @@ | |||||||
| package server | package v1 | ||||||
| 
 | 
 | ||||||
| import ( | import ( | ||||||
| 	"encoding/json" | 	"encoding/json" | ||||||
| @@ -6,31 +6,48 @@ import ( | |||||||
| 	"net/http" | 	"net/http" | ||||||
| 
 | 
 | ||||||
| 	"github.com/boojack/shortify/api" | 	"github.com/boojack/shortify/api" | ||||||
| 	"github.com/google/uuid" | 	"github.com/boojack/shortify/server/auth" | ||||||
| 
 | 	"github.com/boojack/shortify/store" | ||||||
| 	"github.com/labstack/echo/v4" | 	"github.com/labstack/echo/v4" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
| func (s *Server) registerAuthRoutes(g *echo.Group) { | var ( | ||||||
|  | 	userIDContextKey = "user-id" | ||||||
|  | ) | ||||||
|  | 
 | ||||||
|  | func getUserIDContextKey() string { | ||||||
|  | 	return userIDContextKey | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type SignUpRequest struct { | ||||||
|  | 	Username string `json:"username"` | ||||||
|  | 	Password string `json:"password"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | type SignInRequest struct { | ||||||
|  | 	Username string `json:"username"` | ||||||
|  | 	Password string `json:"password"` | ||||||
|  | } | ||||||
|  | 
 | ||||||
|  | func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { | ||||||
| 	g.POST("/auth/signin", func(c echo.Context) error { | 	g.POST("/auth/signin", func(c echo.Context) error { | ||||||
| 		ctx := c.Request().Context() | 		ctx := c.Request().Context() | ||||||
| 		signin := &api.Signin{} | 		signin := &SignInRequest{} | ||||||
| 		if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil { | 		if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) | 			return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		userFind := &api.UserFind{ | 		user, err := s.Store.GetUserV1(ctx, &store.FindUser{ | ||||||
| 			Email: &signin.Email, | 			Username: &signin.Username, | ||||||
| 		} | 		}) | ||||||
| 		user, err := s.Store.FindUser(ctx, userFind) |  | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by email %s", signin.Email)).SetInternal(err) | 			return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by username %s", signin.Username)).SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 		if user == nil { | 		if user == nil { | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("User not found with email %s", signin.Email)) | 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("User not found with username %s", signin.Username)) | ||||||
| 		} else if user.RowStatus == api.Archived { | 		} else if user.RowStatus == store.Archived { | ||||||
| 			return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", signin.Email)) | 			return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", signin.Username)) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		// Compare the stored hashed password, with the hashed version of the password that was received. | 		// Compare the stored hashed password, with the hashed version of the password that was received. | ||||||
| @@ -39,34 +56,28 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { | |||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err) | 			return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		if err = setUserSession(c, user); err != nil { | 		if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signin session").SetInternal(err) | 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 		return c.JSON(http.StatusOK, composeResponse(user)) | 		return c.JSON(http.StatusOK, user) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	g.POST("/auth/signup", func(c echo.Context) error { | 	g.POST("/auth/signup", func(c echo.Context) error { | ||||||
| 		ctx := c.Request().Context() | 		ctx := c.Request().Context() | ||||||
| 		signup := &api.Signup{} | 		signup := &SignUpRequest{} | ||||||
| 		if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil { | 		if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) | 			return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		userCreate := &api.UserCreate{ | 		user := &store.User{ | ||||||
| 			Email:       signup.Email, | 			Username: signup.Username, | ||||||
| 			DisplayName: signup.DisplayName, | 			Nickname: signup.Username, | ||||||
| 			Password:    signup.Password, |  | ||||||
| 			OpenID:      genUUID(), |  | ||||||
| 		} | 		} | ||||||
| 		if err := userCreate.Validate(); err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format.").SetInternal(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost) | 		passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err) | 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 		userCreate.PasswordHash = string(passwordHash) | 		user.PasswordHash = string(passwordHash) | ||||||
| 
 | 
 | ||||||
| 		existingUsers, err := s.Store.FindUserList(ctx, &api.UserFind{}) | 		existingUsers, err := s.Store.FindUserList(ctx, &api.UserFind{}) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -74,35 +85,26 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { | |||||||
| 		} | 		} | ||||||
| 		// The first user to sign up is an admin by default. | 		// The first user to sign up is an admin by default. | ||||||
| 		if len(existingUsers) == 0 { | 		if len(existingUsers) == 0 { | ||||||
| 			userCreate.Role = api.RoleAdmin | 			user.Role = store.RoleAdmin | ||||||
| 		} else { | 		} else { | ||||||
| 			userCreate.Role = api.RoleUser | 			user.Role = store.RoleUser | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		user, err := s.Store.CreateUser(ctx, userCreate) | 		user, err = s.Store.CreateUserV1(ctx, user) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) | 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		err = setUserSession(c, user) | 		if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { | ||||||
| 		if err != nil { | 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signup session").SetInternal(err) |  | ||||||
| 		} | 		} | ||||||
| 
 | 
 | ||||||
| 		return c.JSON(http.StatusOK, composeResponse(user)) | 		return c.JSON(http.StatusOK, user) | ||||||
| 	}) | 	}) | ||||||
| 
 | 
 | ||||||
| 	g.POST("/auth/logout", func(c echo.Context) error { | 	g.POST("/auth/logout", func(c echo.Context) error { | ||||||
| 		err := removeUserSession(c) | 		auth.RemoveTokensAndCookies(c) | ||||||
| 		if err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err) |  | ||||||
| 		} |  | ||||||
| 
 |  | ||||||
| 		c.Response().WriteHeader(http.StatusOK) | 		c.Response().WriteHeader(http.StatusOK) | ||||||
| 		return nil | 		return nil | ||||||
| 	}) | 	}) | ||||||
| } | } | ||||||
| 
 |  | ||||||
| func genUUID() string { |  | ||||||
| 	return uuid.New().String() |  | ||||||
| } |  | ||||||
| @@ -10,8 +10,8 @@ const ( | |||||||
| 	Archived RowStatus = "ARCHIVED" | 	Archived RowStatus = "ARCHIVED" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (e RowStatus) String() string { | func (status RowStatus) String() string { | ||||||
| 	switch e { | 	switch status { | ||||||
| 	case Normal: | 	case Normal: | ||||||
| 		return "NORMAL" | 		return "NORMAL" | ||||||
| 	case Archived: | 	case Archived: | ||||||
|   | |||||||
| @@ -2,8 +2,10 @@ package v1 | |||||||
|  |  | ||||||
| import ( | import ( | ||||||
| 	"fmt" | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
| 	"net/mail" | 	"net/mail" | ||||||
|  |  | ||||||
|  | 	"github.com/boojack/shortify/store" | ||||||
| 	"github.com/labstack/echo/v4" | 	"github.com/labstack/echo/v4" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -17,8 +19,8 @@ const ( | |||||||
| 	RoleUser Role = "USER" | 	RoleUser Role = "USER" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| func (e Role) String() string { | func (r Role) String() string { | ||||||
| 	switch e { | 	switch r { | ||||||
| 	case RoleAdmin: | 	case RoleAdmin: | ||||||
| 		return "ADMIN" | 		return "ADMIN" | ||||||
| 	case RoleUser: | 	case RoleUser: | ||||||
| @@ -39,7 +41,6 @@ type User struct { | |||||||
| 	Email           string         `json:"email"` | 	Email           string         `json:"email"` | ||||||
| 	DisplayName     string         `json:"displayName"` | 	DisplayName     string         `json:"displayName"` | ||||||
| 	PasswordHash    string         `json:"-"` | 	PasswordHash    string         `json:"-"` | ||||||
| 	OpenID          string         `json:"openId"` |  | ||||||
| 	Role            Role           `json:"role"` | 	Role            Role           `json:"role"` | ||||||
| 	UserSettingList []*UserSetting `json:"userSettingList"` | 	UserSettingList []*UserSetting `json:"userSettingList"` | ||||||
| } | } | ||||||
| @@ -49,7 +50,6 @@ type UserCreate struct { | |||||||
| 	DisplayName  string `json:"displayName"` | 	DisplayName  string `json:"displayName"` | ||||||
| 	Password     string `json:"password"` | 	Password     string `json:"password"` | ||||||
| 	PasswordHash string `json:"-"` | 	PasswordHash string `json:"-"` | ||||||
| 	OpenID       string `json:"-"` |  | ||||||
| 	Role         Role   `json:"-"` | 	Role         Role   `json:"-"` | ||||||
| } | } | ||||||
|  |  | ||||||
| @@ -77,32 +77,35 @@ type UserPatch struct { | |||||||
| 	Email        *string `json:"email"` | 	Email        *string `json:"email"` | ||||||
| 	DisplayName  *string `json:"displayName"` | 	DisplayName  *string `json:"displayName"` | ||||||
| 	Password     *string `json:"password"` | 	Password     *string `json:"password"` | ||||||
| 	ResetOpenID  *bool   `json:"resetOpenId"` |  | ||||||
| 	PasswordHash *string `json:"-"` | 	PasswordHash *string `json:"-"` | ||||||
| 	OpenID       *string `json:"-"` |  | ||||||
| } |  | ||||||
|  |  | ||||||
| type UserFind struct { |  | ||||||
| 	ID *int `json:"id"` |  | ||||||
|  |  | ||||||
| 	// Standard fields |  | ||||||
| 	RowStatus *RowStatus `json:"rowStatus"` |  | ||||||
|  |  | ||||||
| 	// Domain specific fields |  | ||||||
| 	Email       *string `json:"email"` |  | ||||||
| 	DisplayName *string `json:"displayName"` |  | ||||||
| 	OpenID      *string `json:"openId"` |  | ||||||
| 	Role        *Role   `json:"-"` |  | ||||||
| } | } | ||||||
|  |  | ||||||
| type UserDelete struct { | type UserDelete struct { | ||||||
| 	ID int | 	ID int | ||||||
| } | } | ||||||
|  |  | ||||||
| func (*APIV1Service) RegisterUserRoutes(g *echo.Group) { | func (s *APIV1Service) registerUserRoutes(g *echo.Group) { | ||||||
| 	g.GET("/user", func(c echo.Context) error { | 	g.GET("/user", func(c echo.Context) error { | ||||||
| 		return c.String(200, "GET /user") | 		return c.String(200, "GET /user") | ||||||
| 	}) | 	}) | ||||||
|  |  | ||||||
|  | 	// GET /api/user/me is used to check if the user is logged in. | ||||||
|  | 	g.GET("/user/me", func(c echo.Context) error { | ||||||
|  | 		ctx := c.Request().Context() | ||||||
|  | 		userID, ok := c.Get(getUserIDContextKey()).(int) | ||||||
|  | 		if !ok { | ||||||
|  | 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		user, err := s.Store.GetUserV1(ctx, &store.FindUser{ | ||||||
|  | 			ID: &userID, | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		return c.JSON(http.StatusOK, user) | ||||||
|  | 	}) | ||||||
| } | } | ||||||
|  |  | ||||||
| // validateEmail validates the email. | // validateEmail validates the email. | ||||||
|   | |||||||
| @@ -18,6 +18,7 @@ func NewAPIV1Service(profile *profile.Profile, store *store.Store) *APIV1Service | |||||||
| 	} | 	} | ||||||
| } | } | ||||||
|  |  | ||||||
| func (s *APIV1Service) Start(apiV1Group *echo.Group) { | func (s *APIV1Service) Start(apiV1Group *echo.Group, secret string) { | ||||||
| 	s.RegisterUserRoutes(apiV1Group) | 	s.registerAuthRoutes(apiV1Group, secret) | ||||||
|  | 	s.registerUserRoutes(apiV1Group) | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										2
									
								
								go.mod
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								go.mod
									
									
									
									
									
								
							| @@ -64,6 +64,8 @@ require ( | |||||||
| ) | ) | ||||||
|  |  | ||||||
| require ( | require ( | ||||||
|  | 	github.com/golang-jwt/jwt/v4 v4.5.0 | ||||||
|  | 	github.com/pkg/errors v0.9.1 | ||||||
| 	golang.org/x/mod v0.8.0 | 	golang.org/x/mod v0.8.0 | ||||||
| 	modernc.org/sqlite v1.23.1 | 	modernc.org/sqlite v1.23.1 | ||||||
| ) | ) | ||||||
|   | |||||||
							
								
								
									
										3
									
								
								go.sum
									
									
									
									
									
								
							
							
						
						
									
										3
									
								
								go.sum
									
									
									
									
									
								
							| @@ -66,6 +66,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 | |||||||
| github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= | github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= | ||||||
| github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= | ||||||
| github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= | ||||||
|  | github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= | ||||||
|  | github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= | ||||||
| github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= | github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= | ||||||
| github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | ||||||
| github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= | ||||||
| @@ -169,6 +171,7 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua | |||||||
| github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= | github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= | ||||||
| github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= | ||||||
|  | github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= | ||||||
| github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= | ||||||
| github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= | github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= | ||||||
| github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= | ||||||
|   | |||||||
| @@ -1,93 +0,0 @@ | |||||||
| package server |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"fmt" |  | ||||||
| 	"net/http" |  | ||||||
| 	"strconv" |  | ||||||
|  |  | ||||||
| 	"github.com/boojack/shortify/api" |  | ||||||
|  |  | ||||||
| 	"github.com/gorilla/sessions" |  | ||||||
| 	"github.com/labstack/echo-contrib/session" |  | ||||||
| 	"github.com/labstack/echo/v4" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| var ( |  | ||||||
| 	userIDContextKey = "user-id" |  | ||||||
| 	sessionName      = "shortify-session" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func getUserIDContextKey() string { |  | ||||||
| 	return userIDContextKey |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func setUserSession(ctx echo.Context, user *api.User) error { |  | ||||||
| 	sess, _ := session.Get(sessionName, ctx) |  | ||||||
| 	sess.Options = &sessions.Options{ |  | ||||||
| 		Path:     "/", |  | ||||||
| 		MaxAge:   1000 * 3600 * 24 * 30, |  | ||||||
| 		HttpOnly: true, |  | ||||||
| 	} |  | ||||||
| 	sess.Values[userIDContextKey] = user.ID |  | ||||||
| 	err := sess.Save(ctx.Request(), ctx.Response()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("failed to set session, err: %w", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func removeUserSession(ctx echo.Context) error { |  | ||||||
| 	sess, _ := session.Get(sessionName, ctx) |  | ||||||
| 	sess.Options = &sessions.Options{ |  | ||||||
| 		Path:     "/", |  | ||||||
| 		MaxAge:   0, |  | ||||||
| 		HttpOnly: true, |  | ||||||
| 	} |  | ||||||
| 	sess.Values[userIDContextKey] = nil |  | ||||||
| 	err := sess.Save(ctx.Request(), ctx.Response()) |  | ||||||
| 	if err != nil { |  | ||||||
| 		return fmt.Errorf("failed to set session, err: %w", err) |  | ||||||
| 	} |  | ||||||
| 	return nil |  | ||||||
| } |  | ||||||
|  |  | ||||||
| func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { |  | ||||||
| 	return func(c echo.Context) error { |  | ||||||
| 		ctx := c.Request().Context() |  | ||||||
| 		path := c.Path() |  | ||||||
|  |  | ||||||
| 		if s.defaultAuthSkipper(c) { |  | ||||||
| 			return next(c) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		sess, _ := session.Get(sessionName, c) |  | ||||||
| 		userIDValue := sess.Values[userIDContextKey] |  | ||||||
| 		if userIDValue != nil { |  | ||||||
| 			userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) |  | ||||||
| 			userFind := &api.UserFind{ |  | ||||||
| 				ID: &userID, |  | ||||||
| 			} |  | ||||||
| 			user, err := s.Store.FindUser(ctx, userFind) |  | ||||||
| 			if err != nil { |  | ||||||
| 				return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) |  | ||||||
| 			} |  | ||||||
| 			if user != nil { |  | ||||||
| 				if user.RowStatus == api.Archived { |  | ||||||
| 					return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) |  | ||||||
| 				} |  | ||||||
| 				c.Set(getUserIDContextKey(), userID) |  | ||||||
| 			} |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		if hasPrefixes(path, "/api/ping", "/api/status", "/api/workspace") && c.Request().Method == http.MethodGet { |  | ||||||
| 			return next(c) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		userID := c.Get(getUserIDContextKey()) |  | ||||||
| 		if userID == nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		return next(c) |  | ||||||
| 	} |  | ||||||
| } |  | ||||||
							
								
								
									
										133
									
								
								server/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										133
									
								
								server/auth/auth.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,133 @@ | |||||||
|  | package auth | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/boojack/shortify/store" | ||||||
|  | 	"github.com/golang-jwt/jwt/v4" | ||||||
|  | 	"github.com/labstack/echo/v4" | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	issuer = "shortify" | ||||||
|  | 	// Signing key section. For now, this is only used for signing, not for verifying since we only | ||||||
|  | 	// have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. | ||||||
|  | 	keyID = "v1" | ||||||
|  | 	// AccessTokenAudienceName is the audience name of the access token. | ||||||
|  | 	AccessTokenAudienceName = "user.access-token" | ||||||
|  | 	// RefreshTokenAudienceName is the audience name of the refresh token. | ||||||
|  | 	RefreshTokenAudienceName = "user.refresh-token" | ||||||
|  | 	apiTokenDuration         = 2 * time.Hour | ||||||
|  | 	accessTokenDuration      = 24 * time.Hour | ||||||
|  | 	refreshTokenDuration     = 7 * 24 * time.Hour | ||||||
|  | 	// RefreshThresholdDuration is the threshold duration for refreshing token. | ||||||
|  | 	RefreshThresholdDuration = 1 * time.Hour | ||||||
|  |  | ||||||
|  | 	// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user | ||||||
|  | 	// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. | ||||||
|  | 	// Suppose we have a valid refresh token, we will refresh the token in 2 cases: | ||||||
|  | 	// 1. The access token is about to expire in <<refreshThresholdDuration>> | ||||||
|  | 	// 2. The access token has already expired, we refresh the token so that the ongoing request can pass through. | ||||||
|  | 	CookieExpDuration = refreshTokenDuration - 1*time.Minute | ||||||
|  | 	// AccessTokenCookieName is the cookie name of access token. | ||||||
|  | 	AccessTokenCookieName = "access-token" | ||||||
|  | 	// RefreshTokenCookieName is the cookie name of refresh token. | ||||||
|  | 	RefreshTokenCookieName = "refresh-token" | ||||||
|  | 	// UserIDCookieName is the cookie name of user ID. | ||||||
|  | 	UserIDCookieName = "user" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | type claimsMessage struct { | ||||||
|  | 	Name string `json:"name"` | ||||||
|  | 	jwt.RegisteredClaims | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GenerateAPIToken generates an API token. | ||||||
|  | func GenerateAPIToken(userName string, userID int, secret string) (string, error) { | ||||||
|  | 	expirationTime := time.Now().Add(apiTokenDuration) | ||||||
|  | 	return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GenerateAccessToken generates an access token for web. | ||||||
|  | func GenerateAccessToken(userName string, userID int, secret string) (string, error) { | ||||||
|  | 	expirationTime := time.Now().Add(accessTokenDuration) | ||||||
|  | 	return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GenerateRefreshToken generates a refresh token for web. | ||||||
|  | func GenerateRefreshToken(userName string, userID int, secret string) (string, error) { | ||||||
|  | 	expirationTime := time.Now().Add(refreshTokenDuration) | ||||||
|  | 	return generateToken(userName, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. | ||||||
|  | func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { | ||||||
|  | 	accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.Wrap(err, "failed to generate access token") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	cookieExp := time.Now().Add(CookieExpDuration) | ||||||
|  | 	setTokenCookie(c, AccessTokenCookieName, accessToken, cookieExp) | ||||||
|  |  | ||||||
|  | 	// We generate here a new refresh token and saving it to the cookie. | ||||||
|  | 	refreshToken, err := GenerateRefreshToken(user.Username, user.ID, secret) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return errors.Wrap(err, "failed to generate refresh token") | ||||||
|  | 	} | ||||||
|  | 	setTokenCookie(c, RefreshTokenCookieName, refreshToken, cookieExp) | ||||||
|  |  | ||||||
|  | 	return nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. | ||||||
|  | func RemoveTokensAndCookies(c echo.Context) { | ||||||
|  | 	// We set the expiration time to the past, so that the cookie will be removed. | ||||||
|  | 	cookieExp := time.Now().Add(-1 * time.Hour) | ||||||
|  | 	setTokenCookie(c, AccessTokenCookieName, "", cookieExp) | ||||||
|  | 	setTokenCookie(c, RefreshTokenCookieName, "", cookieExp) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // setTokenCookie sets the token to the cookie. | ||||||
|  | func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { | ||||||
|  | 	cookie := new(http.Cookie) | ||||||
|  | 	cookie.Name = name | ||||||
|  | 	cookie.Value = token | ||||||
|  | 	cookie.Expires = expiration | ||||||
|  | 	cookie.Path = "/" | ||||||
|  | 	// Http-only helps mitigate the risk of client side script accessing the protected cookie. | ||||||
|  | 	cookie.HttpOnly = true | ||||||
|  | 	cookie.SameSite = http.SameSiteStrictMode | ||||||
|  | 	c.SetCookie(cookie) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // generateToken generates a jwt token. | ||||||
|  | func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { | ||||||
|  | 	// Create the JWT claims, which includes the username and expiry time. | ||||||
|  | 	claims := &claimsMessage{ | ||||||
|  | 		Name: username, | ||||||
|  | 		RegisteredClaims: jwt.RegisteredClaims{ | ||||||
|  | 			Audience: jwt.ClaimStrings{aud}, | ||||||
|  | 			// In JWT, the expiry time is expressed as unix milliseconds. | ||||||
|  | 			ExpiresAt: jwt.NewNumericDate(expirationTime), | ||||||
|  | 			IssuedAt:  jwt.NewNumericDate(time.Now()), | ||||||
|  | 			Issuer:    issuer, | ||||||
|  | 			Subject:   strconv.Itoa(userID), | ||||||
|  | 		}, | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	// Declare the token with the HS256 algorithm used for signing, and the claims. | ||||||
|  | 	token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) | ||||||
|  | 	token.Header["kid"] = keyID | ||||||
|  |  | ||||||
|  | 	// Create the JWT string. | ||||||
|  | 	tokenString, err := token.SignedString(secret) | ||||||
|  | 	if err != nil { | ||||||
|  | 		return "", err | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return tokenString, nil | ||||||
|  | } | ||||||
| @@ -3,7 +3,6 @@ package server | |||||||
| import ( | import ( | ||||||
| 	"strings" | 	"strings" | ||||||
|  |  | ||||||
| 	"github.com/boojack/shortify/api" |  | ||||||
| 	"github.com/labstack/echo/v4" | 	"github.com/labstack/echo/v4" | ||||||
| ) | ) | ||||||
|  |  | ||||||
| @@ -29,34 +28,10 @@ func hasPrefixes(src string, prefixes ...string) bool { | |||||||
|  |  | ||||||
| func defaultAPIRequestSkipper(c echo.Context) bool { | func defaultAPIRequestSkipper(c echo.Context) bool { | ||||||
| 	path := c.Path() | 	path := c.Path() | ||||||
| 	return hasPrefixes(path, "/api", "/o") | 	return hasPrefixes(path, "/api") | ||||||
| } | } | ||||||
|  |  | ||||||
| func (server *Server) defaultAuthSkipper(c echo.Context) bool { | func (*Server) defaultAuthSkipper(c echo.Context) bool { | ||||||
| 	ctx := c.Request().Context() |  | ||||||
| 	path := c.Path() | 	path := c.Path() | ||||||
|  | 	return hasPrefixes(path, "/api/v1/auth") | ||||||
| 	// Skip auth. |  | ||||||
| 	if hasPrefixes(path, "/api/auth") { |  | ||||||
| 		return true |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	// If there is openId in query string and related user is found, then skip auth. |  | ||||||
| 	openID := c.QueryParam("openId") |  | ||||||
| 	if openID != "" { |  | ||||||
| 		userFind := &api.UserFind{ |  | ||||||
| 			OpenID: &openID, |  | ||||||
| 		} |  | ||||||
| 		user, err := server.Store.FindUser(ctx, userFind) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return false |  | ||||||
| 		} |  | ||||||
| 		if user != nil { |  | ||||||
| 			// Stores userID into context. |  | ||||||
| 			c.Set(getUserIDContextKey(), user.ID) |  | ||||||
| 			return true |  | ||||||
| 		} |  | ||||||
| 	} |  | ||||||
|  |  | ||||||
| 	return false |  | ||||||
| } | } | ||||||
|   | |||||||
							
								
								
									
										201
									
								
								server/jwt.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										201
									
								
								server/jwt.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,201 @@ | |||||||
|  | package server | ||||||
|  |  | ||||||
|  | import ( | ||||||
|  | 	"fmt" | ||||||
|  | 	"net/http" | ||||||
|  | 	"strconv" | ||||||
|  | 	"strings" | ||||||
|  | 	"time" | ||||||
|  |  | ||||||
|  | 	"github.com/boojack/shortify/server/auth" | ||||||
|  | 	"github.com/boojack/shortify/store" | ||||||
|  | 	"github.com/golang-jwt/jwt/v4" | ||||||
|  | 	"github.com/labstack/echo/v4" | ||||||
|  | 	"github.com/pkg/errors" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | const ( | ||||||
|  | 	// Context section | ||||||
|  | 	// The key name used to store user id in the context | ||||||
|  | 	// user id is extracted from the jwt token subject field. | ||||||
|  | 	userIDContextKey = "user-id" | ||||||
|  | ) | ||||||
|  |  | ||||||
|  | func getUserIDContextKey() string { | ||||||
|  | 	return userIDContextKey | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // Claims creates a struct that will be encoded to a JWT. | ||||||
|  | // We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. | ||||||
|  | type Claims struct { | ||||||
|  | 	Name string `json:"name"` | ||||||
|  | 	jwt.RegisteredClaims | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func extractTokenFromHeader(c echo.Context) (string, error) { | ||||||
|  | 	authHeader := c.Request().Header.Get("Authorization") | ||||||
|  | 	if authHeader == "" { | ||||||
|  | 		return "", nil | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	authHeaderParts := strings.Fields(authHeader) | ||||||
|  | 	if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { | ||||||
|  | 		return "", errors.New("Authorization header format must be Bearer {token}") | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return authHeaderParts[1], nil | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func findAccessToken(c echo.Context) string { | ||||||
|  | 	accessToken := "" | ||||||
|  | 	cookie, _ := c.Cookie(auth.AccessTokenCookieName) | ||||||
|  | 	if cookie != nil { | ||||||
|  | 		accessToken = cookie.Value | ||||||
|  | 	} | ||||||
|  | 	if accessToken == "" { | ||||||
|  | 		accessToken, _ = extractTokenFromHeader(c) | ||||||
|  | 	} | ||||||
|  |  | ||||||
|  | 	return accessToken | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func audienceContains(audience jwt.ClaimStrings, token string) bool { | ||||||
|  | 	for _, v := range audience { | ||||||
|  | 		if v == token { | ||||||
|  | 			return true | ||||||
|  | 		} | ||||||
|  | 	} | ||||||
|  | 	return false | ||||||
|  | } | ||||||
|  |  | ||||||
|  | // JWTMiddleware validates the access token. | ||||||
|  | // If the access token is about to expire or has expired and the request has a valid refresh token, it | ||||||
|  | // will try to generate new access token and refresh token. | ||||||
|  | func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { | ||||||
|  | 	return func(c echo.Context) error { | ||||||
|  | 		path := c.Request().URL.Path | ||||||
|  | 		method := c.Request().Method | ||||||
|  |  | ||||||
|  | 		if server.defaultAuthSkipper(c) { | ||||||
|  | 			return next(c) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Skip validation for server status endpoints. | ||||||
|  | 		if hasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet { | ||||||
|  | 			return next(c) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		token := findAccessToken(c) | ||||||
|  | 		if token == "" { | ||||||
|  | 			// When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts. | ||||||
|  | 			if hasPrefixes(path, "/api/status", "/api/shortcut") && method == http.MethodGet { | ||||||
|  | 				return next(c) | ||||||
|  | 			} | ||||||
|  | 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		claims := &Claims{} | ||||||
|  | 		accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { | ||||||
|  | 			if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||||
|  | 				return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||||
|  | 			} | ||||||
|  | 			if kid, ok := t.Header["kid"].(string); ok { | ||||||
|  | 				if kid == "v1" { | ||||||
|  | 					return []byte(secret), nil | ||||||
|  | 				} | ||||||
|  | 			} | ||||||
|  | 			return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) | ||||||
|  | 		}) | ||||||
|  | 		if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { | ||||||
|  | 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) | ||||||
|  | 		} | ||||||
|  | 		generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration | ||||||
|  | 		if err != nil { | ||||||
|  | 			var ve *jwt.ValidationError | ||||||
|  | 			if errors.As(err, &ve) { | ||||||
|  | 				// If expiration error is the only error, we will clear the err | ||||||
|  | 				// and generate new access token and refresh token | ||||||
|  | 				if ve.Errors == jwt.ValidationErrorExpired { | ||||||
|  | 					generateToken = true | ||||||
|  | 				} | ||||||
|  | 			} else { | ||||||
|  | 				return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// We either have a valid access token or we will attempt to generate new access token and refresh token | ||||||
|  | 		ctx := c.Request().Context() | ||||||
|  | 		userID, err := strconv.Atoi(claims.Subject) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Even if there is no error, we still need to make sure the user still exists. | ||||||
|  | 		user, err := server.Store.GetUserV1(ctx, &store.FindUser{ | ||||||
|  | 			ID: &userID, | ||||||
|  | 		}) | ||||||
|  | 		if err != nil { | ||||||
|  | 			return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) | ||||||
|  | 		} | ||||||
|  | 		if user == nil { | ||||||
|  | 			return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		if generateToken { | ||||||
|  | 			generateTokenFunc := func() error { | ||||||
|  | 				rc, err := c.Cookie(auth.RefreshTokenCookieName) | ||||||
|  | 				if err != nil { | ||||||
|  | 					return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// Parses token and checks if it's valid. | ||||||
|  | 				refreshTokenClaims := &Claims{} | ||||||
|  | 				refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { | ||||||
|  | 					if t.Method.Alg() != jwt.SigningMethodHS256.Name { | ||||||
|  | 						return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) | ||||||
|  | 					} | ||||||
|  |  | ||||||
|  | 					if kid, ok := t.Header["kid"].(string); ok { | ||||||
|  | 						if kid == "v1" { | ||||||
|  | 							return []byte(secret), nil | ||||||
|  | 						} | ||||||
|  | 					} | ||||||
|  | 					return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) | ||||||
|  | 				}) | ||||||
|  | 				if err != nil { | ||||||
|  | 					if err == jwt.ErrSignatureInvalid { | ||||||
|  | 						return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") | ||||||
|  | 					} | ||||||
|  | 					return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { | ||||||
|  | 					return echo.NewHTTPError(http.StatusUnauthorized, | ||||||
|  | 						fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", | ||||||
|  | 							refreshTokenClaims.Audience, | ||||||
|  | 							auth.RefreshTokenAudienceName, | ||||||
|  | 						)) | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				// If we have a valid refresh token, we will generate new access token and refresh token | ||||||
|  | 				if refreshToken != nil && refreshToken.Valid { | ||||||
|  | 					if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { | ||||||
|  | 						return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) | ||||||
|  | 					} | ||||||
|  | 				} | ||||||
|  |  | ||||||
|  | 				return nil | ||||||
|  | 			} | ||||||
|  |  | ||||||
|  | 			// It may happen that we still have a valid access token, but we encounter issue when trying to generate new token | ||||||
|  | 			// In such case, we won't return the error. | ||||||
|  | 			if err := generateTokenFunc(); err != nil && !accessToken.Valid { | ||||||
|  | 				return err | ||||||
|  | 			} | ||||||
|  | 		} | ||||||
|  |  | ||||||
|  | 		// Stores userID into context. | ||||||
|  | 		c.Set(getUserIDContextKey(), userID) | ||||||
|  | 		return next(c) | ||||||
|  | 	} | ||||||
|  | } | ||||||
| @@ -1,51 +0,0 @@ | |||||||
| package server |  | ||||||
|  |  | ||||||
| import ( |  | ||||||
| 	"net/http" |  | ||||||
|  |  | ||||||
| 	"github.com/boojack/shortify/api" |  | ||||||
| 	"github.com/labstack/echo/v4" |  | ||||||
| ) |  | ||||||
|  |  | ||||||
| func (s *Server) registerRedirectRoutes(g *echo.Group) { |  | ||||||
| 	g.GET("/:shortcutName", func(c echo.Context) error { |  | ||||||
| 		ctx := c.Request().Context() |  | ||||||
| 		userID, ok := c.Get(getUserIDContextKey()).(int) |  | ||||||
| 		if !ok { |  | ||||||
| 			return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") |  | ||||||
| 		} |  | ||||||
| 		shortcutName := c.Param("shortcutName") |  | ||||||
| 		if shortcutName == "" { |  | ||||||
| 			return echo.NewHTTPError(http.StatusBadRequest, "Missing shortcut name") |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		list := []*api.Shortcut{} |  | ||||||
|  |  | ||||||
| 		shortcutFind := &api.ShortcutFind{ |  | ||||||
| 			Name:     &shortcutName, |  | ||||||
| 			MemberID: &userID, |  | ||||||
| 		} |  | ||||||
| 		shortcutFind.VisibilityList = []api.Visibility{api.VisibilityWorkspace, api.VisibilityPublic} |  | ||||||
| 		visibleShortcutList, err := s.Store.FindShortcutList(ctx, shortcutFind) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) |  | ||||||
| 		} |  | ||||||
| 		list = append(list, visibleShortcutList...) |  | ||||||
|  |  | ||||||
| 		shortcutFind.VisibilityList = []api.Visibility{api.VisibilityPrivite} |  | ||||||
| 		shortcutFind.CreatorID = &userID |  | ||||||
| 		privateShortcutList, err := s.Store.FindShortcutList(ctx, shortcutFind) |  | ||||||
| 		if err != nil { |  | ||||||
| 			return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch private shortcut list").SetInternal(err) |  | ||||||
| 		} |  | ||||||
| 		list = append(list, privateShortcutList...) |  | ||||||
|  |  | ||||||
| 		if len(list) == 0 { |  | ||||||
| 			return echo.NewHTTPError(http.StatusNotFound, "Not found shortcut").SetInternal(err) |  | ||||||
| 		} |  | ||||||
|  |  | ||||||
| 		// TODO(steven): improve the matched result later |  | ||||||
| 		matchedShortcut := list[0] |  | ||||||
| 		return c.Redirect(http.StatusPermanentRedirect, matchedShortcut.Link) |  | ||||||
| 	}) |  | ||||||
| } |  | ||||||
| @@ -54,24 +54,17 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) { | |||||||
| 	embedFrontend(e) | 	embedFrontend(e) | ||||||
|  |  | ||||||
| 	// In dev mode, set the const secret key to make signin session persistence. | 	// In dev mode, set the const secret key to make signin session persistence. | ||||||
| 	secret := []byte("iamshortify") | 	secret := "iamshortify" | ||||||
| 	if profile.Mode == "prod" { | 	if profile.Mode == "prod" { | ||||||
| 		secret = securecookie.GenerateRandomKey(16) | 		secret = string(securecookie.GenerateRandomKey(16)) | ||||||
| 	} | 	} | ||||||
| 	e.Use(session.Middleware(sessions.NewCookieStore(secret))) | 	e.Use(session.Middleware(sessions.NewCookieStore([]byte(secret)))) | ||||||
|  |  | ||||||
| 	redirectGroup := e.Group("/o") |  | ||||||
| 	redirectGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { |  | ||||||
| 		return aclMiddleware(s, next) |  | ||||||
| 	}) |  | ||||||
| 	s.registerRedirectRoutes(redirectGroup) |  | ||||||
|  |  | ||||||
| 	apiGroup := e.Group("/api") | 	apiGroup := e.Group("/api") | ||||||
| 	apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { | 	apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { | ||||||
| 		return aclMiddleware(s, next) | 		return JWTMiddleware(s, next, string(secret)) | ||||||
| 	}) | 	}) | ||||||
| 	s.registerSystemRoutes(apiGroup) | 	s.registerSystemRoutes(apiGroup) | ||||||
| 	s.registerAuthRoutes(apiGroup) |  | ||||||
| 	s.registerUserRoutes(apiGroup) | 	s.registerUserRoutes(apiGroup) | ||||||
| 	s.registerWorkspaceRoutes(apiGroup) | 	s.registerWorkspaceRoutes(apiGroup) | ||||||
| 	s.registerWorkspaceUserRoutes(apiGroup) | 	s.registerWorkspaceUserRoutes(apiGroup) | ||||||
| @@ -80,7 +73,10 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) { | |||||||
| 	// Register API v1 routes. | 	// Register API v1 routes. | ||||||
| 	apiV1Service := apiv1.NewAPIV1Service(profile, store) | 	apiV1Service := apiv1.NewAPIV1Service(profile, store) | ||||||
| 	apiV1Group := apiGroup.Group("/api/v1") | 	apiV1Group := apiGroup.Group("/api/v1") | ||||||
| 	apiV1Service.RegisterUserRoutes(apiV1Group) | 	apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { | ||||||
|  | 		return JWTMiddleware(s, next, string(secret)) | ||||||
|  | 	}) | ||||||
|  | 	apiV1Service.Start(apiV1Group, secret) | ||||||
|  |  | ||||||
| 	return s, nil | 	return s, nil | ||||||
| } | } | ||||||
|   | |||||||
| @@ -8,6 +8,7 @@ import ( | |||||||
| 	"strconv" | 	"strconv" | ||||||
|  |  | ||||||
| 	"github.com/boojack/shortify/api" | 	"github.com/boojack/shortify/api" | ||||||
|  | 	"github.com/google/uuid" | ||||||
|  |  | ||||||
| 	"github.com/labstack/echo/v4" | 	"github.com/labstack/echo/v4" | ||||||
| 	"golang.org/x/crypto/bcrypt" | 	"golang.org/x/crypto/bcrypt" | ||||||
| @@ -155,3 +156,7 @@ func validateEmail(email string) bool { | |||||||
| 	} | 	} | ||||||
| 	return true | 	return true | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func genUUID() string { | ||||||
|  | 	return uuid.New().String() | ||||||
|  | } | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user
	 Steven
					Steven