From dfe47b9b7e902144db82e9e0b3d6bbd91c8c8f46 Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 2 Aug 2023 07:29:58 +0800 Subject: [PATCH] feat: update jwt auth --- api/auth/auth.go | 27 ++++++++ api/v1/auth.go | 8 +-- api/v1/auth/auth.go | 128 ---------------------------------- api/v1/jwt.go | 162 ++++++++++++++++++++----------------------- api/v1/redirector.go | 3 +- api/v1/shortcut.go | 12 ++-- api/v1/user.go | 10 +-- api/v1/v1.go | 1 - api/v1/workspace.go | 5 +- 9 files changed, 121 insertions(+), 235 deletions(-) create mode 100644 api/auth/auth.go delete mode 100644 api/v1/auth/auth.go diff --git a/api/auth/auth.go b/api/auth/auth.go new file mode 100644 index 0000000..3330481 --- /dev/null +++ b/api/auth/auth.go @@ -0,0 +1,27 @@ +package auth + +import ( + "time" +) + +const ( + // The key name used to store user id in the context + // user id is extracted from the jwt token subject field. + UserIDContextKey = "user-id" + // issuer is the issuer of the jwt token. + Issuer = "slash" + // Signing key section. For now, this is only used for signing, not for verifying since we only + // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. + KeyID = "v1" + // AccessTokenAudienceName is the audience name of the access token. + AccessTokenAudienceName = "user.access-token" + AccessTokenDuration = 7 * 24 * time.Hour + + // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user + // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. + // Suppose we have a valid refresh token, we will refresh the token in cases: + // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. + CookieExpDuration = AccessTokenDuration - 1*time.Minute + // AccessTokenCookieName is the cookie name of access token. + AccessTokenCookieName = "slash.access-token" +) diff --git a/api/v1/auth.go b/api/v1/auth.go index 9fc249f..078428f 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -5,9 +5,7 @@ import ( "fmt" "net/http" - "github.com/boojack/slash/api/v1/auth" "github.com/boojack/slash/store" - "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" ) @@ -48,7 +46,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusUnauthorized, "unmatched email and password") } - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } return c.JSON(http.StatusOK, convertUserFromStore(user)) @@ -97,7 +95,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create user, err: %s", err)).SetInternal(err) } - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { + if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } @@ -105,7 +103,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { }) g.POST("/auth/logout", func(c echo.Context) error { - auth.RemoveTokensAndCookies(c) + RemoveTokensAndCookies(c) c.Response().WriteHeader(http.StatusOK) return nil }) diff --git a/api/v1/auth/auth.go b/api/v1/auth/auth.go deleted file mode 100644 index e0d45e2..0000000 --- a/api/v1/auth/auth.go +++ /dev/null @@ -1,128 +0,0 @@ -package auth - -import ( - "net/http" - "strconv" - "time" - - "github.com/boojack/slash/store" - "github.com/golang-jwt/jwt/v4" - "github.com/labstack/echo/v4" - "github.com/pkg/errors" -) - -const ( - issuer = "slash" - // Signing key section. For now, this is only used for signing, not for verifying since we only - // have 1 version. But it will be used to maintain backward compatibility if we change the signing mechanism. - keyID = "v1" - // AccessTokenAudienceName is the audience name of the access token. - AccessTokenAudienceName = "user.access-token" - // RefreshTokenAudienceName is the audience name of the refresh token. - RefreshTokenAudienceName = "user.refresh-token" - apiTokenDuration = 2 * time.Hour - accessTokenDuration = 24 * time.Hour - refreshTokenDuration = 7 * 24 * time.Hour - - // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user - // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. - // Suppose we have a valid refresh token, we will refresh the token in the following cases: - // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. - CookieExpDuration = refreshTokenDuration - 1*time.Minute - // AccessTokenCookieName is the cookie name of access token. - AccessTokenCookieName = "slash.access-token" - // RefreshTokenCookieName is the cookie name of refresh token. - RefreshTokenCookieName = "slash.refresh-token" -) - -type claimsMessage struct { - Name string `json:"name"` - jwt.RegisteredClaims -} - -// GenerateAPIToken generates an API token. -func GenerateAPIToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(apiTokenDuration) - return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(accessTokenDuration) - return generateToken(username, userID, AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateRefreshToken generates a refresh token for web. -func GenerateRefreshToken(username string, userID int, secret string) (string, error) { - expirationTime := time.Now().Add(refreshTokenDuration) - return generateToken(username, userID, RefreshTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. -func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { - accessToken, err := GenerateAccessToken(user.Email, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate access token") - } - - cookieExp := time.Now().Add(CookieExpDuration) - setTokenCookie(c, AccessTokenCookieName, accessToken, cookieExp) - - // We generate here a new refresh token and saving it to the cookie. - refreshToken, err := GenerateRefreshToken(user.Email, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate refresh token") - } - setTokenCookie(c, RefreshTokenCookieName, refreshToken, cookieExp) - - return nil -} - -// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. -func RemoveTokensAndCookies(c echo.Context) { - // We set the expiration time to the past, so that the cookie will be removed. - cookieExp := time.Now().Add(-1 * time.Hour) - setTokenCookie(c, AccessTokenCookieName, "", cookieExp) - setTokenCookie(c, RefreshTokenCookieName, "", cookieExp) -} - -// setTokenCookie sets the token to the cookie. -func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { - cookie := new(http.Cookie) - cookie.Name = name - cookie.Value = token - cookie.Expires = expiration - cookie.Path = "/" - // Http-only helps mitigate the risk of client side script accessing the protected cookie. - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteStrictMode - c.SetCookie(cookie) -} - -// generateToken generates a jwt token. -func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &claimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Audience: jwt.ClaimStrings{aud}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Issuer: issuer, - Subject: strconv.Itoa(userID), - }, - } - - // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) - token.Header["kid"] = keyID - - // Create the JWT string. - tokenString, err := token.SignedString(secret) - if err != nil { - return "", err - } - - return tokenString, nil -} diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 7ab11d5..c17ddc4 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -5,8 +5,9 @@ import ( "net/http" "strconv" "strings" + "time" - "github.com/boojack/slash/api/v1/auth" + "github.com/boojack/slash/api/auth" "github.com/boojack/slash/internal/util" "github.com/boojack/slash/store" "github.com/golang-jwt/jwt/v4" @@ -14,24 +15,76 @@ import ( "github.com/pkg/errors" ) -const ( - // Context section - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - userIDContextKey = "user-id" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - -// Claims creates a struct that will be encoded to a JWT. -// We add jwt.RegisteredClaims as an embedded type, to provide fields such as name. -type Claims struct { +type claimsMessage struct { Name string `json:"name"` jwt.RegisteredClaims } +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(username string, userID int, secret string) (string, error) { + expirationTime := time.Now().Add(auth.AccessTokenDuration) + return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. +func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { + accessToken, err := GenerateAccessToken(user.Email, user.ID, secret) + if err != nil { + return errors.Wrap(err, "failed to generate access token") + } + + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) + return nil +} + +// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} + +// generateToken generates a jwt token. +func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) { + // Create the JWT claims, which includes the username and expiry time. + claims := &claimsMessage{ + Name: username, + RegisteredClaims: jwt.RegisteredClaims{ + Audience: jwt.ClaimStrings{aud}, + // In JWT, the expiry time is expressed as unix milliseconds. + ExpiresAt: jwt.NewNumericDate(expirationTime), + IssuedAt: jwt.NewNumericDate(time.Now()), + Issuer: auth.Issuer, + Subject: strconv.Itoa(userID), + }, + } + + // Declare the token with the HS256 algorithm used for signing, and the claims. + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token.Header["kid"] = auth.KeyID + + // Create the JWT string. + tokenString, err := token.SignedString(secret) + if err != nil { + return "", err + } + + return tokenString, nil +} + func extractTokenFromHeader(c echo.Context) (string, error) { authHeader := c.Request().Header.Get("Authorization") if authHeader == "" { @@ -73,7 +126,8 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { // will try to generate new access token and refresh token. func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { return func(c echo.Context) error { - path := c.Path() + ctx := c.Request().Context() + path := c.Request().URL.Path method := c.Request().Method // Pass auth and profile endpoints. @@ -87,12 +141,11 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e if util.HasPrefixes(path, "/s/*") && method == http.MethodGet { return next(c) } - auth.RemoveTokensAndCookies(c) return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") } - claims := &Claims{} - accessToken, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { + claims := &claimsMessage{} + _, err := jwt.ParseWithClaims(token, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) } @@ -104,27 +157,15 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) }) - generateToken := false if err != nil { - var ve *jwt.ValidationError - if errors.As(err, &ve) { - // If expiration error is the only error, we will ignore the err - // and generate new access token and refresh token. - if ve.Errors == jwt.ValidationErrorExpired { - generateToken = true - } - } else { - auth.RemoveTokensAndCookies(c) - return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) - } + RemoveTokensAndCookies(c) + return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) } - if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) } // We either have a valid access token or we will attempt to generate new access token and refresh token - ctx := c.Request().Context() userID, err := strconv.Atoi(claims.Subject) if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") @@ -141,61 +182,8 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) } - if generateToken { - generateTokenFunc := func() error { - rc, err := c.Cookie(auth.RefreshTokenCookieName) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Missing refresh token.") - } - - // Parses token and checks if it's valid. - refreshTokenClaims := &Claims{} - refreshToken, err := jwt.ParseWithClaims(rc.Value, refreshTokenClaims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected refresh token signing method=%v, expected %v", t.Header["alg"], jwt.SigningMethodHS256) - } - - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected refresh token kid=%v", t.Header["kid"]) - }) - if err != nil { - if err == jwt.ErrSignatureInvalid { - return echo.NewHTTPError(http.StatusUnauthorized, "Failed to generate access token. Invalid refresh token signature.") - } - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - - if !audienceContains(refreshTokenClaims.Audience, auth.RefreshTokenAudienceName) { - return echo.NewHTTPError(http.StatusUnauthorized, - fmt.Sprintf("Invalid refresh token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - refreshTokenClaims.Audience, - auth.RefreshTokenAudienceName, - )) - } - - // If we have a valid refresh token, we will generate new access token and refresh token - if refreshToken != nil && refreshToken.Valid { - if err := auth.GenerateTokensAndSetCookies(c, user, secret); err != nil { - return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to refresh expired token. User Id %d", userID)).SetInternal(err) - } - } - - return nil - } - - // It may happen that we still have a valid access token, but we encounter issue when trying to generate new token - // In such case, we won't return the error. - if err := generateTokenFunc(); err != nil && !accessToken.Valid { - return err - } - } - // Stores userID into context. - c.Set(getUserIDContextKey(), userID) + c.Set(auth.UserIDContextKey, userID) return next(c) } } diff --git a/api/v1/redirector.go b/api/v1/redirector.go index fda08e4..6ea4c00 100644 --- a/api/v1/redirector.go +++ b/api/v1/redirector.go @@ -8,6 +8,7 @@ import ( "net/url" "strings" + "github.com/boojack/slash/api/auth" "github.com/boojack/slash/store" "github.com/labstack/echo/v4" "github.com/pkg/errors" @@ -31,7 +32,7 @@ func (s *APIV1Service) registerRedirectorRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("not found shortcut with name: %s", shortcutName)) } if shortcut.Visibility != store.VisibilityPublic { - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") } diff --git a/api/v1/shortcut.go b/api/v1/shortcut.go index 8791681..793ba24 100644 --- a/api/v1/shortcut.go +++ b/api/v1/shortcut.go @@ -8,10 +8,10 @@ import ( "strconv" "strings" + "github.com/boojack/slash/api/auth" "github.com/boojack/slash/store" - "github.com/pkg/errors" - "github.com/labstack/echo/v4" + "github.com/pkg/errors" ) // Visibility is the type of a shortcut visibility. @@ -81,7 +81,7 @@ type PatchShortcutRequest struct { func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.POST("/shortcut", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } @@ -125,7 +125,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("shortcut ID is not a number: %s", c.Param("shortcutId"))).SetInternal(err) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } @@ -196,7 +196,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { g.GET("/shortcut", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } @@ -263,7 +263,7 @@ func (s *APIV1Service) registerShortcutRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("shortcut id is not a number: %s", c.Param("id"))).SetInternal(err) } - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } diff --git a/api/v1/user.go b/api/v1/user.go index 1ceb682..16b98be 100644 --- a/api/v1/user.go +++ b/api/v1/user.go @@ -7,8 +7,8 @@ import ( "net/mail" "strconv" + "github.com/boojack/slash/api/auth" "github.com/boojack/slash/store" - "github.com/labstack/echo/v4" "golang.org/x/crypto/bcrypt" ) @@ -84,7 +84,7 @@ type PatchUserRequest struct { func (s *APIV1Service) registerUserRoutes(g *echo.Group) { g.POST("/user", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "Missing auth session") } @@ -145,7 +145,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { // GET /api/user/me is used to check if the user is logged in. g.GET("/user/me", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing auth session") } @@ -183,7 +183,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("user id is not a number: %s", c.Param("id"))).SetInternal(err) } - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } @@ -255,7 +255,7 @@ func (s *APIV1Service) registerUserRoutes(g *echo.Group) { g.DELETE("/user/:id", func(c echo.Context) error { ctx := c.Request().Context() - currentUserID, ok := c.Get(getUserIDContextKey()).(int) + currentUserID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } diff --git a/api/v1/v1.go b/api/v1/v1.go index ecc6080..d13aa6d 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -3,7 +3,6 @@ package v1 import ( "github.com/boojack/slash/server/profile" "github.com/boojack/slash/store" - "github.com/labstack/echo/v4" ) diff --git a/api/v1/workspace.go b/api/v1/workspace.go index 3635311..5c06df3 100644 --- a/api/v1/workspace.go +++ b/api/v1/workspace.go @@ -5,6 +5,7 @@ import ( "fmt" "net/http" + "github.com/boojack/slash/api/auth" "github.com/boojack/slash/server/profile" "github.com/boojack/slash/store" "github.com/labstack/echo/v4" @@ -62,7 +63,7 @@ func (s *APIV1Service) registerWorkspaceRoutes(g *echo.Group) { g.POST("/workspace/setting", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") } @@ -97,7 +98,7 @@ func (s *APIV1Service) registerWorkspaceRoutes(g *echo.Group) { g.GET("/workspace/setting", func(c echo.Context) error { ctx := c.Request().Context() - userID, ok := c.Get(getUserIDContextKey()).(int) + userID, ok := c.Get(auth.UserIDContextKey).(int) if !ok { return echo.NewHTTPError(http.StatusUnauthorized, "missing user in session") }