feat: initial postgres driver

This commit is contained in:
Steven 2023-12-17 15:08:51 +08:00
parent 41cb597f03
commit a7d48e8059
17 changed files with 1674 additions and 0 deletions

View File

@ -65,6 +65,8 @@ linters-settings:
disabled: true
- name: early-return
disabled: true
- name: use-any
disabled: true
- name: exported
arguments:
- "disableStutteringCheck"

1
go.mod
View File

@ -74,6 +74,7 @@ require (
github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1
github.com/h2non/filetype v1.1.3
github.com/improbable-eng/grpc-web v0.15.0
github.com/lib/pq v1.10.9
github.com/mssola/useragent v1.0.0
github.com/patrickmn/go-cache v2.1.0+incompatible
github.com/pkg/errors v0.9.1

2
go.sum
View File

@ -292,6 +292,8 @@ github.com/labstack/echo/v4 v4.11.2/go.mod h1:UcGuQ8V6ZNRmSweBIJkPvGfwCMIlFmiqrP
github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8=
github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM=
github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM=
github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4=
github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ=

View File

@ -5,6 +5,7 @@ import (
"github.com/yourselfhosted/slash/server/profile"
"github.com/yourselfhosted/slash/store"
"github.com/yourselfhosted/slash/store/db/postgres"
"github.com/yourselfhosted/slash/store/db/sqlite"
)
@ -16,6 +17,8 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) {
switch profile.Driver {
case "sqlite":
driver, err = sqlite.NewDB(profile)
case "postgres":
driver, err = postgres.NewDB(profile)
default:
return nil, errors.New("unknown db driver")
}

View File

@ -0,0 +1,88 @@
package postgres
import (
"context"
"fmt"
"strings"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) {
stmt := `
INSERT INTO activity (
creator_id,
type,
level,
payload
)
VALUES ($1, $2, $3, $4)
RETURNING id, created_ts
`
if err := d.db.QueryRowContext(ctx, stmt,
create.CreatorID,
create.Type.String(),
create.Level.String(),
create.Payload,
).Scan(
&create.ID,
&create.CreatedTs,
); err != nil {
return nil, err
}
activity := create
return activity, nil
}
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{}
if find.Type != "" {
where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String())
}
if find.Level != "" {
where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String())
}
if find.Where != nil {
where = append(where, find.Where...)
}
query := `
SELECT
id,
creator_id,
created_ts,
type,
level,
payload
FROM activity
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*store.Activity{}
for rows.Next() {
activity := &store.Activity{}
if err := rows.Scan(
&activity.ID,
&activity.CreatorID,
&activity.CreatedTs,
&activity.Type,
&activity.Level,
&activity.Payload,
); err != nil {
return nil, err
}
list = append(list, activity)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

View File

@ -0,0 +1,194 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/internal/util"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"}
stmt := `
INSERT INTO collection (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts
`
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
); err != nil {
return nil, err
}
collection := create
return collection, nil
}
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
set, args := []string{}, []any{}
if update.Name != nil {
set, args = append(set, "name = $1"), append(args, *update.Name)
}
if update.Title != nil {
set, args = append(set, "title = $2"), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, "description = $3"), append(args, *update.Description)
}
if update.ShortcutIDs != nil {
set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
}
if update.Visibility != nil {
set, args = append(set, "visibility = $5"), append(args, update.Visibility.String())
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE collection
SET
` + strings.Join(set, ", ") + `
WHERE
id = $6
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
`
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = convertVisibilityStringToStorepb(visibility)
return collection, nil
}
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = $3"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for i, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
name,
title,
description,
shortcut_ids,
visibility
FROM collection
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Collection, 0)
for rows.Next() {
collection := &storepb.Collection{}
var shortcutIDs, visibility string
if err := rows.Scan(
&collection.Id,
&collection.CreatorId,
&collection.CreatedTs,
&collection.UpdatedTs,
&collection.Name,
&collection.Title,
&collection.Description,
&shortcutIDs,
&visibility,
); err != nil {
return nil, err
}
collection.ShortcutIds = []int32{}
if shortcutIDs != "" {
for _, idStr := range strings.Split(shortcutIDs, ",") {
shortcutID, err := util.ConvertStringToInt32(idStr)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
}
}
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
list = append(list, collection)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollection) error {
if _, err := d.db.ExecContext(ctx, `DELETE FROM collection WHERE id = $1`, delete.ID); err != nil {
return err
}
return nil
}
func vacuumCollection(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

208
store/db/postgres/memo.go Normal file
View File

@ -0,0 +1,208 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Memo, error) {
set := []string{"creator_id", "name", "title", "content", "visibility", "tag"}
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
stmt := `
INSERT INTO memo (
` + strings.Join(set, ", ") + `
)
VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts, row_status
`
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
memo := create
return memo, nil
}
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
set, args := []string{}, []any{}
if update.RowStatus != nil {
set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1))
args = append(args, update.RowStatus.String())
}
if update.Name != nil {
set = append(set, fmt.Sprintf("name = $%d", len(set)+1))
args = append(args, *update.Name)
}
if update.Title != nil {
set = append(set, fmt.Sprintf("title = $%d", len(set)+1))
args = append(args, *update.Title)
}
if update.Content != nil {
set = append(set, fmt.Sprintf("content = $%d", len(set)+1))
args = append(args, *update.Content)
}
if update.Visibility != nil {
set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1))
args = append(args, update.Visibility.String())
}
if update.Tag != nil {
set = append(set, fmt.Sprintf("tag = $%d", len(set)+1))
args = append(args, *update.Tag)
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := `
UPDATE memo
SET
` + strings.Join(set, ", ") + `
WHERE
id = $` + fmt.Sprint(len(set)+1) + `
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
`
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
memo.Visibility = convertVisibilityStringToStorepb(visibility)
memo.Tags = filterTags(strings.Split(tags, " "))
return memo, nil
}
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $3"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = $4"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for i, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%")
}
rows, err := d.db.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
title,
content,
visibility,
tag
FROM memo
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Memo, 0)
for rows.Next() {
memo := &storepb.Memo{}
var rowStatus, visibility, tags string
if err := rows.Scan(
&memo.Id,
&memo.CreatorId,
&memo.CreatedTs,
&memo.UpdatedTs,
&rowStatus,
&memo.Name,
&memo.Title,
&memo.Content,
&visibility,
&tags,
); err != nil {
return nil, err
}
memo.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
memo.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
memo.Tags = filterTags(strings.Split(tags, " "))
list = append(list, memo)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
if _, err := d.db.ExecContext(ctx, `DELETE FROM memo WHERE id = $1`, delete.ID); err != nil {
return err
}
return nil
}
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func placeholders(n int) string {
placeholder := ""
for i := 0; i < n; i++ {
if i == 0 {
placeholder = fmt.Sprintf("$%d", i+1)
} else {
placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1)
}
}
return placeholder
}

