chore: extract jwt from server to api

This commit is contained in:
Steven 2023-06-24 13:58:38 +08:00
parent c4388fb211
commit 0a811d2568
8 changed files with 41 additions and 48 deletions

View File

@ -12,14 +12,6 @@ import (
"golang.org/x/crypto/bcrypt"
)
var (
userIDContextKey = "user-id"
)
func getUserIDContextKey() string {
return userIDContextKey
}
type SignInRequest struct {
Email string `json:"email"`
Password string `json:"password"`

View File

@ -1,4 +1,4 @@
package server
package v1
import (
"fmt"
@ -7,6 +7,7 @@ import (
"strings"
"time"
"github.com/boojack/shortify/internal/util"
"github.com/boojack/shortify/server/auth"
"github.com/boojack/shortify/store"
"github.com/golang-jwt/jwt/v4"
@ -71,19 +72,19 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool {
// JWTMiddleware validates the access token.
// If the access token is about to expire or has expired and the request has a valid refresh token, it
// will try to generate new access token and refresh token.
func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc {
func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc {
return func(c echo.Context) error {
path := c.Path()
method := c.Request().Method
if server.defaultAuthSkipper(c) {
if defaultAuthSkipper(c) {
return next(c)
}
token := findAccessToken(c)
if token == "" {
// When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts.
if hasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet {
if util.HasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet {
return next(c)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
@ -194,3 +195,8 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha
return next(c)
}
}
func defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
return util.HasPrefixes(path, "/api/v1/auth")
}

View File

@ -21,6 +21,9 @@ func NewAPIV1Service(profile *profile.Profile, store *store.Store) *APIV1Service
func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) {
apiV1Group := apiGroup.Group("/api/v1")
apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
s.registerSystemRoutes(apiV1Group)
s.registerWorkspaceSettingRoutes(apiV1Group)
s.registerAuthRoutes(apiV1Group, secret)
@ -28,5 +31,8 @@ func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) {
s.registerShortcutRoutes(apiV1Group)
redirectorGroup := apiGroup.Group("/o")
redirectorGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
s.registerRedirectorRoutes(redirectorGroup)
}

13
internal/util/util.go Normal file
View File

@ -0,0 +1,13 @@
package util
import "strings"
// HasPrefixes returns true if the string s has any of the given prefixes.
func HasPrefixes(src string, prefixes ...string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(src, prefix) {
return true
}
}
return false
}

View File

@ -1,27 +0,0 @@
package server
import (
"strings"
"github.com/labstack/echo/v4"
)
// hasPrefixes returns true if the string s has any of the given prefixes.
func hasPrefixes(src string, prefixes ...string) bool {
for _, prefix := range prefixes {
if strings.HasPrefix(src, prefix) {
return true
}
}
return false
}
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return hasPrefixes(path, "/api")
}
func (*Server) defaultAuthSkipper(c echo.Context) bool {
path := c.Path()
return hasPrefixes(path, "/api/v1/auth")
}

View File

@ -5,7 +5,7 @@
<meta charset="UTF-8" />
<meta http-equiv="X-UA-Compatible" content="IE=edge" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Corgi</title>
<title>Shortify</title>
</head>
<body>
<p>No frontend embeded.</p>

View File

@ -5,6 +5,7 @@ import (
"io/fs"
"net/http"
"github.com/boojack/shortify/internal/util"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
@ -21,6 +22,11 @@ func getFileSystem(path string) http.FileSystem {
return http.FS(fs)
}
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return util.HasPrefixes(path, "/api")
}
func embedFrontend(e *echo.Echo) {
// Use echo static middleware to serve the built dist folder
// refer: https://github.com/labstack/echo/blob/master/middleware/static.go
@ -30,14 +36,14 @@ func embedFrontend(e *echo.Echo) {
Filesystem: getFileSystem("dist"),
}))
g := e.Group("assets")
g.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
assetsGroup := e.Group("assets")
assetsGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
c.Response().Header().Set(echo.HeaderCacheControl, "max-age=31536000, immutable")
return next(c)
}
})
g.Use(middleware.StaticWithConfig(middleware.StaticConfig{
assetsGroup.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Skipper: defaultAPIRequestSkipper,
HTML5: true,
Filesystem: getFileSystem("dist/assets"),

View File

@ -63,9 +63,6 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) {
apiGroup := e.Group("")
// Register API v1 routes.
apiV1Service := apiv1.NewAPIV1Service(profile, store)
apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return JWTMiddleware(s, next, string(secret))
})
apiV1Service.Start(apiGroup, secret)
return s, nil
@ -79,12 +76,12 @@ func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Shutdown echo server
// Shutdown echo server.
if err := s.e.Shutdown(ctx); err != nil {
fmt.Printf("failed to shutdown server, error: %v\n", err)
}
// Close database connection
// Close database connection.
if err := s.Store.Close(); err != nil {
fmt.Printf("failed to close database, error: %v\n", err)
}