feat: refactor code structure

This commit is contained in:
Steven
2023-02-23 08:22:06 +08:00
parent f77a84a649
commit 0fbbcae872
36 changed files with 768 additions and 936 deletions

View File

@ -15,6 +15,7 @@ import (
var (
userIDContextKey = "user-id"
sessionName = "corgi_session"
)
func getUserIDContextKey() string {
@ -22,7 +23,7 @@ func getUserIDContextKey() string {
}
func setUserSession(ctx echo.Context, user *api.User) error {
sess, _ := session.Get("session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 1000 * 3600 * 24 * 30,
@ -37,7 +38,7 @@ func setUserSession(ctx echo.Context, user *api.User) error {
}
func removeUserSession(ctx echo.Context) error {
sess, _ := session.Get("session", ctx)
sess, _ := session.Get(sessionName, ctx)
sess.Options = &sessions.Options{
Path: "/",
MaxAge: 0,
@ -55,56 +56,31 @@ func aclMiddleware(s *Server, next echo.HandlerFunc) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
path := c.Path()
// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
if s.defaultAuthSkipper(c) {
return next(c)
}
if common.HasPrefixes(path, "/api/ping", "/api/status", "/api/user/:id") && c.Request().Method == http.MethodGet {
return next(c)
}
{
// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user by open_id").SetInternal(err)
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return next(c)
sess, _ := session.Get(sessionName, c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
}
c.Set(getUserIDContextKey(), userID)
}
}
{
sess, _ := session.Get("session", c)
userIDValue := sess.Values[userIDContextKey]
if userIDValue != nil {
userID, _ := strconv.Atoi(fmt.Sprintf("%v", userIDValue))
userFind := &api.UserFind{
ID: &userID,
}
user, err := s.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find user by ID: %d", userID)).SetInternal(err)
}
if user != nil {
if user.RowStatus == api.Archived {
return echo.NewHTTPError(http.StatusForbidden, fmt.Sprintf("User has been archived with email %s", user.Email))
}
c.Set(getUserIDContextKey(), userID)
}
}
}
if common.HasPrefixes(path, "/go/:shortcutName", "/api/workspace/:workspaceName/shortcut/:shortcutName") && c.Request().Method == http.MethodGet {
if common.HasPrefixes(path, "/api/ping", "/api/status") && c.Request().Method == http.MethodGet {
return next(c)
}

View File

@ -42,12 +42,7 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
if err = setUserSession(c, user); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signin session").SetInternal(err)
}
c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8)
if err := json.NewEncoder(c.Response().Writer).Encode(composeResponse(user)); err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to encode user response").SetInternal(err)
}
return nil
return c.JSON(http.StatusOK, composeResponse(user))
})
g.POST("/auth/signup", func(c echo.Context) error {
@ -58,10 +53,10 @@ func (s *Server) registerAuthRoutes(g *echo.Group) {
}
userCreate := &api.UserCreate{
Email: signup.Email,
Name: signup.Name,
Password: signup.Password,
OpenID: common.GenUUID(),
Email: signup.Email,
DisplayName: signup.DisplayName,
Password: signup.Password,
OpenID: common.GenUUID(),
}
if err := userCreate.Validate(); err != nil {
return echo.NewHTTPError(http.StatusBadRequest, "Invalid user create format.").SetInternal(err)

View File

@ -1,5 +1,11 @@
package server
import (
"github.com/boojack/corgi/api"
"github.com/boojack/corgi/common"
"github.com/labstack/echo/v4"
)
func composeResponse(data interface{}) interface{} {
type R struct {
Data interface{} `json:"data"`
@ -9,3 +15,37 @@ func composeResponse(data interface{}) interface{} {
Data: data,
}
}
func defaultAPIRequestSkipper(c echo.Context) bool {
path := c.Path()
return common.HasPrefixes(path, "/api", "/o")
}
func (server *Server) defaultAuthSkipper(c echo.Context) bool {
ctx := c.Request().Context()
path := c.Path()
// Skip auth.
if common.HasPrefixes(path, "/api/auth") {
return true
}
// If there is openId in query string and related user is found, then skip auth.
openID := c.QueryParam("openId")
if openID != "" {
userFind := &api.UserFind{
OpenID: &openID,
}
user, err := server.Store.FindUser(ctx, userFind)
if err != nil && common.ErrorCode(err) != common.NotFound {
return false
}
if user != nil {
// Stores userID into context.
c.Set(getUserIDContextKey(), user.ID)
return true
}
}
return false
}

View File

@ -5,7 +5,6 @@ import (
"io/fs"
"net/http"
"github.com/boojack/corgi/common"
"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
)
@ -22,16 +21,11 @@ func getFileSystem(path string) http.FileSystem {
return http.FS(fs)
}
func skipper(c echo.Context) bool {
path := c.Path()
return common.HasPrefixes(path, "/api", "/o")
}
func embedFrontend(e *echo.Echo) {
// Use echo static middleware to serve the built dist folder
// refer: https://github.com/labstack/echo/blob/master/middleware/static.go
e.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Skipper: skipper,
Skipper: defaultAPIRequestSkipper,
HTML5: true,
Filesystem: getFileSystem("dist"),
}))
@ -44,7 +38,7 @@ func embedFrontend(e *echo.Echo) {
}
})
g.Use(middleware.StaticWithConfig(middleware.StaticConfig{
Skipper: skipper,
Skipper: defaultAPIRequestSkipper,
HTML5: true,
Filesystem: getFileSystem("dist/assets"),
}))