View File

@ -0,0 +1,92 @@
-- migration_history
CREATE TABLE migration_history (
version TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW())
);
-- workspace_setting
CREATE TABLE workspace_setting (
key TEXT NOT NULL UNIQUE,
value TEXT NOT NULL
);
-- user
CREATE TABLE user (
id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
email TEXT NOT NULL UNIQUE,
nickname TEXT NOT NULL,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
);
CREATE INDEX idx_user_email ON user(email);
-- user_setting
CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (user_id, key)
);
-- shortcut
CREATE TABLE shortcut (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
link TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE',
tag TEXT NOT NULL DEFAULT '',
og_metadata TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
payload TEXT NOT NULL DEFAULT '{}'
);
-- collection
CREATE TABLE collection (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
shortcut_ids INTEGER ARRAY NOT NULL,
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE'
);
CREATE INDEX idx_collection_name ON collection(name);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL DEFAULT '',
content TEXT NOT NULL DEFAULT '',
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE',
tag TEXT NOT NULL DEFAULT ''
);
CREATE INDEX idx_memo_name ON memo(name);

View File

@ -0,0 +1,92 @@
-- migration_history
CREATE TABLE migration_history (
version TEXT NOT NULL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW())
);
-- workspace_setting
CREATE TABLE workspace_setting (
key TEXT NOT NULL UNIQUE,
value TEXT NOT NULL
);
-- user
CREATE TABLE user (
id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
email TEXT NOT NULL UNIQUE,
nickname TEXT NOT NULL,
password_hash TEXT NOT NULL,
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
);
CREATE INDEX idx_user_email ON user(email);
-- user_setting
CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
PRIMARY KEY (user_id, key)
);
-- shortcut
CREATE TABLE shortcut (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
link TEXT NOT NULL,
title TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE',
tag TEXT NOT NULL DEFAULT '',
og_metadata TEXT NOT NULL DEFAULT '{}'
);
CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity
CREATE TABLE activity (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
payload TEXT NOT NULL DEFAULT '{}'
);
-- collection
CREATE TABLE collection (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL DEFAULT '',
description TEXT NOT NULL DEFAULT '',
shortcut_ids INTEGER ARRAY NOT NULL,
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE'
);
CREATE INDEX idx_collection_name ON collection(name);
-- memo
CREATE TABLE memo (
id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL DEFAULT '',
content TEXT NOT NULL DEFAULT '',
visibility TEXT NOT NULL CHECK (visibility IN ('PRIVATE', 'WORKSPACE', 'PUBLIC')) DEFAULT 'PRIVATE',
tag TEXT NOT NULL DEFAULT ''
);
CREATE INDEX idx_memo_name ON memo(name);

