chore: fix postgres driver

This commit is contained in:
Steven 2023-12-17 20:07:25 +08:00
parent a7d48e8059
commit 4c66edc170
18 changed files with 104 additions and 200 deletions

View File

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"fmt"
"strings" "strings"
"github.com/yourselfhosted/slash/store" "github.com/yourselfhosted/slash/store"
@ -38,10 +37,10 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if find.Type != "" { if find.Type != "" {
where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String()) where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
} }
if find.Level != "" { if find.Level != "" {
where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String()) where, args = append(where, "level = "+placeholder(len(args)+1)), append(args, find.Level.String())
} }
if find.Where != nil { if find.Where != nil {
where = append(where, find.Where...) where = append(where, find.Where...)

View File

@ -6,23 +6,20 @@ import (
"fmt" "fmt"
"strings" "strings"
"github.com/lib/pq"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/yourselfhosted/slash/internal/util"
storepb "github.com/yourselfhosted/slash/proto/gen/store" storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store" "github.com/yourselfhosted/slash/store"
) )
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()} args := []any{create.CreatorId, create.Name, create.Title, create.Description, pq.Array(create.ShortcutIds), create.Visibility.String()}
placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"}
stmt := ` stmt := `
INSERT INTO collection ( INSERT INTO collection (` + strings.Join(set, ", ") + `)
` + strings.Join(set, ", ") + ` VALUES (` + placeholders(len(args)) + `)
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts RETURNING id, created_ts, updated_ts
` `
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
@ -39,35 +36,34 @@ func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) { func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
set, args := []string{}, []any{} set, args := []string{}, []any{}
if update.Name != nil { if update.Name != nil {
set, args = append(set, "name = $1"), append(args, *update.Name) set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
} }
if update.Title != nil { if update.Title != nil {
set, args = append(set, "title = $2"), append(args, *update.Title) set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
} }
if update.Description != nil { if update.Description != nil {
set, args = append(set, "description = $3"), append(args, *update.Description) set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *update.Description)
} }
if update.ShortcutIDs != nil { if update.ShortcutIDs != nil {
set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]")) set, args = append(set, "shortcut_ids = "+placeholder(len(args)+1)), append(args, pq.Array(update.ShortcutIDs))
} }
if update.Visibility != nil { if update.Visibility != nil {
set, args = append(set, "visibility = $5"), append(args, update.Visibility.String()) set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
} }
if len(set) == 0 { if len(set) == 0 {
return nil, errors.New("no update specified") return nil, errors.New("no update specified")
} }
args = append(args, update.ID)
stmt := ` stmt := `
UPDATE collection UPDATE collection
SET SET ` + strings.Join(set, ", ") + `
` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + `
WHERE
id = $6
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
` `
args = append(args, update.ID)
collection := &storepb.Collection{} collection := &storepb.Collection{}
var shortcutIDs, visibility string var shortcutIDs []sql.NullInt32
var visibility string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
&collection.Id, &collection.Id,
&collection.CreatorId, &collection.CreatorId,
@ -76,20 +72,16 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
&collection.Name, &collection.Name,
&collection.Title, &collection.Title,
&collection.Description, &collection.Description,
&shortcutIDs, pq.Array(&shortcutIDs),
&visibility, &visibility,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
collection.ShortcutIds = []int32{} collection.ShortcutIds = []int32{}
if shortcutIDs != "" { for _, id := range shortcutIDs {
for _, idStr := range strings.Split(shortcutIDs, ",") { if id.Valid {
shortcutID, err := util.ConvertStringToInt32(idStr) collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
} }
} }
collection.Visibility = convertVisibilityStringToStorepb(visibility) collection.Visibility = convertVisibilityStringToStorepb(visibility)
@ -99,19 +91,18 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) { func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatorID; v != nil { if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v) where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Name; v != nil { if v := find.Name; v != nil {
where, args = append(where, "name = $3"), append(args, *v) where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.VisibilityList; len(v) != 0 { if v := find.VisibilityList; len(v) != 0 {
list := []string{} list := []string{}
for i, visibility := range v { for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
args = append(args, visibility)
} }
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
} }
@ -140,7 +131,8 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
list := make([]*storepb.Collection, 0) list := make([]*storepb.Collection, 0)
for rows.Next() { for rows.Next() {
collection := &storepb.Collection{} collection := &storepb.Collection{}
var shortcutIDs, visibility string var shortcutIDs []sql.NullInt32
var visibility string
if err := rows.Scan( if err := rows.Scan(
&collection.Id, &collection.Id,
&collection.CreatorId, &collection.CreatorId,
@ -149,20 +141,16 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
&collection.Name, &collection.Name,
&collection.Title, &collection.Title,
&collection.Description, &collection.Description,
&shortcutIDs, pq.Array(&shortcutIDs),
&visibility, &visibility,
); err != nil { ); err != nil {
return nil, err return nil, err
} }
collection.ShortcutIds = []int32{} collection.ShortcutIds = []int32{}
if shortcutIDs != "" { for _, id := range shortcutIDs {
for _, idStr := range strings.Split(shortcutIDs, ",") { if id.Valid {
shortcutID, err := util.ConvertStringToInt32(idStr) collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
if err != nil {
return nil, errors.Wrap(err, "failed to convert shortcut id")
}
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
} }
} }
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
@ -182,13 +170,3 @@ func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollectio
return nil return nil
} }
func vacuumCollection(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

View File

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
@ -17,9 +16,7 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")} args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
stmt := ` stmt := `
INSERT INTO memo ( INSERT INTO memo (` + strings.Join(set, ", ") + `)
` + strings.Join(set, ", ") + `
)
VALUES (` + placeholders(len(args)) + `) VALUES (` + placeholders(len(args)) + `)
RETURNING id, created_ts, updated_ts, row_status RETURNING id, created_ts, updated_ts, row_status
` `
@ -41,43 +38,34 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) { func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
set, args := []string{}, []any{} set, args := []string{}, []any{}
if update.RowStatus != nil { if update.RowStatus != nil {
set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1)) set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, update.RowStatus.String())
args = append(args, update.RowStatus.String())
} }
if update.Name != nil { if update.Name != nil {
set = append(set, fmt.Sprintf("name = $%d", len(set)+1)) set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
args = append(args, *update.Name)
} }
if update.Title != nil { if update.Title != nil {
set = append(set, fmt.Sprintf("title = $%d", len(set)+1)) set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
args = append(args, *update.Title)
} }
if update.Content != nil { if update.Content != nil {
set = append(set, fmt.Sprintf("content = $%d", len(set)+1)) set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *update.Content)
args = append(args, *update.Content)
} }
if update.Visibility != nil { if update.Visibility != nil {
set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1)) set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
args = append(args, update.Visibility.String())
} }
if update.Tag != nil { if update.Tag != nil {
set = append(set, fmt.Sprintf("tag = $%d", len(set)+1)) set, args = append(set, "tag = "+placeholder(len(args)+1)), append(args, *update.Tag)
args = append(args, *update.Tag)
} }
if len(set) == 0 { if len(set) == 0 {
return nil, errors.New("no update specified") return nil, errors.New("no update specified")
} }
args = append(args, update.ID)
stmt := ` stmt := `
UPDATE memo UPDATE memo
SET SET ` + strings.Join(set, ", ") + `
` + strings.Join(set, ", ") + ` WHERE id = ` + placeholder(len(args)+1) + `
WHERE
id = $` + fmt.Sprint(len(set)+1) + `
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
` `
args = append(args, update.ID)
memo := &storepb.Memo{} memo := &storepb.Memo{}
var rowStatus, visibility, tags string var rowStatus, visibility, tags string
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
@ -103,27 +91,26 @@ func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) { func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, "id = $1"), append(args, *v) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.CreatorID; v != nil { if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = $2"), append(args, *v) where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.RowStatus; v != nil { if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $3"), append(args, *v) where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Name; v != nil { if v := find.Name; v != nil {
where, args = append(where, "name = $4"), append(args, *v) where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.VisibilityList; len(v) != 0 { if v := find.VisibilityList; len(v) != 0 {
list := []string{} list := []string{}
for i, visibility := range v { for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
args = append(args, visibility)
} }
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
} }
if v := find.Tag; v != nil { if v := find.Tag; v != nil {
where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%") where, args = append(where, "tag LIKE "+placeholder(len(args)+1)), append(args, "%"+*v+"%")
} }
rows, err := d.db.QueryContext(ctx, ` rows, err := d.db.QueryContext(ctx, `
@ -185,24 +172,10 @@ func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
return nil return nil
} }
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}
func placeholders(n int) string { func placeholders(n int) string {
placeholder := "" list := []string{}
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
if i == 0 { list = append(list, fmt.Sprintf("$%d", i+1))
placeholder = fmt.Sprintf("$%d", i+1)
} else {
placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1)
} }
} return strings.Join(list, ", ")
return placeholder
} }

