feat: abstract database drivers

This commit is contained in:
Steven 2023-12-17 13:56:41 +08:00
parent 6350b19478
commit 9173c8f19a
39 changed files with 1707 additions and 1356 deletions

View File

@ -30,20 +30,27 @@ var (
mode string mode string
port int port int
data string data string
driver string
dsn string
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{
Use: "slash", Use: "slash",
Short: `An open source, self-hosted bookmarks and link sharing platform.`, Short: `An open source, self-hosted bookmarks and link sharing platform.`,
Run: func(_cmd *cobra.Command, _args []string) { Run: func(_cmd *cobra.Command, _args []string) {
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
db := db.NewDB(serverProfile) dbDriver, err := db.NewDBDriver(serverProfile)
if err := db.Open(ctx); err != nil { if err != nil {
cancel() cancel()
log.Error("failed to open database", zap.Error(err)) log.Error("failed to create db driver", zap.Error(err))
return
}
if err := dbDriver.Migrate(ctx); err != nil {
cancel()
log.Error("failed to migrate db", zap.Error(err))
return return
} }
storeInstance := store.New(db.DBInstance, serverProfile) storeInstance := store.New(dbDriver, serverProfile)
s, err := server.NewServer(ctx, serverProfile, storeInstance) s, err := server.NewServer(ctx, serverProfile, storeInstance)
if err != nil { if err != nil {
cancel() cancel()
@ -92,6 +99,8 @@ func init() {
rootCmd.PersistentFlags().StringVarP(&mode, "mode", "m", "demo", `mode of server, can be "prod" or "dev" or "demo"`) rootCmd.PersistentFlags().StringVarP(&mode, "mode", "m", "demo", `mode of server, can be "prod" or "dev" or "demo"`)
rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8082, "port of server") rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8082, "port of server")
rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory") rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory")
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")) err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
if err != nil { if err != nil {
@ -105,9 +114,18 @@ func init() {
if err != nil { if err != nil {
panic(err) panic(err)
} }
err = viper.BindPFlag("driver", rootCmd.PersistentFlags().Lookup("driver"))
if err != nil {
panic(err)
}
err = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn"))
if err != nil {
panic(err)
}
viper.SetDefault("mode", "demo") viper.SetDefault("mode", "demo")
viper.SetDefault("port", 8082) viper.SetDefault("port", 8082)
viper.SetDefault("driver", "sqlite")
viper.SetEnvPrefix("slash") viper.SetEnvPrefix("slash")
} }

View File

@ -23,6 +23,9 @@ type Profile struct {
Data string `json:"-"` Data string `json:"-"`
// DSN points to where slash stores its own data // DSN points to where slash stores its own data
DSN string `json:"-"` DSN string `json:"-"`
// Driver is the database driver
// sqlite, mysql
Driver string `json:"-"`
// Version is the current version of server // Version is the current version of server
Version string `json:"version"` Version string `json:"version"`
} }

View File

@ -2,7 +2,6 @@ package store
import ( import (
"context" "context"
"strings"
) )
type ActivityType string type ActivityType string
@ -63,82 +62,11 @@ type FindActivity struct {
} }
func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) {
stmt := ` return s.driver.CreateActivity(ctx, create)
INSERT INTO activity (
creator_id,
type,
level,
payload
)
VALUES (?, ?, ?, ?)
RETURNING id, created_ts
`
if err := s.db.QueryRowContext(ctx, stmt,
create.CreatorID,
create.Type.String(),
create.Level.String(),
create.Payload,
).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
activity := create
return activity, nil
} }
func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) { func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) {
where, args := []string{"1 = 1"}, []any{} return s.driver.ListActivities(ctx, find)
if find.Type != "" {
where, args = append(where, "type = ?"), append(args, find.Type.String())
}
if find.Level != "" {
where, args = append(where, "level = ?"), append(args, find.Level.String())
}
if find.Where != nil {
where = append(where, find.Where...)
}
query := `
SELECT
id,
creator_id,
created_ts,
type,
level,
payload
FROM activity
WHERE ` + strings.Join(where, " AND ")
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*Activity{}
for rows.Next() {
activity := &Activity{}
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.CreatedTs,
&activity.Type,
&activity.Level,
&activity.Payload,
); err != nil {
return nil, err
}
list = append(list, activity)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) { func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) {

View File

@ -2,12 +2,7 @@ package store
import ( import (
"context" "context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/internal/util"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
) )
@ -35,165 +30,15 @@ type DeleteCollection struct {
} }
func (s *Store) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { func (s *Store) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} return s.driver.CreateCollection(ctx, create)
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
placeholder := []string{"?", "?", "?", "?", "?", "?"}
stmt := `
INSERT INTO collection (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts
`
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
); err != nil {
return nil, err
}
collection := create
return collection, nil
} }
func (s *Store) UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error) { func (s *Store) UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error) {
set, args := []string{}, []any{} return s.driver.UpdateCollection(ctx, update)
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = ?"), append(args, *update.Description)
}
if update.ShortcutIDs != nil {
set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE collection
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
`
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = convertVisibilityStringToStorepb(visibility)
return collection, nil
} }
func (s *Store) ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error) { func (s *Store) ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error) {
where, args := []string{"1 = 1"}, []any{} return s.driver.ListCollections(ctx, find)
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
name,
title,
description,
shortcut_ids,
visibility
FROM collection
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Collection, 0)
for rows.Next() {
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := rows.Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
list = append(list, collection)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*storepb.Collection, error) { func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*storepb.Collection, error) {
@ -211,9 +56,5 @@ func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*store
} }
func (s *Store) DeleteCollection(ctx context.Context, delete *DeleteCollection) error { func (s *Store) DeleteCollection(ctx context.Context, delete *DeleteCollection) error {
if _, err := s.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil { return s.driver.DeleteCollection(ctx, delete)
return err
}
return nil
} }

View File

@ -1,9 +1,5 @@
package store package store
import (
storepb "github.com/yourselfhosted/slash/proto/gen/store"
)
// RowStatus is the status for a row. // RowStatus is the status for a row.
type RowStatus string type RowStatus string
@ -24,16 +20,6 @@ func (e RowStatus) String() string {
return "" return ""
} }
func convertRowStatusStringToStorepb(status string) storepb.RowStatus {
switch status {
case "NORMAL":
return storepb.RowStatus_NORMAL
case "ARCHIVED":
return storepb.RowStatus_ARCHIVED
}
return storepb.RowStatus_ROW_STATUS_UNSPECIFIED
}
// Visibility is the type of a visibility. // Visibility is the type of a visibility.
type Visibility string type Visibility string

View File

@ -1,266 +1,26 @@
package db package db
import ( import (
"context"
"database/sql"
"embed"
"fmt"
"io/fs"
"log/slog"
"os"
"regexp"
"sort"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/yourselfhosted/slash/server/profile" "github.com/yourselfhosted/slash/server/profile"
"github.com/yourselfhosted/slash/server/version" "github.com/yourselfhosted/slash/store"
"github.com/yourselfhosted/slash/store/db/sqlite"
) )
//go:embed migration // NewDBDriver creates new db driver based on profile.
var migrationFS embed.FS func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
var driver store.Driver
var err error
//go:embed seed switch profile.Driver {
var seedFS embed.FS case "sqlite":
driver, err = sqlite.NewDB(profile)
type DB struct { default:
// sqlite db connection instance return nil, errors.New("unknown db driver")
DBInstance *sql.DB }
profile *profile.Profile if err != nil {
} return nil, errors.Wrap(err, "failed to create db driver")
}
// NewDB returns a new instance of DB associated with the given datasource name. return driver, nil
func NewDB(profile *profile.Profile) *DB {
db := &DB{
profile: profile,
}
return db
}
func (db *DB) Open(ctx context.Context) (err error) {
// Ensure a DSN is set before attempting to open the database.
if db.profile.DSN == "" {
return errors.New("dsn required")
}
// Connect to the database with some sane settings:
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
// - No foreign key constraints: it's currently disabled by default, but it's a
// good practice to be explicit and prevent future surprises on SQLite upgrades.
// - Journal mode set to WAL: it's the recommended journal mode for most applications
// as it prevents locking issues.
//
// Notes:
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
//
// References:
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
// - https://www.sqlite.org/sharedcache.html
// - https://www.sqlite.org/pragma.html
sqliteDB, err := sql.Open("sqlite", db.profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
if err != nil {
return errors.Wrapf(err, "failed to open db with dsn: %s", db.profile.DSN)
}
db.DBInstance = sqliteDB
currentVersion := version.GetCurrentVersion(db.profile.Mode)
if db.profile.Mode == "prod" {
_, err := os.Stat(db.profile.DSN)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return errors.Wrap(err, "failed to get db file stat")
}
// If db file not exists, we should create a new one with latest schema.
err := db.applyLatestSchema(ctx)
if err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
_, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
Version: currentVersion,
})
if err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
return nil
}
// If db file exists, we should check if we need to migrate the database.
migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
if err != nil {
return errors.Wrap(err, "failed to find migration history")
}
if len(migrationHistoryList) == 0 {
_, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
Version: currentVersion,
})
if err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
return nil
}
migrationHistoryVersionList := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersionList))
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
minorVersionList := getMinorVersionList()
// backup the raw database file before migration
rawBytes, err := os.ReadFile(db.profile.DSN)
if err != nil {
return errors.Wrap(err, "failed to read raw database file")
}
backupDBFilePath := fmt.Sprintf("%s/slash_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix())
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
return errors.Wrap(err, "failed to write raw database file")
}
slog.Log(ctx, slog.LevelInfo, "succeed to copy a backup database file")
slog.Log(ctx, slog.LevelInfo, "start migrate")
for _, minorVersion := range minorVersionList {
normalizedVersion := minorVersion + ".0"
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
slog.Log(ctx, slog.LevelInfo, fmt.Sprintf("applying migration for %s", normalizedVersion))
if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrap(err, "failed to apply minor version migration")
}
}
}
slog.Log(ctx, slog.LevelInfo, "end migrate")
// remove the created backup db file after migrate succeed
if err := os.Remove(backupDBFilePath); err != nil {
slog.Log(ctx, slog.LevelError, fmt.Sprintf("Failed to remove temp database file, err %v", err))
}
}
} else {
// In non-prod mode, we should always migrate the database.
if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) {
if err := db.applyLatestSchema(ctx); err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
// In demo mode, we should seed the database.
if db.profile.Mode == "demo" {
if err := db.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
}
}
return nil
}
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
)
func (db *DB) applyLatestSchema(ctx context.Context) error {
schemaMode := "dev"
if db.profile.Mode == "prod" {
schemaMode = "prod"
}
latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName)
buf, err := migrationFS.ReadFile(latestSchemaPath)
if err != nil {
return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath)
}
stmt := string(buf)
if err := db.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: statement %s", stmt)
}
return nil
}
func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
if err != nil {
return errors.Wrap(err, "failed to read migrate files")
}
sort.Strings(filenames)
migrationStmt := ""
// Loop over all migration files and execute them in order.
for _, filename := range filenames {
buf, err := migrationFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file, filename %s", filename)
}
stmt := string(buf)
migrationStmt += stmt
if err := db.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: statement %s", stmt)
}
}
// Upsert the newest version to migration_history.
version := minorVersion + ".0"
if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
Version: version,
}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version %s", version)
}
return nil
}
func (db *DB) seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, "seed/*.sql")
if err != nil {
return errors.Wrap(err, "failed to read seed files")
}
sort.Strings(filenames)
// Loop over all seed files and execute them in order.
for _, filename := range filenames {
buf, err := seedFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read seed file, filename %s", filename)
}
stmt := string(buf)
if err := db.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "seed error: statement %s", stmt)
}
}
return nil
}
// execute runs a single SQL statement within a transaction.
func (db *DB) execute(ctx context.Context, stmt string) error {
if _, err := db.DBInstance.ExecContext(ctx, stmt); err != nil {
return errors.Wrap(err, "failed to execute statement")
}
return nil
}
// minorDirRegexp is a regular expression for minor version directory.
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
func getMinorVersionList() []string {
minorVersionList := []string{}
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
if err != nil {
return err
}
if file.IsDir() && minorDirRegexp.MatchString(path) {
minorVersionList = append(minorVersionList, file.Name())
}
return nil
}); err != nil {
panic(err)
}
sort.Sort(version.SortVersion(minorVersionList))
return minorVersionList
} }

View File

@ -1,82 +0,0 @@
package db
import (
"context"
"strings"
)
type MigrationHistory struct {
Version string
CreatedTs int64
}
type MigrationHistoryUpsert struct {
Version string
}
type MigrationHistoryFind struct {
Version *string
}
func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Version; v != nil {
where, args = append(where, "version = ?"), append(args, *v)
}
stmt := `
SELECT
version,
created_ts
FROM
migration_history
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC
`
rows, err := db.DBInstance.QueryContext(ctx, stmt, args...)
if err != nil {
return nil, err
}
defer rows.Close()
migrationHistoryList := make([]*MigrationHistory, 0)
for rows.Next() {
var migrationHistory MigrationHistory
if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
migrationHistoryList = append(migrationHistoryList, &migrationHistory)
}
if err := rows.Err(); err != nil {
return nil, err
}
return migrationHistoryList, nil
}
func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
query := `
INSERT INTO migration_history (
version
)
VALUES (?)
ON CONFLICT(version) DO UPDATE
SET
version=EXCLUDED.version
RETURNING version, created_ts
`
migrationHistory := &MigrationHistory{}
if err := db.DBInstance.QueryRowContext(ctx, query, upsert.Version).Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
return migrationHistory, nil
}

View File

@ -0,0 +1,87 @@
package sqlite
import (
"context"
"strings"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
stmt := `
INSERT INTO activity (
creator_id,
type,
level,
payload
)
VALUES (?, ?, ?, ?)
RETURNING id, created_ts
`
if err := d.db.QueryRowContext(ctx, stmt,
create.CreatorID,
create.Type.String(),
create.Level.String(),
create.Payload,
).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
activity := create
return activity, nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Type != "" {
where, args = append(where, "type = ?"), append(args, find.Type.String())
}
if find.Level != "" {
where, args = append(where, "level = ?"), append(args, find.Level.String())
}
if find.Where != nil {
where = append(where, find.Where...)
}
query := `
SELECT
id,
creator_id,
created_ts,
type,
level,
payload
FROM activity
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Activity{}
for rows.Next() {
activity := &store.Activity{}
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.CreatedTs,
&activity.Type,
&activity.Level,
&activity.Payload,
); err != nil {
return nil, err
}
list = append(list, activity)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

View File

@ -0,0 +1,183 @@
package sqlite
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/internal/util"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
placeholder := []string{"?", "?", "?", "?", "?", "?"}
stmt := `
INSERT INTO collection (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts
`
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
); err != nil {
return nil, err
}
collection := create
return collection, nil
}
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
set, args := []string{}, []any{}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = ?"), append(args, *update.Description)
}
if update.ShortcutIDs != nil {
set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE collection
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
`
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = convertVisibilityStringToStorepb(visibility)
return collection, nil
}
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
name,
title,
description,
shortcut_ids,
visibility
FROM collection
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Collection, 0)
for rows.Next() {
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := rows.Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
list = append(list, collection)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollection) error {
if _, err := d.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil {
return err
}
return nil
}

59
store/db/sqlite/common.go Normal file
View File

@ -0,0 +1,59 @@
package sqlite
import (
storepb "github.com/yourselfhosted/slash/proto/gen/store"
)
// RowStatus is the status for a row.
type RowStatus string
const (
// Normal is the status for a normal row.
Normal RowStatus = "NORMAL"
// Archived is the status for an archived row.
Archived RowStatus = "ARCHIVED"
)
func (e RowStatus) String() string {
switch e {
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
}
func convertRowStatusStringToStorepb(status string) storepb.RowStatus {
switch status {
case "NORMAL":
return storepb.RowStatus_NORMAL
case "ARCHIVED":
return storepb.RowStatus_ARCHIVED
}
return storepb.RowStatus_ROW_STATUS_UNSPECIFIED
}
// Visibility is the type of a visibility.
type Visibility string
const (
// VisibilityPublic is the PUBLIC visibility.
VisibilityPublic Visibility = "PUBLIC"
// VisibilityWorkspace is the WORKSPACE visibility.
VisibilityWorkspace Visibility = "WORKSPACE"
// VisibilityPrivate is the PRIVATE visibility.
VisibilityPrivate Visibility = "PRIVATE"
)
func (e Visibility) String() string {
switch e {
case VisibilityPublic:
return "PUBLIC"
case VisibilityWorkspace:
return "WORKSPACE"
case VisibilityPrivate:
return "PRIVATE"
}
return "PRIVATE"
}

202
store/db/sqlite/memo.go Normal file
View File

@ -0,0 +1,202 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) {
set := []string{"creator_id", "name", "title", "content", "visibility", "tag"}
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
stmt := `
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts, row_status
`
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo := create
return memo, nil
}
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
set, args := []string{}, []any{}
if update.RowStatus != nil {
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Content != nil {
set, args = append(set, "content = ?"), append(args, *update.Content)
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if update.Tag != nil {
set, args = append(set, "tag = ?"), append(args, *update.Tag)
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE memo
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
`
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo.Visibility = convertVisibilityStringToStorepb(visibility)
memo.Tags = filterTags(strings.Split(tags, " "))
return memo, nil
}
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
title,
content,
visibility,
tag
FROM memo
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Memo, 0)
for rows.Next() {
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := rows.Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
memo.Tags = filterTags(strings.Split(tags, " "))
list = append(list, memo)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
if _, err := d.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil {
return err
}
return nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
memo
WHERE
creator_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func placeholders(n int) string {
return strings.Repeat("?,", n-1) + "?"
}

View File

@ -0,0 +1,57 @@
package sqlite
import (
"context"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
stmt := `
INSERT INTO migration_history (
version
)
VALUES (?)
ON CONFLICT(version) DO UPDATE
SET
version=EXCLUDED.version
RETURNING version, created_ts
`
var migrationHistory store.MigrationHistory
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
return &migrationHistory, nil
}
func (d *DB) ListMigrationHistories(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
query := "SELECT `version`, `created_ts` FROM `migration_history` ORDER BY `created_ts` DESC"
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.MigrationHistory, 0)
for rows.Next() {
var migrationHistory store.MigrationHistory
if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
list = append(list, &migrationHistory)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

234
store/db/sqlite/migrator.go Normal file
View File

@ -0,0 +1,234 @@
package sqlite
import (
"context"
"embed"
"fmt"
"io/fs"
"os"
"regexp"
"sort"
"time"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/server/version"
"github.com/yourselfhosted/slash/store"
)
//go:embed migration
var migrationFS embed.FS
//go:embed seed
var seedFS embed.FS
// Migrate applies the latest schema to the database.
func (d *DB) Migrate(ctx context.Context) error {
currentVersion := version.GetCurrentVersion(d.profile.Mode)
if d.profile.Mode == "prod" {
_, err := os.Stat(d.profile.DSN)
if err != nil {
// If db file not exists, we should create a new one with latest schema.
if errors.Is(err, os.ErrNotExist) {
if err := d.applyLatestSchema(ctx); err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
// Upsert the newest version to migration_history.
if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: currentVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
} else {
return errors.Wrap(err, "failed to get db file stat")
}
} else {
// If db file exists, we should check if we need to migrate the database.
migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{})
if err != nil {
return errors.Wrap(err, "failed to find migration history")
}
// If no migration history, we should apply the latest version migration and upsert the migration history.
if len(migrationHistoryList) == 0 {
minorVersion := version.GetMinorVersion(currentVersion)
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrapf(err, "failed to apply version %s migration", minorVersion)
}
_, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: currentVersion,
})
if err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
return nil
}
migrationHistoryVersionList := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersionList))
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
minorVersionList := getMinorVersionList()
// backup the raw database file before migration
rawBytes, err := os.ReadFile(d.profile.DSN)
if err != nil {
return errors.Wrap(err, "failed to read raw database file")
}
backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix())
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
return errors.Wrap(err, "failed to write raw database file")
}
println("succeed to copy a backup database file")
println("start migrate")
for _, minorVersion := range minorVersionList {
normalizedVersion := minorVersion + ".0"
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
println("applying migration for", normalizedVersion)
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrap(err, "failed to apply minor version migration")
}
}
}
println("end migrate")
// remove the created backup db file after migrate succeed
if err := os.Remove(backupDBFilePath); err != nil {
println(fmt.Sprintf("Failed to remove temp database file, err %v", err))
}
}
}
} else {
// In non-prod mode, we should always migrate the database.
if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) {
if err := d.applyLatestSchema(ctx); err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
// In demo mode, we should seed the database.
if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
}
}
return nil
}
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
)
func (d *DB) applyLatestSchema(ctx context.Context) error {
schemaMode := "dev"
if d.profile.Mode == "prod" {
schemaMode = "prod"
}
latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName)
buf, err := migrationFS.ReadFile(latestSchemaPath)
if err != nil {
return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath)
}
stmt := string(buf)
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: %s", stmt)
}
return nil
}
func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion))
if err != nil {
return errors.Wrap(err, "failed to read ddl files")
}
sort.Strings(filenames)
migrationStmt := ""
// Loop over all migration files and execute them in order.
for _, filename := range filenames {
buf, err := migrationFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
}
stmt := string(buf)
migrationStmt += stmt
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: %s", stmt)
}
}
// Upsert the newest version to migration_history.
version := minorVersion + ".0"
if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: version,
}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
}
return nil
}
func (d *DB) seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil {
return errors.Wrap(err, "failed to read seed files")
}
sort.Strings(filenames)
// Loop over all seed files and execute them in order.
for _, filename := range filenames {
buf, err := seedFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
}
stmt := string(buf)
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "seed error: %s", stmt)
}
}
return nil
}
// execute runs a single SQL statement within a transaction.
func (d *DB) execute(ctx context.Context, stmt string) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, stmt); err != nil {
return errors.Wrap(err, "failed to execute statement")
}
return tx.Commit()
}
// minorDirRegexp is a regular expression for minor version directory.
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
func getMinorVersionList() []string {
minorVersionList := []string{}
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
if err != nil {
return err
}
if file.IsDir() && minorDirRegexp.MatchString(path) {
minorVersionList = append(minorVersionList, file.Name())
}
return nil
}); err != nil {
panic(err)
}
sort.Sort(version.SortVersion(minorVersionList))
return minorVersionList
}