View File

@ -0,0 +1,57 @@
package postgres
import (
"context"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) {
stmt := `
INSERT INTO migration_history (
version
)
VALUES ($1)
ON CONFLICT(version) DO UPDATE
SET
version=EXCLUDED.version
RETURNING version, created_ts
`
var migrationHistory store.MigrationHistory
if err := d.db.QueryRowContext(ctx, stmt, upsert.Version).Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
return &migrationHistory, nil
}
func (d *DB) ListMigrationHistories(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) {
query := "SELECT version, created_ts FROM migration_history ORDER BY created_ts DESC"
rows, err := d.db.QueryContext(ctx, query)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.MigrationHistory, 0)
for rows.Next() {
var migrationHistory store.MigrationHistory
if err := rows.Scan(
&migrationHistory.Version,
&migrationHistory.CreatedTs,
); err != nil {
return nil, err
}
list = append(list, &migrationHistory)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}

View File

@ -0,0 +1,233 @@
package postgres
import (
"context"
"embed"
"fmt"
"io/fs"
"os"
"regexp"
"sort"
"time"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/server/version"
"github.com/yourselfhosted/slash/store"
)
//go:embed migration
var migrationFS embed.FS
//go:embed seed
var seedFS embed.FS
// Migrate applies the latest schema to the database.
func (d *DB) Migrate(ctx context.Context) error {
currentVersion := version.GetCurrentVersion(d.profile.Mode)
if d.profile.Mode == "prod" {
_, err := os.Stat(d.profile.DSN)
if err != nil {
// If db file not exists, we should create a new one with latest schema.
if errors.Is(err, os.ErrNotExist) {
if err := d.applyLatestSchema(ctx); err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
// Upsert the newest version to migration_history.
if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: currentVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
} else {
return errors.Wrap(err, "failed to get db file stat")
}
} else {
// If db file exists, we should check if we need to migrate the database.
migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{})
if err != nil {
return errors.Wrap(err, "failed to find migration history")
}
// If no migration history, we should apply the latest version migration and upsert the migration history.
if len(migrationHistoryList) == 0 {
minorVersion := version.GetMinorVersion(currentVersion)
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrapf(err, "failed to apply version %s migration", minorVersion)
}
_, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: currentVersion,
})
if err != nil {
return errors.Wrap(err, "failed to upsert migration history")
}
return nil
}
migrationHistoryVersionList := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersionList))
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
minorVersionList := getMinorVersionList()
// backup the raw database file before migration
rawBytes, err := os.ReadFile(d.profile.DSN)
if err != nil {
return errors.Wrap(err, "failed to read raw database file")
}
backupDBFilePath := fmt.Sprintf("%s/memos_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix())
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
return errors.Wrap(err, "failed to write raw database file")
}
println("succeed to copy a backup database file")
println("start migrate")
for _, minorVersion := range minorVersionList {
normalizedVersion := minorVersion + ".0"
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
println("applying migration for", normalizedVersion)
if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return errors.Wrap(err, "failed to apply minor version migration")
}
}
}
println("end migrate")
// remove the created backup db file after migrate succeed
if err := os.Remove(backupDBFilePath); err != nil {
println(fmt.Sprintf("Failed to remove temp database file, err %v", err))
}
}
}
} else {
// In non-prod mode, we should always migrate the database.
if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) {
if err := d.applyLatestSchema(ctx); err != nil {
return errors.Wrap(err, "failed to apply latest schema")
}
// In demo mode, we should seed the database.
if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed")
}
}
}
}
return nil
}
const (
latestSchemaFileName = "LATEST__SCHEMA.sql"
)
func (d *DB) applyLatestSchema(ctx context.Context) error {
schemaMode := "dev"
if d.profile.Mode == "prod" {
schemaMode = "prod"
}
latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName)
buf, err := migrationFS.ReadFile(latestSchemaPath)
if err != nil {
return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath)
}
stmt := string(buf)
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: %s", stmt)
}
return nil
}
func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error {
filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion))
if err != nil {
return errors.Wrap(err, "failed to read ddl files")
}
sort.Strings(filenames)
migrationStmt := ""
// Loop over all migration files and execute them in order.
for _, filename := range filenames {
buf, err := migrationFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename)
}
stmt := string(buf)
migrationStmt += stmt
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "migrate error: %s", stmt)
}
}
// Upsert the newest version to migration_history.
version := minorVersion + ".0"
if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{
Version: version,
}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version: %s", version)
}
return nil
}
func (d *DB) seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil {
return errors.Wrap(err, "failed to read seed files")
}
sort.Strings(filenames)
// Loop over all seed files and execute them in order.
for _, filename := range filenames {
buf, err := seedFS.ReadFile(filename)
if err != nil {
return errors.Wrapf(err, "failed to read seed file, filename=%s", filename)
}
stmt := string(buf)
if err := d.execute(ctx, stmt); err != nil {
return errors.Wrapf(err, "seed error: %s", stmt)
}
}
return nil
}
// execute runs a single SQL statement within a transaction.
func (d *DB) execute(ctx context.Context, stmt string) error {
tx, err := d.db.Begin()
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, stmt); err != nil {
return errors.Wrap(err, "failed to execute statement")
}
return tx.Commit()
}
// minorDirRegexp is a regular expression for minor version directory.
var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`)
func getMinorVersionList() []string {
minorVersionList := []string{}
if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error {
if err != nil {
return err
}
if file.IsDir() && minorDirRegexp.MatchString(path) {
minorVersionList = append(minorVersionList, file.Name())
}
return nil
}); err != nil {
panic(err)
}
sort.Sort(version.SortVersion(minorVersionList))
return minorVersionList
}

View File

@ -0,0 +1,45 @@
package postgres
import (
"database/sql"
"log"
// Import the PostgreSQL driver.
_ "github.com/lib/pq"
"github.com/pkg/errors"
"github.com/yourselfhosted/slash/server/profile"
"github.com/yourselfhosted/slash/store"
)
type DB struct {
db *sql.DB
profile *profile.Profile
}
func NewDB(profile *profile.Profile) (store.Driver, error) {
if profile == nil {
return nil, errors.New("profile is nil")
}
// Open the PostgreSQL connection
db, err := sql.Open("postgres", profile.DSN)
if err != nil {
log.Printf("Failed to open database: %s", err)
return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN)
}
var driver store.Driver = &DB{
db: db,
profile: profile,
}
return driver, nil
}
func (d *DB) GetDB() *sql.DB {
return d.db
}
func (d *DB) Close() error {
return d.db.Close()
}

View File

@ -0,0 +1,9 @@
DELETE FROM activity;
DELETE FROM shortcut;
DELETE FROM user_setting;
DELETE FROM user;
DELETE FROM workspace_setting;

View File

@ -0,0 +1,228 @@
package postgres
import (
"context"
"database/sql"
"fmt"
"strings"
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateShortcut(ctx context.Context, create *storepb.Shortcut) (*storepb.Shortcut, error) {
set := []string{"creator_id", "name", "link", "title", "description", "visibility", "tag"}
args := []any{create.CreatorId, create.Name, create.Link, create.Title, create.Description, create.Visibility.String(), strings.Join(create.Tags, " ")}
if create.OgMetadata != nil {
set = append(set, "og_metadata")
openGraphMetadataBytes, err := protojson.Marshal(create.OgMetadata)
if err != nil {
return nil, err
}
args = append(args, string(openGraphMetadataBytes))
}
stmt := fmt.Sprintf(`
INSERT INTO shortcut (%s)
VALUES (%s)
RETURNING id, created_ts, updated_ts, row_status
`, strings.Join(set, ","), placeholders(len(args)))
var rowStatus string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&create.Id,
&create.CreatedTs,
&create.UpdatedTs,
&rowStatus,
); err != nil {
return nil, err
}
create.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
shortcut := create
return shortcut, nil
}
func (d *DB) UpdateShortcut(ctx context.Context, update *store.UpdateShortcut) (*storepb.Shortcut, error) {
set, args := []string{}, []any{}
if update.RowStatus != nil {
set, args = append(set, fmt.Sprintf("row_status = $%d", len(args)+1)), append(args, update.RowStatus.String())
}
if update.Name != nil {
set, args = append(set, fmt.Sprintf("name = $%d", len(args)+1)), append(args, *update.Name)
}
if update.Link != nil {
set, args = append(set, fmt.Sprintf("link = $%d", len(args)+1)), append(args, *update.Link)
}
if update.Title != nil {
set, args = append(set, fmt.Sprintf("title = $%d", len(args)+1)), append(args, *update.Title)
}
if update.Description != nil {
set, args = append(set, fmt.Sprintf("description = $%d", len(args)+1)), append(args, *update.Description)
}
if update.Visibility != nil {
set, args = append(set, fmt.Sprintf("visibility = $%d", len(args)+1)), append(args, update.Visibility.String())
}
if update.Tag != nil {
set, args = append(set, fmt.Sprintf("tag = $%d", len(args)+1)), append(args, *update.Tag)
}
if update.OpenGraphMetadata != nil {
openGraphMetadataBytes, err := protojson.Marshal(update.OpenGraphMetadata)
if err != nil {
return nil, errors.Wrap(err, "Failed to marshal activity payload")
}
set, args = append(set, fmt.Sprintf("og_metadata = $%d", len(args)+1)), append(args, string(openGraphMetadataBytes))
}
if len(set) == 0 {
return nil, errors.New("no update specified")
}
args = append(args, update.ID)
stmt := fmt.Sprintf(`
UPDATE shortcut
SET %s
WHERE id = $%d
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, title, description, visibility, tag, og_metadata
`, strings.Join(set, ","), len(args))
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err
}
shortcut.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = convertVisibilityStringToStorepb(visibility)
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
return shortcut, nil
}
func (d *DB) ListShortcuts(ctx context.Context, find *store.FindShortcut) ([]*storepb.Shortcut, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, fmt.Sprintf("creator_id = $%d", len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, fmt.Sprintf("row_status = $%d", len(args)+1)), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, fmt.Sprintf("name = $%d", len(args)+1)), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
}
if v := find.Tag; v != nil {
where, args = append(where, fmt.Sprintf("tag LIKE $%d", len(args)+1)), append(args, "%"+*v+"%")
}
rows, err := d.db.QueryContext(ctx, fmt.Sprintf(`
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
name,
link,
title,
description,
visibility,
tag,
og_metadata
FROM shortcut
WHERE %s
ORDER BY created_ts DESC
`, strings.Join(where, " AND ")), args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*storepb.Shortcut, 0)
for rows.Next() {
shortcut := &storepb.Shortcut{}
var rowStatus, visibility, tags, openGraphMetadataString string
if err := rows.Scan(
&shortcut.Id,
&shortcut.CreatorId,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&rowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Title,
&shortcut.Description,
&visibility,
&tags,
&openGraphMetadataString,
); err != nil {
return nil, err
}
shortcut.RowStatus = store.ConvertRowStatusStringToStorepb(rowStatus)
shortcut.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
shortcut.Tags = filterTags(strings.Split(tags, " "))
var ogMetadata storepb.OpenGraphMetadata
if err := protojson.Unmarshal([]byte(openGraphMetadataString), &ogMetadata); err != nil {
return nil, err
}
shortcut.OgMetadata = &ogMetadata
list = append(list, shortcut)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) error {
_, err := d.db.ExecContext(ctx, "DELETE FROM shortcut WHERE id = $1", delete.ID)
return err
}
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
return err
}
func filterTags(tags []string) []string {
result := []string{}
for _, tag := range tags {
if tag != "" {
result = append(result, tag)
}
}
return result
}
func convertVisibilityStringToStorepb(visibility string) storepb.Visibility {
return storepb.Visibility(storepb.Visibility_value[visibility])
}

182
store/db/postgres/user.go Normal file
View File

@ -0,0 +1,182 @@
package postgres
import (
"context"
"errors"
"fmt"
"strings"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) {
stmt := `
INSERT INTO "user" (
email,
nickname,
password_hash,
role
)
VALUES ($1, $2, $3, $4)
RETURNING id, created_ts, updated_ts, row_status
`
if err := d.db.QueryRowContext(ctx, stmt,
create.Email,
create.Nickname,
create.PasswordHash,
create.Role,
).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
user := create
return user, nil
}
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v)
}
if len(set) == 0 {
return nil, errors.New("no fields to update")
}
stmt := `
UPDATE "user"
SET ` + strings.Join(set, ", ") + `
WHERE id = $` + placeholder(len(args)+1) + `
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
`
args = append(args, update.ID)
user := &store.User{}
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
return user, nil
}
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String())
}
if v := find.Email; v != nil {
where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v)
}
query := `
SELECT
id,
created_ts,
updated_ts,
row_status,
email,
nickname,
password_hash,
role
FROM "user"
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY updated_ts DESC, created_ts DESC
`
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*store.User, 0)
for rows.Next() {
user := &store.User{}
if err := rows.Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Email,
&user.Nickname,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
list = append(list, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
tx, err := d.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM "user" WHERE id = $1
`, delete.ID); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
return err
}
if err := vacuumShortcut(ctx, tx); err != nil {
return err
}
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
if err := vacuumCollection(ctx, tx); err != nil {
return err
}
return tx.Commit()
}
func placeholder(n int) string {
return "$" + fmt.Sprint(n)
}