View File

@ -1,3 +1,13 @@
-- drop all tables first (PostgreSQL style)
DROP TABLE IF EXISTS migration_history CASCADE;
DROP TABLE IF EXISTS workspace_setting CASCADE;
DROP TABLE IF EXISTS "user" CASCADE;
DROP TABLE IF EXISTS user_setting CASCADE;
DROP TABLE IF EXISTS shortcut CASCADE;
DROP TABLE IF EXISTS activity CASCADE;
DROP TABLE IF EXISTS collection CASCADE;
DROP TABLE IF EXISTS memo CASCADE;
-- migration_history -- migration_history
CREATE TABLE migration_history ( CREATE TABLE migration_history (
version TEXT NOT NULL PRIMARY KEY, version TEXT NOT NULL PRIMARY KEY,
@ -11,7 +21,7 @@ CREATE TABLE workspace_setting (
); );
-- user -- user
CREATE TABLE user ( CREATE TABLE "user" (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
@ -22,11 +32,11 @@ CREATE TABLE user (
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
); );
CREATE INDEX idx_user_email ON user(email); CREATE INDEX idx_user_email ON "user"(email);
-- user_setting -- user_setting
CREATE TABLE user_setting ( CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL, user_id INTEGER REFERENCES "user"(id) NOT NULL,
key TEXT NOT NULL, key TEXT NOT NULL,
value TEXT NOT NULL, value TEXT NOT NULL,
PRIMARY KEY (user_id, key) PRIMARY KEY (user_id, key)
@ -35,7 +45,7 @@ CREATE TABLE user_setting (
-- shortcut -- shortcut
CREATE TABLE shortcut ( CREATE TABLE shortcut (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
@ -53,7 +63,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity -- activity
CREATE TABLE activity ( CREATE TABLE activity (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '', type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
@ -63,7 +73,7 @@ CREATE TABLE activity (
-- collection -- collection
CREATE TABLE collection ( CREATE TABLE collection (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE, name TEXT NOT NULL UNIQUE,
@ -78,7 +88,7 @@ CREATE INDEX idx_collection_name ON collection(name);
-- memo -- memo
CREATE TABLE memo ( CREATE TABLE memo (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',

View File

@ -11,7 +11,7 @@ CREATE TABLE workspace_setting (
); );
-- user -- user
CREATE TABLE user ( CREATE TABLE "user" (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
@ -22,11 +22,11 @@ CREATE TABLE user (
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
); );
CREATE INDEX idx_user_email ON user(email); CREATE INDEX idx_user_email ON "user"(email);
-- user_setting -- user_setting
CREATE TABLE user_setting ( CREATE TABLE user_setting (
user_id INTEGER REFERENCES user(id) NOT NULL, user_id INTEGER REFERENCES "user"(id) NOT NULL,
key TEXT NOT NULL, key TEXT NOT NULL,
value TEXT NOT NULL, value TEXT NOT NULL,
PRIMARY KEY (user_id, key) PRIMARY KEY (user_id, key)
@ -35,7 +35,7 @@ CREATE TABLE user_setting (
-- shortcut -- shortcut
CREATE TABLE shortcut ( CREATE TABLE shortcut (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
@ -53,7 +53,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
-- activity -- activity
CREATE TABLE activity ( CREATE TABLE activity (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
type TEXT NOT NULL DEFAULT '', type TEXT NOT NULL DEFAULT '',
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
@ -63,7 +63,7 @@ CREATE TABLE activity (
-- collection -- collection
CREATE TABLE collection ( CREATE TABLE collection (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
name TEXT NOT NULL UNIQUE, name TEXT NOT NULL UNIQUE,
@ -78,7 +78,7 @@ CREATE INDEX idx_collection_name ON collection(name);
-- memo -- memo
CREATE TABLE memo ( CREATE TABLE memo (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
creator_id INTEGER REFERENCES user(id) NOT NULL, creator_id INTEGER REFERENCES "user"(id) NOT NULL,
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',

View File

@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
} }
// In demo mode, we should seed the database. // In demo mode, we should seed the database.
if d.profile.Mode == "demo" { if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil { if err := d.Seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed") return errors.Wrap(err, "failed to seed")
} }
} }
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
} }
const ( const (
latestSchemaFileName = "LATEST__SCHEMA.sql" latestSchemaFileName = "LATEST.sql"
) )
func (d *DB) applyLatestSchema(ctx context.Context) error { func (d *DB) applyLatestSchema(ctx context.Context) error {
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
return nil return nil
} }
func (d *DB) seed(ctx context.Context) error { func (d *DB) Seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to read seed files") return errors.Wrap(err, "failed to read seed files")

View File

@ -1,9 +0,0 @@
DELETE FROM activity;
DELETE FROM shortcut;
DELETE FROM user_setting;
DELETE FROM user;
DELETE FROM workspace_setting;

View File

@ -2,7 +2,6 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"strings" "strings"
@ -207,12 +206,6 @@ func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) e
return err return err
} }
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
return err
}
func filterTags(tags []string) []string { func filterTags(tags []string) []string {
result := []string{} result := []string{}
for _, tag := range tags { for _, tag := range tags {

View File

@ -41,21 +41,20 @@ func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, e
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
set, args := []string{}, []any{} set, args := []string{}, []any{}
if v := update.RowStatus; v != nil { if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Email; v != nil { if v := update.Email; v != nil {
set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Nickname; v != nil { if v := update.Nickname; v != nil {
set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.PasswordHash; v != nil { if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := update.Role; v != nil { if v := update.Role; v != nil {
set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v) set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
} }
if len(set) == 0 { if len(set) == 0 {
return nil, errors.New("no fields to update") return nil, errors.New("no fields to update")
} }
@ -63,7 +62,7 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
stmt := ` stmt := `
UPDATE "user" UPDATE "user"
SET ` + strings.Join(set, ", ") + ` SET ` + strings.Join(set, ", ") + `
WHERE id = $` + placeholder(len(args)+1) + ` WHERE id = ` + placeholder(len(args)+1) + `
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
` `
args = append(args, update.ID) args = append(args, update.ID)
@ -88,19 +87,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil { if v := find.ID; v != nil {
where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v) where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.RowStatus; v != nil { if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String()) where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, v.String())
} }
if v := find.Email; v != nil { if v := find.Email; v != nil {
where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v) where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Nickname; v != nil { if v := find.Nickname; v != nil {
where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v) where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
} }
if v := find.Role; v != nil { if v := find.Role; v != nil {
where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v) where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
} }
query := ` query := `
@ -149,32 +148,10 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
} }
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
tx, err := d.db.BeginTx(ctx, nil) if _, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID); err != nil {
if err != nil {
return err return err
} }
defer tx.Rollback() return nil
if _, err := tx.ExecContext(ctx, `
DELETE FROM "user" WHERE id = $1
`, delete.ID); err != nil {
return err
}
if err := vacuumUserSetting(ctx, tx); err != nil {
return err
}
if err := vacuumShortcut(ctx, tx); err != nil {
return err
}
if err := vacuumMemo(ctx, tx); err != nil {
return err
}
if err := vacuumCollection(ctx, tx); err != nil {
return err
}
return tx.Commit()
} }
func placeholder(n int) string { func placeholder(n int) string {

View File

@ -2,9 +2,7 @@ package postgres
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt"
"strings" "strings"
"google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/encoding/protojson"
@ -51,10 +49,10 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
where, args := []string{"1 = 1"}, []any{} where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String()) where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
} }
if v := find.UserID; v != nil { if v := find.UserID; v != nil {
where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID) where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
} }
query := ` query := `
@ -110,13 +108,3 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
return userSettingList, nil return userSettingList, nil
} }
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")`
_, err := tx.ExecContext(ctx, stmt)
if err != nil {
return err
}
return nil
}

View File

@ -55,7 +55,7 @@ func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspac
where, args := []string{"1 = 1"}, []interface{}{} where, args := []string{"1 = 1"}, []interface{}{}
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED { if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String()) where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, find.Key.String())
} }
query := ` query := `

View File

@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
} }
// In demo mode, we should seed the database. // In demo mode, we should seed the database.
if d.profile.Mode == "demo" { if d.profile.Mode == "demo" {
if err := d.seed(ctx); err != nil { if err := d.Seed(ctx); err != nil {
return errors.Wrap(err, "failed to seed") return errors.Wrap(err, "failed to seed")
} }
} }
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
} }
const ( const (
latestSchemaFileName = "LATEST__SCHEMA.sql" latestSchemaFileName = "LATEST.sql"
) )
func (d *DB) applyLatestSchema(ctx context.Context) error { func (d *DB) applyLatestSchema(ctx context.Context) error {
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
return nil return nil
} }
func (d *DB) seed(ctx context.Context) error { func (d *DB) Seed(ctx context.Context) error {
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to read seed files") return errors.Wrap(err, "failed to read seed files")

View File

@ -14,6 +14,7 @@ type Driver interface {
Close() error Close() error
Migrate(ctx context.Context) error Migrate(ctx context.Context) error
Seed(ctx context.Context) error
// MigrationHistory model related methods. // MigrationHistory model related methods.
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)

View File

@ -12,11 +12,13 @@ import (
func TestActivityStore(t *testing.T) { func TestActivityStore(t *testing.T) {
ctx := context.Background() ctx := context.Background()
ts := NewTestingStore(ctx, t) ts := NewTestingStore(ctx, t)
user, err := createTestingAdminUser(ctx, ts)
require.NoError(t, err)
list, err := ts.ListActivities(ctx, &store.FindActivity{}) list, err := ts.ListActivities(ctx, &store.FindActivity{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(list)) require.Equal(t, 0, len(list))
activity, err := ts.CreateActivity(ctx, &store.Activity{ activity, err := ts.CreateActivity(ctx, &store.Activity{
CreatorID: -1, CreatorID: user.ID,
Type: store.ActivityShortcutCreate, Type: store.ActivityShortcutCreate,
Level: store.ActivityInfo, Level: store.ActivityInfo,
Payload: "", Payload: "",

View File

@ -22,6 +22,9 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
if err := dbDriver.Migrate(ctx); err != nil { if err := dbDriver.Migrate(ctx); err != nil {
fmt.Printf("failed to migrate db, error: %+v\n", err) fmt.Printf("failed to migrate db, error: %+v\n", err)
} }
if err := dbDriver.Seed(ctx); err != nil {
fmt.Printf("failed to seed db, error: %+v\n", err)
}
store := store.New(dbDriver, profile) store := store.New(dbDriver, profile)
return store return store

View File

@ -7,7 +7,6 @@ import (
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
storepb "github.com/yourselfhosted/slash/proto/gen/store"
"github.com/yourselfhosted/slash/store" "github.com/yourselfhosted/slash/store"
) )
@ -27,13 +26,6 @@ func TestUserStore(t *testing.T) {
Nickname: &userPatchNickname, Nickname: &userPatchNickname,
}) })
require.NoError(t, err) require.NoError(t, err)
_, err = ts.CreateShortcut(ctx, &storepb.Shortcut{
CreatorId: user.ID,
Name: "test_shortcut",
Link: "https://www.google.com",
Visibility: storepb.Visibility_PUBLIC,
})
require.NoError(t, err)
require.Equal(t, userPatchNickname, user.Nickname) require.Equal(t, userPatchNickname, user.Nickname)
err = ts.DeleteUser(ctx, &store.DeleteUser{ err = ts.DeleteUser(ctx, &store.DeleteUser{
ID: user.ID, ID: user.ID,
@ -42,9 +34,6 @@ func TestUserStore(t *testing.T) {
users, err = ts.ListUsers(ctx, &store.FindUser{}) users, err = ts.ListUsers(ctx, &store.FindUser{})
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, 0, len(users)) require.Equal(t, 0, len(users))
shortcuts, err := ts.ListShortcuts(ctx, &store.FindShortcut{})
require.NoError(t, err)
require.Equal(t, 0, len(shortcuts))
} }
// createTestingAdminUser creates a testing admin user. // createTestingAdminUser creates a testing admin user.