diff --git a/bin/slash/main.go b/bin/slash/main.go index a25e84b..cf094ee 100644 --- a/bin/slash/main.go +++ b/bin/slash/main.go @@ -30,20 +30,27 @@ var ( mode string port int data string + driver string + dsn string rootCmd = &cobra.Command{ Use: "slash", Short: `An open source, self-hosted bookmarks and link sharing platform.`, Run: func(_cmd *cobra.Command, _args []string) { ctx, cancel := context.WithCancel(context.Background()) - db := db.NewDB(serverProfile) - if err := db.Open(ctx); err != nil { + dbDriver, err := db.NewDBDriver(serverProfile) + if err != nil { cancel() - log.Error("failed to open database", zap.Error(err)) + log.Error("failed to create db driver", zap.Error(err)) + return + } + if err := dbDriver.Migrate(ctx); err != nil { + cancel() + log.Error("failed to migrate db", zap.Error(err)) return } - storeInstance := store.New(db.DBInstance, serverProfile) + storeInstance := store.New(dbDriver, serverProfile) s, err := server.NewServer(ctx, serverProfile, storeInstance) if err != nil { cancel() @@ -92,6 +99,8 @@ func init() { rootCmd.PersistentFlags().StringVarP(&mode, "mode", "m", "demo", `mode of server, can be "prod" or "dev" or "demo"`) rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8082, "port of server") rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory") + rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver") + rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)") err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")) if err != nil { @@ -105,9 +114,18 @@ func init() { if err != nil { panic(err) } + err = viper.BindPFlag("driver", rootCmd.PersistentFlags().Lookup("driver")) + if err != nil { + panic(err) + } + err = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn")) + if err != nil { + panic(err) + } viper.SetDefault("mode", "demo") viper.SetDefault("port", 8082) + viper.SetDefault("driver", "sqlite") viper.SetEnvPrefix("slash") } diff --git a/server/profile/profile.go b/server/profile/profile.go index 23ff1ec..7091304 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -23,6 +23,9 @@ type Profile struct { Data string `json:"-"` // DSN points to where slash stores its own data DSN string `json:"-"` + // Driver is the database driver + // sqlite, mysql + Driver string `json:"-"` // Version is the current version of server Version string `json:"version"` } diff --git a/store/activity.go b/store/activity.go index eee14be..d283c7d 100644 --- a/store/activity.go +++ b/store/activity.go @@ -2,7 +2,6 @@ package store import ( "context" - "strings" ) type ActivityType string @@ -63,82 +62,11 @@ type FindActivity struct { } func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { - stmt := ` - INSERT INTO activity ( - creator_id, - type, - level, - payload - ) - VALUES (?, ?, ?, ?) - RETURNING id, created_ts - ` - if err := s.db.QueryRowContext(ctx, stmt, - create.CreatorID, - create.Type.String(), - create.Level.String(), - create.Payload, - ).Scan( - &create.ID, - &create.CreatedTs, - ); err != nil { - return nil, err - } - - activity := create - return activity, nil + return s.driver.CreateActivity(ctx, create) } func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) { - where, args := []string{"1 = 1"}, []any{} - if find.Type != "" { - where, args = append(where, "type = ?"), append(args, find.Type.String()) - } - if find.Level != "" { - where, args = append(where, "level = ?"), append(args, find.Level.String()) - } - if find.Where != nil { - where = append(where, find.Where...) - } - - query := ` - SELECT - id, - creator_id, - created_ts, - type, - level, - payload - FROM activity - WHERE ` + strings.Join(where, " AND ") - rows, err := s.db.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := []*Activity{} - for rows.Next() { - activity := &Activity{} - if err := rows.Scan( - &activity.ID, - &activity.CreatorID, - &activity.CreatedTs, - &activity.Type, - &activity.Level, - &activity.Payload, - ); err != nil { - return nil, err - } - - list = append(list, activity) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil + return s.driver.ListActivities(ctx, find) } func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) { diff --git a/store/collection.go b/store/collection.go index 4bac2da..ae60090 100644 --- a/store/collection.go +++ b/store/collection.go @@ -2,12 +2,7 @@ package store import ( "context" - "fmt" - "strings" - "github.com/pkg/errors" - - "github.com/yourselfhosted/slash/internal/util" storepb "github.com/yourselfhosted/slash/proto/gen/store" ) @@ -35,165 +30,15 @@ type DeleteCollection struct { } func (s *Store) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { - set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} - args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()} - placeholder := []string{"?", "?", "?", "?", "?", "?"} - - stmt := ` - INSERT INTO collection ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) - RETURNING id, created_ts, updated_ts - ` - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &create.Id, - &create.CreatedTs, - &create.UpdatedTs, - ); err != nil { - return nil, err - } - collection := create - return collection, nil + return s.driver.CreateCollection(ctx, create) } func (s *Store) UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error) { - set, args := []string{}, []any{} - if update.Name != nil { - set, args = append(set, "name = ?"), append(args, *update.Name) - } - if update.Title != nil { - set, args = append(set, "title = ?"), append(args, *update.Title) - } - if update.Description != nil { - set, args = append(set, "description = ?"), append(args, *update.Description) - } - if update.ShortcutIDs != nil { - set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]")) - } - if update.Visibility != nil { - set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) - } - if len(set) == 0 { - return nil, errors.New("no update specified") - } - args = append(args, update.ID) - - stmt := ` - UPDATE collection - SET - ` + strings.Join(set, ", ") + ` - WHERE - id = ? - RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility - ` - collection := &storepb.Collection{} - var shortcutIDs, visibility string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &collection.Id, - &collection.CreatorId, - &collection.CreatedTs, - &collection.UpdatedTs, - &collection.Name, - &collection.Title, - &collection.Description, - &shortcutIDs, - &visibility, - ); err != nil { - return nil, err - } - - collection.ShortcutIds = []int32{} - if shortcutIDs != "" { - for _, idStr := range strings.Split(shortcutIDs, ",") { - shortcutID, err := util.ConvertStringToInt32(idStr) - if err != nil { - return nil, errors.Wrap(err, "failed to convert shortcut id") - } - collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) - } - } - collection.Visibility = convertVisibilityStringToStorepb(visibility) - return collection, nil + return s.driver.UpdateCollection(ctx, update) } func (s *Store) ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error) { - where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = ?"), append(args, *v) - } - if v := find.Name; v != nil { - where, args = append(where, "name = ?"), append(args, *v) - } - if v := find.VisibilityList; len(v) != 0 { - list := []string{} - for _, visibility := range v { - list = append(list, fmt.Sprintf("$%d", len(args)+1)) - args = append(args, visibility) - } - where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT - id, - creator_id, - created_ts, - updated_ts, - name, - title, - description, - shortcut_ids, - visibility - FROM collection - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - list := make([]*storepb.Collection, 0) - for rows.Next() { - collection := &storepb.Collection{} - var shortcutIDs, visibility string - if err := rows.Scan( - &collection.Id, - &collection.CreatorId, - &collection.CreatedTs, - &collection.UpdatedTs, - &collection.Name, - &collection.Title, - &collection.Description, - &shortcutIDs, - &visibility, - ); err != nil { - return nil, err - } - - collection.ShortcutIds = []int32{} - if shortcutIDs != "" { - for _, idStr := range strings.Split(shortcutIDs, ",") { - shortcutID, err := util.ConvertStringToInt32(idStr) - if err != nil { - return nil, errors.Wrap(err, "failed to convert shortcut id") - } - collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) - } - } - collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) - list = append(list, collection) - } - - if err := rows.Err(); err != nil { - return nil, err - } - return list, nil + return s.driver.ListCollections(ctx, find) } func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*storepb.Collection, error) { @@ -211,9 +56,5 @@ func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*store } func (s *Store) DeleteCollection(ctx context.Context, delete *DeleteCollection) error { - if _, err := s.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil { - return err - } - - return nil + return s.driver.DeleteCollection(ctx, delete) } diff --git a/store/common.go b/store/common.go index 515644a..12e04cf 100644 --- a/store/common.go +++ b/store/common.go @@ -1,9 +1,5 @@ package store -import ( - storepb "github.com/yourselfhosted/slash/proto/gen/store" -) - // RowStatus is the status for a row. type RowStatus string @@ -24,16 +20,6 @@ func (e RowStatus) String() string { return "" } -func convertRowStatusStringToStorepb(status string) storepb.RowStatus { - switch status { - case "NORMAL": - return storepb.RowStatus_NORMAL - case "ARCHIVED": - return storepb.RowStatus_ARCHIVED - } - return storepb.RowStatus_ROW_STATUS_UNSPECIFIED -} - // Visibility is the type of a visibility. type Visibility string diff --git a/store/db/db.go b/store/db/db.go index 6d9db47..4d6eac7 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -1,266 +1,26 @@ package db import ( - "context" - "database/sql" - "embed" - "fmt" - "io/fs" - "log/slog" - "os" - "regexp" - "sort" - "time" - "github.com/pkg/errors" "github.com/yourselfhosted/slash/server/profile" - "github.com/yourselfhosted/slash/server/version" + "github.com/yourselfhosted/slash/store" + "github.com/yourselfhosted/slash/store/db/sqlite" ) -//go:embed migration -var migrationFS embed.FS +// NewDBDriver creates new db driver based on profile. +func NewDBDriver(profile *profile.Profile) (store.Driver, error) { + var driver store.Driver + var err error -//go:embed seed -var seedFS embed.FS - -type DB struct { - // sqlite db connection instance - DBInstance *sql.DB - profile *profile.Profile -} - -// NewDB returns a new instance of DB associated with the given datasource name. -func NewDB(profile *profile.Profile) *DB { - db := &DB{ - profile: profile, + switch profile.Driver { + case "sqlite": + driver, err = sqlite.NewDB(profile) + default: + return nil, errors.New("unknown db driver") } - return db -} - -func (db *DB) Open(ctx context.Context) (err error) { - // Ensure a DSN is set before attempting to open the database. - if db.profile.DSN == "" { - return errors.New("dsn required") - } - - // Connect to the database with some sane settings: - // - No shared-cache: it's obsolete; WAL journal mode is a better solution. - // - No foreign key constraints: it's currently disabled by default, but it's a - // good practice to be explicit and prevent future surprises on SQLite upgrades. - // - Journal mode set to WAL: it's the recommended journal mode for most applications - // as it prevents locking issues. - // - // Notes: - // - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`. - // - // References: - // - https://pkg.go.dev/modernc.org/sqlite#Driver.Open - // - https://www.sqlite.org/sharedcache.html - // - https://www.sqlite.org/pragma.html - sqliteDB, err := sql.Open("sqlite", db.profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)") if err != nil { - return errors.Wrapf(err, "failed to open db with dsn: %s", db.profile.DSN) + return nil, errors.Wrap(err, "failed to create db driver") } - db.DBInstance = sqliteDB - currentVersion := version.GetCurrentVersion(db.profile.Mode) - - if db.profile.Mode == "prod" { - _, err := os.Stat(db.profile.DSN) - if err != nil { - if !errors.Is(err, os.ErrNotExist) { - return errors.Wrap(err, "failed to get db file stat") - } - - // If db file not exists, we should create a new one with latest schema. - err := db.applyLatestSchema(ctx) - if err != nil { - return errors.Wrap(err, "failed to apply latest schema") - } - _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ - Version: currentVersion, - }) - if err != nil { - return errors.Wrap(err, "failed to upsert migration history") - } - return nil - } - - // If db file exists, we should check if we need to migrate the database. - migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{}) - if err != nil { - return errors.Wrap(err, "failed to find migration history") - } - if len(migrationHistoryList) == 0 { - _, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ - Version: currentVersion, - }) - if err != nil { - return errors.Wrap(err, "failed to upsert migration history") - } - return nil - } - - migrationHistoryVersionList := []string{} - for _, migrationHistory := range migrationHistoryList { - migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) - } - sort.Sort(version.SortVersion(migrationHistoryVersionList)) - latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] - - if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { - minorVersionList := getMinorVersionList() - - // backup the raw database file before migration - rawBytes, err := os.ReadFile(db.profile.DSN) - if err != nil { - return errors.Wrap(err, "failed to read raw database file") - } - backupDBFilePath := fmt.Sprintf("%s/slash_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix()) - if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { - return errors.Wrap(err, "failed to write raw database file") - } - slog.Log(ctx, slog.LevelInfo, "succeed to copy a backup database file") - - slog.Log(ctx, slog.LevelInfo, "start migrate") - for _, minorVersion := range minorVersionList { - normalizedVersion := minorVersion + ".0" - if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { - slog.Log(ctx, slog.LevelInfo, fmt.Sprintf("applying migration for %s", normalizedVersion)) - if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { - return errors.Wrap(err, "failed to apply minor version migration") - } - } - } - slog.Log(ctx, slog.LevelInfo, "end migrate") - - // remove the created backup db file after migrate succeed - if err := os.Remove(backupDBFilePath); err != nil { - slog.Log(ctx, slog.LevelError, fmt.Sprintf("Failed to remove temp database file, err %v", err)) - } - } - } else { - // In non-prod mode, we should always migrate the database. - if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) { - if err := db.applyLatestSchema(ctx); err != nil { - return errors.Wrap(err, "failed to apply latest schema") - } - // In demo mode, we should seed the database. - if db.profile.Mode == "demo" { - if err := db.seed(ctx); err != nil { - return errors.Wrap(err, "failed to seed") - } - } - } - } - - return nil -} - -const ( - latestSchemaFileName = "LATEST__SCHEMA.sql" -) - -func (db *DB) applyLatestSchema(ctx context.Context) error { - schemaMode := "dev" - if db.profile.Mode == "prod" { - schemaMode = "prod" - } - latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName) - buf, err := migrationFS.ReadFile(latestSchemaPath) - if err != nil { - return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath) - } - stmt := string(buf) - if err := db.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "migrate error: statement %s", stmt) - } - return nil -} - -func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { - filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion)) - if err != nil { - return errors.Wrap(err, "failed to read migrate files") - } - - sort.Strings(filenames) - migrationStmt := "" - - // Loop over all migration files and execute them in order. - for _, filename := range filenames { - buf, err := migrationFS.ReadFile(filename) - if err != nil { - return errors.Wrapf(err, "failed to read minor version migration file, filename %s", filename) - } - stmt := string(buf) - migrationStmt += stmt - if err := db.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "migrate error: statement %s", stmt) - } - } - - // Upsert the newest version to migration_history. - version := minorVersion + ".0" - if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ - Version: version, - }); err != nil { - return errors.Wrapf(err, "failed to upsert migration history with version %s", version) - } - - return nil -} - -func (db *DB) seed(ctx context.Context) error { - filenames, err := fs.Glob(seedFS, "seed/*.sql") - if err != nil { - return errors.Wrap(err, "failed to read seed files") - } - - sort.Strings(filenames) - // Loop over all seed files and execute them in order. - for _, filename := range filenames { - buf, err := seedFS.ReadFile(filename) - if err != nil { - return errors.Wrapf(err, "failed to read seed file, filename %s", filename) - } - stmt := string(buf) - if err := db.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "seed error: statement %s", stmt) - } - } - return nil -} - -// execute runs a single SQL statement within a transaction. -func (db *DB) execute(ctx context.Context, stmt string) error { - if _, err := db.DBInstance.ExecContext(ctx, stmt); err != nil { - return errors.Wrap(err, "failed to execute statement") - } - - return nil -} - -// minorDirRegexp is a regular expression for minor version directory. -var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) - -func getMinorVersionList() []string { - minorVersionList := []string{} - - if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { - if err != nil { - return err - } - if file.IsDir() && minorDirRegexp.MatchString(path) { - minorVersionList = append(minorVersionList, file.Name()) - } - - return nil - }); err != nil { - panic(err) - } - - sort.Sort(version.SortVersion(minorVersionList)) - - return minorVersionList + return driver, nil } diff --git a/store/db/migration_history.go b/store/db/migration_history.go deleted file mode 100644 index 3aa83b3..0000000 --- a/store/db/migration_history.go +++ /dev/null @@ -1,82 +0,0 @@ -package db - -import ( - "context" - "strings" -) - -type MigrationHistory struct { - Version string - CreatedTs int64 -} - -type MigrationHistoryUpsert struct { - Version string -} - -type MigrationHistoryFind struct { - Version *string -} - -func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.Version; v != nil { - where, args = append(where, "version = ?"), append(args, *v) - } - - stmt := ` - SELECT - version, - created_ts - FROM - migration_history - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY created_ts DESC - ` - rows, err := db.DBInstance.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - migrationHistoryList := make([]*MigrationHistory, 0) - for rows.Next() { - var migrationHistory MigrationHistory - if err := rows.Scan( - &migrationHistory.Version, - &migrationHistory.CreatedTs, - ); err != nil { - return nil, err - } - migrationHistoryList = append(migrationHistoryList, &migrationHistory) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return migrationHistoryList, nil -} - -func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - query := ` - INSERT INTO migration_history ( - version - ) - VALUES (?) - ON CONFLICT(version) DO UPDATE - SET - version=EXCLUDED.version - RETURNING version, created_ts - ` - migrationHistory := &MigrationHistory{} - if err := db.DBInstance.QueryRowContext(ctx, query, upsert.Version).Scan( - &migrationHistory.Version, - &migrationHistory.CreatedTs, - ); err != nil { - return nil, err - } - - return migrationHistory, nil -} diff --git a/store/db/sqlite/activity.go b/store/db/sqlite/activity.go new file mode 100644 index 0000000..85ac3c9 --- /dev/null +++ b/store/db/sqlite/activity.go @@ -0,0 +1,87 @@ +package sqlite + +import ( + "context" + "strings" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + stmt := ` + INSERT INTO activity ( + creator_id, + type, + level, + payload + ) + VALUES (?, ?, ?, ?) + RETURNING id, created_ts + ` + if err := d.db.QueryRowContext(ctx, stmt, + create.CreatorID, + create.Type.String(), + create.Level.String(), + create.Payload, + ).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { + return nil, err + } + + activity := create + return activity, nil +} + +func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { + where, args := []string{"1 = 1"}, []any{} + if find.Type != "" { + where, args = append(where, "type = ?"), append(args, find.Type.String()) + } + if find.Level != "" { + where, args = append(where, "level = ?"), append(args, find.Level.String()) + } + if find.Where != nil { + where = append(where, find.Where...) + } + + query := ` + SELECT + id, + creator_id, + created_ts, + type, + level, + payload + FROM activity + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Activity{} + for rows.Next() { + activity := &store.Activity{} + if err := rows.Scan( + &activity.ID, + &activity.CreatorID, + &activity.CreatedTs, + &activity.Type, + &activity.Level, + &activity.Payload, + ); err != nil { + return nil, err + } + + list = append(list, activity) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/sqlite/collection.go b/store/db/sqlite/collection.go new file mode 100644 index 0000000..3a12c9c --- /dev/null +++ b/store/db/sqlite/collection.go @@ -0,0 +1,183 @@ +package sqlite + +import ( + "context" + "fmt" + "strings" + + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/internal/util" + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { + set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} + args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()} + placeholder := []string{"?", "?", "?", "?", "?", "?"} + + stmt := ` + INSERT INTO collection ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + RETURNING id, created_ts, updated_ts + ` + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + ); err != nil { + return nil, err + } + collection := create + return collection, nil +} + +func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) { + set, args := []string{}, []any{} + if update.Name != nil { + set, args = append(set, "name = ?"), append(args, *update.Name) + } + if update.Title != nil { + set, args = append(set, "title = ?"), append(args, *update.Title) + } + if update.Description != nil { + set, args = append(set, "description = ?"), append(args, *update.Description) + } + if update.ShortcutIDs != nil { + set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]")) + } + if update.Visibility != nil { + set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := ` + UPDATE collection + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = ? + RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility + ` + collection := &storepb.Collection{} + var shortcutIDs, visibility string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &collection.Id, + &collection.CreatorId, + &collection.CreatedTs, + &collection.UpdatedTs, + &collection.Name, + &collection.Title, + &collection.Description, + &shortcutIDs, + &visibility, + ); err != nil { + return nil, err + } + + collection.ShortcutIds = []int32{} + if shortcutIDs != "" { + for _, idStr := range strings.Split(shortcutIDs, ",") { + shortcutID, err := util.ConvertStringToInt32(idStr) + if err != nil { + return nil, errors.Wrap(err, "failed to convert shortcut id") + } + collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + } + } + collection.Visibility = convertVisibilityStringToStorepb(visibility) + return collection, nil +} + +func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = ?"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + name, + title, + description, + shortcut_ids, + visibility + FROM collection + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Collection, 0) + for rows.Next() { + collection := &storepb.Collection{} + var shortcutIDs, visibility string + if err := rows.Scan( + &collection.Id, + &collection.CreatorId, + &collection.CreatedTs, + &collection.UpdatedTs, + &collection.Name, + &collection.Title, + &collection.Description, + &shortcutIDs, + &visibility, + ); err != nil { + return nil, err + } + + collection.ShortcutIds = []int32{} + if shortcutIDs != "" { + for _, idStr := range strings.Split(shortcutIDs, ",") { + shortcutID, err := util.ConvertStringToInt32(idStr) + if err != nil { + return nil, errors.Wrap(err, "failed to convert shortcut id") + } + collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + } + } + collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + list = append(list, collection) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollection) error { + if _, err := d.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil { + return err + } + + return nil +} diff --git a/store/db/sqlite/common.go b/store/db/sqlite/common.go new file mode 100644 index 0000000..a805d1e --- /dev/null +++ b/store/db/sqlite/common.go @@ -0,0 +1,59 @@ +package sqlite + +import ( + storepb "github.com/yourselfhosted/slash/proto/gen/store" +) + +// RowStatus is the status for a row. +type RowStatus string + +const ( + // Normal is the status for a normal row. + Normal RowStatus = "NORMAL" + // Archived is the status for an archived row. + Archived RowStatus = "ARCHIVED" +) + +func (e RowStatus) String() string { + switch e { + case Normal: + return "NORMAL" + case Archived: + return "ARCHIVED" + } + return "" +} + +func convertRowStatusStringToStorepb(status string) storepb.RowStatus { + switch status { + case "NORMAL": + return storepb.RowStatus_NORMAL + case "ARCHIVED": + return storepb.RowStatus_ARCHIVED + } + return storepb.RowStatus_ROW_STATUS_UNSPECIFIED +} + +// Visibility is the type of a visibility. +type Visibility string + +const ( + // VisibilityPublic is the PUBLIC visibility. + VisibilityPublic Visibility = "PUBLIC" + // VisibilityWorkspace is the WORKSPACE visibility. + VisibilityWorkspace Visibility = "WORKSPACE" + // VisibilityPrivate is the PRIVATE visibility. + VisibilityPrivate Visibility = "PRIVATE" +) + +func (e Visibility) String() string { + switch e { + case VisibilityPublic: + return "PUBLIC" + case VisibilityWorkspace: + return "WORKSPACE" + case VisibilityPrivate: + return "PRIVATE" + } + return "PRIVATE" +} diff --git a/store/db/sqlite/memo.go b/store/db/sqlite/memo.go new file mode 100644 index 0000000..e45c52c --- /dev/null +++ b/store/db/sqlite/memo.go @@ -0,0 +1,202 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) { + set := []string{"creator_id", "name", "title", "content", "visibility", "tag"} + args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")} + + stmt := ` + INSERT INTO memo ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + placeholders(len(args)) + `) + RETURNING id, created_ts, updated_ts, row_status + ` + + var rowStatus string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + &rowStatus, + ); err != nil { + return nil, err + } + create.RowStatus = convertRowStatusStringToStorepb(rowStatus) + memo := create + return memo, nil +} + +func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) { + set, args := []string{}, []any{} + if update.RowStatus != nil { + set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String()) + } + if update.Name != nil { + set, args = append(set, "name = ?"), append(args, *update.Name) + } + if update.Title != nil { + set, args = append(set, "title = ?"), append(args, *update.Title) + } + if update.Content != nil { + set, args = append(set, "content = ?"), append(args, *update.Content) + } + if update.Visibility != nil { + set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) + } + if update.Tag != nil { + set, args = append(set, "tag = ?"), append(args, *update.Tag) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := ` + UPDATE memo + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = ? + RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag + ` + memo := &storepb.Memo{} + var rowStatus, visibility, tags string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &memo.Id, + &memo.CreatorId, + &memo.CreatedTs, + &memo.UpdatedTs, + &rowStatus, + &memo.Name, + &memo.Title, + &memo.Content, + &visibility, + &tags, + ); err != nil { + return nil, err + } + memo.RowStatus = convertRowStatusStringToStorepb(rowStatus) + memo.Visibility = convertVisibilityStringToStorepb(visibility) + memo.Tags = filterTags(strings.Split(tags, " ")) + return memo, nil +} + +func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = ?"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) + } + if v := find.Tag; v != nil { + where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%") + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + row_status, + name, + title, + content, + visibility, + tag + FROM memo + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Memo, 0) + for rows.Next() { + memo := &storepb.Memo{} + var rowStatus, visibility, tags string + if err := rows.Scan( + &memo.Id, + &memo.CreatorId, + &memo.CreatedTs, + &memo.UpdatedTs, + &rowStatus, + &memo.Name, + &memo.Title, + &memo.Content, + &visibility, + &tags, + ); err != nil { + return nil, err + } + memo.RowStatus = convertRowStatusStringToStorepb(rowStatus) + memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + memo.Tags = filterTags(strings.Split(tags, " ")) + list = append(list, memo) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { + if _, err := d.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil { + return err + } + return nil +} + +func vacuumMemo(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + memo + WHERE + creator_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} + +func placeholders(n int) string { + return strings.Repeat("?,", n-1) + "?" +} diff --git a/store/db/migration/dev/LATEST__SCHEMA.sql b/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql similarity index 100% rename from store/db/migration/dev/LATEST__SCHEMA.sql rename to store/db/sqlite/migration/dev/LATEST__SCHEMA.sql diff --git a/store/db/migration/prod/0.2/00__create_index.sql b/store/db/sqlite/migration/prod/0.2/00__create_index.sql similarity index 100% rename from store/db/migration/prod/0.2/00__create_index.sql rename to store/db/sqlite/migration/prod/0.2/00__create_index.sql diff --git a/store/db/migration/prod/0.3/00__add_og_metadata.sql b/store/db/sqlite/migration/prod/0.3/00__add_og_metadata.sql similarity index 100% rename from store/db/migration/prod/0.3/00__add_og_metadata.sql rename to store/db/sqlite/migration/prod/0.3/00__add_og_metadata.sql diff --git a/store/db/migration/prod/0.4/00__add_shortcut_title.sql b/store/db/sqlite/migration/prod/0.4/00__add_shortcut_title.sql similarity index 100% rename from store/db/migration/prod/0.4/00__add_shortcut_title.sql rename to store/db/sqlite/migration/prod/0.4/00__add_shortcut_title.sql diff --git a/store/db/migration/prod/0.5/00__drop_idp.sql b/store/db/sqlite/migration/prod/0.5/00__drop_idp.sql similarity index 100% rename from store/db/migration/prod/0.5/00__drop_idp.sql rename to store/db/sqlite/migration/prod/0.5/00__drop_idp.sql diff --git a/store/db/migration/prod/0.5/01__collection.sql b/store/db/sqlite/migration/prod/0.5/01__collection.sql similarity index 100% rename from store/db/migration/prod/0.5/01__collection.sql rename to store/db/sqlite/migration/prod/0.5/01__collection.sql diff --git a/store/db/migration/prod/LATEST__SCHEMA.sql b/store/db/sqlite/migration/prod/LATEST__SCHEMA.sql similarity index 100% rename from store/db/migration/prod/LATEST__SCHEMA.sql rename to store/db/sqlite/migration/prod/LATEST__SCHEMA.sql diff --git a/store/db/sqlite/migration_history.go b/store/db/sqlite/migration_history.go new file mode 100644 index 0000000..392c7fb --- /dev/null +++ b/store/db/sqlite/migration_history.go @@ -0,0 +1,57 @@ +package sqlite + +import ( + "context" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) { + stmt := ` + INSERT INTO migration_history ( + version + ) + VALUES (?) + ON CONFLICT(version) DO UPDATE + SET + version=EXCLUDED.version + RETURNING version, created_ts + ` + var migrationHistory store.MigrationHistory + if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + return &migrationHistory, nil +} + +func (d *DB) ListMigrationHistories(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) { + query := "SELECT `version`, `created_ts` FROM `migration_history` ORDER BY `created_ts` DESC" + rows, err := d.db.QueryContext(ctx, query) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.MigrationHistory, 0) + for rows.Next() { + var migrationHistory store.MigrationHistory + if err := rows.Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err + } + + list = append(list, &migrationHistory) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/sqlite/migrator.go b/store/db/sqlite/migrator.go new file mode 100644 index 0000000..7428759 --- /dev/null +++ b/store/db/sqlite/migrator.go @@ -0,0 +1,234 @@ +package sqlite + +import ( + "context" + "embed" + "fmt" + "io/fs" + "os" + "regexp" + "sort" + "time" + + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/server/version" + "github.com/yourselfhosted/slash/store" +) + +//go:embed migration +var migrationFS embed.FS + +//go:embed seed +var seedFS embed.FS + +// Migrate applies the latest schema to the database. +func (d *DB) Migrate(ctx context.Context) error { + currentVersion := version.GetCurrentVersion(d.profile.Mode) + if d.profile.Mode == "prod" { + _, err := os.Stat(d.profile.DSN) + if err != nil { + // If db file not exists, we should create a new one with latest schema. + if errors.Is(err, os.ErrNotExist) { + if err := d.applyLatestSchema(ctx); err != nil { + return errors.Wrap(err, "failed to apply latest schema") + } + // Upsert the newest version to migration_history. + if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: currentVersion, + }); err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + } else { + return errors.Wrap(err, "failed to get db file stat") + } + } else { + // If db file exists, we should check if we need to migrate the database. + migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + // If no migration history, we should apply the latest version migration and upsert the migration history. + if len(migrationHistoryList) == 0 { + minorVersion := version.GetMinorVersion(currentVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrapf(err, "failed to apply version %s migration", minorVersion) + } + _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: currentVersion, + }) + if err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + return nil + } + + migrationHistoryVersionList := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) + } + sort.Sort(version.SortVersion(migrationHistoryVersionList)) + latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] + + if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { + minorVersionList := getMinorVersionList() + // backup the raw database file before migration + rawBytes, err := os.ReadFile(d.profile.DSN) + if err != nil { + return errors.Wrap(err, "failed to read raw database file") + } + backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix()) + if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { + return errors.Wrap(err, "failed to write raw database file") + } + println("succeed to copy a backup database file") + println("start migrate") + for _, minorVersion := range minorVersionList { + normalizedVersion := minorVersion + ".0" + if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { + println("applying migration for", normalizedVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrap(err, "failed to apply minor version migration") + } + } + } + println("end migrate") + + // remove the created backup db file after migrate succeed + if err := os.Remove(backupDBFilePath); err != nil { + println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) + } + } + } + } else { + // In non-prod mode, we should always migrate the database. + if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) { + if err := d.applyLatestSchema(ctx); err != nil { + return errors.Wrap(err, "failed to apply latest schema") + } + // In demo mode, we should seed the database. + if d.profile.Mode == "demo" { + if err := d.seed(ctx); err != nil { + return errors.Wrap(err, "failed to seed") + } + } + } + } + + return nil +} + +const ( + latestSchemaFileName = "LATEST__SCHEMA.sql" +) + +func (d *DB) applyLatestSchema(ctx context.Context) error { + schemaMode := "dev" + if d.profile.Mode == "prod" { + schemaMode = "prod" + } + latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName) + buf, err := migrationFS.ReadFile(latestSchemaPath) + if err != nil { + return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath) + } + stmt := string(buf) + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + return nil +} + +func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { + filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion)) + if err != nil { + return errors.Wrap(err, "failed to read ddl files") + } + + sort.Strings(filenames) + migrationStmt := "" + + // Loop over all migration files and execute them in order. + for _, filename := range filenames { + buf, err := migrationFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) + } + stmt := string(buf) + migrationStmt += stmt + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + } + + // Upsert the newest version to migration_history. + version := minorVersion + ".0" + if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: version, + }); err != nil { + return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) + } + + return nil +} + +func (d *DB) seed(ctx context.Context) error { + filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) + if err != nil { + return errors.Wrap(err, "failed to read seed files") + } + + sort.Strings(filenames) + + // Loop over all seed files and execute them in order. + for _, filename := range filenames { + buf, err := seedFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) + } + stmt := string(buf) + if err := d.execute(ctx, stmt); err != nil { + return errors.Wrapf(err, "seed error: %s", stmt) + } + } + return nil +} + +// execute runs a single SQL statement within a transaction. +func (d *DB) execute(ctx context.Context, stmt string) error { + tx, err := d.db.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return errors.Wrap(err, "failed to execute statement") + } + + return tx.Commit() +} + +// minorDirRegexp is a regular expression for minor version directory. +var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) + +func getMinorVersionList() []string { + minorVersionList := []string{} + + if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { + if err != nil { + return err + } + if file.IsDir() && minorDirRegexp.MatchString(path) { + minorVersionList = append(minorVersionList, file.Name()) + } + + return nil + }); err != nil { + panic(err) + } + + sort.Sort(version.SortVersion(minorVersionList)) + + return minorVersionList +} diff --git a/store/db/seed/10000__reset.sql b/store/db/sqlite/seed/10000__reset.sql similarity index 100% rename from store/db/seed/10000__reset.sql rename to store/db/sqlite/seed/10000__reset.sql diff --git a/store/db/seed/10001__user.sql b/store/db/sqlite/seed/10001__user.sql similarity index 100% rename from store/db/seed/10001__user.sql rename to store/db/sqlite/seed/10001__user.sql diff --git a/store/db/seed/10002__shortcut.sql b/store/db/sqlite/seed/10002__shortcut.sql similarity index 100% rename from store/db/seed/10002__shortcut.sql rename to store/db/sqlite/seed/10002__shortcut.sql diff --git a/store/db/seed/10003__collection.sql b/store/db/sqlite/seed/10003__collection.sql similarity index 100% rename from store/db/seed/10003__collection.sql rename to store/db/sqlite/seed/10003__collection.sql diff --git a/store/db/sqlite/shortcut.go b/store/db/sqlite/shortcut.go new file mode 100644 index 0000000..8fa88c1 --- /dev/null +++ b/store/db/sqlite/shortcut.go @@ -0,0 +1,249 @@ +package sqlite + +import ( + "context" + "database/sql" + "fmt" + "strings" + + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) { + set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"} + args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")} + placeholder := []string{"?", "?", "?", "?", "?", "?", "?"} + if create.OgMetadata != nil { + set = append(set, "og_metadata") + openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata) + if err != nil { + return nil, err + } + args = append(args, string(openGraphMetadataBytes)) + placeholder = append(placeholder, "?") + } + + stmt := ` + INSERT INTO shortcut ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + RETURNING id, created_ts, updated_ts, row_status + ` + var rowStatus string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &create.Id, + &create.CreatedTs, + &create.UpdatedTs, + &rowStatus, + ); err != nil { + return nil, err + } + create.RowStatus = convertRowStatusStringToStorepb(rowStatus) + shortcut := create + return shortcut, nil +} + +func (d *DB) UpdateShortcut(ctx context.Context, update *store.UpdateShortcut) (*storepb.Shortcut, error) { + set, args := []string{}, []any{} + if update.RowStatus != nil { + set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String()) + } + if update.Name != nil { + set, args = append(set, "name = ?"), append(args, *update.Name) + } + if update.Link != nil { + set, args = append(set, "link = ?"), append(args, *update.Link) + } + if update.Title != nil { + set, args = append(set, "title = ?"), append(args, *update.Title) + } + if update.Description != nil { + set, args = append(set, "description = ?"), append(args, *update.Description) + } + if update.Visibility != nil { + set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) + } + if update.Tag != nil { + set, args = append(set, "tag = ?"), append(args, *update.Tag) + } + if update.OpenGraphMetadata != nil { + openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata) + if err != nil { + return nil, errors.Wrap(err, "Failed to marshal activity payload") + } + set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes)) + } + if len(set) == 0 { + return nil, errors.New("no update specified") + } + args = append(args, update.ID) + + stmt := ` + UPDATE shortcut + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = ? + RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata + ` + shortcut := &storepb.Shortcut{} + var rowStatus, visibility, tags, openGraphMetadataString string + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &shortcut.Id, + &shortcut.CreatorId, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &rowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Title, + &shortcut.Description, + &visibility, + &tags, + &openGraphMetadataString, + ); err != nil { + return nil, err + } + shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus) + shortcut.Visibility = convertVisibilityStringToStorepb(visibility) + shortcut.Tags = filterTags(strings.Split(tags, " ")) + var ogMetadata storepb.OpenGraphMetadata + if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { + return nil, err + } + shortcut.OgMetadata = &ogMetadata + return shortcut, nil +} + +func (d *DB) ListShortcuts(ctx context.Context, find *store.FindShortcut) ([]*storepb.Shortcut, error) { + where, args := []string{"1 = 1"}, []any{} + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = ?"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) + } + if v := find.Tag; v != nil { + where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%") + } + + rows, err := d.db.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + row_status, + name, + link, + title, + description, + visibility, + tag, + og_metadata + FROM shortcut + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*storepb.Shortcut, 0) + for rows.Next() { + shortcut := &storepb.Shortcut{} + var rowStatus, visibility, tags, openGraphMetadataString string + if err := rows.Scan( + &shortcut.Id, + &shortcut.CreatorId, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &rowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Title, + &shortcut.Description, + &visibility, + &tags, + &openGraphMetadataString, + ); err != nil { + return nil, err + } + shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus) + shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) + shortcut.Tags = filterTags(strings.Split(tags, " ")) + var ogMetadata storepb.OpenGraphMetadata + if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { + return nil, err + } + shortcut.OgMetadata = &ogMetadata + list = append(list, shortcut) + } + + if err := rows.Err(); err != nil { + return nil, err + } + return list, nil +} + +func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) error { + if _, err := d.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil { + return err + } + + return nil +} + +func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + shortcut + WHERE + creator_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} + +func filterTags(tags []string) []string { + result := []string{} + for _, tag := range tags { + if tag != "" { + result = append(result, tag) + } + } + return result +} + +func convertVisibilityStringToStorepb(visibility string) storepb.Visibility { + return storepb.Visibility(storepb.Visibility_value[visibility]) +} diff --git a/store/db/sqlite/sqlite.go b/store/db/sqlite/sqlite.go new file mode 100644 index 0000000..09c3e29 --- /dev/null +++ b/store/db/sqlite/sqlite.go @@ -0,0 +1,56 @@ +package sqlite + +import ( + "database/sql" + + "github.com/pkg/errors" + + "github.com/yourselfhosted/slash/server/profile" + "github.com/yourselfhosted/slash/store" +) + +type DB struct { + db *sql.DB + profile *profile.Profile +} + +// NewDB opens a database specified by its database driver name and a +// driver-specific data source name, usually consisting of at least a +// database name and connection information. +func NewDB(profile *profile.Profile) (store.Driver, error) { + // Ensure a DSN is set before attempting to open the database. + if profile.DSN == "" { + return nil, errors.New("dsn required") + } + + // Connect to the database with some sane settings: + // - No shared-cache: it's obsolete; WAL journal mode is a better solution. + // - No foreign key constraints: it's currently disabled by default, but it's a + // good practice to be explicit and prevent future surprises on SQLite upgrades. + // - Journal mode set to WAL: it's the recommended journal mode for most applications + // as it prevents locking issues. + // + // Notes: + // - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`. + // + // References: + // - https://pkg.go.dev/modernc.org/sqlite#Driver.Open + // - https://www.sqlite.org/sharedcache.html + // - https://www.sqlite.org/pragma.html + sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)") + if err != nil { + return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN) + } + + driver := DB{db: sqliteDB, profile: profile} + + return &driver, nil +} + +func (d *DB) GetDB() *sql.DB { + return d.db +} + +func (d *DB) Close() error { + return d.db.Close() +} diff --git a/store/db/sqlite/user.go b/store/db/sqlite/user.go new file mode 100644 index 0000000..310682a --- /dev/null +++ b/store/db/sqlite/user.go @@ -0,0 +1,176 @@ +package sqlite + +import ( + "context" + "errors" + "strings" + + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { + stmt := ` + INSERT INTO user ( + email, + nickname, + password_hash, + role + ) + VALUES (?, ?, ?, ?) + RETURNING id, created_ts, updated_ts, row_status + ` + if err := d.db.QueryRowContext(ctx, stmt, + create.Email, + create.Nickname, + create.PasswordHash, + create.Role, + ).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + user := create + return user, nil +} + +func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { + set, args := []string{}, []any{} + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Email; v != nil { + set, args = append(set, "email = ?"), append(args, *v) + } + if v := update.Nickname; v != nil { + set, args = append(set, "nickname = ?"), append(args, *v) + } + if v := update.PasswordHash; v != nil { + set, args = append(set, "password_hash = ?"), append(args, *v) + } + if v := update.Role; v != nil { + set, args = append(set, "role = ?"), append(args, *v) + } + + if len(set) == 0 { + return nil, errors.New("no fields to update") + } + + stmt := ` + UPDATE user + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role + ` + args = append(args, update.ID) + user := &store.User{} + if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + + return user, nil +} + +func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, v.String()) + } + if v := find.Email; v != nil { + where, args = append(where, "email = ?"), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + + query := ` + SELECT + id, + created_ts, + updated_ts, + row_status, + email, + nickname, + password_hash, + role + FROM user + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY updated_ts DESC, created_ts DESC + ` + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.User, 0) + for rows.Next() { + user := &store.User{} + if err := rows.Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + list = append(list, user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID); err != nil { + return err + } + + if err := vacuumUserSetting(ctx, tx); err != nil { + return err + } + + if err := vacuumShortcut(ctx, tx); err != nil { + return err + } + + if err := vacuumMemo(ctx, tx); err != nil { + return err + } + + return tx.Commit() +} diff --git a/store/db/sqlite/user_setting.go b/store/db/sqlite/user_setting.go new file mode 100644 index 0000000..853efc9 --- /dev/null +++ b/store/db/sqlite/user_setting.go @@ -0,0 +1,128 @@ +package sqlite + +import ( + "context" + "database/sql" + "errors" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { + stmt := ` + INSERT INTO user_setting ( + user_id, key, value + ) + VALUES (?, ?, ?) + ON CONFLICT(user_id, key) DO UPDATE + SET value = EXCLUDED.value + ` + var valueString string + if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { + valueString = upsert.GetLocale().String() + } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { + valueString = upsert.GetColorTheme().String() + } else { + return nil, errors.New("invalid user setting key") + } + + if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil { + return nil, err + } + + userSettingMessage := upsert + return userSettingMessage, nil +} + +func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { + where, args = append(where, "key = ?"), append(args, v.String()) + } + if v := find.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *find.UserID) + } + + query := ` + SELECT + user_id, + key, + value + FROM user_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + userSettingList := make([]*storepb.UserSetting, 0) + for rows.Next() { + userSetting := &storepb.UserSetting{} + var keyString, valueString string + if err := rows.Scan( + &userSetting.UserId, + &keyString, + &valueString, + ); err != nil { + return nil, err + } + userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) + if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + accessTokensUserSetting := &storepb.AccessTokensUserSetting{} + if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_AccessTokens{ + AccessTokens: accessTokensUserSetting, + } + } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { + userSetting.Value = &storepb.UserSetting_Locale{ + Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]), + } + } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { + userSetting.Value = &storepb.UserSetting_ColorTheme{ + ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]), + } + } else { + return nil, errors.New("invalid user setting key") + } + userSettingList = append(userSettingList, userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { + stmt := ` + DELETE FROM + user_setting + WHERE + user_id NOT IN ( + SELECT + id + FROM + user + )` + _, err := tx.ExecContext(ctx, stmt) + if err != nil { + return err + } + + return nil +} diff --git a/store/db/sqlite/workspace_setting.go b/store/db/sqlite/workspace_setting.go new file mode 100644 index 0000000..9950047 --- /dev/null +++ b/store/db/sqlite/workspace_setting.go @@ -0,0 +1,116 @@ +package sqlite + +import ( + "context" + "errors" + "strconv" + "strings" + + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/store" +) + +func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) { + stmt := ` + INSERT INTO workspace_setting ( + key, + value + ) + VALUES (?, ?) + ON CONFLICT(key) DO UPDATE + SET value = EXCLUDED.value + ` + var valueString string + if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { + valueString = upsert.GetLicenseKey() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { + valueString = upsert.GetSecretSession() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { + valueString = strconv.FormatBool(upsert.GetEnableSignup()) + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { + valueString = upsert.GetCustomStyle() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { + valueString = upsert.GetCustomScript() + } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { + valueBytes, err := protojson.Marshal(upsert.GetAutoBackup()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else { + return nil, errors.New("invalid workspace setting key") + } + + if _, err := d.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil { + return nil, err + } + + workspaceSetting := upsert + return workspaceSetting, nil +} + +func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) { + where, args := []string{"1 = 1"}, []any{} + + if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED { + where, args = append(where, "key = ?"), append(args, find.Key.String()) + } + + query := ` + SELECT + key, + value + FROM workspace_setting + WHERE ` + strings.Join(where, " AND ") + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + + defer rows.Close() + + list := []*storepb.WorkspaceSetting{} + for rows.Next() { + workspaceSetting := &storepb.WorkspaceSetting{} + var keyString, valueString string + if err := rows.Scan( + &keyString, + &valueString, + ); err != nil { + return nil, err + } + workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString]) + if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { + workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { + workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { + enableSignup, err := strconv.ParseBool(valueString) + if err != nil { + return nil, err + } + workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { + workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { + workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString} + } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { + autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{} + if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil { + return nil, err + } + workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting} + } else { + continue + } + list = append(list, workspaceSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/driver.go b/store/driver.go new file mode 100644 index 0000000..b29656e --- /dev/null +++ b/store/driver.go @@ -0,0 +1,57 @@ +package store + +import ( + "context" + "database/sql" + + storepb "github.com/yourselfhosted/slash/proto/gen/store" +) + +// Driver is an interface for store driver. +// It contains all methods that store database driver should implement. +type Driver interface { + GetDB() *sql.DB + Close() error + + Migrate(ctx context.Context) error + + // MigrationHistory model related methods. + UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) + ListMigrationHistories(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error) + + // Activity model related methods. + CreateActivity(ctx context.Context, create *Activity) (*Activity, error) + ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) + + // Collection model related methods. + CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) + UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error) + ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error) + DeleteCollection(ctx context.Context, delete *DeleteCollection) error + + // Memo model related methods. + CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) + UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error) + ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error) + DeleteMemo(ctx context.Context, delete *DeleteMemo) error + + // Shortcut model related methods. + CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) + UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error) + ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error) + DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error + + // User model related methods. + CreateUser(ctx context.Context, create *User) (*User, error) + UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) + ListUsers(ctx context.Context, find *FindUser) ([]*User, error) + DeleteUser(ctx context.Context, delete *DeleteUser) error + + // UserSetting model related methods. + UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) + ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) + + // WorkspaceSetting model related methods. + UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) + ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) +} diff --git a/store/memo.go b/store/memo.go index 2d92e2d..4035f48 100644 --- a/store/memo.go +++ b/store/memo.go @@ -2,11 +2,6 @@ package store import ( "context" - "database/sql" - "fmt" - "strings" - - "github.com/pkg/errors" storepb "github.com/yourselfhosted/slash/proto/gen/store" ) @@ -35,162 +30,15 @@ type DeleteMemo struct { } func (s *Store) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) { - set := []string{"creator_id", "name", "title", "content", "visibility", "tag"} - args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")} - - stmt := ` - INSERT INTO memo ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + placeholders(len(args)) + `) - RETURNING id, created_ts, updated_ts, row_status - ` - - var rowStatus string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &create.Id, - &create.CreatedTs, - &create.UpdatedTs, - &rowStatus, - ); err != nil { - return nil, err - } - create.RowStatus = convertRowStatusStringToStorepb(rowStatus) - memo := create - return memo, nil + return s.driver.CreateMemo(ctx, create) } func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error) { - set, args := []string{}, []any{} - if update.RowStatus != nil { - set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String()) - } - if update.Name != nil { - set, args = append(set, "name = ?"), append(args, *update.Name) - } - if update.Title != nil { - set, args = append(set, "title = ?"), append(args, *update.Title) - } - if update.Content != nil { - set, args = append(set, "content = ?"), append(args, *update.Content) - } - if update.Visibility != nil { - set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) - } - if update.Tag != nil { - set, args = append(set, "tag = ?"), append(args, *update.Tag) - } - if len(set) == 0 { - return nil, errors.New("no update specified") - } - args = append(args, update.ID) - - stmt := ` - UPDATE memo - SET - ` + strings.Join(set, ", ") + ` - WHERE - id = ? - RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag - ` - memo := &storepb.Memo{} - var rowStatus, visibility, tags string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &memo.Id, - &memo.CreatorId, - &memo.CreatedTs, - &memo.UpdatedTs, - &rowStatus, - &memo.Name, - &memo.Title, - &memo.Content, - &visibility, - &tags, - ); err != nil { - return nil, err - } - memo.RowStatus = convertRowStatusStringToStorepb(rowStatus) - memo.Visibility = convertVisibilityStringToStorepb(visibility) - memo.Tags = filterTags(strings.Split(tags, " ")) - return memo, nil + return s.driver.UpdateMemo(ctx, update) } func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error) { - where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = ?"), append(args, *v) - } - if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = ?"), append(args, *v) - } - if v := find.Name; v != nil { - where, args = append(where, "name = ?"), append(args, *v) - } - if v := find.VisibilityList; len(v) != 0 { - list := []string{} - for _, visibility := range v { - list = append(list, fmt.Sprintf("$%d", len(args)+1)) - args = append(args, visibility) - } - where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) - } - if v := find.Tag; v != nil { - where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%") - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT - id, - creator_id, - created_ts, - updated_ts, - row_status, - name, - title, - content, - visibility, - tag - FROM memo - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) - if err != nil { - return nil, err - } - defer rows.Close() - - list := make([]*storepb.Memo, 0) - for rows.Next() { - memo := &storepb.Memo{} - var rowStatus, visibility, tags string - if err := rows.Scan( - &memo.Id, - &memo.CreatorId, - &memo.CreatedTs, - &memo.UpdatedTs, - &rowStatus, - &memo.Name, - &memo.Title, - &memo.Content, - &visibility, - &tags, - ); err != nil { - return nil, err - } - memo.RowStatus = convertRowStatusStringToStorepb(rowStatus) - memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) - memo.Tags = filterTags(strings.Split(tags, " ")) - list = append(list, memo) - } - - if err := rows.Err(); err != nil { - return nil, err - } - return list, nil + return s.driver.ListMemos(ctx, find) } func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, error) { @@ -208,31 +56,5 @@ func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, err } func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { - if _, err := s.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil { - return err - } - return nil -} - -func vacuumMemo(ctx context.Context, tx *sql.Tx) error { - stmt := ` - DELETE FROM - memo - WHERE - creator_id NOT IN ( - SELECT - id - FROM - user - )` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} - -func placeholders(n int) string { - return strings.Repeat("?,", n-1) + "?" + return s.driver.DeleteMemo(ctx, delete) } diff --git a/store/migration_history.go b/store/migration_history.go new file mode 100644 index 0000000..693663b --- /dev/null +++ b/store/migration_history.go @@ -0,0 +1,13 @@ +package store + +type MigrationHistory struct { + Version string + CreatedTs int64 +} + +type UpsertMigrationHistory struct { + Version string +} + +type FindMigrationHistory struct { +} diff --git a/store/shortcut.go b/store/shortcut.go index 13e65b9..2b43fc5 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -2,12 +2,6 @@ package store import ( "context" - "database/sql" - "fmt" - "strings" - - "github.com/pkg/errors" - "google.golang.org/protobuf/encoding/protojson" storepb "github.com/yourselfhosted/slash/proto/gen/store" ) @@ -39,198 +33,28 @@ type DeleteShortcut struct { } func (s *Store) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) { - set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"} - args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")} - placeholder := []string{"?", "?", "?", "?", "?", "?", "?"} - if create.OgMetadata != nil { - set = append(set, "og_metadata") - openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata) - if err != nil { - return nil, err - } - args = append(args, string(openGraphMetadataBytes)) - placeholder = append(placeholder, "?") - } - - stmt := ` - INSERT INTO shortcut ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) - RETURNING id, created_ts, updated_ts, row_status - ` - var rowStatus string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &create.Id, - &create.CreatedTs, - &create.UpdatedTs, - &rowStatus, - ); err != nil { + shortcut, err := s.driver.CreateShortcut(ctx, create) + if err != nil { return nil, err } - create.RowStatus = convertRowStatusStringToStorepb(rowStatus) - shortcut := create s.shortcutCache.Store(shortcut.Id, shortcut) return shortcut, nil } func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error) { - set, args := []string{}, []any{} - if update.RowStatus != nil { - set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String()) - } - if update.Name != nil { - set, args = append(set, "name = ?"), append(args, *update.Name) - } - if update.Link != nil { - set, args = append(set, "link = ?"), append(args, *update.Link) - } - if update.Title != nil { - set, args = append(set, "title = ?"), append(args, *update.Title) - } - if update.Description != nil { - set, args = append(set, "description = ?"), append(args, *update.Description) - } - if update.Visibility != nil { - set, args = append(set, "visibility = ?"), append(args, update.Visibility.String()) - } - if update.Tag != nil { - set, args = append(set, "tag = ?"), append(args, *update.Tag) - } - if update.OpenGraphMetadata != nil { - openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata) - if err != nil { - return nil, errors.Wrap(err, "Failed to marshal activity payload") - } - set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes)) - } - if len(set) == 0 { - return nil, errors.New("no update specified") - } - args = append(args, update.ID) - - stmt := ` - UPDATE shortcut - SET - ` + strings.Join(set, ", ") + ` - WHERE - id = ? - RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata - ` - shortcut := &storepb.Shortcut{} - var rowStatus, visibility, tags, openGraphMetadataString string - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &shortcut.Id, - &shortcut.CreatorId, - &shortcut.CreatedTs, - &shortcut.UpdatedTs, - &rowStatus, - &shortcut.Name, - &shortcut.Link, - &shortcut.Title, - &shortcut.Description, - &visibility, - &tags, - &openGraphMetadataString, - ); err != nil { + shortcut, err := s.driver.UpdateShortcut(ctx, update) + if err != nil { return nil, err } - shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus) - shortcut.Visibility = convertVisibilityStringToStorepb(visibility) - shortcut.Tags = filterTags(strings.Split(tags, " ")) - var ogMetadata storepb.OpenGraphMetadata - if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { - return nil, err - } - shortcut.OgMetadata = &ogMetadata s.shortcutCache.Store(shortcut.Id, shortcut) return shortcut, nil } func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error) { - where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = ?"), append(args, *v) - } - if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = ?"), append(args, *v) - } - if v := find.Name; v != nil { - where, args = append(where, "name = ?"), append(args, *v) - } - if v := find.VisibilityList; len(v) != 0 { - list := []string{} - for _, visibility := range v { - list = append(list, fmt.Sprintf("$%d", len(args)+1)) - args = append(args, visibility) - } - where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) - } - if v := find.Tag; v != nil { - where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%") - } - - rows, err := s.db.QueryContext(ctx, ` - SELECT - id, - creator_id, - created_ts, - updated_ts, - row_status, - name, - link, - title, - description, - visibility, - tag, - og_metadata - FROM shortcut - WHERE `+strings.Join(where, " AND ")+` - ORDER BY created_ts DESC`, - args..., - ) + list, err := s.driver.ListShortcuts(ctx, find) if err != nil { return nil, err } - defer rows.Close() - - list := make([]*storepb.Shortcut, 0) - for rows.Next() { - shortcut := &storepb.Shortcut{} - var rowStatus, visibility, tags, openGraphMetadataString string - if err := rows.Scan( - &shortcut.Id, - &shortcut.CreatorId, - &shortcut.CreatedTs, - &shortcut.UpdatedTs, - &rowStatus, - &shortcut.Name, - &shortcut.Link, - &shortcut.Title, - &shortcut.Description, - &visibility, - &tags, - &openGraphMetadataString, - ); err != nil { - return nil, err - } - shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus) - shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) - shortcut.Tags = filterTags(strings.Split(tags, " ")) - var ogMetadata storepb.OpenGraphMetadata - if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil { - return nil, err - } - shortcut.OgMetadata = &ogMetadata - list = append(list, shortcut) - } - - if err := rows.Err(); err != nil { - return nil, err - } for _, shortcut := range list { s.shortcutCache.Store(shortcut.Id, shortcut) } @@ -259,44 +83,10 @@ func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*storepb.S } func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { - if _, err := s.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil { + if err := s.driver.DeleteShortcut(ctx, delete); err != nil { return err } s.shortcutCache.Delete(delete.ID) - return nil } - -func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { - stmt := ` - DELETE FROM - shortcut - WHERE - creator_id NOT IN ( - SELECT - id - FROM - user - )` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} - -func filterTags(tags []string) []string { - result := []string{} - for _, tag := range tags { - if tag != "" { - result = append(result, tag) - } - } - return result -} - -func convertVisibilityStringToStorepb(visibility string) storepb.Visibility { - return storepb.Visibility(storepb.Visibility_value[visibility]) -} diff --git a/store/store.go b/store/store.go index 8d6c001..40f59ea 100644 --- a/store/store.go +++ b/store/store.go @@ -1,7 +1,6 @@ package store import ( - "database/sql" "sync" "github.com/yourselfhosted/slash/server/profile" @@ -9,8 +8,8 @@ import ( // Store provides database access to all raw objects. type Store struct { - db *sql.DB profile *profile.Profile + driver Driver workspaceSettingCache sync.Map // map[string]*WorkspaceSetting userCache sync.Map // map[int]*User @@ -19,14 +18,14 @@ type Store struct { } // New creates a new instance of Store. -func New(db *sql.DB, profile *profile.Profile) *Store { +func New(driver Driver, profile *profile.Profile) *Store { return &Store{ - db: db, + driver: driver, profile: profile, } } // Close closes the database connection. func (s *Store) Close() error { - return s.db.Close() + return s.driver.Close() } diff --git a/store/user.go b/store/user.go index ddcb77e..fa2442e 100644 --- a/store/user.go +++ b/store/user.go @@ -2,8 +2,6 @@ package store import ( "context" - "errors" - "strings" ) // Role is the type of a role. @@ -54,147 +52,31 @@ type DeleteUser struct { } func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { - stmt := ` - INSERT INTO user ( - email, - nickname, - password_hash, - role - ) - VALUES (?, ?, ?, ?) - RETURNING id, created_ts, updated_ts, row_status - ` - if err := s.db.QueryRowContext(ctx, stmt, - create.Email, - create.Nickname, - create.PasswordHash, - create.Role, - ).Scan( - &create.ID, - &create.CreatedTs, - &create.UpdatedTs, - &create.RowStatus, - ); err != nil { + user, err := s.driver.CreateUser(ctx, create) + if err != nil { return nil, err } - - user := create s.userCache.Store(user.ID, user) return user, nil } func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { - set, args := []string{}, []any{} - if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = ?"), append(args, *v) - } - if v := update.Email; v != nil { - set, args = append(set, "email = ?"), append(args, *v) - } - if v := update.Nickname; v != nil { - set, args = append(set, "nickname = ?"), append(args, *v) - } - if v := update.PasswordHash; v != nil { - set, args = append(set, "password_hash = ?"), append(args, *v) - } - if v := update.Role; v != nil { - set, args = append(set, "role = ?"), append(args, *v) - } - - if len(set) == 0 { - return nil, errors.New("no fields to update") - } - - stmt := ` - UPDATE user - SET ` + strings.Join(set, ", ") + ` - WHERE id = ? - RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role - ` - args = append(args, update.ID) - user := &User{} - if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( - &user.ID, - &user.CreatedTs, - &user.UpdatedTs, - &user.RowStatus, - &user.Email, - &user.Nickname, - &user.PasswordHash, - &user.Role, - ); err != nil { + user, err := s.driver.UpdateUser(ctx, update) + if err != nil { return nil, err } - s.userCache.Store(user.ID, user) return user, nil } func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = ?"), append(args, v.String()) - } - if v := find.Email; v != nil { - where, args = append(where, "email = ?"), append(args, *v) - } - if v := find.Nickname; v != nil { - where, args = append(where, "nickname = ?"), append(args, *v) - } - if v := find.Role; v != nil { - where, args = append(where, "role = ?"), append(args, *v) - } - - query := ` - SELECT - id, - created_ts, - updated_ts, - row_status, - email, - nickname, - password_hash, - role - FROM user - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY updated_ts DESC, created_ts DESC - ` - rows, err := s.db.QueryContext(ctx, query, args...) + list, err := s.driver.ListUsers(ctx, find) if err != nil { return nil, err } - defer rows.Close() - - list := make([]*User, 0) - for rows.Next() { - user := &User{} - if err := rows.Scan( - &user.ID, - &user.CreatedTs, - &user.UpdatedTs, - &user.RowStatus, - &user.Email, - &user.Nickname, - &user.PasswordHash, - &user.Role, - ); err != nil { - return nil, err - } - list = append(list, user) - } - - if err := rows.Err(); err != nil { - return nil, err - } - for _, user := range list { s.userCache.Store(user.ID, user) } - return list, nil } @@ -218,35 +100,10 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { } func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, ` - DELETE FROM user WHERE id = ? - `, delete.ID); err != nil { - return err - } - - if err := vacuumUserSetting(ctx, tx); err != nil { - return err - } - - if err := vacuumShortcut(ctx, tx); err != nil { - return err - } - - if err := vacuumMemo(ctx, tx); err != nil { - return err - } - - if err := tx.Commit(); err != nil { + if err := s.driver.DeleteUser(ctx, delete); err != nil { return err } s.userCache.Delete(delete.ID) - return nil } diff --git a/store/user_setting.go b/store/user_setting.go index 01dc552..3bcd703 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -2,11 +2,6 @@ package store import ( "context" - "database/sql" - "errors" - "strings" - - "google.golang.org/protobuf/encoding/protojson" storepb "github.com/yourselfhosted/slash/proto/gen/store" ) @@ -17,96 +12,17 @@ type FindUserSetting struct { } func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { - stmt := ` - INSERT INTO user_setting ( - user_id, key, value - ) - VALUES (?, ?, ?) - ON CONFLICT(user_id, key) DO UPDATE - SET value = EXCLUDED.value - ` - var valueString string - if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { - valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) - if err != nil { - return nil, err - } - valueString = string(valueBytes) - } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { - valueString = upsert.GetLocale().String() - } else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { - valueString = upsert.GetColorTheme().String() - } else { - return nil, errors.New("invalid user setting key") - } - - if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil { - return nil, err - } - - userSettingMessage := upsert - s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage) - return userSettingMessage, nil -} - -func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { - where, args = append(where, "key = ?"), append(args, v.String()) - } - if v := find.UserID; v != nil { - where, args = append(where, "user_id = ?"), append(args, *find.UserID) - } - - query := ` - SELECT - user_id, - key, - value - FROM user_setting - WHERE ` + strings.Join(where, " AND ") - rows, err := s.db.QueryContext(ctx, query, args...) + userSetting, err := s.driver.UpsertUserSetting(ctx, upsert) if err != nil { return nil, err } - defer rows.Close() + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) + return userSetting, nil +} - userSettingList := make([]*storepb.UserSetting, 0) - for rows.Next() { - userSetting := &storepb.UserSetting{} - var keyString, valueString string - if err := rows.Scan( - &userSetting.UserId, - &keyString, - &valueString, - ); err != nil { - return nil, err - } - userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) - if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { - accessTokensUserSetting := &storepb.AccessTokensUserSetting{} - if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { - return nil, err - } - userSetting.Value = &storepb.UserSetting_AccessTokens{ - AccessTokens: accessTokensUserSetting, - } - } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { - userSetting.Value = &storepb.UserSetting_Locale{ - Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]), - } - } else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME { - userSetting.Value = &storepb.UserSetting_ColorTheme{ - ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]), - } - } else { - return nil, errors.New("invalid user setting key") - } - userSettingList = append(userSettingList, userSetting) - } - - if err := rows.Err(); err != nil { +func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) { + userSettingList, err := s.driver.ListUserSettings(ctx, find) + if err != nil { return nil, err } @@ -137,25 +53,6 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto return userSetting, nil } -func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { - stmt := ` - DELETE FROM - user_setting - WHERE - user_id NOT IN ( - SELECT - id - FROM - user - )` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} - // GetUserAccessTokens returns the access tokens of the user. func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) { userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ diff --git a/store/workspace_setting.go b/store/workspace_setting.go index 401a35c..a4dee63 100644 --- a/store/workspace_setting.go +++ b/store/workspace_setting.go @@ -2,11 +2,6 @@ package store import ( "context" - "errors" - "strconv" - "strings" - - "google.golang.org/protobuf/encoding/protojson" storepb "github.com/yourselfhosted/slash/proto/gen/store" ) @@ -16,110 +11,22 @@ type FindWorkspaceSetting struct { } func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) { - stmt := ` - INSERT INTO workspace_setting ( - key, - value - ) - VALUES (?, ?) - ON CONFLICT(key) DO UPDATE - SET value = EXCLUDED.value - ` - var valueString string - if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { - valueString = upsert.GetLicenseKey() - } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { - valueString = upsert.GetSecretSession() - } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { - valueString = strconv.FormatBool(upsert.GetEnableSignup()) - } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { - valueString = upsert.GetCustomStyle() - } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { - valueString = upsert.GetCustomScript() - } else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { - valueBytes, err := protojson.Marshal(upsert.GetAutoBackup()) - if err != nil { - return nil, err - } - valueString = string(valueBytes) - } else { - return nil, errors.New("invalid workspace setting key") - } - - if _, err := s.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil { + workspaceSetting, err := s.driver.UpsertWorkspaceSetting(ctx, upsert) + if err != nil { return nil, err } - - workspaceSetting := upsert s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) return workspaceSetting, nil } func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) { - where, args := []string{"1 = 1"}, []any{} - - if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED { - where, args = append(where, "key = ?"), append(args, find.Key.String()) - } - - query := ` - SELECT - key, - value - FROM workspace_setting - WHERE ` + strings.Join(where, " AND ") - rows, err := s.db.QueryContext(ctx, query, args...) + list, err := s.driver.ListWorkspaceSettings(ctx, find) if err != nil { return nil, err } - - defer rows.Close() - - list := []*storepb.WorkspaceSetting{} - for rows.Next() { - workspaceSetting := &storepb.WorkspaceSetting{} - var keyString, valueString string - if err := rows.Scan( - &keyString, - &valueString, - ); err != nil { - return nil, err - } - workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString]) - if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY { - workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString} - } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION { - workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString} - } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP { - enableSignup, err := strconv.ParseBool(valueString) - if err != nil { - return nil, err - } - workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup} - } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE { - workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString} - } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT { - workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString} - } else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP { - autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{} - if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil { - return nil, err - } - workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting} - } else { - continue - } - list = append(list, workspaceSetting) - } - - if err := rows.Err(); err != nil { - return nil, err - } - for _, workspaceSetting := range list { s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) } - return list, nil } diff --git a/test/server/server.go b/test/server/server.go index 601763f..b553a0b 100644 --- a/test/server/server.go +++ b/test/server/server.go @@ -31,12 +31,17 @@ type TestingServer struct { func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) { profile := test.GetTestingProfile(t) - db := db.NewDB(profile) - if err := db.Open(ctx); err != nil { - return nil, errors.Wrap(err, "failed to open db") + dbDriver, err := db.NewDBDriver(profile) + if err != nil { + fmt.Printf("failed to create db driver, error: %+v\n", err) + return nil, err + } + if err := dbDriver.Migrate(ctx); err != nil { + fmt.Printf("failed to migrate db, error: %+v\n", err) + return nil, err } - store := store.New(db.DBInstance, profile) + store := store.New(dbDriver, profile) server, err := server.NewServer(ctx, profile, store) if err != nil { return nil, errors.Wrap(err, "failed to create server") diff --git a/test/store/store.go b/test/store/store.go index cf8d606..4ba6b1c 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -15,11 +15,14 @@ import ( func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { profile := test.GetTestingProfile(t) - db := db.NewDB(profile) - if err := db.Open(ctx); err != nil { - fmt.Printf("failed to open db, error: %+v\n", err) + dbDriver, err := db.NewDBDriver(profile) + if err != nil { + fmt.Printf("failed to create db driver, error: %+v\n", err) + } + if err := dbDriver.Migrate(ctx); err != nil { + fmt.Printf("failed to migrate db, error: %+v\n", err) } - store := store.New(db.DBInstance, profile) + store := store.New(dbDriver, profile) return store }