diff --git a/api/auth/auth.go b/api/auth/auth.go index b7bbfff..79c417d 100644 --- a/api/auth/auth.go +++ b/api/auth/auth.go @@ -37,21 +37,21 @@ func GenerateAccessToken(username string, userID int32, expirationTime time.Time // generateToken generates a jwt token. func generateToken(username string, userID int32, audience string, expirationTime time.Time, secret []byte) (string, error) { - // Create the JWT claims, which includes the username and expiry time. - claims := &ClaimsMessage{ - Name: username, - RegisteredClaims: jwt.RegisteredClaims{ - Issuer: Issuer, - Audience: jwt.ClaimStrings{audience}, - // In JWT, the expiry time is expressed as unix milliseconds. - ExpiresAt: jwt.NewNumericDate(expirationTime), - IssuedAt: jwt.NewNumericDate(time.Now()), - Subject: fmt.Sprint(userID), - }, + registeredClaims := jwt.RegisteredClaims{ + Issuer: Issuer, + Audience: jwt.ClaimStrings{audience}, + IssuedAt: jwt.NewNumericDate(time.Now()), + Subject: fmt.Sprint(userID), + } + if expirationTime.After(time.Now()) { + registeredClaims.ExpiresAt = jwt.NewNumericDate(expirationTime) } // Declare the token with the HS256 algorithm used for signing, and the claims. - token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + token := jwt.NewWithClaims(jwt.SigningMethodHS256, &ClaimsMessage{ + Name: username, + RegisteredClaims: registeredClaims, + }) token.Header["kid"] = KeyID // Create the JWT string. diff --git a/api/v2/acl.go b/api/v2/acl.go index d6e3ccd..06fedd7 100644 --- a/api/v2/acl.go +++ b/api/v2/acl.go @@ -17,9 +17,7 @@ import ( "google.golang.org/grpc/status" ) -var authenticationAllowlistMethods = map[string]bool{ - "/memos.api.v2.UserService/GetUser": true, -} +var authenticationAllowlistMethods = map[string]bool{} // IsAuthenticationAllowed returns whether the method is exempted from authentication. func IsAuthenticationAllowed(fullMethodName string) bool { @@ -60,7 +58,7 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re } accessToken, err := getTokenFromMetadata(md) if err != nil { - return nil, status.Errorf(codes.Unauthenticated, err.Error()) + return nil, status.Errorf(codes.Unauthenticated, "failed to get access token from metadata: %v", err) } userID, err := in.authenticate(ctx, accessToken) @@ -71,11 +69,11 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, err } - accessTokens, err := in.Store.GetUserAccessTokens(ctx, userID) + userAccessTokens, err := in.Store.GetUserAccessTokens(ctx, userID) if err != nil { return nil, errors.Wrap(err, "failed to get user access tokens") } - if !validateAccessToken(accessToken, accessTokens) { + if !validateAccessToken(accessToken, userAccessTokens) { return nil, status.Errorf(codes.Unauthenticated, "invalid access token") } @@ -132,15 +130,16 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr } func getTokenFromMetadata(md metadata.MD) (string, error) { + // Try to get the token from the authorization header first. authorizationHeaders := md.Get("Authorization") - if len(md.Get("Authorization")) > 0 { + if len(authorizationHeaders) > 0 { authHeaderParts := strings.Fields(authorizationHeaders[0]) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "", errors.Errorf("authorization header format must be Bearer {token}") } return authHeaderParts[1], nil } - // check the HTTP cookie + // Try to get the token from the cookie header. var accessToken string for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) { header := http.Header{} diff --git a/api/v2/user_service.go b/api/v2/user_service.go index 8307645..1aa8576 100644 --- a/api/v2/user_service.go +++ b/api/v2/user_service.go @@ -9,6 +9,7 @@ import ( "github.com/boojack/slash/store" "github.com/golang-jwt/jwt/v4" "github.com/pkg/errors" + "golang.org/x/exp/slices" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" "google.golang.org/protobuf/types/known/timestamppb" @@ -77,14 +78,21 @@ func (s *UserService) ListUserAccessTokens(ctx context.Context, request *apiv2pb continue } - accessTokens = append(accessTokens, &apiv2pb.UserAccessToken{ + userAccessToken := &apiv2pb.UserAccessToken{ AccessToken: userAccessToken.AccessToken, Description: userAccessToken.Description, IssuedAt: timestamppb.New(claims.IssuedAt.Time), - ExpiresAt: timestamppb.New(claims.ExpiresAt.Time), - }) + } + if claims.ExpiresAt != nil { + userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time) + } + accessTokens = append(accessTokens, userAccessToken) } + // Sort by issued time in descending order. + slices.SortFunc(accessTokens, func(i, j *apiv2pb.UserAccessToken) bool { + return i.IssuedAt.Seconds > j.IssuedAt.Seconds + }) response := &apiv2pb.ListUserAccessTokensResponse{ AccessTokens: accessTokens, } @@ -133,13 +141,16 @@ func (s *UserService) CreateUserAccessToken(ctx context.Context, request *apiv2p return nil, status.Errorf(codes.Internal, "failed to upsert access token to store: %v", err) } + userAccessToken := &apiv2pb.UserAccessToken{ + AccessToken: accessToken, + Description: request.UserAccessToken.Description, + IssuedAt: timestamppb.New(claims.IssuedAt.Time), + } + if claims.ExpiresAt != nil { + userAccessToken.ExpiresAt = timestamppb.New(claims.ExpiresAt.Time) + } response := &apiv2pb.CreateUserAccessTokenResponse{ - AccessToken: &apiv2pb.UserAccessToken{ - AccessToken: accessToken, - Description: request.UserAccessToken.Description, - IssuedAt: timestamppb.New(claims.IssuedAt.Time), - ExpiresAt: timestamppb.New(claims.ExpiresAt.Time), - }, + AccessToken: userAccessToken, } return response, nil } diff --git a/web/src/components/CreateAccessTokenDialog.tsx b/web/src/components/CreateAccessTokenDialog.tsx index 8e6ef2c..a728080 100644 --- a/web/src/components/CreateAccessTokenDialog.tsx +++ b/web/src/components/CreateAccessTokenDialog.tsx @@ -25,8 +25,8 @@ const expirationOptions = [ value: 3600 * 24 * 7, }, { - label: "Life time", - value: 3600 * 24 * 365 * 100, + label: "Never", + value: 0, }, ]; diff --git a/web/src/components/setting/AccessTokenSection.tsx b/web/src/components/setting/AccessTokenSection.tsx index e06354c..57429e8 100644 --- a/web/src/components/setting/AccessTokenSection.tsx +++ b/web/src/components/setting/AccessTokenSection.tsx @@ -95,7 +95,9 @@ const AccessTokenSection = () => { {getFormatedAccessToken(userAccessToken.accessToken)}