View File

@ -1,13 +1,13 @@
package profile
import (
"flag"
"fmt"
"os"
"path/filepath"
"strings"
"github.com/boojack/corgi/server/version"
"github.com/spf13/viper"
)
// Profile is the configuration to start main server.
@ -47,10 +47,10 @@ func checkDSN(dataDir string) (string, error) {
// GetDevProfile will return a profile for dev or prod.
func GetProfile() (*Profile, error) {
profile := Profile{}
flag.StringVar(&profile.Mode, "mode", "dev", "mode of server")
flag.IntVar(&profile.Port, "port", 8081, "port of server")
flag.StringVar(&profile.Data, "data", "", "data directory")
flag.Parse()
err := viper.Unmarshal(&profile)
if err != nil {
return nil, err
}
if profile.Mode != "dev" && profile.Mode != "prod" {
profile.Mode = "dev"
@ -69,6 +69,5 @@ func GetProfile() (*Profile, error) {
profile.Data = dataDir
profile.DSN = fmt.Sprintf("%s/corgi_%s.db", dataDir, profile.Mode)
profile.Version = version.GetCurrentVersion(profile.Mode)
return &profile, nil
}

View File

@ -1,11 +1,15 @@
package server
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/boojack/corgi/server/profile"
"github.com/boojack/corgi/store"
"github.com/boojack/corgi/store/db"
"github.com/pkg/errors"
"github.com/gorilla/securecookie"
"github.com/gorilla/sessions"
@ -15,25 +19,40 @@ import (
)
type Server struct {
e *echo.Echo
e *echo.Echo
db *sql.DB
Profile *profile.Profile
Store *store.Store
Store *store.Store
}
func NewServer(profile *profile.Profile) *Server {
func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) {
e := echo.New()
e.Debug = true
e.HideBanner = true
e.HidePort = true
db := db.NewDB(profile)
if err := db.Open(ctx); err != nil {
return nil, errors.Wrap(err, "cannot open db")
}
s := &Server{
e: e,
db: db.DBInstance,
Profile: profile,
}
storeInstance := store.New(db.DBInstance, profile)
s.Store = storeInstance
e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
Format: `{"time":"${time_rfc3339}",` +
`"method":"${method}","uri":"${uri}",` +
`"status":${status},"error":"${error}"}` + "\n",
}))
e.Use(middleware.Gzip())
e.Use(middleware.CORS())
e.Use(middleware.TimeoutWithConfig(middleware.TimeoutConfig{
@ -51,11 +70,6 @@ func NewServer(profile *profile.Profile) *Server {
}
e.Use(session.Middleware(sessions.NewCookieStore(secret)))
s := &Server{
e: e,
Profile: profile,
}
redirectGroup := e.Group("/o")
redirectGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return aclMiddleware(s, next)
@ -73,9 +87,26 @@ func NewServer(profile *profile.Profile) *Server {
s.registerWorkspaceUserRoutes(apiGroup)
s.registerShortcutRoutes(apiGroup)
return s
return s, nil
}
func (server *Server) Run() error {
return server.e.Start(fmt.Sprintf(":%d", server.Profile.Port))
func (s *Server) Start(_ context.Context) error {
return s.e.Start(fmt.Sprintf(":%d", s.Profile.Port))
}
func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
defer cancel()
// Shutdown echo server
if err := s.e.Shutdown(ctx); err != nil {
fmt.Printf("failed to shutdown server, error: %v\n", err)
}
// Close database connection
if err := s.db.Close(); err != nil {
fmt.Printf("failed to close database, error: %v\n", err)
}
fmt.Printf("server stopped properly\n")
}

View File

@ -1,8 +1,10 @@
package version
import (
"strconv"
"fmt"
"strings"
"golang.org/x/mod/semver"
)
// Version is the service current released version.
@ -29,39 +31,31 @@ func GetMinorVersion(version string) string {
func GetSchemaVersion(version string) string {
minorVersion := GetMinorVersion(version)
return minorVersion + ".0"
}
// convSemanticVersionToInt converts version string to int.
func convSemanticVersionToInt(version string) int {
versionList := strings.Split(version, ".")
if len(versionList) < 3 {
return 0
}
major, err := strconv.Atoi(versionList[0])
if err != nil {
return 0
}
minor, err := strconv.Atoi(versionList[1])
if err != nil {
return 0
}
patch, err := strconv.Atoi(versionList[2])
if err != nil {
return 0
}
return major*10000 + minor*100 + patch
}
// IsVersionGreaterThanOrEqualTo returns true if version is greater than or equal to target.
func IsVersionGreaterOrEqualThan(version, target string) bool {
return convSemanticVersionToInt(version) >= convSemanticVersionToInt(target)
return semver.Compare(fmt.Sprintf("v%s", version), fmt.Sprintf("v%s", target)) > -1
}
// IsVersionGreaterThan returns true if version is greater than target.
func IsVersionGreaterThan(version, target string) bool {
return convSemanticVersionToInt(version) > convSemanticVersionToInt(target)
return semver.Compare(fmt.Sprintf("v%s", version), fmt.Sprintf("v%s", target)) > 0
}
type SortVersion []string
func (s SortVersion) Len() int {
return len(s)
}
func (s SortVersion) Swap(i, j int) {
s[i], s[j] = s[j], s[i]
}
func (s SortVersion) Less(i, j int) bool {
v1 := fmt.Sprintf("v%s", s[i])
v2 := fmt.Sprintf("v%s", s[j])
return semver.Compare(v1, v2) == -1
}

View File

@ -0,0 +1,93 @@
package version
import (
"sort"
"testing"
"github.com/stretchr/testify/assert"
)
func TestIsVersionGreaterOrEqualThan(t *testing.T) {
tests := []struct {
version string
target string
want bool
}{
{
version: "0.9.1",
target: "0.9.1",
want: true,
},
{
version: "0.10.0",
target: "0.9.1",
want: true,
},
{
version: "0.9.0",
target: "0.9.1",
want: false,
},
}
for _, test := range tests {
result := IsVersionGreaterOrEqualThan(test.version, test.target)
if result != test.want {
t.Errorf("got result %v, want %v.", result, test.want)
}
}
}
func TestIsVersionGreaterThan(t *testing.T) {
tests := []struct {
version string
target string
want bool
}{
{
version: "0.9.1",
target: "0.9.1",
want: false,
},
{
version: "0.10.0",
target: "0.8.0",
want: true,
},
{
version: "0.8.0",
target: "0.10.0",
want: false,
},
{
version: "0.9.0",
target: "0.9.1",
want: false,
},
}
for _, test := range tests {
result := IsVersionGreaterThan(test.version, test.target)
if result != test.want {
t.Errorf("got result %v, want %v.", result, test.want)
}
}
}
func TestSortVersion(t *testing.T) {
tests := []struct {
versionList []string
want []string
}{
{
versionList: []string{"0.9.1", "0.10.0", "0.8.0"},
want: []string{"0.8.0", "0.9.1", "0.10.0"},
},
{
versionList: []string{"1.9.1", "0.9.1", "0.10.0", "0.8.0"},
want: []string{"0.8.0", "0.9.1", "0.10.0", "1.9.1"},
},
}
for _, test := range tests {
sort.Sort(SortVersion(test.versionList))
assert.Equal(t, test.versionList, test.want)
}
}