249
store/db/sqlite/shortcut.go Normal file
View File

@ -0,0 +1,249 @@
package sqlite
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) {
set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"}
args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")}
placeholder := []string{"?", "?", "?", "?", "?", "?", "?"}
if create.OgMetadata != nil {
set = append(set, "og_metadata")
openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata)
if err != nil {
return nil, err
}
args = append(args, string(openGraphMetadataBytes))
placeholder = append(placeholder, "?")
}
stmt := `
INSERT INTO shortcut (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts, row_status
`
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut := create
return shortcut, nil
}
func (d *DB) UpdateShortcut(ctx context.Context, update *store.UpdateShortcut) (*storepb.Shortcut, error) {
set, args := []string{}, []any{}
if update.RowStatus != nil {
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Link != nil {
set, args = append(set, "link = ?"), append(args, *update.Link)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = ?"), append(args, *update.Description)
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if update.Tag != nil {
set, args = append(set, "tag = ?"), append(args, *update.Tag)
}
if update.OpenGraphMetadata != nil {
openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata)
if err != nil {
return nil, errors.Wrap(err, "Failed to marshal activity payload")
}
set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes))
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE shortcut
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata
`
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err
}
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = convertVisibilityStringToStorepb(visibility)
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
return shortcut, nil
}
func (d *DB) ListShortcuts(ctx context.Context, find *store.FindShortcut) ([]*storepb.Shortcut, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
link,
title,
description,
visibility,
tag,
og_metadata
FROM shortcut
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Shortcut, 0)
for rows.Next() {
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := rows.Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err
}
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
list = append(list, shortcut)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) error {
if _, err := d.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil {
return err
}
return nil
}
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
shortcut
WHERE
creator_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func filterTags(tags []string) []string {
result := []string{}
for _, tag := range tags {
if tag != "" {
result = append(result, tag)
}
}
return result
}
func convertVisibilityStringToStorepb(visibility string) storepb.Visibility {
return storepb.Visibility(storepb.Visibility_value[visibility])
}

56
store/db/sqlite/sqlite.go Normal file
View File

@ -0,0 +1,56 @@
package sqlite
import (
"database/sql"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/server/profile"
"github.com/yourselfhosted/slash/store"
)
type DB struct {
db *sql.DB
profile *profile.Profile
}
// NewDB opens a database specified by its database driver name and a
// driver-specific data source name, usually consisting of at least a
// database name and connection information.
func NewDB(profile *profile.Profile) (store.Driver, error) {
// Ensure a DSN is set before attempting to open the database.
if profile.DSN == "" {
return nil, errors.New("dsn required")
}
// Connect to the database with some sane settings:
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
// - No foreign key constraints: it's currently disabled by default, but it's a
// good practice to be explicit and prevent future surprises on SQLite upgrades.
// - Journal mode set to WAL: it's the recommended journal mode for most applications
// as it prevents locking issues.
//
// Notes:
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
//
// References:
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
// - https://www.sqlite.org/sharedcache.html
// - https://www.sqlite.org/pragma.html
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
if err != nil {
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
}
driver := DB{db: sqliteDB, profile: profile}
return &driver, nil
}
func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) Close() error {
return d.db.Close()
}

176
store/db/sqlite/user.go Normal file
View File

@ -0,0 +1,176 @@
package sqlite
import (
"context"
"errors"
"strings"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
stmt := `
INSERT INTO user (
email,
nickname,
password_hash,
role
)
VALUES (?, ?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := d.db.QueryRowContext(ctx, stmt,
create.Email,
create.Nickname,
create.PasswordHash,
create.Role,
).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
user := create
return user, nil
}
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = ?"), append(args, *v)
}
if len(set) == 0 {
return nil, errors.New("no fields to update")
}
stmt := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
`
args = append(args, update.ID)
user := &store.User{}
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
return user, nil
}
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, v.String())
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
query := `
SELECT
id,
created_ts,
updated_ts,
row_status,
email,
nickname,
password_hash,
role
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY updated_ts DESC, created_ts DESC
`
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.User, 0)
for rows.Next() {
user := &store.User{}
if err := rows.Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
list = append(list, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
return err
}
if err := vacuumShortcut(ctx, tx); err != nil {
return err
}
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
return tx.Commit()
}

View File

@ -0,0 +1,128 @@
package sqlite
import (
"context"
"database/sql"
"errors"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
valueString = upsert.GetLocale().String()
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
valueString = upsert.GetColorTheme().String()
} else {
return nil, errors.New("invalid user setting key")
}
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
return nil, err
}
userSettingMessage := upsert
return userSettingMessage, nil
}
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
}
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
var keyString, valueString string
if err := rows.Scan(
&userSetting.UserId,
&keyString,
&valueString,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
return nil, err
}
userSetting.Value = &storepb.UserSetting_AccessTokens{
AccessTokens: accessTokensUserSetting,
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
userSetting.Value = &storepb.UserSetting_Locale{
Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]),
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
userSetting.Value = &storepb.UserSetting_ColorTheme{
ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]),
}
} else {
return nil, errors.New("invalid user setting key")
}
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingList, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
user_setting
WHERE
user_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,116 @@
package sqlite
import (
"context"
"errors"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) {
stmt := `
INSERT INTO workspace_setting (
key,
value
)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
valueString = upsert.GetLicenseKey()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
valueString = upsert.GetSecretSession()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
valueString = strconv.FormatBool(upsert.GetEnableSignup())
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
valueString = upsert.GetCustomStyle()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
valueString = upsert.GetCustomScript()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
valueBytes, err := protojson.Marshal(upsert.GetAutoBackup())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else {
return nil, errors.New("invalid workspace setting key")
}
if _, err := d.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil {
return nil, err
}
workspaceSetting := upsert
return workspaceSetting, nil
}
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, find.Key.String())
}
query := `
SELECT
key,
value
FROM workspace_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*storepb.WorkspaceSetting{}
for rows.Next() {
workspaceSetting := &storepb.WorkspaceSetting{}
var keyString, valueString string
if err := rows.Scan(
&keyString,
&valueString,
); err != nil {
return nil, err
}
workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString])
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
enableSignup, err := strconv.ParseBool(valueString)
if err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{}
if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting}
} else {
continue
}
list = append(list, workspaceSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

57
store/driver.go Normal file
View File

@ -0,0 +1,57 @@
package store
import (
"context"
"database/sql"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
)
// Driver is an interface for store driver.
// It contains all methods that store database driver should implement.
type Driver interface {
GetDB() *sql.DB
Close() error
Migrate(ctx context.Context) error
// MigrationHistory model related methods.
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)
ListMigrationHistories(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
// Activity model related methods.
CreateActivity(ctx context.Context, create *Activity) (*Activity, error)
ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error)
// Collection model related methods.
CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error)
UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error)
ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error)
DeleteCollection(ctx context.Context, delete *DeleteCollection) error
// Memo model related methods.
CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error)
UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error)
ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error)
DeleteMemo(ctx context.Context, delete *DeleteMemo) error
// Shortcut model related methods.
CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error)
UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error)
ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error)
DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error
// User model related methods.
CreateUser(ctx context.Context, create *User) (*User, error)
UpdateUser(ctx context.Context, update *UpdateUser) (*User, error)
ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
DeleteUser(ctx context.Context, delete *DeleteUser) error
// UserSetting model related methods.
UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error)
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error)
// WorkspaceSetting model related methods.
UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error)
ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error)
}

View File

@ -2,11 +2,6 @@ package store
import ( import (
"context" "context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
) )
@ -35,162 +30,15 @@ type DeleteMemo struct {
} }
func (s *Store) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) { func (s *Store) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) {
set := []string{"creator_id", "name", "title", "content", "visibility", "tag"} return s.driver.CreateMemo(ctx, create)
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
stmt := `
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts, row_status
`
var rowStatus string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo := create
return memo, nil
} }
func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error) { func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error) {
set, args := []string{}, []any{} return s.driver.UpdateMemo(ctx, update)
if update.RowStatus != nil {
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Content != nil {
set, args = append(set, "content = ?"), append(args, *update.Content)
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if update.Tag != nil {
set, args = append(set, "tag = ?"), append(args, *update.Tag)
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE memo
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
`
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo.Visibility = convertVisibilityStringToStorepb(visibility)
memo.Tags = filterTags(strings.Split(tags, " "))
return memo, nil
} }
func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error) { func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error) {
where, args := []string{"1 = 1"}, []any{} return s.driver.ListMemos(ctx, find)
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
}
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
title,
content,
visibility,
tag
FROM memo
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Memo, 0)
for rows.Next() {
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := rows.Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
memo.Tags = filterTags(strings.Split(tags, " "))
list = append(list, memo)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
} }
func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, error) { func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, error) {
@ -208,31 +56,5 @@ func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, err
} }
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error { func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
if _, err := s.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil { return s.driver.DeleteMemo(ctx, delete)
return err
}
return nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
memo
WHERE
creator_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func placeholders(n int) string {
return strings.Repeat("?,", n-1) + "?"
} }

