feat: implement part of user service

This commit is contained in:
Steven
2023-08-02 07:35:36 +08:00
parent dfe47b9b7e
commit 59a75c89eb
9 changed files with 411 additions and 49 deletions

17
api/v2/common.go Normal file
View File

@ -0,0 +1,17 @@
package v2
import (
apiv2pb "github.com/boojack/slash/proto/gen/api/v2"
"github.com/boojack/slash/store"
)
func convertRowStatusFromStore(rowStatus store.RowStatus) apiv2pb.RowStatus {
switch rowStatus {
case store.Normal:
return apiv2pb.RowStatus_ACTIVE
case store.Archived:
return apiv2pb.RowStatus_ARCHIVED
default:
return apiv2pb.RowStatus_ROW_STATUS_UNSPECIFIED
}
}

193
api/v2/jwt.go Normal file
View File

@ -0,0 +1,193 @@
package v2
import (
"context"
"net/http"
"strconv"
"strings"
"time"
"github.com/golang-jwt/jwt/v4"
"github.com/pkg/errors"
"github.com/boojack/slash/api/auth"
"github.com/boojack/slash/store"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
var authenticationAllowlistMethods = map[string]bool{
"/memos.api.v2.UserService/GetUser": true,
}
// IsAuthenticationAllowed returns whether the method is exempted from authentication.
func IsAuthenticationAllowed(fullMethodName string) bool {
if strings.HasPrefix(fullMethodName, "/grpc.reflection") {
return true
}
return authenticationAllowlistMethods[fullMethodName]
}
// ContextKey is the key type of context value.
type ContextKey int
const (
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
UserIDContextKey ContextKey = iota
)
// GRPCAuthInterceptor is the auth interceptor for gRPC server.
type GRPCAuthInterceptor struct {
store *store.Store
secret string
}
// NewGRPCAuthInterceptor returns a new API auth interceptor.
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
return &GRPCAuthInterceptor{
store: store,
secret: secret,
}
}
// AuthenticationInterceptor is the unary interceptor for gRPC API.
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
}
accessTokenStr, err := getTokenFromMetadata(md)
if err != nil {
return nil, status.Errorf(codes.Unauthenticated, err.Error())
}
userID, err := in.authenticate(ctx, accessTokenStr)
if err != nil {
if IsAuthenticationAllowed(serverInfo.FullMethod) {
return handler(ctx, request)
}
return nil, err
}
// Stores userID into context.
childCtx := context.WithValue(ctx, UserIDContextKey, userID)
return handler(childCtx, request)
}
func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessTokenStr string) (int, error) {
if accessTokenStr == "" {
return 0, status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &claimsMessage{}
_, err := jwt.ParseWithClaims(accessTokenStr, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(in.secret), nil
}
}
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "Invalid or expired access token")
}
if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) {
return 0, status.Errorf(codes.Unauthenticated,
"invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment",
claims.Audience,
auth.AccessTokenAudienceName,
)
}
userID, err := strconv.Atoi(claims.Subject)
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "malformed ID %q in the access token", claims.Subject)
}
user, err := in.store.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return 0, status.Errorf(codes.Unauthenticated, "failed to find user ID %q in the access token", userID)
}
if user == nil {
return 0, status.Errorf(codes.Unauthenticated, "user ID %q not exists in the access token", userID)
}
if user.RowStatus == store.Archived {
return 0, status.Errorf(codes.Unauthenticated, "user ID %q has been deactivated by administrators", userID)
}
return userID, nil
}
func getTokenFromMetadata(md metadata.MD) (string, error) {
authorizationHeaders := md.Get("Authorization")
if len(md.Get("Authorization")) > 0 {
authHeaderParts := strings.Fields(authorizationHeaders[0])
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.Errorf("authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
// check the HTTP cookie
var accessToken string
for _, t := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
accessToken = v.Value
}
}
return accessToken, nil
}
func audienceContains(audience jwt.ClaimStrings, token string) bool {
for _, v := range audience {
if v == token {
return true
}
}
return false
}
type claimsMessage struct {
Name string `json:"name"`
jwt.RegisteredClaims
}
// GenerateAccessToken generates an access token for web.
func GenerateAccessToken(username string, userID int, secret string) (string, error) {
expirationTime := time.Now().Add(auth.AccessTokenDuration)
return generateToken(username, userID, auth.AccessTokenAudienceName, expirationTime, []byte(secret))
}
func generateToken(username string, userID int, aud string, expirationTime time.Time, secret []byte) (string, error) {
// Create the JWT claims, which includes the username and expiry time.
claims := &claimsMessage{
Name: username,
RegisteredClaims: jwt.RegisteredClaims{
Audience: jwt.ClaimStrings{aud},
// In JWT, the expiry time is expressed as unix milliseconds.
ExpiresAt: jwt.NewNumericDate(expirationTime),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: auth.Issuer,
Subject: strconv.Itoa(userID),
},
}
// Declare the token with the HS256 algorithm used for signing, and the claims.
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
token.Header["kid"] = auth.KeyID
// Create the JWT string.
tokenString, err := token.SignedString(secret)
if err != nil {
return "", err
}
return tokenString, nil
}