View File

@ -0,0 +1,122 @@
package postgres
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertUserSetting(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) {
stmt := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES ($1, $2, $3)
ON CONFLICT(user_id, key) DO UPDATE
SET value = EXCLUDED.value
RETURNING user_id, key, value
`
var valueString string
if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
valueBytes, err := protojson.Marshal(upsert.GetAccessTokens())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
valueString = upsert.GetLocale().String()
} else if upsert.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
valueString = upsert.GetColorTheme().String()
} else {
return nil, errors.New("invalid user setting key")
}
if _, err := d.db.ExecContext(ctx, stmt, upsert.UserId, upsert.Key.String(), valueString); err != nil {
return nil, err
}
userSettingMessage := upsert
return userSettingMessage, nil
}
func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*storepb.UserSetting, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String())
}
if v := find.UserID; v != nil {
where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID)
}
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
userSettingList := make([]*storepb.UserSetting, 0)
for rows.Next() {
userSetting := &storepb.UserSetting{}
var keyString, valueString string
if err := rows.Scan(
&userSetting.UserId,
&keyString,
&valueString,
); err != nil {
return nil, err
}
userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString])
if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS {
accessTokensUserSetting := &storepb.AccessTokensUserSetting{}
if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil {
return nil, err
}
userSetting.Value = &storepb.UserSetting_AccessTokens{
AccessTokens: accessTokensUserSetting,
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_LOCALE {
userSetting.Value = &storepb.UserSetting_Locale{
Locale: storepb.LocaleUserSetting(storepb.LocaleUserSetting_value[valueString]),
}
} else if userSetting.Key == storepb.UserSettingKey_USER_SETTING_COLOR_THEME {
userSetting.Value = &storepb.UserSetting_ColorTheme{
ColorTheme: storepb.ColorThemeUserSetting(storepb.ColorThemeUserSetting_value[valueString]),
}
} else {
return nil, errors.New("invalid user setting key")
}
userSettingList = append(userSettingList, userSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingList, nil
}
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,116 @@
package postgres
import (
"context"
"errors"
"strconv"
"strings"
"google.golang.org/protobuf/encoding/protojson"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store"
)
func (d *DB) UpsertWorkspaceSetting(ctx context.Context, upsert *storepb.WorkspaceSetting) (*storepb.WorkspaceSetting, error) {
stmt := `
INSERT INTO workspace_setting (
key,
value
)
VALUES ($1, $2)
ON CONFLICT(key) DO UPDATE
SET value = EXCLUDED.value
`
var valueString string
if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
valueString = upsert.GetLicenseKey()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
valueString = upsert.GetSecretSession()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
valueString = strconv.FormatBool(upsert.GetEnableSignup())
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
valueString = upsert.GetCustomStyle()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
valueString = upsert.GetCustomScript()
} else if upsert.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
valueBytes, err := protojson.Marshal(upsert.GetAutoBackup())
if err != nil {
return nil, err
}
valueString = string(valueBytes)
} else {
return nil, errors.New("invalid workspace setting key")
}
if _, err := d.db.ExecContext(ctx, stmt, upsert.Key.String(), valueString); err != nil {
return nil, err
}
workspaceSetting := upsert
return workspaceSetting, nil
}
func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspaceSetting) ([]*storepb.WorkspaceSetting, error) {
where, args := []string{"1 = 1"}, []interface{}{}
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String())
}
query := `
SELECT
key,
value
FROM workspace_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := d.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := []*storepb.WorkspaceSetting{}
for rows.Next() {
workspaceSetting := &storepb.WorkspaceSetting{}
var keyString, valueString string
if err := rows.Scan(
&keyString,
&valueString,
); err != nil {
return nil, err
}
workspaceSetting.Key = storepb.WorkspaceSettingKey(storepb.WorkspaceSettingKey_value[keyString])
if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_LICENSE_KEY {
workspaceSetting.Value = &storepb.WorkspaceSetting_LicenseKey{LicenseKey: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_SECRET_SESSION {
workspaceSetting.Value = &storepb.WorkspaceSetting_SecretSession{SecretSession: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSAPCE_SETTING_ENABLE_SIGNUP {
enableSignup, err := strconv.ParseBool(valueString)
if err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_EnableSignup{EnableSignup: enableSignup}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_STYLE {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomStyle{CustomStyle: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_CUSTOM_SCRIPT {
workspaceSetting.Value = &storepb.WorkspaceSetting_CustomScript{CustomScript: valueString}
} else if workspaceSetting.Key == storepb.WorkspaceSettingKey_WORKSPACE_SETTING_AUTO_BACKUP {
autoBackupSetting := &storepb.AutoBackupWorkspaceSetting{}
if err := protojson.Unmarshal([]byte(valueString), autoBackupSetting); err != nil {
return nil, err
}
workspaceSetting.Value = &storepb.WorkspaceSetting_AutoBackup{AutoBackup: autoBackupSetting}
} else {
continue
}
list = append(list, workspaceSetting)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}