diff --git a/api/auth/auth.go b/api/auth/auth.go index 7517ad4..708582d 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -16,8 +16,6 @@ const ( // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. - // Suppose we have a valid refresh token, we will refresh the token in cases: - // 1. The access token has already expired, we refresh the token so that the ongoing request can pass through. CookieExpDuration = AccessTokenDuration - 1*time.Minute // AccessTokenCookieName is the cookie name of access token. AccessTokenCookieName = "slash.access-token" diff --git a/api/v1/auth.go b/api/v1/auth.go index 078428f..a53379e 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -1,13 +1,19 @@ package v1 import ( + "context" "encoding/json" "fmt" "net/http" + "time" + "github.com/boojack/slash/api/auth" + storepb "github.com/boojack/slash/proto/gen/store" "github.com/boojack/slash/store" "github.com/labstack/echo/v4" + "github.com/pkg/errors" "golang.org/x/crypto/bcrypt" + "google.golang.org/protobuf/types/known/timestamppb" ) type SignInRequest struct { @@ -46,9 +52,16 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusUnauthorized, "unmatched email and password") } - if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { + accessToken, err := GenerateAccessToken(user.Email, user.ID, secret) + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } + if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err) + } + + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) return c.JSON(http.StatusOK, convertUserFromStore(user)) }) @@ -95,10 +108,16 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create user, err: %s", err)).SetInternal(err) } - if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { + accessToken, err := GenerateAccessToken(user.Email, user.ID, secret) + if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) } + if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err) + } + cookieExp := time.Now().Add(auth.CookieExpDuration) + setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) return c.JSON(http.StatusOK, convertUserFromStore(user)) }) @@ -108,3 +127,54 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return nil }) } + +func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken string) error { + userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID) + if err != nil { + return errors.Wrap(err, "failed to get user access tokens") + } + userAccessToken := storepb.AccessTokensUserSetting_AccessToken{ + AccessToken: accessToken, + Description: "user sign in", + CreatedTime: timestamppb.Now(), + ExpiresTime: timestamppb.New(time.Now().Add(auth.AccessTokenDuration)), + } + userAccessTokens = append(userAccessTokens, &userAccessToken) + if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{ + UserId: user.ID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + Value: &storepb.UserSetting_AccessTokensUserSetting{ + AccessTokensUserSetting: &storepb.AccessTokensUserSetting{ + AccessTokens: userAccessTokens, + }, + }, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert user setting, err: %s", err)).SetInternal(err) + } + return nil +} + +// GenerateAccessToken generates an access token for web. +func GenerateAccessToken(username string, userID int32, secret string) (string, error) { + expirationTime := time.Now().Add(auth.AccessTokenDuration) + return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) +} + +// RemoveTokensAndCookies removes the jwt token from the cookies. +func RemoveTokensAndCookies(c echo.Context) { + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 1e8a7ac..53012d3 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -8,6 +8,7 @@ import ( "github.com/boojack/slash/api/auth" "github.com/boojack/slash/internal/util" + storepb "github.com/boojack/slash/proto/gen/store" "github.com/boojack/slash/store" "github.com/golang-jwt/jwt/v4" "github.com/labstack/echo/v4" @@ -25,43 +26,6 @@ type claimsMessage struct { jwt.RegisteredClaims } -// GenerateAccessToken generates an access token for web. -func GenerateAccessToken(username string, userID int32, secret string) (string, error) { - expirationTime := time.Now().Add(auth.AccessTokenDuration) - return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret)) -} - -// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie. -func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error { - accessToken, err := GenerateAccessToken(user.Email, user.ID, secret) - if err != nil { - return errors.Wrap(err, "failed to generate access token") - } - - cookieExp := time.Now().Add(auth.CookieExpDuration) - setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp) - return nil -} - -// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies. -func RemoveTokensAndCookies(c echo.Context) { - cookieExp := time.Now().Add(-1 * time.Hour) - setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) -} - -// setTokenCookie sets the token to the cookie. -func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { - cookie := new(http.Cookie) - cookie.Name = name - cookie.Value = token - cookie.Expires = expiration - cookie.Path = "/" - // Http-only helps mitigate the risk of client side script accessing the protected cookie. - cookie.HttpOnly = true - cookie.SameSite = http.SameSiteStrictMode - c.SetCookie(cookie) -} - // generateToken generates a jwt token. func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) { // Create the JWT claims, which includes the username and expiry time. @@ -127,9 +91,7 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { } // JWTMiddleware validates the access token. -// If the access token is about to expire or has expired and the request has a valid refresh token, it -// will try to generate new access token and refresh token. -func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { +func JWTMiddleware(s *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { return func(c echo.Context) error { ctx := c.Request().Context() path := c.Request().URL.Path @@ -163,21 +125,28 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e }) if err != nil { - RemoveTokensAndCookies(c) return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) } if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) } - // We either have a valid access token or we will attempt to generate new access token and refresh token + // We either have a valid access token or we will attempt to generate new access token. userID, err := util.ConvertStringToInt32(claims.Subject) if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.").WithInternal(err) } + accessTokens, err := s.Store.GetUserAccessTokens(ctx, userID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) + } + if !validateAccessToken(token, accessTokens) { + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") + } + // Even if there is no error, we still need to make sure the user still exists. - user, err := server.Store.GetUser(ctx, &store.FindUser{ + user, err := s.Store.GetUser(ctx, &store.FindUser{ ID: &userID, }) if err != nil { @@ -192,3 +161,12 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return next(c) } } + +func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { + for _, userAccessToken := range userAccessTokens { + if accessTokenString == userAccessToken.AccessToken && !userAccessToken.Revoked { + return true + } + } + return false +} diff --git a/api/v2/acl.go b/api/v2/acl.go index b5db14b..f7c61c6 100644 --- a/api/v2/acl.go +++ b/api/v2/acl.go @@ -7,6 +7,7 @@ import ( "github.com/boojack/slash/api/auth" "github.com/boojack/slash/internal/util" + storepb "github.com/boojack/slash/proto/gen/store" "github.com/boojack/slash/store" "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" @@ -44,14 +45,14 @@ type claimsMessage struct { // GRPCAuthInterceptor is the auth interceptor for gRPC server. type GRPCAuthInterceptor struct { - store *store.Store + Store *store.Store secret string } // NewGRPCAuthInterceptor returns a new API auth interceptor. func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor { return &GRPCAuthInterceptor{ - store: store, + Store: store, secret: secret, } } @@ -62,12 +63,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re if !ok { return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") } - accessTokenStr, err := getTokenFromMetadata(md) + accessToken, err := getTokenFromMetadata(md) if err != nil { return nil, status.Errorf(codes.Unauthenticated, err.Error()) } - userID, err := in.authenticate(ctx, accessTokenStr) + userID, err := in.authenticate(ctx, accessToken) if err != nil { if IsAuthenticationAllowed(serverInfo.FullMethod) { return handler(ctx, request) @@ -75,6 +76,14 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, err } + accessTokens, err := in.Store.GetUserAccessTokens(ctx, userID) + if err != nil { + return nil, errors.Wrap(err, "failed to get user access tokens") + } + if !validateAccessToken(accessToken, accessTokens) { + return nil, status.Errorf(codes.Unauthenticated, "invalid access token") + } + // Stores userID into context. childCtx := context.WithValue(ctx, UserIDContextKey, userID) return handler(childCtx, request) @@ -111,7 +120,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr if err != nil { return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) } - user, err := in.store.GetUser(ctx, &store.FindUser{ + user, err := in.Store.GetUser(ctx, &store.FindUser{ ID: &userID, }) if err != nil { @@ -157,3 +166,12 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { } return false } + +func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { + for _, userAccessToken := range userAccessTokens { + if accessTokenString == userAccessToken.AccessToken && !userAccessToken.Revoked { + return true + } + } + return false +} diff --git a/store/user_setting.go b/store/user_setting.go index 3f921e9..e34d163 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -140,3 +140,20 @@ func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { return nil } + +// GetUserAccessTokens returns the access tokens of the user. +func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) { + userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ + UserID: &userID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + }) + if err != nil { + return nil, err + } + if userSetting == nil { + return []*storepb.AccessTokensUserSetting_AccessToken{}, nil + } + + accessTokensUserSetting := userSetting.GetAccessTokensUserSetting() + return accessTokensUserSetting.AccessTokens, nil +}