65
api/v2/user_service.go Normal file
View File

@ -0,0 +1,65 @@
package v2
import (
"context"
apiv2pb "github.com/boojack/slash/proto/gen/api/v2"
"github.com/boojack/slash/store"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)
type UserService struct {
apiv2pb.UnimplementedUserServiceServer
Store *store.Store
}
// NewUserService creates a new UserService.
func NewUserService(store *store.Store) *UserService {
return &UserService{
Store: store,
}
}
func (s *UserService) GetUser(ctx context.Context, request *apiv2pb.GetUserRequest) (*apiv2pb.GetUserResponse, error) {
id := int(request.Id)
user, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &id,
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list tags: %v", err)
}
if user == nil {
return nil, status.Errorf(codes.NotFound, "user not found")
}
userMessage := convertUserFromStore(user)
response := &apiv2pb.GetUserResponse{
User: userMessage,
}
return response, nil
}
func convertUserFromStore(user *store.User) *apiv2pb.User {
return &apiv2pb.User{
Id: int32(user.ID),
RowStatus: convertRowStatusFromStore(user.RowStatus),
CreatedTs: user.CreatedTs,
UpdatedTs: user.UpdatedTs,
Role: convertUserRoleFromStore(user.Role),
Email: user.Email,
Nickname: user.Nickname,
}
}
func convertUserRoleFromStore(role store.Role) apiv2pb.Role {
switch role {
case store.RoleAdmin:
return apiv2pb.Role_ADMIN
case store.RoleUser:
return apiv2pb.Role_USER
default:
return apiv2pb.Role_ROLE_UNSPECIFIED
}
}

67
api/v2/v2.go Normal file
View File

@ -0,0 +1,67 @@
package v2
import (
"context"
"fmt"
apiv2pb "github.com/boojack/slash/proto/gen/api/v2"
"github.com/boojack/slash/server/profile"
"github.com/boojack/slash/store"
grpcRuntime "github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"github.com/labstack/echo/v4"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
type APIV2Service struct {
Secret string
Profile *profile.Profile
Store *store.Store
grpcServer *grpc.Server
grpcServerPort int
}
func NewAPIV2Service(secret string, profile *profile.Profile, store *store.Store, grpcServerPort int) *APIV2Service {
authProvider := NewGRPCAuthInterceptor(store, secret)
grpcServer := grpc.NewServer(
grpc.ChainUnaryInterceptor(
authProvider.AuthenticationInterceptor,
),
)
apiv2pb.RegisterUserServiceServer(grpcServer, NewUserService(store))
return &APIV2Service{
Secret: secret,
Profile: profile,
Store: store,
grpcServer: grpcServer,
grpcServerPort: grpcServerPort,
}
}
func (s *APIV2Service) GetGRPCServer() *grpc.Server {
return s.grpcServer
}
// RegisterGateway registers the gRPC-Gateway with the given Echo instance.
func (s *APIV2Service) RegisterGateway(ctx context.Context, e *echo.Echo) error {
// Create a client connection to the gRPC Server we just started.
// This is where the gRPC-Gateway proxies the requests.
conn, err := grpc.DialContext(
ctx,
fmt.Sprintf(":%d", s.grpcServerPort),
grpc.WithTransportCredentials(insecure.NewCredentials()),
)
if err != nil {
return err
}
gwMux := grpcRuntime.NewServeMux()
if err := apiv2pb.RegisterUserServiceHandler(context.Background(), gwMux, conn); err != nil {
return err
}
e.Any("/api/v2/*", echo.WrapHandler(gwMux))
return nil
}