From 4c66edc170506b81ebcfc461e710bd5c8a6cf5a3 Mon Sep 17 00:00:00 2001 From: Steven Date: Sun, 17 Dec 2023 20:07:25 +0800 Subject: [PATCH] chore: fix postgres driver --- store/db/postgres/activity.go | 5 +- store/db/postgres/collection.go | 80 +++++++------------ store/db/postgres/memo.go | 67 +++++----------- .../LATEST__SCHEMA.sql => dev/LATEST.sql} | 24 ++++-- .../LATEST__SCHEMA.sql => prod/LATEST.sql} | 14 ++-- store/db/postgres/migrator.go | 6 +- store/db/postgres/seed/10000__reset.sql | 9 --- store/db/postgres/shortcut.go | 7 -- store/db/postgres/user.go | 49 +++--------- store/db/postgres/user_setting.go | 16 +--- store/db/postgres/workspace_setting.go | 2 +- .../dev/{LATEST__SCHEMA.sql => LATEST.sql} | 0 .../prod/{LATEST__SCHEMA.sql => LATEST.sql} | 0 store/db/sqlite/migrator.go | 6 +- store/driver.go | 1 + test/store/activity_test.go | 4 +- test/store/store.go | 3 + test/store/user_test.go | 11 --- 18 files changed, 104 insertions(+), 200 deletions(-) rename store/db/postgres/migration/{prod/LATEST__SCHEMA.sql => dev/LATEST.sql} (79%) rename store/db/postgres/migration/{dev/LATEST__SCHEMA.sql => prod/LATEST.sql} (89%) rename store/db/sqlite/migration/dev/{LATEST__SCHEMA.sql => LATEST.sql} (100%) rename store/db/sqlite/migration/prod/{LATEST__SCHEMA.sql => LATEST.sql} (100%) diff --git a/store/db/postgres/activity.go b/store/db/postgres/activity.go index a16c2b0..8c735d6 100644 --- a/store/db/postgres/activity.go +++ b/store/db/postgres/activity.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "fmt" "strings" "github.com/yourselfhosted/slash/store" @@ -38,10 +37,10 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { where, args := []string{"1 = 1"}, []any{} if find.Type != "" { - where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String()) + where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String()) } if find.Level != "" { - where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String()) + where, args = append(where, "level = "+placeholder(len(args)+1)), append(args, find.Level.String()) } if find.Where != nil { where = append(where, find.Where...) diff --git a/store/db/postgres/collection.go b/store/db/postgres/collection.go index 665e17e..463689a 100644 --- a/store/db/postgres/collection.go +++ b/store/db/postgres/collection.go @@ -6,23 +6,20 @@ import ( "fmt" "strings" + "github.com/lib/pq" "github.com/pkg/errors" - "github.com/yourselfhosted/slash/internal/util" storepb "github.com/yourselfhosted/slash/proto/gen/store" "github.com/yourselfhosted/slash/store" ) func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) { set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"} - args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()} - placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"} + args := []any{create.CreatorId, create.Name, create.Title, create.Description, pq.Array(create.ShortcutIds), create.Visibility.String()} stmt := ` - INSERT INTO collection ( - ` + strings.Join(set, ", ") + ` - ) - VALUES (` + strings.Join(placeholder, ",") + `) + INSERT INTO collection (` + strings.Join(set, ", ") + `) + VALUES (` + placeholders(len(args)) + `) RETURNING id, created_ts, updated_ts ` if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( @@ -39,35 +36,34 @@ func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) ( func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) { set, args := []string{}, []any{} if update.Name != nil { - set, args = append(set, "name = $1"), append(args, *update.Name) + set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name) } if update.Title != nil { - set, args = append(set, "title = $2"), append(args, *update.Title) + set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title) } if update.Description != nil { - set, args = append(set, "description = $3"), append(args, *update.Description) + set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *update.Description) } if update.ShortcutIDs != nil { - set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]")) + set, args = append(set, "shortcut_ids = "+placeholder(len(args)+1)), append(args, pq.Array(update.ShortcutIDs)) } if update.Visibility != nil { - set, args = append(set, "visibility = $5"), append(args, update.Visibility.String()) + set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String()) } if len(set) == 0 { return nil, errors.New("no update specified") } - args = append(args, update.ID) stmt := ` UPDATE collection - SET - ` + strings.Join(set, ", ") + ` - WHERE - id = $6 + SET ` + strings.Join(set, ", ") + ` + WHERE id = ` + placeholder(len(args)+1) + ` RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility ` + args = append(args, update.ID) collection := &storepb.Collection{} - var shortcutIDs, visibility string + var shortcutIDs []sql.NullInt32 + var visibility string if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( &collection.Id, &collection.CreatorId, @@ -76,20 +72,16 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio &collection.Name, &collection.Title, &collection.Description, - &shortcutIDs, + pq.Array(&shortcutIDs), &visibility, ); err != nil { return nil, err } collection.ShortcutIds = []int32{} - if shortcutIDs != "" { - for _, idStr := range strings.Split(shortcutIDs, ",") { - shortcutID, err := util.ConvertStringToInt32(idStr) - if err != nil { - return nil, errors.Wrap(err, "failed to convert shortcut id") - } - collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + for _, id := range shortcutIDs { + if id.Valid { + collection.ShortcutIds = append(collection.ShortcutIds, id.Int32) } } collection.Visibility = convertVisibilityStringToStorepb(visibility) @@ -99,19 +91,18 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "id = $1"), append(args, *v) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = $2"), append(args, *v) + where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Name; v != nil { - where, args = append(where, "name = $3"), append(args, *v) + where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v) } if v := find.VisibilityList; len(v) != 0 { list := []string{} - for i, visibility := range v { - list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) - args = append(args, visibility) + for _, visibility := range v { + list, args = append(list, placeholder(len(args)+1)), append(args, visibility) } where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) } @@ -140,7 +131,8 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([ list := make([]*storepb.Collection, 0) for rows.Next() { collection := &storepb.Collection{} - var shortcutIDs, visibility string + var shortcutIDs []sql.NullInt32 + var visibility string if err := rows.Scan( &collection.Id, &collection.CreatorId, @@ -149,20 +141,16 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([ &collection.Name, &collection.Title, &collection.Description, - &shortcutIDs, + pq.Array(&shortcutIDs), &visibility, ); err != nil { return nil, err } collection.ShortcutIds = []int32{} - if shortcutIDs != "" { - for _, idStr := range strings.Split(shortcutIDs, ",") { - shortcutID, err := util.ConvertStringToInt32(idStr) - if err != nil { - return nil, errors.Wrap(err, "failed to convert shortcut id") - } - collection.ShortcutIds = append(collection.ShortcutIds, shortcutID) + for _, id := range shortcutIDs { + if id.Valid { + collection.ShortcutIds = append(collection.ShortcutIds, id.Int32) } } collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility]) @@ -182,13 +170,3 @@ func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollectio return nil } - -func vacuumCollection(ctx context.Context, tx *sql.Tx) error { - stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go index aa0cf37..7bce42a 100644 --- a/store/db/postgres/memo.go +++ b/store/db/postgres/memo.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "database/sql" "fmt" "strings" @@ -17,9 +16,7 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")} stmt := ` - INSERT INTO memo ( - ` + strings.Join(set, ", ") + ` - ) + INSERT INTO memo (` + strings.Join(set, ", ") + `) VALUES (` + placeholders(len(args)) + `) RETURNING id, created_ts, updated_ts, row_status ` @@ -41,43 +38,34 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) { set, args := []string{}, []any{} if update.RowStatus != nil { - set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1)) - args = append(args, update.RowStatus.String()) + set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, update.RowStatus.String()) } if update.Name != nil { - set = append(set, fmt.Sprintf("name = $%d", len(set)+1)) - args = append(args, *update.Name) + set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name) } if update.Title != nil { - set = append(set, fmt.Sprintf("title = $%d", len(set)+1)) - args = append(args, *update.Title) + set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title) } if update.Content != nil { - set = append(set, fmt.Sprintf("content = $%d", len(set)+1)) - args = append(args, *update.Content) + set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *update.Content) } if update.Visibility != nil { - set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1)) - args = append(args, update.Visibility.String()) + set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String()) } if update.Tag != nil { - set = append(set, fmt.Sprintf("tag = $%d", len(set)+1)) - args = append(args, *update.Tag) + set, args = append(set, "tag = "+placeholder(len(args)+1)), append(args, *update.Tag) } if len(set) == 0 { return nil, errors.New("no update specified") } - args = append(args, update.ID) stmt := ` UPDATE memo - SET - ` + strings.Join(set, ", ") + ` - WHERE - id = $` + fmt.Sprint(len(set)+1) + ` + SET ` + strings.Join(set, ", ") + ` + WHERE id = ` + placeholder(len(args)+1) + ` RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag ` - + args = append(args, update.ID) memo := &storepb.Memo{} var rowStatus, visibility, tags string if err := d.db.QueryRowContext(ctx, stmt, args...).Scan( @@ -103,27 +91,26 @@ func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) { where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "id = $1"), append(args, *v) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.CreatorID; v != nil { - where, args = append(where, "creator_id = $2"), append(args, *v) + where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = $3"), append(args, *v) + where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Name; v != nil { - where, args = append(where, "name = $4"), append(args, *v) + where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v) } if v := find.VisibilityList; len(v) != 0 { list := []string{} - for i, visibility := range v { - list = append(list, fmt.Sprintf("$%d", len(args)+i+1)) - args = append(args, visibility) + for _, visibility := range v { + list, args = append(list, placeholder(len(args)+1)), append(args, visibility) } where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ","))) } if v := find.Tag; v != nil { - where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%") + where, args = append(where, "tag LIKE "+placeholder(len(args)+1)), append(args, "%"+*v+"%") } rows, err := d.db.QueryContext(ctx, ` @@ -185,24 +172,10 @@ func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { return nil } -func vacuumMemo(ctx context.Context, tx *sql.Tx) error { - stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} - func placeholders(n int) string { - placeholder := "" + list := []string{} for i := 0; i < n; i++ { - if i == 0 { - placeholder = fmt.Sprintf("$%d", i+1) - } else { - placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1) - } + list = append(list, fmt.Sprintf("$%d", i+1)) } - return placeholder + return strings.Join(list, ", ") } diff --git a/store/db/postgres/migration/prod/LATEST__SCHEMA.sql b/store/db/postgres/migration/dev/LATEST.sql similarity index 79% rename from store/db/postgres/migration/prod/LATEST__SCHEMA.sql rename to store/db/postgres/migration/dev/LATEST.sql index b8a257a..a4d4cec 100644 --- a/store/db/postgres/migration/prod/LATEST__SCHEMA.sql +++ b/store/db/postgres/migration/dev/LATEST.sql @@ -1,3 +1,13 @@ +-- drop all tables first (PostgreSQL style) +DROP TABLE IF EXISTS migration_history CASCADE; +DROP TABLE IF EXISTS workspace_setting CASCADE; +DROP TABLE IF EXISTS "user" CASCADE; +DROP TABLE IF EXISTS user_setting CASCADE; +DROP TABLE IF EXISTS shortcut CASCADE; +DROP TABLE IF EXISTS activity CASCADE; +DROP TABLE IF EXISTS collection CASCADE; +DROP TABLE IF EXISTS memo CASCADE; + -- migration_history CREATE TABLE migration_history ( version TEXT NOT NULL PRIMARY KEY, @@ -11,7 +21,7 @@ CREATE TABLE workspace_setting ( ); -- user -CREATE TABLE user ( +CREATE TABLE "user" ( id SERIAL PRIMARY KEY, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), @@ -22,11 +32,11 @@ CREATE TABLE user ( role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' ); -CREATE INDEX idx_user_email ON user(email); +CREATE INDEX idx_user_email ON "user"(email); -- user_setting CREATE TABLE user_setting ( - user_id INTEGER REFERENCES user(id) NOT NULL, + user_id INTEGER REFERENCES "user"(id) NOT NULL, key TEXT NOT NULL, value TEXT NOT NULL, PRIMARY KEY (user_id, key) @@ -35,7 +45,7 @@ CREATE TABLE user_setting ( -- shortcut CREATE TABLE shortcut ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', @@ -53,7 +63,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name); -- activity CREATE TABLE activity ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), type TEXT NOT NULL DEFAULT '', level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', @@ -63,7 +73,7 @@ CREATE TABLE activity ( -- collection CREATE TABLE collection ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), name TEXT NOT NULL UNIQUE, @@ -78,7 +88,7 @@ CREATE INDEX idx_collection_name ON collection(name); -- memo CREATE TABLE memo ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', diff --git a/store/db/postgres/migration/dev/LATEST__SCHEMA.sql b/store/db/postgres/migration/prod/LATEST.sql similarity index 89% rename from store/db/postgres/migration/dev/LATEST__SCHEMA.sql rename to store/db/postgres/migration/prod/LATEST.sql index b8a257a..84977bc 100644 --- a/store/db/postgres/migration/dev/LATEST__SCHEMA.sql +++ b/store/db/postgres/migration/prod/LATEST.sql @@ -11,7 +11,7 @@ CREATE TABLE workspace_setting ( ); -- user -CREATE TABLE user ( +CREATE TABLE "user" ( id SERIAL PRIMARY KEY, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), @@ -22,11 +22,11 @@ CREATE TABLE user ( role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' ); -CREATE INDEX idx_user_email ON user(email); +CREATE INDEX idx_user_email ON "user"(email); -- user_setting CREATE TABLE user_setting ( - user_id INTEGER REFERENCES user(id) NOT NULL, + user_id INTEGER REFERENCES "user"(id) NOT NULL, key TEXT NOT NULL, value TEXT NOT NULL, PRIMARY KEY (user_id, key) @@ -35,7 +35,7 @@ CREATE TABLE user_setting ( -- shortcut CREATE TABLE shortcut ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', @@ -53,7 +53,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name); -- activity CREATE TABLE activity ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), type TEXT NOT NULL DEFAULT '', level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO', @@ -63,7 +63,7 @@ CREATE TABLE activity ( -- collection CREATE TABLE collection ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), name TEXT NOT NULL UNIQUE, @@ -78,7 +78,7 @@ CREATE INDEX idx_collection_name ON collection(name); -- memo CREATE TABLE memo ( id SERIAL PRIMARY KEY, - creator_id INTEGER REFERENCES user(id) NOT NULL, + creator_id INTEGER REFERENCES "user"(id) NOT NULL, created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', diff --git a/store/db/postgres/migrator.go b/store/db/postgres/migrator.go index 724e1b1..d1399cb 100644 --- a/store/db/postgres/migrator.go +++ b/store/db/postgres/migrator.go @@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error { } // In demo mode, we should seed the database. if d.profile.Mode == "demo" { - if err := d.seed(ctx); err != nil { + if err := d.Seed(ctx); err != nil { return errors.Wrap(err, "failed to seed") } } @@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error { } const ( - latestSchemaFileName = "LATEST__SCHEMA.sql" + latestSchemaFileName = "LATEST.sql" ) func (d *DB) applyLatestSchema(ctx context.Context) error { @@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str return nil } -func (d *DB) seed(ctx context.Context) error { +func (d *DB) Seed(ctx context.Context) error { filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) if err != nil { return errors.Wrap(err, "failed to read seed files") diff --git a/store/db/postgres/seed/10000__reset.sql b/store/db/postgres/seed/10000__reset.sql index 79be1d8..e69de29 100644 --- a/store/db/postgres/seed/10000__reset.sql +++ b/store/db/postgres/seed/10000__reset.sql @@ -1,9 +0,0 @@ -DELETE FROM activity; - -DELETE FROM shortcut; - -DELETE FROM user_setting; - -DELETE FROM user; - -DELETE FROM workspace_setting; diff --git a/store/db/postgres/shortcut.go b/store/db/postgres/shortcut.go index 821898d..a23cc2a 100644 --- a/store/db/postgres/shortcut.go +++ b/store/db/postgres/shortcut.go @@ -2,7 +2,6 @@ package postgres import ( "context" - "database/sql" "fmt" "strings" @@ -207,12 +206,6 @@ func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) e return err } -func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { - stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")` - _, err := tx.ExecContext(ctx, stmt) - return err -} - func filterTags(tags []string) []string { result := []string{} for _, tag := range tags { diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go index 641e687..b595fbe 100644 --- a/store/db/postgres/user.go +++ b/store/db/postgres/user.go @@ -41,21 +41,20 @@ func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, e func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { set, args := []string{}, []any{} if v := update.RowStatus; v != nil { - set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v) + set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Email; v != nil { - set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v) + set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Nickname; v != nil { - set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v) + set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v) } if v := update.PasswordHash; v != nil { - set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v) + set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v) } if v := update.Role; v != nil { - set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v) + set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v) } - if len(set) == 0 { return nil, errors.New("no fields to update") } @@ -63,7 +62,7 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U stmt := ` UPDATE "user" SET ` + strings.Join(set, ", ") + ` - WHERE id = $` + placeholder(len(args)+1) + ` + WHERE id = ` + placeholder(len(args)+1) + ` RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role ` args = append(args, update.ID) @@ -88,19 +87,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User where, args := []string{"1 = 1"}, []any{} if v := find.ID; v != nil { - where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v) } if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String()) + where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, v.String()) } if v := find.Email; v != nil { - where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Nickname; v != nil { - where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v) } if v := find.Role; v != nil { - where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v) + where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v) } query := ` @@ -149,32 +148,10 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User } func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { - tx, err := d.db.BeginTx(ctx, nil) - if err != nil { + if _, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID); err != nil { return err } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, ` - DELETE FROM "user" WHERE id = $1 - `, delete.ID); err != nil { - return err - } - - if err := vacuumUserSetting(ctx, tx); err != nil { - return err - } - if err := vacuumShortcut(ctx, tx); err != nil { - return err - } - if err := vacuumMemo(ctx, tx); err != nil { - return err - } - if err := vacuumCollection(ctx, tx); err != nil { - return err - } - - return tx.Commit() + return nil } func placeholder(n int) string { diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go index a797da1..9cd3188 100644 --- a/store/db/postgres/user_setting.go +++ b/store/db/postgres/user_setting.go @@ -2,9 +2,7 @@ package postgres import ( "context" - "database/sql" "errors" - "fmt" "strings" "google.golang.org/protobuf/encoding/protojson" @@ -51,10 +49,10 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { - where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String()) + where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String()) } if v := find.UserID; v != nil { - where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID) + where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID) } query := ` @@ -110,13 +108,3 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) return userSettingList, nil } - -func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { - stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")` - _, err := tx.ExecContext(ctx, stmt) - if err != nil { - return err - } - - return nil -} diff --git a/store/db/postgres/workspace_setting.go b/store/db/postgres/workspace_setting.go index e651442..f96536d 100644 --- a/store/db/postgres/workspace_setting.go +++ b/store/db/postgres/workspace_setting.go @@ -55,7 +55,7 @@ func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspac where, args := []string{"1 = 1"}, []interface{}{} if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED { - where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String()) + where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, find.Key.String()) } query := ` diff --git a/store/db/sqlite/migration/dev/LATEST__SCHEMA.sql b/store/db/sqlite/migration/dev/LATEST.sql similarity index 100% rename from store/db/sqlite/migration/dev/LATEST__SCHEMA.sql rename to store/db/sqlite/migration/dev/LATEST.sql diff --git a/store/db/sqlite/migration/prod/LATEST__SCHEMA.sql b/store/db/sqlite/migration/prod/LATEST.sql similarity index 100% rename from store/db/sqlite/migration/prod/LATEST__SCHEMA.sql rename to store/db/sqlite/migration/prod/LATEST.sql diff --git a/store/db/sqlite/migrator.go b/store/db/sqlite/migrator.go index 41d1d88..779a150 100644 --- a/store/db/sqlite/migrator.go +++ b/store/db/sqlite/migrator.go @@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error { } // In demo mode, we should seed the database. if d.profile.Mode == "demo" { - if err := d.seed(ctx); err != nil { + if err := d.Seed(ctx); err != nil { return errors.Wrap(err, "failed to seed") } } @@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error { } const ( - latestSchemaFileName = "LATEST__SCHEMA.sql" + latestSchemaFileName = "LATEST.sql" ) func (d *DB) applyLatestSchema(ctx context.Context) error { @@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str return nil } -func (d *DB) seed(ctx context.Context) error { +func (d *DB) Seed(ctx context.Context) error { filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) if err != nil { return errors.Wrap(err, "failed to read seed files") diff --git a/store/driver.go b/store/driver.go index b29656e..5099aa9 100644 --- a/store/driver.go +++ b/store/driver.go @@ -14,6 +14,7 @@ type Driver interface { Close() error Migrate(ctx context.Context) error + Seed(ctx context.Context) error // MigrationHistory model related methods. UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) diff --git a/test/store/activity_test.go b/test/store/activity_test.go index 12d88fe..d259dc4 100644 --- a/test/store/activity_test.go +++ b/test/store/activity_test.go @@ -12,11 +12,13 @@ import ( func TestActivityStore(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) + user, err := createTestingAdminUser(ctx, ts) + require.NoError(t, err) list, err := ts.ListActivities(ctx, &store.FindActivity{}) require.NoError(t, err) require.Equal(t, 0, len(list)) activity, err := ts.CreateActivity(ctx, &store.Activity{ - CreatorID: -1, + CreatorID: user.ID, Type: store.ActivityShortcutCreate, Level: store.ActivityInfo, Payload: "", diff --git a/test/store/store.go b/test/store/store.go index 4ba6b1c..13b7a97 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -22,6 +22,9 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { if err := dbDriver.Migrate(ctx); err != nil { fmt.Printf("failed to migrate db, error: %+v\n", err) } + if err := dbDriver.Seed(ctx); err != nil { + fmt.Printf("failed to seed db, error: %+v\n", err) + } store := store.New(dbDriver, profile) return store diff --git a/test/store/user_test.go b/test/store/user_test.go index 3c835c7..1164f1e 100644 --- a/test/store/user_test.go +++ b/test/store/user_test.go @@ -7,7 +7,6 @@ import ( "github.com/stretchr/testify/require" "golang.org/x/crypto/bcrypt" - storepb "github.com/yourselfhosted/slash/proto/gen/store" "github.com/yourselfhosted/slash/store" ) @@ -27,13 +26,6 @@ func TestUserStore(t *testing.T) { Nickname: &userPatchNickname, }) require.NoError(t, err) - _, err = ts.CreateShortcut(ctx, &storepb.Shortcut{ - CreatorId: user.ID, - Name: "test_shortcut", - Link: "https://www.google.com", - Visibility: storepb.Visibility_PUBLIC, - }) - require.NoError(t, err) require.Equal(t, userPatchNickname, user.Nickname) err = ts.DeleteUser(ctx, &store.DeleteUser{ ID: user.ID, @@ -42,9 +34,6 @@ func TestUserStore(t *testing.T) { users, err = ts.ListUsers(ctx, &store.FindUser{}) require.NoError(t, err) require.Equal(t, 0, len(users)) - shortcuts, err := ts.ListShortcuts(ctx, &store.FindShortcut{}) - require.NoError(t, err) - require.Equal(t, 0, len(shortcuts)) } // createTestingAdminUser creates a testing admin user.