diff --git a/api/v1/auth.go b/api/v1/auth.go index 01faac9..bf11e59 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -12,14 +12,6 @@ import ( "golang.org/x/crypto/bcrypt" ) -var ( - userIDContextKey = "user-id" -) - -func getUserIDContextKey() string { - return userIDContextKey -} - type SignInRequest struct { Email string `json:"email"` Password string `json:"password"` diff --git a/server/jwt.go b/api/v1/jwt.go similarity index 94% rename from server/jwt.go rename to api/v1/jwt.go index 2c87bae..eb11fac 100644 --- a/server/jwt.go +++ b/api/v1/jwt.go @@ -1,4 +1,4 @@ -package server +package v1 import ( "fmt" @@ -7,6 +7,7 @@ import ( "strings" "time" + "github.com/boojack/shortify/internal/util" "github.com/boojack/shortify/server/auth" "github.com/boojack/shortify/store" "github.com/golang-jwt/jwt/v4" @@ -71,19 +72,19 @@ func audienceContains(audience jwt.ClaimStrings, token string) bool { // JWTMiddleware validates the access token. // If the access token is about to expire or has expired and the request has a valid refresh token, it // will try to generate new access token and refresh token. -func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.HandlerFunc { +func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { return func(c echo.Context) error { path := c.Path() method := c.Request().Method - if server.defaultAuthSkipper(c) { + if defaultAuthSkipper(c) { return next(c) } token := findAccessToken(c) if token == "" { // When the request is not authenticated, we allow the user to access the shortcut endpoints for those public shortcuts. - if hasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet { + if util.HasPrefixes(path, "/api/v1/status", "/o/*") && method == http.MethodGet { return next(c) } return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") @@ -194,3 +195,8 @@ func JWTMiddleware(server *Server, next echo.HandlerFunc, secret string) echo.Ha return next(c) } } + +func defaultAuthSkipper(c echo.Context) bool { + path := c.Path() + return util.HasPrefixes(path, "/api/v1/auth") +} diff --git a/api/v1/v1.go b/api/v1/v1.go index 150d892..b1f02fc 100644 --- a/api/v1/v1.go +++ b/api/v1/v1.go @@ -21,6 +21,9 @@ func NewAPIV1Service(profile *profile.Profile, store *store.Store) *APIV1Service func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) { apiV1Group := apiGroup.Group("/api/v1") + apiV1Group.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return JWTMiddleware(s, next, string(secret)) + }) s.registerSystemRoutes(apiV1Group) s.registerWorkspaceSettingRoutes(apiV1Group) s.registerAuthRoutes(apiV1Group, secret) @@ -28,5 +31,8 @@ func (s *APIV1Service) Start(apiGroup *echo.Group, secret string) { s.registerShortcutRoutes(apiV1Group) redirectorGroup := apiGroup.Group("/o") + redirectorGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return JWTMiddleware(s, next, string(secret)) + }) s.registerRedirectorRoutes(redirectorGroup) } diff --git a/internal/util/util.go b/internal/util/util.go new file mode 100644 index 0000000..1547b05 --- /dev/null +++ b/internal/util/util.go @@ -0,0 +1,13 @@ +package util + +import "strings" + +// HasPrefixes returns true if the string s has any of the given prefixes. +func HasPrefixes(src string, prefixes ...string) bool { + for _, prefix := range prefixes { + if strings.HasPrefix(src, prefix) { + return true + } + } + return false +} diff --git a/server/common.go b/server/common.go deleted file mode 100644 index 1ee22d1..0000000 --- a/server/common.go +++ /dev/null @@ -1,27 +0,0 @@ -package server - -import ( - "strings" - - "github.com/labstack/echo/v4" -) - -// hasPrefixes returns true if the string s has any of the given prefixes. -func hasPrefixes(src string, prefixes ...string) bool { - for _, prefix := range prefixes { - if strings.HasPrefix(src, prefix) { - return true - } - } - return false -} - -func defaultAPIRequestSkipper(c echo.Context) bool { - path := c.Path() - return hasPrefixes(path, "/api") -} - -func (*Server) defaultAuthSkipper(c echo.Context) bool { - path := c.Path() - return hasPrefixes(path, "/api/v1/auth") -} diff --git a/server/dist/index.html b/server/dist/index.html index 45533a7..7c1d8fb 100644 --- a/server/dist/index.html +++ b/server/dist/index.html @@ -5,7 +5,7 @@ - Corgi + Shortify

No frontend embeded.

diff --git a/server/embed_frontend.go b/server/embed_frontend.go index 505e9ae..33616d9 100644 --- a/server/embed_frontend.go +++ b/server/embed_frontend.go @@ -5,6 +5,7 @@ import ( "io/fs" "net/http" + "github.com/boojack/shortify/internal/util" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" ) @@ -21,6 +22,11 @@ func getFileSystem(path string) http.FileSystem { return http.FS(fs) } +func defaultAPIRequestSkipper(c echo.Context) bool { + path := c.Path() + return util.HasPrefixes(path, "/api") +} + func embedFrontend(e *echo.Echo) { // Use echo static middleware to serve the built dist folder // refer: https://github.com/labstack/echo/blob/master/middleware/static.go @@ -30,14 +36,14 @@ func embedFrontend(e *echo.Echo) { Filesystem: getFileSystem("dist"), })) - g := e.Group("assets") - g.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + assetsGroup := e.Group("assets") + assetsGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { c.Response().Header().Set(echo.HeaderCacheControl, "max-age=31536000, immutable") return next(c) } }) - g.Use(middleware.StaticWithConfig(middleware.StaticConfig{ + assetsGroup.Use(middleware.StaticWithConfig(middleware.StaticConfig{ Skipper: defaultAPIRequestSkipper, HTML5: true, Filesystem: getFileSystem("dist/assets"), diff --git a/server/server.go b/server/server.go index c3c504d..d6a99ce 100644 --- a/server/server.go +++ b/server/server.go @@ -63,9 +63,6 @@ func NewServer(profile *profile.Profile, store *store.Store) (*Server, error) { apiGroup := e.Group("") // Register API v1 routes. apiV1Service := apiv1.NewAPIV1Service(profile, store) - apiGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { - return JWTMiddleware(s, next, string(secret)) - }) apiV1Service.Start(apiGroup, secret) return s, nil @@ -79,12 +76,12 @@ func (s *Server) Shutdown(ctx context.Context) { ctx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - // Shutdown echo server + // Shutdown echo server. if err := s.e.Shutdown(ctx); err != nil { fmt.Printf("failed to shutdown server, error: %v\n", err) } - // Close database connection + // Close database connection. if err := s.Store.Close(); err != nil { fmt.Printf("failed to close database, error: %v\n", err) }