feat: validate access token

This commit is contained in:
Steven 2023-08-06 14:16:23 +08:00
parent d8903875d3
commit 84ddafeb84
5 changed files with 133 additions and 52 deletions

View File

@ -16,8 +16,6 @@ const (
// CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user // CookieExpDuration expires slightly earlier than the jwt expiration. Client would be logged out if the user
// cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt. // cookie expires, thus the client would always logout first before attempting to make a request with the expired jwt.
// Suppose we have a valid refresh token, we will refresh the token in cases:
// 1. The access token has already expired, we refresh the token so that the ongoing request can pass through.
CookieExpDuration = AccessTokenDuration - 1*time.Minute CookieExpDuration = AccessTokenDuration - 1*time.Minute
// AccessTokenCookieName is the cookie name of access token. // AccessTokenCookieName is the cookie name of access token.
AccessTokenCookieName = "slash.access-token" AccessTokenCookieName = "slash.access-token"

View File

@ -1,13 +1,19 @@
package v1 package v1
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"net/http" "net/http"
"time"
"github.com/boojack/slash/api/auth"
storepb "github.com/boojack/slash/proto/gen/store"
"github.com/boojack/slash/store" "github.com/boojack/slash/store"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/pkg/errors"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"google.golang.org/protobuf/types/known/timestamppb"
) )
type SignInRequest struct { type SignInRequest struct {
@ -46,9 +52,16 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return echo.NewHTTPError(http.StatusUnauthorized, "unmatched email and password") return echo.NewHTTPError(http.StatusUnauthorized, "unmatched email and password")
} }
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { accessToken, err := GenerateAccessToken(user.Email, user.ID, secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
} }
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
return c.JSON(http.StatusOK, convertUserFromStore(user)) return c.JSON(http.StatusOK, convertUserFromStore(user))
}) })
@ -95,10 +108,16 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create user, err: %s", err)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to create user, err: %s", err)).SetInternal(err)
} }
if err := GenerateTokensAndSetCookies(c, user, secret); err != nil { accessToken, err := GenerateAccessToken(user.Email, user.ID, secret)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err) return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to generate tokens, err: %s", err)).SetInternal(err)
} }
if err := s.UpsertAccessTokenToStore(ctx, user, accessToken); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert access token, err: %s", err)).SetInternal(err)
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
return c.JSON(http.StatusOK, convertUserFromStore(user)) return c.JSON(http.StatusOK, convertUserFromStore(user))
}) })
@ -108,3 +127,54 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) {
return nil return nil
}) })
} }
func (s *APIV1Service) UpsertAccessTokenToStore(ctx context.Context, user *store.User, accessToken string) error {
userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, user.ID)
if err != nil {
return errors.Wrap(err, "failed to get user access tokens")
}
userAccessToken := storepb.AccessTokensUserSetting_AccessToken{
AccessToken: accessToken,
Description: "user sign in",
CreatedTime: timestamppb.Now(),
ExpiresTime: timestamppb.New(time.Now().Add(auth.AccessTokenDuration)),
}
userAccessTokens = append(userAccessTokens, &userAccessToken)
if _, err := s.Store.UpsertUserSetting(ctx, &storepb.UserSetting{
UserId: user.ID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
Value: &storepb.UserSetting_AccessTokensUserSetting{
AccessTokensUserSetting: &storepb.AccessTokensUserSetting{
AccessTokens: userAccessTokens,
},
},
}); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert user setting, err: %s", err)).SetInternal(err)
}
return nil
}
// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}
// RemoveTokensAndCookies removes the jwt token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}

View File

