feat: validate access token

This commit is contained in:
Steven
2023-08-06 14:16:23 +08:00
parent d8903875d3
commit 84ddafeb84
5 changed files with 133 additions and 52 deletions

View File

@ -7,6 +7,7 @@ import (
"github.com/boojack/slash/api/auth"
"github.com/boojack/slash/internal/util"
storepb "github.com/boojack/slash/proto/gen/store"
"github.com/boojack/slash/store"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
@ -44,14 +45,14 @@ type claimsMessage struct {
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
store *store.Store
Store *store.Store
secret string
}
// NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
store: store,
Store: store,
secret: secret,
}
}
@ -62,12 +63,12 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
accessTokenStr, err := getTokenFromMetadata(md)
accessToken, err := getTokenFromMetadata(md)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, err.Error())
}
userID, err := in.authenticate(ctx, accessTokenStr)
userID, err := in.authenticate(ctx, accessToken)
if err != nil {
if IsAuthenticationAllowed(serverInfo.FullMethod) {
return handler(ctx, request)
@ -75,6 +76,14 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re
return nil, err
}
accessTokens, err := in.Store.GetUserAccessTokens(ctx, userID)
if err != nil {
return nil, errors.Wrap(err, "failed to get user access tokens")
}
if !validateAccessToken(accessToken, accessTokens) {
return nil, status.Errorf(codes.Unauthenticated, "invalid access token")
}
// Stores userID into context.
childCtx := context.WithValue(ctx, UserIDContextKey, userID)
return handler(childCtx, request)
@ -111,7 +120,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject)
}
user, err := in.store.GetUser(ctx, &store.FindUser{
user, err := in.Store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
@ -157,3 +166,12 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
}
return false
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken && !userAccessToken.Revoked {
return true
}
}
return false
}