mirror of
https://github.com/aykhans/slash-e.git
synced 2025-04-20 14:01:24 +00:00
feat: abstract database drivers
This commit is contained in:
parent
6350b19478
commit
9173c8f19a
@ -30,20 +30,27 @@ var (
|
||||
mode string
|
||||
port int
|
||||
data string
|
||||
driver string
|
||||
dsn string
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "slash",
|
||||
Short: `An open source, self-hosted bookmarks and link sharing platform.`,
|
||||
Run: func(_cmd *cobra.Command, _args []string) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
db := db.NewDB(serverProfile)
|
||||
if err := db.Open(ctx); err != nil {
|
||||
dbDriver, err := db.NewDBDriver(serverProfile)
|
||||
if err != nil {
|
||||
cancel()
|
||||
log.Error("failed to open database", zap.Error(err))
|
||||
log.Error("failed to create db driver", zap.Error(err))
|
||||
return
|
||||
}
|
||||
if err := dbDriver.Migrate(ctx); err != nil {
|
||||
cancel()
|
||||
log.Error("failed to migrate db", zap.Error(err))
|
||||
return
|
||||
}
|
||||
|
||||
storeInstance := store.New(db.DBInstance, serverProfile)
|
||||
storeInstance := store.New(dbDriver, serverProfile)
|
||||
s, err := server.NewServer(ctx, serverProfile, storeInstance)
|
||||
if err != nil {
|
||||
cancel()
|
||||
@ -92,6 +99,8 @@ func init() {
|
||||
rootCmd.PersistentFlags().StringVarP(&mode, "mode", "m", "demo", `mode of server, can be "prod" or "dev" or "demo"`)
|
||||
rootCmd.PersistentFlags().IntVarP(&port, "port", "p", 8082, "port of server")
|
||||
rootCmd.PersistentFlags().StringVarP(&data, "data", "d", "", "data directory")
|
||||
rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver")
|
||||
rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)")
|
||||
|
||||
err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode"))
|
||||
if err != nil {
|
||||
@ -105,9 +114,18 @@ func init() {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = viper.BindPFlag("driver", rootCmd.PersistentFlags().Lookup("driver"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = viper.BindPFlag("dsn", rootCmd.PersistentFlags().Lookup("dsn"))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
viper.SetDefault("mode", "demo")
|
||||
viper.SetDefault("port", 8082)
|
||||
viper.SetDefault("driver", "sqlite")
|
||||
viper.SetEnvPrefix("slash")
|
||||
}
|
||||
|
||||
|
@ -23,6 +23,9 @@ type Profile struct {
|
||||
Data string `json:"-"`
|
||||
// DSN points to where slash stores its own data
|
||||
DSN string `json:"-"`
|
||||
// Driver is the database driver
|
||||
// sqlite, mysql
|
||||
Driver string `json:"-"`
|
||||
// Version is the current version of server
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ActivityType string
|
||||
@ -63,82 +62,11 @@ type FindActivity struct {
|
||||
}
|
||||
|
||||
func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) {
|
||||
stmt := `
|
||||
INSERT INTO activity (
|
||||
creator_id,
|
||||
type,
|
||||
level,
|
||||
payload
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id, created_ts
|
||||
`
|
||||
if err := s.db.QueryRowContext(ctx, stmt,
|
||||
create.CreatorID,
|
||||
create.Type.String(),
|
||||
create.Level.String(),
|
||||
create.Payload,
|
||||
).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activity := create
|
||||
return activity, nil
|
||||
return s.driver.CreateActivity(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Type != "" {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type.String())
|
||||
}
|
||||
if find.Level != "" {
|
||||
where, args = append(where, "level = ?"), append(args, find.Level.String())
|
||||
}
|
||||
if find.Where != nil {
|
||||
where = append(where, find.Where...)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
type,
|
||||
level,
|
||||
payload
|
||||
FROM activity
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*Activity{}
|
||||
for rows.Next() {
|
||||
activity := &Activity{}
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.CreatedTs,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&activity.Payload,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
return s.driver.ListActivities(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) {
|
||||
|
@ -2,12 +2,7 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/internal/util"
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
|
||||
@ -35,165 +30,15 @@ type DeleteCollection struct {
|
||||
}
|
||||
|
||||
func (s *Store) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
|
||||
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO collection (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + strings.Join(placeholder, ",") + `)
|
||||
RETURNING id, created_ts, updated_ts
|
||||
`
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collection := create
|
||||
return collection, nil
|
||||
return s.driver.CreateCollection(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Description != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *update.Description)
|
||||
}
|
||||
if update.ShortcutIDs != nil {
|
||||
set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE collection
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
|
||||
`
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
&collection.CreatedTs,
|
||||
&collection.UpdatedTs,
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
}
|
||||
}
|
||||
collection.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
return collection, nil
|
||||
return s.driver.UpdateCollection(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Store) ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
name,
|
||||
title,
|
||||
description,
|
||||
shortcut_ids,
|
||||
visibility
|
||||
FROM collection
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Collection, 0)
|
||||
for rows.Next() {
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
if err := rows.Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
&collection.CreatedTs,
|
||||
&collection.UpdatedTs,
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
}
|
||||
}
|
||||
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
list = append(list, collection)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
return s.driver.ListCollections(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*storepb.Collection, error) {
|
||||
@ -211,9 +56,5 @@ func (s *Store) GetCollection(ctx context.Context, find *FindCollection) (*store
|
||||
}
|
||||
|
||||
func (s *Store) DeleteCollection(ctx context.Context, delete *DeleteCollection) error {
|
||||
if _, err := s.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
return s.driver.DeleteCollection(ctx, delete)
|
||||
}
|
||||
|
@ -1,9 +1,5 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
|
||||
// RowStatus is the status for a row.
|
||||
type RowStatus string
|
||||
|
||||
@ -24,16 +20,6 @@ func (e RowStatus) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertRowStatusStringToStorepb(status string) storepb.RowStatus {
|
||||
switch status {
|
||||
case "NORMAL":
|
||||
return storepb.RowStatus_NORMAL
|
||||
case "ARCHIVED":
|
||||
return storepb.RowStatus_ARCHIVED
|
||||
}
|
||||
return storepb.RowStatus_ROW_STATUS_UNSPECIFIED
|
||||
}
|
||||
|
||||
// Visibility is the type of a visibility.
|
||||
type Visibility string
|
||||
|
||||
|
266
store/db/db.go
266
store/db/db.go
@ -1,266 +1,26 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/server/profile"
|
||||
"github.com/yourselfhosted/slash/server/version"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
"github.com/yourselfhosted/slash/store/db/sqlite"
|
||||
)
|
||||
|
||||
//go:embed migration
|
||||
var migrationFS embed.FS
|
||||
// NewDBDriver creates new db driver based on profile.
|
||||
func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
|
||||
var driver store.Driver
|
||||
var err error
|
||||
|
||||
//go:embed seed
|
||||
var seedFS embed.FS
|
||||
|
||||
type DB struct {
|
||||
// sqlite db connection instance
|
||||
DBInstance *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
// NewDB returns a new instance of DB associated with the given datasource name.
|
||||
func NewDB(profile *profile.Profile) *DB {
|
||||
db := &DB{
|
||||
profile: profile,
|
||||
switch profile.Driver {
|
||||
case "sqlite":
|
||||
driver, err = sqlite.NewDB(profile)
|
||||
default:
|
||||
return nil, errors.New("unknown db driver")
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *DB) Open(ctx context.Context) (err error) {
|
||||
// Ensure a DSN is set before attempting to open the database.
|
||||
if db.profile.DSN == "" {
|
||||
return errors.New("dsn required")
|
||||
}
|
||||
|
||||
// Connect to the database with some sane settings:
|
||||
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
|
||||
// - No foreign key constraints: it's currently disabled by default, but it's a
|
||||
// good practice to be explicit and prevent future surprises on SQLite upgrades.
|
||||
// - Journal mode set to WAL: it's the recommended journal mode for most applications
|
||||
// as it prevents locking issues.
|
||||
//
|
||||
// Notes:
|
||||
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
|
||||
//
|
||||
// References:
|
||||
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
|
||||
// - https://www.sqlite.org/sharedcache.html
|
||||
// - https://www.sqlite.org/pragma.html
|
||||
sqliteDB, err := sql.Open("sqlite", db.profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to open db with dsn: %s", db.profile.DSN)
|
||||
return nil, errors.Wrap(err, "failed to create db driver")
|
||||
}
|
||||
db.DBInstance = sqliteDB
|
||||
currentVersion := version.GetCurrentVersion(db.profile.Mode)
|
||||
|
||||
if db.profile.Mode == "prod" {
|
||||
_, err := os.Stat(db.profile.DSN)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return errors.Wrap(err, "failed to get db file stat")
|
||||
}
|
||||
|
||||
// If db file not exists, we should create a new one with latest schema.
|
||||
err := db.applyLatestSchema(ctx)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to apply latest schema")
|
||||
}
|
||||
_, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
|
||||
Version: currentVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// If db file exists, we should check if we need to migrate the database.
|
||||
migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
}
|
||||
if len(migrationHistoryList) == 0 {
|
||||
_, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
|
||||
Version: currentVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationHistoryVersionList := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
|
||||
}
|
||||
sort.Sort(version.SortVersion(migrationHistoryVersionList))
|
||||
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
|
||||
|
||||
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
|
||||
minorVersionList := getMinorVersionList()
|
||||
|
||||
// backup the raw database file before migration
|
||||
rawBytes, err := os.ReadFile(db.profile.DSN)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read raw database file")
|
||||
}
|
||||
backupDBFilePath := fmt.Sprintf("%s/slash_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix())
|
||||
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
|
||||
return errors.Wrap(err, "failed to write raw database file")
|
||||
}
|
||||
slog.Log(ctx, slog.LevelInfo, "succeed to copy a backup database file")
|
||||
|
||||
slog.Log(ctx, slog.LevelInfo, "start migrate")
|
||||
for _, minorVersion := range minorVersionList {
|
||||
normalizedVersion := minorVersion + ".0"
|
||||
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
|
||||
slog.Log(ctx, slog.LevelInfo, fmt.Sprintf("applying migration for %s", normalizedVersion))
|
||||
if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply minor version migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
slog.Log(ctx, slog.LevelInfo, "end migrate")
|
||||
|
||||
// remove the created backup db file after migrate succeed
|
||||
if err := os.Remove(backupDBFilePath); err != nil {
|
||||
slog.Log(ctx, slog.LevelError, fmt.Sprintf("Failed to remove temp database file, err %v", err))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In non-prod mode, we should always migrate the database.
|
||||
if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) {
|
||||
if err := db.applyLatestSchema(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to apply latest schema")
|
||||
}
|
||||
// In demo mode, we should seed the database.
|
||||
if db.profile.Mode == "demo" {
|
||||
if err := db.seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
)
|
||||
|
||||
func (db *DB) applyLatestSchema(ctx context.Context) error {
|
||||
schemaMode := "dev"
|
||||
if db.profile.Mode == "prod" {
|
||||
schemaMode = "prod"
|
||||
}
|
||||
latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName)
|
||||
buf, err := migrationFS.ReadFile(latestSchemaPath)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath)
|
||||
}
|
||||
stmt := string(buf)
|
||||
if err := db.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: statement %s", stmt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
|
||||
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read migrate files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
migrationStmt := ""
|
||||
|
||||
// Loop over all migration files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := migrationFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read minor version migration file, filename %s", filename)
|
||||
}
|
||||
stmt := string(buf)
|
||||
migrationStmt += stmt
|
||||
if err := db.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: statement %s", stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert the newest version to migration_history.
|
||||
version := minorVersion + ".0"
|
||||
if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
|
||||
Version: version,
|
||||
}); err != nil {
|
||||
return errors.Wrapf(err, "failed to upsert migration history with version %s", version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, "seed/*.sql")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
// Loop over all seed files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := seedFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read seed file, filename %s", filename)
|
||||
}
|
||||
stmt := string(buf)
|
||||
if err := db.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "seed error: statement %s", stmt)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// execute runs a single SQL statement within a transaction.
|
||||
func (db *DB) execute(ctx context.Context, stmt string) error {
|
||||
if _, err := db.DBInstance.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrap(err, "failed to execute statement")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// minorDirRegexp is a regular expression for minor version directory.
|
||||
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
|
||||
|
||||
func getMinorVersionList() []string {
|
||||
minorVersionList := []string{}
|
||||
|
||||
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if file.IsDir() && minorDirRegexp.MatchString(path) {
|
||||
minorVersionList = append(minorVersionList, file.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sort.Sort(version.SortVersion(minorVersionList))
|
||||
|
||||
return minorVersionList
|
||||
return driver, nil
|
||||
}
|
||||
|
@ -1,82 +0,0 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type MigrationHistory struct {
|
||||
Version string
|
||||
CreatedTs int64
|
||||
}
|
||||
|
||||
type MigrationHistoryUpsert struct {
|
||||
Version string
|
||||
}
|
||||
|
||||
type MigrationHistoryFind struct {
|
||||
Version *string
|
||||
}
|
||||
|
||||
func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Version; v != nil {
|
||||
where, args = append(where, "version = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
stmt := `
|
||||
SELECT
|
||||
version,
|
||||
created_ts
|
||||
FROM
|
||||
migration_history
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY created_ts DESC
|
||||
`
|
||||
rows, err := db.DBInstance.QueryContext(ctx, stmt, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
migrationHistoryList := make([]*MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
migrationHistoryList = append(migrationHistoryList, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return migrationHistoryList, nil
|
||||
}
|
||||
|
||||
func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) {
|
||||
query := `
|
||||
INSERT INTO migration_history (
|
||||
version
|
||||
)
|
||||
VALUES (?)
|
||||
ON CONFLICT(version) DO UPDATE
|
||||
SET
|
||||
version=EXCLUDED.version
|
||||
RETURNING version, created_ts
|
||||
`
|
||||
migrationHistory := &MigrationHistory{}
|
||||
if err := db.DBInstance.QueryRowContext(ctx, query, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return migrationHistory, nil
|
||||
}
|
87
store/db/sqlite/activity.go
Normal file
87
store/db/sqlite/activity.go
Normal file
@ -0,0 +1,87 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
|
||||
stmt := `
|
||||
INSERT INTO activity (
|
||||
creator_id,
|
||||
type,
|
||||
level,
|
||||
payload
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id, created_ts
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt,
|
||||
create.CreatorID,
|
||||
create.Type.String(),
|
||||
create.Level.String(),
|
||||
create.Payload,
|
||||
).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
activity := create
|
||||
return activity, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Type != "" {
|
||||
where, args = append(where, "type = ?"), append(args, find.Type.String())
|
||||
}
|
||||
if find.Level != "" {
|
||||
where, args = append(where, "level = ?"), append(args, find.Level.String())
|
||||
}
|
||||
if find.Where != nil {
|
||||
where = append(where, find.Where...)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
type,
|
||||
level,
|
||||
payload
|
||||
FROM activity
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := []*store.Activity{}
|
||||
for rows.Next() {
|
||||
activity := &store.Activity{}
|
||||
if err := rows.Scan(
|
||||
&activity.ID,
|
||||
&activity.CreatorID,
|
||||
&activity.CreatedTs,
|
||||
&activity.Type,
|
||||
&activity.Level,
|
||||
&activity.Payload,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, activity)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
183
store/db/sqlite/collection.go
Normal file
183
store/db/sqlite/collection.go
Normal file
@ -0,0 +1,183 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/internal/util"
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
|
||||
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?"}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO collection (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + strings.Join(placeholder, ",") + `)
|
||||
RETURNING id, created_ts, updated_ts
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collection := create
|
||||
return collection, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Description != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *update.Description)
|
||||
}
|
||||
if update.ShortcutIDs != nil {
|
||||
set, args = append(set, "shortcut_ids = ?"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE collection
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
|
||||
`
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
&collection.CreatedTs,
|
||||
&collection.UpdatedTs,
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
}
|
||||
}
|
||||
collection.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
return collection, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
name,
|
||||
title,
|
||||
description,
|
||||
shortcut_ids,
|
||||
visibility
|
||||
FROM collection
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Collection, 0)
|
||||
for rows.Next() {
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
if err := rows.Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
&collection.CreatedTs,
|
||||
&collection.UpdatedTs,
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
}
|
||||
}
|
||||
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
list = append(list, collection)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollection) error {
|
||||
if _, err := d.db.ExecContext(ctx, `DELETE FROM collection WHERE id = ?`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
59
store/db/sqlite/common.go
Normal file
59
store/db/sqlite/common.go
Normal file
@ -0,0 +1,59 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
|
||||
// RowStatus is the status for a row.
|
||||
type RowStatus string
|
||||
|
||||
const (
|
||||
// Normal is the status for a normal row.
|
||||
Normal RowStatus = "NORMAL"
|
||||
// Archived is the status for an archived row.
|
||||
Archived RowStatus = "ARCHIVED"
|
||||
)
|
||||
|
||||
func (e RowStatus) String() string {
|
||||
switch e {
|
||||
case Normal:
|
||||
return "NORMAL"
|
||||
case Archived:
|
||||
return "ARCHIVED"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func convertRowStatusStringToStorepb(status string) storepb.RowStatus {
|
||||
switch status {
|
||||
case "NORMAL":
|
||||
return storepb.RowStatus_NORMAL
|
||||
case "ARCHIVED":
|
||||
return storepb.RowStatus_ARCHIVED
|
||||
}
|
||||
return storepb.RowStatus_ROW_STATUS_UNSPECIFIED
|
||||
}
|
||||
|
||||
// Visibility is the type of a visibility.
|
||||
type Visibility string
|
||||
|
||||
const (
|
||||
// VisibilityPublic is the PUBLIC visibility.
|
||||
VisibilityPublic Visibility = "PUBLIC"
|
||||
// VisibilityWorkspace is the WORKSPACE visibility.
|
||||
VisibilityWorkspace Visibility = "WORKSPACE"
|
||||
// VisibilityPrivate is the PRIVATE visibility.
|
||||
VisibilityPrivate Visibility = "PRIVATE"
|
||||
)
|
||||
|
||||
func (e Visibility) String() string {
|
||||
switch e {
|
||||
case VisibilityPublic:
|
||||
return "PUBLIC"
|
||||
case VisibilityWorkspace:
|
||||
return "WORKSPACE"
|
||||
case VisibilityPrivate:
|
||||
return "PRIVATE"
|
||||
}
|
||||
return "PRIVATE"
|
||||
}
|
202
store/db/sqlite/memo.go
Normal file
202
store/db/sqlite/memo.go
Normal file
@ -0,0 +1,202 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) {
|
||||
set := []string{"creator_id", "name", "title", "content", "visibility", "tag"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO memo (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + placeholders(len(args)) + `)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
|
||||
var rowStatus string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&rowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo := create
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.RowStatus != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Content != nil {
|
||||
set, args = append(set, "content = ?"), append(args, *update.Content)
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if update.Tag != nil {
|
||||
set, args = append(set, "tag = ?"), append(args, *update.Tag)
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE memo
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
|
||||
`
|
||||
memo := &storepb.Memo{}
|
||||
var rowStatus, visibility, tags string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&memo.Id,
|
||||
&memo.CreatorId,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&rowStatus,
|
||||
&memo.Name,
|
||||
&memo.Title,
|
||||
&memo.Content,
|
||||
&visibility,
|
||||
&tags,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
memo.Tags = filterTags(strings.Split(tags, " "))
|
||||
return memo, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
if v := find.Tag; v != nil {
|
||||
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
name,
|
||||
title,
|
||||
content,
|
||||
visibility,
|
||||
tag
|
||||
FROM memo
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Memo, 0)
|
||||
for rows.Next() {
|
||||
memo := &storepb.Memo{}
|
||||
var rowStatus, visibility, tags string
|
||||
if err := rows.Scan(
|
||||
&memo.Id,
|
||||
&memo.CreatorId,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&rowStatus,
|
||||
&memo.Name,
|
||||
&memo.Title,
|
||||
&memo.Content,
|
||||
&visibility,
|
||||
&tags,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
memo.Tags = filterTags(strings.Split(tags, " "))
|
||||
list = append(list, memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
if _, err := d.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
memo
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
return strings.Repeat("?,", n-1) + "?"
|
||||
}
|
57
store/db/sqlite/migration_history.go
Normal file
57
store/db/sqlite/migration_history.go
Normal file
@ -0,0 +1,57 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
|
||||
stmt := `
|
||||
INSERT INTO migration_history (
|
||||
version
|
||||
)
|
||||
VALUES (?)
|
||||
ON CONFLICT(version) DO UPDATE
|
||||
SET
|
||||
version=EXCLUDED.version
|
||||
RETURNING version, created_ts
|
||||
`
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &migrationHistory, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListMigrationHistories(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
|
||||
query := "SELECT `version`, `created_ts` FROM `migration_history` ORDER BY `created_ts` DESC"
|
||||
rows, err := d.db.QueryContext(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.MigrationHistory, 0)
|
||||
for rows.Next() {
|
||||
var migrationHistory store.MigrationHistory
|
||||
if err := rows.Scan(
|
||||
&migrationHistory.Version,
|
||||
&migrationHistory.CreatedTs,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
list = append(list, &migrationHistory)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
234
store/db/sqlite/migrator.go
Normal file
234
store/db/sqlite/migrator.go
Normal file
@ -0,0 +1,234 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"regexp"
|
||||
"sort"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/server/version"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
//go:embed migration
|
||||
var migrationFS embed.FS
|
||||
|
||||
//go:embed seed
|
||||
var seedFS embed.FS
|
||||
|
||||
// Migrate applies the latest schema to the database.
|
||||
func (d *DB) Migrate(ctx context.Context) error {
|
||||
currentVersion := version.GetCurrentVersion(d.profile.Mode)
|
||||
if d.profile.Mode == "prod" {
|
||||
_, err := os.Stat(d.profile.DSN)
|
||||
if err != nil {
|
||||
// If db file not exists, we should create a new one with latest schema.
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
if err := d.applyLatestSchema(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to apply latest schema")
|
||||
}
|
||||
// Upsert the newest version to migration_history.
|
||||
if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
|
||||
Version: currentVersion,
|
||||
}); err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
} else {
|
||||
return errors.Wrap(err, "failed to get db file stat")
|
||||
}
|
||||
} else {
|
||||
// If db file exists, we should check if we need to migrate the database.
|
||||
migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to find migration history")
|
||||
}
|
||||
// If no migration history, we should apply the latest version migration and upsert the migration history.
|
||||
if len(migrationHistoryList) == 0 {
|
||||
minorVersion := version.GetMinorVersion(currentVersion)
|
||||
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
|
||||
return errors.Wrapf(err, "failed to apply version %s migration", minorVersion)
|
||||
}
|
||||
_, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
|
||||
Version: currentVersion,
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to upsert migration history")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
migrationHistoryVersionList := []string{}
|
||||
for _, migrationHistory := range migrationHistoryList {
|
||||
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
|
||||
}
|
||||
sort.Sort(version.SortVersion(migrationHistoryVersionList))
|
||||
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
|
||||
|
||||
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
|
||||
minorVersionList := getMinorVersionList()
|
||||
// backup the raw database file before migration
|
||||
rawBytes, err := os.ReadFile(d.profile.DSN)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read raw database file")
|
||||
}
|
||||
backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix())
|
||||
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
|
||||
return errors.Wrap(err, "failed to write raw database file")
|
||||
}
|
||||
println("succeed to copy a backup database file")
|
||||
println("start migrate")
|
||||
for _, minorVersion := range minorVersionList {
|
||||
normalizedVersion := minorVersion + ".0"
|
||||
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
|
||||
println("applying migration for", normalizedVersion)
|
||||
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
|
||||
return errors.Wrap(err, "failed to apply minor version migration")
|
||||
}
|
||||
}
|
||||
}
|
||||
println("end migrate")
|
||||
|
||||
// remove the created backup db file after migrate succeed
|
||||
if err := os.Remove(backupDBFilePath); err != nil {
|
||||
println(fmt.Sprintf("Failed to remove temp database file, err %v", err))
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// In non-prod mode, we should always migrate the database.
|
||||
if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) {
|
||||
if err := d.applyLatestSchema(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to apply latest schema")
|
||||
}
|
||||
// In demo mode, we should seed the database.
|
||||
if d.profile.Mode == "demo" {
|
||||
if err := d.seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
)
|
||||
|
||||
func (d *DB) applyLatestSchema(ctx context.Context) error {
|
||||
schemaMode := "dev"
|
||||
if d.profile.Mode == "prod" {
|
||||
schemaMode = "prod"
|
||||
}
|
||||
latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName)
|
||||
buf, err := migrationFS.ReadFile(latestSchemaPath)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath)
|
||||
}
|
||||
stmt := string(buf)
|
||||
if err := d.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: %s", stmt)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
|
||||
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read ddl files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
migrationStmt := ""
|
||||
|
||||
// Loop over all migration files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := migrationFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
|
||||
}
|
||||
stmt := string(buf)
|
||||
migrationStmt += stmt
|
||||
if err := d.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "migrate error: %s", stmt)
|
||||
}
|
||||
}
|
||||
|
||||
// Upsert the newest version to migration_history.
|
||||
version := minorVersion + ".0"
|
||||
if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
|
||||
Version: version,
|
||||
}); err != nil {
|
||||
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
}
|
||||
|
||||
sort.Strings(filenames)
|
||||
|
||||
// Loop over all seed files and execute them in order.
|
||||
for _, filename := range filenames {
|
||||
buf, err := seedFS.ReadFile(filename)
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
|
||||
}
|
||||
stmt := string(buf)
|
||||
if err := d.execute(ctx, stmt); err != nil {
|
||||
return errors.Wrapf(err, "seed error: %s", stmt)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// execute runs a single SQL statement within a transaction.
|
||||
func (d *DB) execute(ctx context.Context, stmt string) error {
|
||||
tx, err := d.db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, stmt); err != nil {
|
||||
return errors.Wrap(err, "failed to execute statement")
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
// minorDirRegexp is a regular expression for minor version directory.
|
||||
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
|
||||
|
||||
func getMinorVersionList() []string {
|
||||
minorVersionList := []string{}
|
||||
|
||||
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if file.IsDir() && minorDirRegexp.MatchString(path) {
|
||||
minorVersionList = append(minorVersionList, file.Name())
|
||||
}
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
sort.Sort(version.SortVersion(minorVersionList))
|
||||
|
||||
return minorVersionList
|
||||
}
|
249
store/db/sqlite/shortcut.go
Normal file
249
store/db/sqlite/shortcut.go
Normal file
@ -0,0 +1,249 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) {
|
||||
set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"}
|
||||
args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?"}
|
||||
if create.OgMetadata != nil {
|
||||
set = append(set, "og_metadata")
|
||||
openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, string(openGraphMetadataBytes))
|
||||
placeholder = append(placeholder, "?")
|
||||
}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO shortcut (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + strings.Join(placeholder, ",") + `)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
var rowStatus string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&rowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut := create
|
||||
return shortcut, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateShortcut(ctx context.Context, update *store.UpdateShortcut) (*storepb.Shortcut, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.RowStatus != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Link != nil {
|
||||
set, args = append(set, "link = ?"), append(args, *update.Link)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Description != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *update.Description)
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if update.Tag != nil {
|
||||
set, args = append(set, "tag = ?"), append(args, *update.Tag)
|
||||
}
|
||||
if update.OpenGraphMetadata != nil {
|
||||
openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to marshal activity payload")
|
||||
}
|
||||
set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE shortcut
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata
|
||||
`
|
||||
shortcut := &storepb.Shortcut{}
|
||||
var rowStatus, visibility, tags, openGraphMetadataString string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&shortcut.Id,
|
||||
&shortcut.CreatorId,
|
||||
&shortcut.CreatedTs,
|
||||
&shortcut.UpdatedTs,
|
||||
&rowStatus,
|
||||
&shortcut.Name,
|
||||
&shortcut.Link,
|
||||
&shortcut.Title,
|
||||
&shortcut.Description,
|
||||
&visibility,
|
||||
&tags,
|
||||
&openGraphMetadataString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
shortcut.Tags = filterTags(strings.Split(tags, " "))
|
||||
var ogMetadata storepb.OpenGraphMetadata
|
||||
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.OgMetadata = &ogMetadata
|
||||
return shortcut, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListShortcuts(ctx context.Context, find *store.FindShortcut) ([]*storepb.Shortcut, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
if v := find.Tag; v != nil {
|
||||
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
name,
|
||||
link,
|
||||
title,
|
||||
description,
|
||||
visibility,
|
||||
tag,
|
||||
og_metadata
|
||||
FROM shortcut
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Shortcut, 0)
|
||||
for rows.Next() {
|
||||
shortcut := &storepb.Shortcut{}
|
||||
var rowStatus, visibility, tags, openGraphMetadataString string
|
||||
if err := rows.Scan(
|
||||
&shortcut.Id,
|
||||
&shortcut.CreatorId,
|
||||
&shortcut.CreatedTs,
|
||||
&shortcut.UpdatedTs,
|
||||
&rowStatus,
|
||||
&shortcut.Name,
|
||||
&shortcut.Link,
|
||||
&shortcut.Title,
|
||||
&shortcut.Description,
|
||||
&visibility,
|
||||
&tags,
|
||||
&openGraphMetadataString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
shortcut.Tags = filterTags(strings.Split(tags, " "))
|
||||
var ogMetadata storepb.OpenGraphMetadata
|
||||
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.OgMetadata = &ogMetadata
|
||||
list = append(list, shortcut)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) error {
|
||||
if _, err := d.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
shortcut
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterTags(tags []string) []string {
|
||||
result := []string{}
|
||||
for _, tag := range tags {
|
||||
if tag != "" {
|
||||
result = append(result, tag)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertVisibilityStringToStorepb(visibility string) storepb.Visibility {
|
||||
return storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
}
|
56
store/db/sqlite/sqlite.go
Normal file
56
store/db/sqlite/sqlite.go
Normal file
@ -0,0 +1,56 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/server/profile"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
}
|
||||
|
||||
// NewDB opens a database specified by its database driver name and a
|
||||
// driver-specific data source name, usually consisting of at least a
|
||||
// database name and connection information.
|
||||
func NewDB(profile *profile.Profile) (store.Driver, error) {
|
||||
// Ensure a DSN is set before attempting to open the database.
|
||||
if profile.DSN == "" {
|
||||
return nil, errors.New("dsn required")
|
||||
}
|
||||
|
||||
// Connect to the database with some sane settings:
|
||||
// - No shared-cache: it's obsolete; WAL journal mode is a better solution.
|
||||
// - No foreign key constraints: it's currently disabled by default, but it's a
|
||||
// good practice to be explicit and prevent future surprises on SQLite upgrades.
|
||||
// - Journal mode set to WAL: it's the recommended journal mode for most applications
|
||||
// as it prevents locking issues.
|
||||
//
|
||||
// Notes:
|
||||
// - When using the `modernc.org/sqlite` driver, each pragma must be prefixed with `_pragma=`.
|
||||
//
|
||||
// References:
|
||||
// - https://pkg.go.dev/modernc.org/sqlite#Driver.Open
|
||||
// - https://www.sqlite.org/sharedcache.html
|
||||
// - https://www.sqlite.org/pragma.html
|
||||
sqliteDB, err := sql.Open("sqlite", profile.DSN+"?_pragma=foreign_keys(0)&_pragma=busy_timeout(10000)&_pragma=journal_mode(WAL)")
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to open db with dsn: %s", profile.DSN)
|
||||
}
|
||||
|
||||
driver := DB{db: sqliteDB, profile: profile}
|
||||
|
||||
return &driver, nil
|
||||
}
|
||||
|
||||
func (d *DB) GetDB() *sql.DB {
|
||||
return d.db
|
||||
}
|
||||
|
||||
func (d *DB) Close() error {
|
||||
return d.db.Close()
|
||||
}
|
176
store/db/sqlite/user.go
Normal file
176
store/db/sqlite/user.go
Normal file
@ -0,0 +1,176 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
|
||||
stmt := `
|
||||
INSERT INTO user (
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
role
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt,
|
||||
create.Email,
|
||||
create.Nickname,
|
||||
create.PasswordHash,
|
||||
create.Role,
|
||||
).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := create
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no fields to update")
|
||||
}
|
||||
|
||||
stmt := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
user := &store.User{}
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&user.ID,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.Role,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
role
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY updated_ts DESC, created_ts DESC
|
||||
`
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*store.User, 0)
|
||||
for rows.Next() {
|
||||
user := &store.User{}
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.Role,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM user WHERE id = ?
|
||||
`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumUserSetting(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumShortcut(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumMemo(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
}
|
128
store/db/sqlite/user_setting.go
Normal file
128
store/db/sqlite/user_setting.go
Normal file
@ -0,0 +1,128 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
|
||||
valueString = upsert.GetLocale().String()
|
||||
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
|
||||
valueString = upsert.GetColorTheme().String()
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userSettingMessage := upsert
|
||||
return userSettingMessage, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
userSettingList := make([]*storepb.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &storepb.UserSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserId,
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
|
||||
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Value = &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: accessTokensUserSetting,
|
||||
}
|
||||
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
|
||||
userSetting.Value = &storepb.UserSetting_Locale{
|
||||
Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]),
|
||||
}
|
||||
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
|
||||
userSetting.Value = &storepb.UserSetting_ColorTheme{
|
||||
ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]),
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
user_setting
|
||||
WHERE
|
||||
user_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
116
store/db/sqlite/workspace_setting.go
Normal file
116
store/db/sqlite/workspace_setting.go
Normal file
@ -0,0 +1,116 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO workspace_setting (
|
||||
key,
|
||||
value
|
||||
)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
|
||||
valueString = upsert.GetLicenseKey()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
|
||||
valueString = upsert.GetSecretSession()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
|
||||
valueString = strconv.FormatBool(upsert.GetEnableSignup())
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
|
||||
valueString = upsert.GetCustomStyle()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
|
||||
valueString = upsert.GetCustomScript()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAutoBackup())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid workspace setting key")
|
||||
}
|
||||
|
||||
if _, err := d.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workspaceSetting := upsert
|
||||
return workspaceSetting, nil
|
||||
}
|
||||
|
||||
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, find.Key.String())
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
key,
|
||||
value
|
||||
FROM workspace_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := d.db.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
list := []*storepb.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
workspaceSetting := &storepb.WorkspaceSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString])
|
||||
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
|
||||
enableSignup, err := strconv.ParseBool(valueString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
|
||||
autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
list = append(list, workspaceSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
57
store/driver.go
Normal file
57
store/driver.go
Normal file
@ -0,0 +1,57 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
|
||||
// Driver is an interface for store driver.
|
||||
// It contains all methods that store database driver should implement.
|
||||
type Driver interface {
|
||||
GetDB() *sql.DB
|
||||
Close() error
|
||||
|
||||
Migrate(ctx context.Context) error
|
||||
|
||||
// MigrationHistory model related methods.
|
||||
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)
|
||||
ListMigrationHistories(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
|
||||
|
||||
// Activity model related methods.
|
||||
CreateActivity(ctx context.Context, create *Activity) (*Activity, error)
|
||||
ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error)
|
||||
|
||||
// Collection model related methods.
|
||||
CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error)
|
||||
UpdateCollection(ctx context.Context, update *UpdateCollection) (*storepb.Collection, error)
|
||||
ListCollections(ctx context.Context, find *FindCollection) ([]*storepb.Collection, error)
|
||||
DeleteCollection(ctx context.Context, delete *DeleteCollection) error
|
||||
|
||||
// Memo model related methods.
|
||||
CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error)
|
||||
UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error)
|
||||
ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error)
|
||||
DeleteMemo(ctx context.Context, delete *DeleteMemo) error
|
||||
|
||||
// Shortcut model related methods.
|
||||
CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error)
|
||||
UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error)
|
||||
ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error)
|
||||
DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error
|
||||
|
||||
// User model related methods.
|
||||
CreateUser(ctx context.Context, create *User) (*User, error)
|
||||
UpdateUser(ctx context.Context, update *UpdateUser) (*User, error)
|
||||
ListUsers(ctx context.Context, find *FindUser) ([]*User, error)
|
||||
DeleteUser(ctx context.Context, delete *DeleteUser) error
|
||||
|
||||
// UserSetting model related methods.
|
||||
UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error)
|
||||
ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error)
|
||||
|
||||
// WorkspaceSetting model related methods.
|
||||
UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error)
|
||||
ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error)
|
||||
}
|
186
store/memo.go
186
store/memo.go
@ -2,11 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
@ -35,162 +30,15 @@ type DeleteMemo struct {
|
||||
}
|
||||
|
||||
func (s *Store) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) {
|
||||
set := []string{"creator_id", "name", "title", "content", "visibility", "tag"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO memo (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + placeholders(len(args)) + `)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
|
||||
var rowStatus string
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&rowStatus,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo := create
|
||||
return memo, nil
|
||||
return s.driver.CreateMemo(ctx, create)
|
||||
}
|
||||
|
||||
func (s *Store) UpdateMemo(ctx context.Context, update *UpdateMemo) (*storepb.Memo, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.RowStatus != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Content != nil {
|
||||
set, args = append(set, "content = ?"), append(args, *update.Content)
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if update.Tag != nil {
|
||||
set, args = append(set, "tag = ?"), append(args, *update.Tag)
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE memo
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
|
||||
`
|
||||
memo := &storepb.Memo{}
|
||||
var rowStatus, visibility, tags string
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&memo.Id,
|
||||
&memo.CreatorId,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&rowStatus,
|
||||
&memo.Name,
|
||||
&memo.Title,
|
||||
&memo.Content,
|
||||
&visibility,
|
||||
&tags,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
memo.Tags = filterTags(strings.Split(tags, " "))
|
||||
return memo, nil
|
||||
return s.driver.UpdateMemo(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Store) ListMemos(ctx context.Context, find *FindMemo) ([]*storepb.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
if v := find.Tag; v != nil {
|
||||
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
name,
|
||||
title,
|
||||
content,
|
||||
visibility,
|
||||
tag
|
||||
FROM memo
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Memo, 0)
|
||||
for rows.Next() {
|
||||
memo := &storepb.Memo{}
|
||||
var rowStatus, visibility, tags string
|
||||
if err := rows.Scan(
|
||||
&memo.Id,
|
||||
&memo.CreatorId,
|
||||
&memo.CreatedTs,
|
||||
&memo.UpdatedTs,
|
||||
&rowStatus,
|
||||
&memo.Name,
|
||||
&memo.Title,
|
||||
&memo.Content,
|
||||
&visibility,
|
||||
&tags,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
memo.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
memo.Tags = filterTags(strings.Split(tags, " "))
|
||||
list = append(list, memo)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return list, nil
|
||||
return s.driver.ListMemos(ctx, find)
|
||||
}
|
||||
|
||||
func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, error) {
|
||||
@ -208,31 +56,5 @@ func (s *Store) GetMemo(ctx context.Context, find *FindMemo) (*storepb.Memo, err
|
||||
}
|
||||
|
||||
func (s *Store) DeleteMemo(ctx context.Context, delete *DeleteMemo) error {
|
||||
if _, err := s.db.ExecContext(ctx, `DELETE FROM memo WHERE id = ?`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
memo
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
return strings.Repeat("?,", n-1) + "?"
|
||||
return s.driver.DeleteMemo(ctx, delete)
|
||||
}
|
||||
|
13
store/migration_history.go
Normal file
13
store/migration_history.go
Normal file
@ -0,0 +1,13 @@
|
||||
package store
|
||||
|
||||
type MigrationHistory struct {
|
||||
Version string
|
||||
CreatedTs int64
|
||||
}
|
||||
|
||||
type UpsertMigrationHistory struct {
|
||||
Version string
|
||||
}
|
||||
|
||||
type FindMigrationHistory struct {
|
||||
}
|
@ -2,12 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
@ -39,198 +33,28 @@ type DeleteShortcut struct {
|
||||
}
|
||||
|
||||
func (s *Store) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) {
|
||||
set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"}
|
||||
args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")}
|
||||
placeholder := []string{"?", "?", "?", "?", "?", "?", "?"}
|
||||
if create.OgMetadata != nil {
|
||||
set = append(set, "og_metadata")
|
||||
openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
args = append(args, string(openGraphMetadataBytes))
|
||||
placeholder = append(placeholder, "?")
|
||||
}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO shortcut (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + strings.Join(placeholder, ",") + `)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
var rowStatus string
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&create.Id,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&rowStatus,
|
||||
); err != nil {
|
||||
shortcut, err := s.driver.CreateShortcut(ctx, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
create.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut := create
|
||||
s.shortcutCache.Store(shortcut.Id, shortcut)
|
||||
return shortcut, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*storepb.Shortcut, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.RowStatus != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = ?"), append(args, *update.Name)
|
||||
}
|
||||
if update.Link != nil {
|
||||
set, args = append(set, "link = ?"), append(args, *update.Link)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = ?"), append(args, *update.Title)
|
||||
}
|
||||
if update.Description != nil {
|
||||
set, args = append(set, "description = ?"), append(args, *update.Description)
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = ?"), append(args, update.Visibility.String())
|
||||
}
|
||||
if update.Tag != nil {
|
||||
set, args = append(set, "tag = ?"), append(args, *update.Tag)
|
||||
}
|
||||
if update.OpenGraphMetadata != nil {
|
||||
openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Failed to marshal activity payload")
|
||||
}
|
||||
set, args = append(set, "og_metadata = ?"), append(args, string(openGraphMetadataBytes))
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE shortcut
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = ?
|
||||
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata
|
||||
`
|
||||
shortcut := &storepb.Shortcut{}
|
||||
var rowStatus, visibility, tags, openGraphMetadataString string
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&shortcut.Id,
|
||||
&shortcut.CreatorId,
|
||||
&shortcut.CreatedTs,
|
||||
&shortcut.UpdatedTs,
|
||||
&rowStatus,
|
||||
&shortcut.Name,
|
||||
&shortcut.Link,
|
||||
&shortcut.Title,
|
||||
&shortcut.Description,
|
||||
&visibility,
|
||||
&tags,
|
||||
&openGraphMetadataString,
|
||||
); err != nil {
|
||||
shortcut, err := s.driver.UpdateShortcut(ctx, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
shortcut.Tags = filterTags(strings.Split(tags, " "))
|
||||
var ogMetadata storepb.OpenGraphMetadata
|
||||
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.OgMetadata = &ogMetadata
|
||||
s.shortcutCache.Store(shortcut.Id, shortcut)
|
||||
return shortcut, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*storepb.Shortcut, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for _, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+1))
|
||||
args = append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
if v := find.Tag; v != nil {
|
||||
where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%")
|
||||
}
|
||||
|
||||
rows, err := s.db.QueryContext(ctx, `
|
||||
SELECT
|
||||
id,
|
||||
creator_id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
name,
|
||||
link,
|
||||
title,
|
||||
description,
|
||||
visibility,
|
||||
tag,
|
||||
og_metadata
|
||||
FROM shortcut
|
||||
WHERE `+strings.Join(where, " AND ")+`
|
||||
ORDER BY created_ts DESC`,
|
||||
args...,
|
||||
)
|
||||
list, err := s.driver.ListShortcuts(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*storepb.Shortcut, 0)
|
||||
for rows.Next() {
|
||||
shortcut := &storepb.Shortcut{}
|
||||
var rowStatus, visibility, tags, openGraphMetadataString string
|
||||
if err := rows.Scan(
|
||||
&shortcut.Id,
|
||||
&shortcut.CreatorId,
|
||||
&shortcut.CreatedTs,
|
||||
&shortcut.UpdatedTs,
|
||||
&rowStatus,
|
||||
&shortcut.Name,
|
||||
&shortcut.Link,
|
||||
&shortcut.Title,
|
||||
&shortcut.Description,
|
||||
&visibility,
|
||||
&tags,
|
||||
&openGraphMetadataString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.RowStatus = convertRowStatusStringToStorepb(rowStatus)
|
||||
shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
shortcut.Tags = filterTags(strings.Split(tags, " "))
|
||||
var ogMetadata storepb.OpenGraphMetadata
|
||||
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
shortcut.OgMetadata = &ogMetadata
|
||||
list = append(list, shortcut)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, shortcut := range list {
|
||||
s.shortcutCache.Store(shortcut.Id, shortcut)
|
||||
}
|
||||
@ -259,44 +83,10 @@ func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*storepb.S
|
||||
}
|
||||
|
||||
func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error {
|
||||
if _, err := s.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil {
|
||||
if err := s.driver.DeleteShortcut(ctx, delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.shortcutCache.Delete(delete.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
shortcut
|
||||
WHERE
|
||||
creator_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func filterTags(tags []string) []string {
|
||||
result := []string{}
|
||||
for _, tag := range tags {
|
||||
if tag != "" {
|
||||
result = append(result, tag)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func convertVisibilityStringToStorepb(visibility string) storepb.Visibility {
|
||||
return storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
}
|
||||
|
@ -1,7 +1,6 @@
|
||||
package store
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"sync"
|
||||
|
||||
"github.com/yourselfhosted/slash/server/profile"
|
||||
@ -9,8 +8,8 @@ import (
|
||||
|
||||
// Store provides database access to all raw objects.
|
||||
type Store struct {
|
||||
db *sql.DB
|
||||
profile *profile.Profile
|
||||
driver Driver
|
||||
|
||||
workspaceSettingCache sync.Map // map[string]*WorkspaceSetting
|
||||
userCache sync.Map // map[int]*User
|
||||
@ -19,14 +18,14 @@ type Store struct {
|
||||
}
|
||||
|
||||
// New creates a new instance of Store.
|
||||
func New(db *sql.DB, profile *profile.Profile) *Store {
|
||||
func New(driver Driver, profile *profile.Profile) *Store {
|
||||
return &Store{
|
||||
db: db,
|
||||
driver: driver,
|
||||
profile: profile,
|
||||
}
|
||||
}
|
||||
|
||||
// Close closes the database connection.
|
||||
func (s *Store) Close() error {
|
||||
return s.db.Close()
|
||||
return s.driver.Close()
|
||||
}
|
||||
|
155
store/user.go
155
store/user.go
@ -2,8 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Role is the type of a role.
|
||||
@ -54,147 +52,31 @@ type DeleteUser struct {
|
||||
}
|
||||
|
||||
func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) {
|
||||
stmt := `
|
||||
INSERT INTO user (
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
role
|
||||
)
|
||||
VALUES (?, ?, ?, ?)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
if err := s.db.QueryRowContext(ctx, stmt,
|
||||
create.Email,
|
||||
create.Nickname,
|
||||
create.PasswordHash,
|
||||
create.Role,
|
||||
).Scan(
|
||||
&create.ID,
|
||||
&create.CreatedTs,
|
||||
&create.UpdatedTs,
|
||||
&create.RowStatus,
|
||||
); err != nil {
|
||||
user, err := s.driver.CreateUser(ctx, create)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user := create
|
||||
s.userCache.Store(user.ID, user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = ?"), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no fields to update")
|
||||
}
|
||||
|
||||
stmt := `
|
||||
UPDATE user
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ?
|
||||
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
user := &User{}
|
||||
if err := s.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&user.ID,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.Role,
|
||||
); err != nil {
|
||||
user, err := s.driver.UpdateUser(ctx, update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.userCache.Store(user.ID, user)
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = ?"), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = ?"), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
id,
|
||||
created_ts,
|
||||
updated_ts,
|
||||
row_status,
|
||||
email,
|
||||
nickname,
|
||||
password_hash,
|
||||
role
|
||||
FROM user
|
||||
WHERE ` + strings.Join(where, " AND ") + `
|
||||
ORDER BY updated_ts DESC, created_ts DESC
|
||||
`
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
list, err := s.driver.ListUsers(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
list := make([]*User, 0)
|
||||
for rows.Next() {
|
||||
user := &User{}
|
||||
if err := rows.Scan(
|
||||
&user.ID,
|
||||
&user.CreatedTs,
|
||||
&user.UpdatedTs,
|
||||
&user.RowStatus,
|
||||
&user.Email,
|
||||
&user.Nickname,
|
||||
&user.PasswordHash,
|
||||
&user.Role,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
list = append(list, user)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, user := range list {
|
||||
s.userCache.Store(user.ID, user)
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
@ -218,35 +100,10 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) {
|
||||
}
|
||||
|
||||
func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error {
|
||||
tx, err := s.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM user WHERE id = ?
|
||||
`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumUserSetting(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumShortcut(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumMemo(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if err := s.driver.DeleteUser(ctx, delete); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.userCache.Delete(delete.ID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,11 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
@ -17,96 +12,17 @@ type FindUserSetting struct {
|
||||
}
|
||||
|
||||
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO user_setting (
|
||||
user_id, key, value
|
||||
)
|
||||
VALUES (?, ?, ?)
|
||||
ON CONFLICT(user_id, key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
|
||||
valueString = upsert.GetLocale().String()
|
||||
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
|
||||
valueString = upsert.GetColorTheme().String()
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userSettingMessage := upsert
|
||||
s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserId, userSettingMessage.Key.String()), userSettingMessage)
|
||||
return userSettingMessage, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, "user_id = ?"), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
user_id,
|
||||
key,
|
||||
value
|
||||
FROM user_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
userSetting, err := s.driver.UpsertUserSetting(ctx, upsert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserId, userSetting.Key.String()), userSetting)
|
||||
return userSetting, nil
|
||||
}
|
||||
|
||||
userSettingList := make([]*storepb.UserSetting, 0)
|
||||
for rows.Next() {
|
||||
userSetting := &storepb.UserSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&userSetting.UserId,
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
|
||||
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
|
||||
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userSetting.Value = &storepb.UserSetting_AccessTokens{
|
||||
AccessTokens: accessTokensUserSetting,
|
||||
}
|
||||
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
|
||||
userSetting.Value = &storepb.UserSetting_Locale{
|
||||
Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]),
|
||||
}
|
||||
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
|
||||
userSetting.Value = &storepb.UserSetting_ColorTheme{
|
||||
ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]),
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("invalid user setting key")
|
||||
}
|
||||
userSettingList = append(userSettingList, userSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*storepb.UserSetting, error) {
|
||||
userSettingList, err := s.driver.ListUserSettings(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@ -137,25 +53,6 @@ func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*sto
|
||||
return userSetting, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `
|
||||
DELETE FROM
|
||||
user_setting
|
||||
WHERE
|
||||
user_id NOT IN (
|
||||
SELECT
|
||||
id
|
||||
FROM
|
||||
user
|
||||
)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetUserAccessTokens returns the access tokens of the user.
|
||||
func (s *Store) GetUserAccessTokens(ctx context.Context, userID int32) ([]*storepb.AccessTokensUserSetting_AccessToken, error) {
|
||||
userSetting, err := s.GetUserSetting(ctx, &FindUserSetting{
|
||||
|
@ -2,11 +2,6 @@ package store
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
)
|
||||
@ -16,110 +11,22 @@ type FindWorkspaceSetting struct {
|
||||
}
|
||||
|
||||
func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) {
|
||||
stmt := `
|
||||
INSERT INTO workspace_setting (
|
||||
key,
|
||||
value
|
||||
)
|
||||
VALUES (?, ?)
|
||||
ON CONFLICT(key) DO UPDATE
|
||||
SET value = EXCLUDED.value
|
||||
`
|
||||
var valueString string
|
||||
if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
|
||||
valueString = upsert.GetLicenseKey()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
|
||||
valueString = upsert.GetSecretSession()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
|
||||
valueString = strconv.FormatBool(upsert.GetEnableSignup())
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
|
||||
valueString = upsert.GetCustomStyle()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
|
||||
valueString = upsert.GetCustomScript()
|
||||
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
|
||||
valueBytes, err := protojson.Marshal(upsert.GetAutoBackup())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
valueString = string(valueBytes)
|
||||
} else {
|
||||
return nil, errors.New("invalid workspace setting key")
|
||||
}
|
||||
|
||||
if _, err := s.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil {
|
||||
workspaceSetting, err := s.driver.UpsertWorkspaceSetting(ctx, upsert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
workspaceSetting := upsert
|
||||
s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting)
|
||||
return workspaceSetting, nil
|
||||
}
|
||||
|
||||
func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = ?"), append(args, find.Key.String())
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT
|
||||
key,
|
||||
value
|
||||
FROM workspace_setting
|
||||
WHERE ` + strings.Join(where, " AND ")
|
||||
rows, err := s.db.QueryContext(ctx, query, args...)
|
||||
list, err := s.driver.ListWorkspaceSettings(ctx, find)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
list := []*storepb.WorkspaceSetting{}
|
||||
for rows.Next() {
|
||||
workspaceSetting := &storepb.WorkspaceSetting{}
|
||||
var keyString, valueString string
|
||||
if err := rows.Scan(
|
||||
&keyString,
|
||||
&valueString,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString])
|
||||
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
|
||||
enableSignup, err := strconv.ParseBool(valueString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString}
|
||||
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
|
||||
autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{}
|
||||
if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting}
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
list = append(list, workspaceSetting)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, workspaceSetting := range list {
|
||||
s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting)
|
||||
}
|
||||
|
||||
return list, nil
|
||||
}
|
||||
|
||||
|
@ -31,12 +31,17 @@ type TestingServer struct {
|
||||
|
||||
func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) {
|
||||
profile := test.GetTestingProfile(t)
|
||||
db := db.NewDB(profile)
|
||||
if err := db.Open(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to open db")
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to create db driver, error: %+v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
if err := dbDriver.Migrate(ctx); err != nil {
|
||||
fmt.Printf("failed to migrate db, error: %+v\n", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store := store.New(db.DBInstance, profile)
|
||||
store := store.New(dbDriver, profile)
|
||||
server, err := server.NewServer(ctx, profile, store)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create server")
|
||||
|
@ -15,11 +15,14 @@ import (
|
||||
|
||||
func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
||||
profile := test.GetTestingProfile(t)
|
||||
db := db.NewDB(profile)
|
||||
if err := db.Open(ctx); err != nil {
|
||||
fmt.Printf("failed to open db, error: %+v\n", err)
|
||||
dbDriver, err := db.NewDBDriver(profile)
|
||||
if err != nil {
|
||||
fmt.Printf("failed to create db driver, error: %+v\n", err)
|
||||
}
|
||||
if err := dbDriver.Migrate(ctx); err != nil {
|
||||
fmt.Printf("failed to migrate db, error: %+v\n", err)
|
||||
}
|
||||
|
||||
store := store.New(db.DBInstance, profile)
|
||||
store := store.New(dbDriver, profile)
|
||||
return store
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user