chore: extract jwt from server to api

This commit is contained in:
Steven 2023-06-24 13:58:38 +08:00
parent c4388fb211
commit 0a811d2568
8 changed files with 41 additions and 48 deletions

View File

@ -12,14 +12,6 @@ import (
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
var (
userIDContextKey = "user-id"
)
func getUserIDContextKey() string {
return userIDContextKey
}
type SignInRequest struct { type SignInRequest struct {
Email string `json:"email"` Email string `json:"email"`
Password string `json:"password"` Password string `json:"password"`

View File

@ -1,4 +1,4 @@
package server package v1
import ( import (
"fmt" "fmt"
@ -7,6 +7,7 @@ import (
"strings" "strings"
"time" "time"
"github.com/boojack/shortify/internal/util"
"github.com/boojack/shortify/server/auth" "github.com/boojack/shortify/server/auth"
"github.com/boojack/shortify/store" "github.com/boojack/shortify/store"
"github.com/golang-jwt/jwt/v4" "github.com/golang-jwt/jwt/v4"
@ -71,19 +72,19 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
// JWTMiddleware validates the access token. // JWTMiddleware validates the access token.
// If the access token is about to expire or has expired and the request has a valid refresh token, it // If the access token is about to expire or has expired and the request has a valid refresh token, it
// will try to generate new access token and refresh token. // will try to generate new access token and refresh token.
func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
path := c.Path() path := c.Path()
method := c.Request().Method method := c.Request().Method
if server.defaultAuthSkipper(c) { if defaultAuthSkipper(c) {
return next(c) return next(c)
} }
token := findAccessToken(c) token := findAccessToken(c)
if token == "" { if token == "" {
// When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts. // When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts.
if hasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet { if util.HasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet {
return next(c) return next(c)
} }
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
@ -194,3 +195,8 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c) return next(c)
} }
} }
func defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
return util.HasPrefixes(path, "/api/v1/auth")
}

View File

@ -21,6 +21,9 @@ func NewAPIV1Service(profile *profile.Profile, store *store.Store) *APIV1Service
func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) { func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) {
apiV1Group := apiGroup.Group("/api/v1") apiV1Group := apiGroup.Group("/api/v1")
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
s.registerSystemRoutes(apiV1Group) s.registerSystemRoutes(apiV1Group)
s.registerWorkspaceSettingRoutes(apiV1Group) s.registerWorkspaceSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group, secret) s.registerAuthRoutes(apiV1Group, secret)
@ -28,5 +31,8 @@ func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) {
s.registerShortcutRoutes(apiV1Group) s.registerShortcutRoutes(apiV1Group)
redirectorGroup := apiGroup.Group("/o") redirectorGroup := apiGroup.Group("/o")
redirectorGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
s.registerRedirectorRoutes(redirectorGroup) s.registerRedirectorRoutes(redirectorGroup)
} }

13
internal/util/util.go Normal file
View File

@ -0,0 +1,13 @@
package util
import "strings"
// HasPrefixes returns true if the string s has any of the given prefixes.
func HasPrefixes(src string, prefixes ...string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(src, prefix) {
return true
}
}
return false
}

View File

@ -1,27 +0,0 @@
package server
import (
"strings"
"github.com/labstack/echo/v4"
)
// hasPrefixes returns true if the string s has any of the given prefixes.
func hasPrefixes(src string, prefixes ...string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(src, prefix) {
return true
}
}
return false
}
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return hasPrefixes(path, "/api")
}
func (*Server) defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
return hasPrefixes(path, "/api/v1/auth")
}

View File

@ -5,7 +5,7 @@
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" /> <meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Corgi</title> <title>Shortify</title>
</head> </head>
<body> <body>
<p>No frontend embeded.</p> <p>No frontend embeded.</p>

View File

@ -5,6 +5,7 @@ import (
"io/fs" "io/fs"
"net/http" "net/http"
"github.com/boojack/shortify/internal/util"
"github.com/labstack/echo/v4" "github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware" "github.com/labstack/echo/v4/middleware"
) )
@ -21,6 +22,11 @@ func getFileSystem(path string) http.FileSystem {
return http.FS(fs) return http.FS(fs)
} }
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return util.HasPrefixes(path, "/api")
}
func embedFrontend(e *echo.Echo) { func embedFrontend(e *echo.Echo) {
// Use echo static middleware to serve the built dist folder // Use echo static middleware to serve the built dist folder
// refer: https://github.com/labstack/echo/blob/master/middleware/static.go // refer: https://github.com/labstack/echo/blob/master/middleware/static.go
@ -30,14 +36,14 @@ func embedFrontend(e *echo.Echo) {
Filesystem: getFileSystem("dist"), Filesystem: getFileSystem("dist"),
})) }))
g := e.Group("assets") assetsGroup := e.Group("assets")
g.Use(func(next echo.HandlerFunc) echo.HandlerFunc { assetsGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error { return func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderCacheControl, "max-age=31536000, immutable") c.Response().Header().Set(echo.HeaderCacheControl, "max-age=31536000, immutable")
return next(c) return next(c)
} }
}) })
g.Use(middleware.StaticWithConfig(middleware.StaticConfig{ assetsGroup.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Skipper: defaultAPIRequestSkipper, Skipper: defaultAPIRequestSkipper,
HTML5: true, HTML5: true,
Filesystem: getFileSystem("dist/assets"), Filesystem: getFileSystem("dist/assets"),

View File

@ -63,9 +63,6 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) {
apiGroup := e.Group("") apiGroup := e.Group("")
// Register API v1 routes. // Register API v1 routes.
apiV1Service := apiv1.NewAPIV1Service(profile, store) apiV1Service := apiv1.NewAPIV1Service(profile, store)
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
apiV1Service.Start(apiGroup, secret) apiV1Service.Start(apiGroup, secret)
return s, nil return s, nil
@ -79,12 +76,12 @@ func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second) ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel() defer cancel()
// Shutdown echo server // Shutdown echo server.
if err := s.e.Shutdown(ctx); err != nil { if err := s.e.Shutdown(ctx); err != nil {
fmt.Printf("failed to shutdown server, error: %v\n", err) fmt.Printf("failed to shutdown server, error: %v\n", err)
} }
// Close database connection // Close database connection.
if err := s.Store.Close(); err != nil { if err := s.Store.Close(); err != nil {
fmt.Printf("failed to close database, error: %v\n", err) fmt.Printf("failed to close database, error: %v\n", err)
} }