diff --git a/server/auth.go b/api/v1/auth.go similarity index 60% rename from server/auth.go rename to api/v1/auth.go index 280db86..f8ede98 100644 --- a/server/auth.go +++ b/api/v1/auth.go @@ -1,4 +1,4 @@ -package server +package v1 import ( "encoding/json" @@ -6,31 +6,48 @@ import ( "net/http" "github.com/boojack/shortify/api" - "github.com/google/uuid" - + "github.com/boojack/shortify/server/auth" + "github.com/boojack/shortify/store" "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" ) -func (s *Server) registerAuthRoutes(g *echo.Group) { +var ( + userIDContextKey = "user-id" +) + +func getUserIDContextKey() string { + return userIDContextKey +} + +type SignUpRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +type SignInRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { g.POST("/auth/signin", func(c echo.Context) error { ctx := c.Request().Context() - signin := &api.Signin{} + signin := &SignInRequest{} if err := json.NewDecoder(c.Request().Body).Decode(signin); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) } - userFind := &api.UserFind{ - Email: &signin.Email, - } - user, err := s.Store.FindUser(ctx, userFind) + user, err := s.Store.GetUserV1(ctx, &store.FindUser{ + Username: &signin.Username, + }) if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by email %s", signin.Email)).SetInternal(err) + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by username %s", signin.Username)).SetInternal(err) } if user == nil { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("User not found with email %s", signin.Email)) - } else if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", signin.Email)) + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("User not found with username %s", signin.Username)) + } else if user.RowStatus == store.Archived { + return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with username %s", signin.Username)) } // Compare the stored hashed password, with the hashed version of the password that was received. @@ -39,34 +56,28 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err) } - if err = setUserSession(c, user); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signin session").SetInternal(err) + if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(user)) + return c.JSON(http.StatusOK, user) }) g.POST("/auth/signup", func(c echo.Context) error { ctx := c.Request().Context() - signup := &api.Signup{} + signup := &SignUpRequest{} if err := json.NewDecoder(c.Request().Body).Decode(signup); err != nil { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) } - userCreate := &api.UserCreate{ - Email: signup.Email, - DisplayName: signup.DisplayName, - Password: signup.Password, - OpenID: genUUID(), + user := &store.User{ + Username: signup.Username, + Nickname: signup.Username, } - if err := userCreate.Validate(); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format.").SetInternal(err) - } - passwordHash, err := bcrypt.GenerateFromPassword([]byte(signup.Password), bcrypt.DefaultCost) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate password hash").SetInternal(err) } - userCreate.PasswordHash = string(passwordHash) + user.PasswordHash = string(passwordHash) existingUsers, err := s.Store.FindUserList(ctx, &api.UserFind{}) if err != nil { @@ -74,35 +85,26 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { } // The first user to sign up is an admin by default. if len(existingUsers) == 0 { - userCreate.Role = api.RoleAdmin + user.Role = store.RoleAdmin } else { - userCreate.Role = api.RoleUser + user.Role = store.RoleUser } - user, err := s.Store.CreateUser(ctx, userCreate) + user, err = s.Store.CreateUserV1(ctx, user) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } - err = setUserSession(c, user) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signup session").SetInternal(err) + if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to generate tokens").SetInternal(err) } - return c.JSON(http.StatusOK, composeResponse(user)) + return c.JSON(http.StatusOK, user) }) g.POST("/auth/logout", func(c echo.Context) error { - err := removeUserSession(c) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err) - } - + auth.RemoveTokensAndCookies(c) c.Response().WriteHeader(http.StatusOK) return nil }) } - -func genUUID() string { - return uuid.New().String() -} diff --git a/api/v1/common.go b/api/v1/common.go index 850164a..2fc252f 100644 --- a/api/v1/common.go +++ b/api/v1/common.go @@ -10,8 +10,8 @@ const ( Archived RowStatus = "ARCHIVED" ) -func (e RowStatus) String() string { - switch e { +func (status RowStatus) String() string { + switch status { case Normal: return "NORMAL" case Archived: diff --git a/api/v1/user.go b/api/v1/user.go index baaf989..9a2fafb 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -2,8 +2,10 @@ package v1 import ( "fmt" + "net/http" "net/mail" + "github.com/boojack/shortify/store" "github.com/labstack/echo/v4" ) @@ -17,8 +19,8 @@ const ( RoleUser Role = "USER" ) -func (e Role) String() string { - switch e { +func (r Role) String() string { + switch r { case RoleAdmin: return "ADMIN" case RoleUser: @@ -39,7 +41,6 @@ type User struct { Email string `json:"email"` DisplayName string `json:"displayName"` PasswordHash string `json:"-"` - OpenID string `json:"openId"` Role Role `json:"role"` UserSettingList []*UserSetting `json:"userSettingList"` } @@ -49,7 +50,6 @@ type UserCreate struct { DisplayName string `json:"displayName"` Password string `json:"password"` PasswordHash string `json:"-"` - OpenID string `json:"-"` Role Role `json:"-"` } @@ -77,32 +77,35 @@ type UserPatch struct { Email *string `json:"email"` DisplayName *string `json:"displayName"` Password *string `json:"password"` - ResetOpenID *bool `json:"resetOpenId"` PasswordHash *string `json:"-"` - OpenID *string `json:"-"` -} - -type UserFind struct { - ID *int `json:"id"` - - // Standard fields - RowStatus *RowStatus `json:"rowStatus"` - - // Domain specific fields - Email *string `json:"email"` - DisplayName *string `json:"displayName"` - OpenID *string `json:"openId"` - Role *Role `json:"-"` } type UserDelete struct { ID int } -func (*APIV1Service) RegisterUserRoutes(g *echo.Group) { +func (s *APIV1Service) registerUserRoutes(g *echo.Group) { g.GET("/user", func(c echo.Context) error { return c.String(200, "GET /user") }) + + // GET /api/user/me is used to check if the user is logged in. + g.GET("/user/me", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") + } + + user, err := s.Store.GetUserV1(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + + return c.JSON(http.StatusOK, user) + }) } // validateEmail validates the email. diff --git a/api/v1/v1.go b/api/v1/v1.go index a9b5222..a376ada 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -18,6 +18,7 @@ func NewAPIV1Service(profile *profile.Profile, store *store.Store) *APIV1Service } } -func (s *APIV1Service) Start(apiV1Group *echo.Group) { - s.RegisterUserRoutes(apiV1Group) +func (s *APIV1Service) Start(apiV1Group *echo.Group, secret string) { + s.registerAuthRoutes(apiV1Group, secret) + s.registerUserRoutes(apiV1Group) } diff --git a/go.mod b/go.mod index 5834d8b..6c5cf43 100644 --- a/go.mod +++ b/go.mod @@ -64,6 +64,8 @@ require ( ) require ( + github.com/golang-jwt/jwt/v4 v4.5.0 + github.com/pkg/errors v0.9.1 golang.org/x/mod v0.8.0 modernc.org/sqlite v1.23.1 ) diff --git a/go.sum b/go.sum index 27c6836..6bd5dbd 100644 --- a/go.sum +++ b/go.sum @@ -66,6 +66,8 @@ github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2 github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= +github.com/golang-jwt/jwt/v4 v4.5.0 h1:7cYmW1XlMY7h7ii7UhUyChSgS5wUJEnm9uZVTGqOWzg= +github.com/golang-jwt/jwt/v4 v4.5.0/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191227052852-215e87163ea7/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= @@ -169,6 +171,7 @@ github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyua github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/sftp v1.13.1/go.mod h1:3HaPG6Dq1ILlpPZRO0HVMrsydcdLt6HRDccSgb87qRg= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= diff --git a/server/acl.go b/server/acl.go deleted file mode 100644 index 30d6a58..0000000 --- a/server/acl.go +++ /dev/null @@ -1,93 +0,0 @@ -package server - -import ( - "fmt" - "net/http" - "strconv" - - "github.com/boojack/shortify/api" - - "github.com/gorilla/sessions" - "github.com/labstack/echo-contrib/session" - "github.com/labstack/echo/v4" -) - -var ( - userIDContextKey = "user-id" - sessionName = "shortify-session" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - -func setUserSession(ctx echo.Context, user *api.User) error { - sess, _ := session.Get(sessionName, ctx) - sess.Options = &sessions.Options{ - Path: "/", - MaxAge: 1000 * 3600 * 24 * 30, - HttpOnly: true, - } - sess.Values[userIDContextKey] = user.ID - err := sess.Save(ctx.Request(), ctx.Response()) - if err != nil { - return fmt.Errorf("failed to set session, err: %w", err) - } - return nil -} - -func removeUserSession(ctx echo.Context) error { - sess, _ := session.Get(sessionName, ctx) - sess.Options = &sessions.Options{ - Path: "/", - MaxAge: 0, - HttpOnly: true, - } - sess.Values[userIDContextKey] = nil - err := sess.Save(ctx.Request(), ctx.Response()) - if err != nil { - return fmt.Errorf("failed to set session, err: %w", err) - } - return nil -} - -func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc { - return func(c echo.Context) error { - ctx := c.Request().Context() - path := c.Path() - - if s.defaultAuthSkipper(c) { - return next(c) - } - - sess, _ := session.Get(sessionName, c) - userIDValue := sess.Values[userIDContextKey] - if userIDValue != nil { - userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue)) - userFind := &api.UserFind{ - ID: &userID, - } - user, err := s.Store.FindUser(ctx, userFind) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err) - } - if user != nil { - if user.RowStatus == api.Archived { - return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email)) - } - c.Set(getUserIDContextKey(), userID) - } - } - - if hasPrefixes(path, "/api/ping", "/api/status", "/api/workspace") && c.Request().Method == http.MethodGet { - return next(c) - } - - userID := c.Get(getUserIDContextKey()) - if userID == nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - - return next(c) - } -} diff --git a/server/auth/auth.go b/server/auth/auth.go new file mode 100644 index 0000000..ed786cf --- /dev/null +++ b/server/auth/auth.go @@ -0,0 +1,133 @@ +package auth + +import ( + "net/http" + "strconv" + "time" + + "github.com/boojack/shortify/store" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" +) + +const ( + issuer = "shortify" + // Signing key section. For now, this is only used for signing, not for verifying since we only + // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. + keyID = "v1" + // AccessTokenAudienceName is the audience name of the access token. + AccessTokenAudienceName = "user.access-token" + // RefreshTokenAudienceName is the audience name of the refresh token. + RefreshTokenAudienceName = "user.refresh-token" + apiTokenDuration = 2 * time.Hour + accessTokenDuration = 24 * time.Hour + refreshTokenDuration = 7 * 24 * time.Hour + // RefreshThresholdDuration is the threshold duration for refreshing token. + RefreshThresholdDuration = 1 * time.Hour + + // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user + // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. + // Suppose we have a valid refresh token, we will refresh the token in 2 cases: + // 1. The access token is about to expire in <> + // 2. The access token has already expired, we refresh the token so that the ongoing request can pass through. + CookieExpDuration = refreshTokenDuration - 1*time.Minute + // AccessTokenCookieName is the cookie name of access token. + AccessTokenCookieName = "access-token" + // RefreshTokenCookieName is the cookie name of refresh token. + RefreshTokenCookieName = "refresh-token" + // UserIDCookieName is the cookie name of user ID. + UserIDCookieName = "user" +) + +type claimsMessage struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +// GenerateAPIToken generates an API token. +func GenerateAPIToken(userName string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(apiTokenDuration) + return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(userName string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(accessTokenDuration) + return generateToken(userName, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateRefreshToken generates a refresh token for web. +func GenerateRefreshToken(userName string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(refreshTokenDuration) + return generateToken(userName, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. +func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { + accessToken, err := GenerateAccessToken(user.Username, user.ID, secret) + if err != nil { + return errors.Wrap(err, "failed to generate access token") + } + + cookieExp := time.Now().Add(CookieExpDuration) + setTokenCookie(c, AccessTokenCookieName, accessToken, cookieExp) + + // We generate here a new refresh token and saving it to the cookie. + refreshToken, err := GenerateRefreshToken(user.Username, user.ID, secret) + if err != nil { + return errors.Wrap(err, "failed to generate refresh token") + } + setTokenCookie(c, RefreshTokenCookieName, refreshToken, cookieExp) + + return nil +} + +// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + // We set the expiration time to the past, so that the cookie will be removed. + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, AccessTokenCookieName, "", cookieExp) + setTokenCookie(c, RefreshTokenCookieName, "", cookieExp) +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} + +// generateToken generates a jwt token. +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = keyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} diff --git a/server/common.go b/server/common.go index daf8c62..cc41560 100644 --- a/server/common.go +++ b/server/common.go @@ -3,7 +3,6 @@ package server import ( "strings" - "github.com/boojack/shortify/api" "github.com/labstack/echo/v4" ) @@ -29,34 +28,10 @@ func hasPrefixes(src string, prefixes ...string) bool { func defaultAPIRequestSkipper(c echo.Context) bool { path := c.Path() - return hasPrefixes(path, "/api", "/o") + return hasPrefixes(path, "/api") } -func (server *Server) defaultAuthSkipper(c echo.Context) bool { - ctx := c.Request().Context() +func (*Server) defaultAuthSkipper(c echo.Context) bool { path := c.Path() - - // Skip auth. - if hasPrefixes(path, "/api/auth") { - return true - } - - // If there is openId in query string and related user is found, then skip auth. - openID := c.QueryParam("openId") - if openID != "" { - userFind := &api.UserFind{ - OpenID: &openID, - } - user, err := server.Store.FindUser(ctx, userFind) - if err != nil { - return false - } - if user != nil { - // Stores userID into context. - c.Set(getUserIDContextKey(), user.ID) - return true - } - } - - return false + return hasPrefixes(path, "/api/v1/auth") } diff --git a/server/jwt.go b/server/jwt.go new file mode 100644 index 0000000..ac816a3 --- /dev/null +++ b/server/jwt.go @@ -0,0 +1,201 @@ +package server + +import ( + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/boojack/shortify/server/auth" + "github.com/boojack/shortify/store" + "github.com/golang-jwt/jwt/v4" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" +) + +const ( + // Context section + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + userIDContextKey = "user-id" +) + +func getUserIDContextKey() string { + return userIDContextKey +} + +// Claims creates a struct that will be encoded to a JWT. +// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. +type Claims struct { + Name string `json:"name"` + jwt.RegisteredClaims +} + +func extractTokenFromHeader(c echo.Context) (string, error) { + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +func findAccessToken(c echo.Context) string { + accessToken := "" + cookie, _ := c.Cookie(auth.AccessTokenCookieName) + if cookie != nil { + accessToken = cookie.Value + } + if accessToken == "" { + accessToken, _ = extractTokenFromHeader(c) + } + + return accessToken +} + +func audienceContains(audience jwt.ClaimStrings, token string) bool { + for _, v := range audience { + if v == token { + return true + } + } + return false +} + +// JWTMiddleware validates the access token. +// If the access token is about to expire or has expired and the request has a valid refresh token, it +// will try to generate new access token and refresh token. +func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { + return func(c echo.Context) error { + path := c.Request().URL.Path + method := c.Request().Method + + if server.defaultAuthSkipper(c) { + return next(c) + } + + // Skip validation for server status endpoints. + if hasPrefixes(path, "/api/ping", "/api/v1/idp", "/api/user/:id") && method == http.MethodGet { + return next(c) + } + + token := findAccessToken(c) + if token == "" { + // When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts. + if hasPrefixes(path, "/api/status", "/api/shortcut") && method == http.MethodGet { + return next(c) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") + } + + claims := &Claims{} + accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) + } + generateToken := time.Until(claims.ExpiresAt.Time) < auth.RefreshThresholdDuration + if err != nil { + var ve *jwt.ValidationError + if errors.As(err, &ve) { + // If expiration error is the only error, we will clear the err + // and generate new access token and refresh token + if ve.Errors == jwt.ValidationErrorExpired { + generateToken = true + } + } else { + return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) + } + } + + // We either have a valid access token or we will attempt to generate new access token and refresh token + ctx := c.Request().Context() + userID, err := strconv.Atoi(claims.Subject) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") + } + + // Even if there is no error, we still need to make sure the user still exists. + user, err := server.Store.GetUserV1(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) + } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) + } + + if generateToken { + generateTokenFunc := func() error { + rc, err := c.Cookie(auth.RefreshTokenCookieName) + if err != nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") + } + + // Parses token and checks if it's valid. + refreshTokenClaims := &Claims{} + refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) + } + + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) + }) + if err != nil { + if err == jwt.ErrSignatureInvalid { + return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") + } + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + + if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { + return echo.NewHTTPError(http.StatusUnauthorized, + fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", + refreshTokenClaims.Audience, + auth.RefreshTokenAudienceName, + )) + } + + // If we have a valid refresh token, we will generate new access token and refresh token + if refreshToken != nil && refreshToken.Valid { + if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) + } + } + + return nil + } + + // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token + // In such case, we won't return the error. + if err := generateTokenFunc(); err != nil && !accessToken.Valid { + return err + } + } + + // Stores userID into context. + c.Set(getUserIDContextKey(), userID) + return next(c) + } +} diff --git a/server/redirect.go b/server/redirect.go deleted file mode 100644 index dcc555c..0000000 --- a/server/redirect.go +++ /dev/null @@ -1,51 +0,0 @@ -package server - -import ( - "net/http" - - "github.com/boojack/shortify/api" - "github.com/labstack/echo/v4" -) - -func (s *Server) registerRedirectRoutes(g *echo.Group) { - g.GET("/:shortcutName", func(c echo.Context) error { - ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) - if !ok { - return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") - } - shortcutName := c.Param("shortcutName") - if shortcutName == "" { - return echo.NewHTTPError(http.StatusBadRequest, "Missing shortcut name") - } - - list := []*api.Shortcut{} - - shortcutFind := &api.ShortcutFind{ - Name: &shortcutName, - MemberID: &userID, - } - shortcutFind.VisibilityList = []api.Visibility{api.VisibilityWorkspace, api.VisibilityPublic} - visibleShortcutList, err := s.Store.FindShortcutList(ctx, shortcutFind) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch shortcut list").SetInternal(err) - } - list = append(list, visibleShortcutList...) - - shortcutFind.VisibilityList = []api.Visibility{api.VisibilityPrivite} - shortcutFind.CreatorID = &userID - privateShortcutList, err := s.Store.FindShortcutList(ctx, shortcutFind) - if err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, "Failed to fetch private shortcut list").SetInternal(err) - } - list = append(list, privateShortcutList...) - - if len(list) == 0 { - return echo.NewHTTPError(http.StatusNotFound, "Not found shortcut").SetInternal(err) - } - - // TODO(steven): improve the matched result later - matchedShortcut := list[0] - return c.Redirect(http.StatusPermanentRedirect, matchedShortcut.Link) - }) -} diff --git a/server/server.go b/server/server.go index af204b7..4acac71 100644 --- a/server/server.go +++ b/server/server.go @@ -54,24 +54,17 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) { embedFrontend(e) // In dev mode, set the const secret key to make signin session persistence. - secret := []byte("iamshortify") + secret := "iamshortify" if profile.Mode == "prod" { - secret = securecookie.GenerateRandomKey(16) + secret = string(securecookie.GenerateRandomKey(16)) } - e.Use(session.Middleware(sessions.NewCookieStore(secret))) - - redirectGroup := e.Group("/o") - redirectGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return aclMiddleware(s, next) - }) - s.registerRedirectRoutes(redirectGroup) + e.Use(session.Middleware(sessions.NewCookieStore([]byte(secret)))) apiGroup := e.Group("/api") apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return aclMiddleware(s, next) + return JWTMiddleware(s, next, string(secret)) }) s.registerSystemRoutes(apiGroup) - s.registerAuthRoutes(apiGroup) s.registerUserRoutes(apiGroup) s.registerWorkspaceRoutes(apiGroup) s.registerWorkspaceUserRoutes(apiGroup) @@ -80,7 +73,10 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) { // Register API v1 routes. apiV1Service := apiv1.NewAPIV1Service(profile, store) apiV1Group := apiGroup.Group("/api/v1") - apiV1Service.RegisterUserRoutes(apiV1Group) + apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return JWTMiddleware(s, next, string(secret)) + }) + apiV1Service.Start(apiV1Group, secret) return s, nil } diff --git a/server/user.go b/server/user.go index e513cd4..3494c36 100644 --- a/server/user.go +++ b/server/user.go @@ -8,6 +8,7 @@ import ( "strconv" "github.com/boojack/shortify/api" + "github.com/google/uuid" "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" @@ -155,3 +156,7 @@ func validateEmail(email string) bool { } return true } + +func genUUID() string { + return uuid.New().String() +}