@ -8,6 +8,7 @@ import (
"github.com/boojack/slash/api/auth" "github.com/boojack/slash/api/auth"
"github.com/boojack/slash/internal/util" "github.com/boojack/slash/internal/util"
storepb "github.com/boojack/slash/proto/gen/store"
"github.com/boojack/slash/store" "github.com/boojack/slash/store"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
@ -25,43 +26,6 @@ type claimsMessage struct {
jwt.RegisteredClaims jwt.RegisteredClaims
} }
// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int32, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}
// GenerateTokensAndSetCookies generates jwt token and saves it to the http-only cookie.
func GenerateTokensAndSetCookies(c echo.Context, user *store.User, secret string) error {
accessToken, err := GenerateAccessToken(user.Email, user.ID, secret)
if err != nil {
return errors.Wrap(err, "failed to generate access token")
}
cookieExp := time.Now().Add(auth.CookieExpDuration)
setTokenCookie(c, auth.AccessTokenCookieName, accessToken, cookieExp)
return nil
}
// RemoveTokensAndCookies removes the jwt token and refresh token from the cookies.
func RemoveTokensAndCookies(c echo.Context) {
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp)
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}
// generateToken generates a jwt token. // generateToken generates a jwt token.
func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) { func generateToken(username string, userID int32, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time. // Create the JWT claims, which includes the username and expiry time.
@ -127,9 +91,7 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
} }
// JWTMiddleware validates the access token. // JWTMiddleware validates the access token.
// If the access token is about to expire or has expired and the request has a valid refresh token, it func JWTMiddleware(s *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc {
// will try to generate new access token and refresh token.
func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
ctx := c.Request().Context() ctx := c.Request().Context()
path := c.Request().URL.Path path := c.Request().URL.Path
@ -163,21 +125,28 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
}) })
if err != nil { if err != nil {
RemoveTokensAndCookies(c)
return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token"))
} }
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName))
} }
// We either have a valid access token or we will attempt to generate new access token and refresh token // We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject) userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil { if err != nil {
return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.").WithInternal(err) return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.").WithInternal(err)
} }
accessTokens, err := s.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err)
}
if !validateAccessToken(token, accessTokens) {
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
}
// Even if there is no error, we still need to make sure the user still exists. // Even if there is no error, we still need to make sure the user still exists.
user, err := server.Store.GetUser(ctx, &store.FindUser{ user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -192,3 +161,12 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e
return next(c) return next(c)
} }
} }
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken && !userAccessToken.Revoked {
return true
}
}
return false
}

View File

@ -7,6 +7,7 @@ import (
"github.com/boojack/slash/api/auth" "github.com/boojack/slash/api/auth"
"github.com/boojack/slash/internal/util" "github.com/boojack/slash/internal/util"
storepb "github.com/boojack/slash/proto/gen/store"
"github.com/boojack/slash/store" "github.com/boojack/slash/store"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -44,14 +45,14 @@ type claimsMessage struct {
// GRPCAuthInterceptor is the auth interceptor for gRPC server. // GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct { type GRPCAuthInterceptor struct {
store *store.Store Store *store.Store
secret string secret string
} }
// NewGRPCAuthInterceptor returns a new API auth interceptor. // NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor { func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{ return &GRPCAuthInterceptor{
store: store, Store: store,
secret: secret, secret: secret,
} }
} }
@ -62,12 +63,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
if !ok { if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
} }
accessTokenStr, err := getTokenFromMetadata(md) accessToken, err := getTokenFromMetadata(md)
if err != nil { if err != nil {
return nil, status.Errorf(codes.Unauthenticated, err.Error()) return nil, status.Errorf(codes.Unauthenticated, err.Error())
} }
userID, err := in.authenticate(ctx, accessTokenStr) userID, err := in.authenticate(ctx, accessToken)
if err != nil { if err != nil {
if IsAuthenticationAllowed(serverInfo.FullMethod) { if IsAuthenticationAllowed(serverInfo.FullMethod) {
return handler(ctx, request) return handler(ctx, request)
@ -75,6 +76,14 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
return nil, err return nil, err
} }
accessTokens, err := in.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return nil, errors.Wrap(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
}
// Stores userID into context. // Stores userID into context.
childCtx := context.WithValue(ctx, UserIDContextKey, userID) childCtx := context.WithValue(ctx, UserIDContextKey, userID)
return handler(childCtx, request) return handler(childCtx, request)
@ -111,7 +120,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr
if err != nil { if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject) return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject)
} }
user, err := in.store.GetUser(ctx, &store.FindUser{ user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID, ID: &userID,
}) })
if err != nil { if err != nil {
@ -157,3 +166,12 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
} }
return false return false
} }
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken && !userAccessToken.Revoked {
return true
}
}
return false
}

View File

@ -140,3 +140,20 @@ func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
return nil return nil
} }
// GetUserAccessTokens returns the access tokens of the user.
func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
UserID: &userID,
Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS,
})
if err != nil {
return nil, err
}
if userSetting == nil {
return []*storepb.AccessTokensUserSetting_AccessToken{}, nil
}
accessTokensUserSetting := userSetting.GetAccessTokensUserSetting()
return accessTokensUserSetting.AccessTokens, nil
}