View File

@ -0,0 +1,13 @@
package store
type MigrationHistory struct {
Version string
CreatedTs int64
}
type UpsertMigrationHistory struct {
Version string
}
type FindMigrationHistory struct {
}

View File

@ -2,12 +2,6 @@ package store
import ( import (
"context" "context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
) )
@ -39,198 +33,28 @@ type DeleteShortcut struct {
} }
func (s *Store) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) { func (s *Store) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) {
set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"} shortcut, err := s.driver.CreateShortcut(ctx, create)
args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")}
placeholder := []string{"?", "?", "?", "?", "?", "?", "?"}
if create.OgMetadata != nil {
set = append(set, "og_metadata")
openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata)
if err != nil { if err != nil {
return nil, err return nil, err
} }
args = append(args, string(openGraphMetadataBytes))
placeholder = append(placeholder, "?")
}
stmt := `
INSERT INTO shortcut (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts, row_status
`
var rowStatus string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut := create
s.shortcutCache.Store(shortcut.Id, shortcut) s.shortcutCache.Store(shortcut.Id, shortcut)
return shortcut, nil return shortcut, nil
} }
func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error) { func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error) {
set, args := []string{}, []any{} shortcut, err := s.driver.UpdateShortcut(ctx, update)
if update.RowStatus != nil {
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Link != nil {
set, args = append(set, "link = ?"), append(args, *update.Link)
}
if update.Title != nil {
set, args = append(set, "title = ?"), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = ?"), append(args, *update.Description)
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
}
if update.Tag != nil {
set, args = append(set, "tag = ?"), append(args, *update.Tag)
}
if update.OpenGraphMetadata != nil {
openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "Failed to marshal activity payload")
}
set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes))
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE shortcut
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata
`
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err return nil, err
} }
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = convertVisibilityStringToStorepb(visibility)
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
s.shortcutCache.Store(shortcut.Id, shortcut) s.shortcutCache.Store(shortcut.Id, shortcut)
return shortcut, nil return shortcut, nil
} }
func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error) { func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error) {
where, args := []string{"1 = 1"}, []any{} list, err := s.driver.ListShortcuts(ctx, find)
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
}
rows, err := s.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
link,
title,
description,
visibility,
tag,
og_metadata
FROM shortcut
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
list := make([]*storepb.Shortcut, 0)
for rows.Next() {
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := rows.Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err
}
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
list = append(list, shortcut)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, shortcut := range list { for _, shortcut := range list {
s.shortcutCache.Store(shortcut.Id, shortcut) s.shortcutCache.Store(shortcut.Id, shortcut)
} }
@ -259,44 +83,10 @@ func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*storepb.S
} }
func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
if _, err := s.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil { if err := s.driver.DeleteShortcut(ctx, delete); err != nil {
return err return err
} }
s.shortcutCache.Delete(delete.ID) s.shortcutCache.Delete(delete.ID)
return nil return nil
} }
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
shortcut
WHERE
creator_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func filterTags(tags []string) []string {
result := []string{}
for _, tag := range tags {
if tag != "" {
result = append(result, tag)
}
}
return result
}
func convertVisibilityStringToStorepb(visibility string) storepb.Visibility {
return storepb.Visibility(storepb.Visibility_value[visibility])
}

View File

@ -1,7 +1,6 @@
package store package store
import ( import (
"database/sql"
"sync" "sync"
"github.com/yourselfhosted/slash/server/profile" "github.com/yourselfhosted/slash/server/profile"
@ -9,8 +8,8 @@ import (
// Store provides database access to all raw objects. // Store provides database access to all raw objects.
type Store struct { type Store struct {
db *sql.DB
profile *profile.Profile profile *profile.Profile
driver Driver
workspaceSettingCache sync.Map // map[string]*WorkspaceSetting workspaceSettingCache sync.Map // map[string]*WorkspaceSetting
userCache sync.Map // map[int]*User userCache sync.Map // map[int]*User
@ -19,14 +18,14 @@ type Store struct {
} }
// New creates a new instance of Store. // New creates a new instance of Store.
func New(db *sql.DB, profile *profile.Profile) *Store { func New(driver Driver, profile *profile.Profile) *Store {
return &Store{ return &Store{
db: db, driver: driver,
profile: profile, profile: profile,
} }
} }
// Close closes the database connection. // Close closes the database connection.
func (s *Store) Close() error { func (s *Store) Close() error {
return s.db.Close() return s.driver.Close()
} }

View File

@ -2,8 +2,6 @@ package store
import ( import (
"context" "context"
"errors"
"strings"
) )
// Role is the type of a role. // Role is the type of a role.
@ -54,147 +52,31 @@ type DeleteUser struct {
} }
func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
stmt := ` user, err := s.driver.CreateUser(ctx, create)
INSERT INTO user ( if err != nil {
email,
nickname,
password_hash,
role
)
VALUES (?, ?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := s.db.QueryRowContext(ctx, stmt,
create.Email,
create.Nickname,
create.PasswordHash,
create.Role,
).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err return nil, err
} }
user := create
s.userCache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
set, args := []string{}, []any{} user, err := s.driver.UpdateUser(ctx, update)
if v := update.RowStatus; v != nil { if err != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = ?"), append(args, *v)
}
if len(set) == 0 {
return nil, errors.New("no fields to update")
}
stmt := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
`
args = append(args, update.ID)
user := &User{}
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err return nil, err
} }
s.userCache.Store(user.ID, user) s.userCache.Store(user.ID, user)
return user, nil return user, nil
} }
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{} list, err := s.driver.ListUsers(ctx, find)
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, v.String())
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
query := `
SELECT
id,
created_ts,
updated_ts,
row_status,
email,
nickname,
password_hash,
role
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY updated_ts DESC, created_ts DESC
`
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
list := make([]*User, 0)
for rows.Next() {
user := &User{}
if err := rows.Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
list = append(list, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, user := range list { for _, user := range list {
s.userCache.Store(user.ID, user) s.userCache.Store(user.ID, user)
} }
return list, nil return list, nil
} }
@ -218,35 +100,10 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
} }
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
tx, err := s.db.BeginTx(ctx, nil) if err := s.driver.DeleteUser(ctx, delete); err != nil {
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
return err
}
if err := vacuumShortcut(ctx, tx); err != nil {
return err
}
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
if err := tx.Commit(); err != nil {
return err return err
} }
s.userCache.Delete(delete.ID) s.userCache.Delete(delete.ID)
return nil return nil
} }

View File

@ -2,11 +2,6 @@ package store
import ( import (
"context" "context"
"database/sql"
"errors"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
) )
@ -17,98 +12,19 @@ type FindUserSetting struct {
} }
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
stmt := ` userSetting, err := s.driver.UpsertUserSetting(ctx, upsert)
INSERT INTO user_setting (
user_id, key, value
)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
if err != nil { if err != nil {
return nil, err return nil, err
} }
valueString = string(valueBytes) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE { return userSetting, nil
valueString = upsert.GetLocale().String()
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
valueString = upsert.GetColorTheme().String()
} else {
return nil, errors.New("invalid user setting key")
}
if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
return nil, err
}
userSettingMessage := upsert
s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage)
return userSettingMessage, nil
} }
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) { func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) {
where, args := []string{"1 = 1"}, []any{} userSettingList, err := s.driver.ListUserSettings(ctx, find)
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
}
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
var keyString, valueString string
if err := rows.Scan(
&userSetting.UserId,
&keyString,
&valueString,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
return nil, err
}
userSetting.Value = &storepb.UserSetting_AccessTokens{
AccessTokens: accessTokensUserSetting,
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
userSetting.Value = &storepb.UserSetting_Locale{
Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]),
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
userSetting.Value = &storepb.UserSetting_ColorTheme{
ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]),
}
} else {
return nil, errors.New("invalid user setting key")
}
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, userSetting := range userSettingList { for _, userSetting := range userSettingList {
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting) s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
@ -137,25 +53,6 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto
return userSetting, nil return userSetting, nil
} }
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `
DELETE FROM
user_setting
WHERE
user_id NOT IN (
SELECT
id
FROM
user
)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
// GetUserAccessTokens returns the access tokens of the user. // GetUserAccessTokens returns the access tokens of the user.
func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) { func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{ userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{

View File

@ -2,11 +2,6 @@ package store
import ( import (
"context" "context"
"errors"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
) )
@ -16,110 +11,22 @@ type FindWorkspaceSetting struct {
} }
func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) { func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) {
stmt := ` workspaceSetting, err := s.driver.UpsertWorkspaceSetting(ctx, upsert)
INSERT INTO workspace_setting (
key,
value
)
VALUES (?, ?)
ON CONFLICT(key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
valueString = upsert.GetLicenseKey()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
valueString = upsert.GetSecretSession()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
valueString = strconv.FormatBool(upsert.GetEnableSignup())
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
valueString = upsert.GetCustomStyle()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
valueString = upsert.GetCustomScript()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
valueBytes, err := protojson.Marshal(upsert.GetAutoBackup())
if err != nil { if err != nil {
return nil, err return nil, err
} }
valueString = string(valueBytes)
} else {
return nil, errors.New("invalid workspace setting key")
}
if _, err := s.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil {
return nil, err
}
workspaceSetting := upsert
s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting)
return workspaceSetting, nil return workspaceSetting, nil
} }
func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) { func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) {
where, args := []string{"1 = 1"}, []any{} list, err := s.driver.ListWorkspaceSettings(ctx, find)
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = ?"), append(args, find.Key.String())
}
query := `
SELECT
key,
value
FROM workspace_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := s.db.QueryContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close()
list := []*storepb.WorkspaceSetting{}
for rows.Next() {
workspaceSetting := &storepb.WorkspaceSetting{}
var keyString, valueString string
if err := rows.Scan(
&keyString,
&valueString,
); err != nil {
return nil, err
}
workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString])
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
enableSignup, err := strconv.ParseBool(valueString)
if err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{}
if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting}
} else {
continue
}
list = append(list, workspaceSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
for _, workspaceSetting := range list { for _, workspaceSetting := range list {
s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting)
} }
return list, nil return list, nil
} }

View File

@ -31,12 +31,17 @@ type TestingServer struct {
func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) { func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
profile := test.GetTestingProfile(t) profile := test.GetTestingProfile(t)
db := db.NewDB(profile) dbDriver, err := db.NewDBDriver(profile)
if err := db.Open(ctx); err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to open db") fmt.Printf("failed to create db driver, error: %+v\n", err)
return nil, err
}
if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err)
return nil, err
} }
store := store.New(db.DBInstance, profile) store := store.New(dbDriver, profile)
server, err := server.NewServer(ctx, profile, store) server, err := server.NewServer(ctx, profile, store)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "failed to create server") return nil, errors.Wrap(err, "failed to create server")

View File

@ -15,11 +15,14 @@ import (
func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
profile := test.GetTestingProfile(t) profile := test.GetTestingProfile(t)
db := db.NewDB(profile) dbDriver, err := db.NewDBDriver(profile)
if err := db.Open(ctx); err != nil { if err != nil {
fmt.Printf("failed to open db, error: %+v\n", err) fmt.Printf("failed to create db driver, error: %+v\n", err)
}
if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err)
} }
store := store.New(db.DBInstance, profile) store := store.New(dbDriver, profile)
return store return store
} }