From 805ab9996e211a3ba47c6b1e4370a9736c55e944 Mon Sep 17 00:00:00 2001 From: Steven Date: Tue, 20 Jun 2023 15:54:56 +0800 Subject: [PATCH] refactor: update stores --- store/common.go | 21 ++ store/db/db.go | 110 ++++---- store/db/migration/dev/LATEST__SCHEMA.sql | 59 ++-- store/db/migration/prod/LATEST__SCHEMA.sql | 59 ++-- store/db/migration_history.go | 2 +- store/db/seed/10000__reset.sql | 18 +- store/db/seed/10001__user.sql | 30 +- store/db/seed/10002__workspace.sql | 21 +- store/db/seed/10003__workspace_user.sql | 34 +-- store/db/seed/10004__shortcut.sql | 50 +--- store/error.go | 1 - store/shortcut.go | 304 ++++++++++++++++++++- store/user.go | 275 ++++++++++++++++++- store/user_setting.go | 150 ---------- store/workspace.go | 256 ++++++++++++++++- store/workspace_user.go | 173 +++++++++++- 16 files changed, 1160 insertions(+), 403 deletions(-) create mode 100644 store/common.go delete mode 100644 store/error.go diff --git a/store/common.go b/store/common.go new file mode 100644 index 0000000..c6699de --- /dev/null +++ b/store/common.go @@ -0,0 +1,21 @@ +package store + +// RowStatus is the status for a row. +type RowStatus string + +const ( + // Normal is the status for a normal row. + Normal RowStatus = "NORMAL" + // Archived is the status for an archived row. + Archived RowStatus = "ARCHIVED" +) + +func (e RowStatus) String() string { + switch e { + case Normal: + return "NORMAL" + case Archived: + return "ARCHIVED" + } + return "" +} diff --git a/store/db/db.go b/store/db/db.go index acb7a2c..0d3c6ff 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -23,12 +23,12 @@ var migrationFS embed.FS var seedFS embed.FS type DB struct { + profile *profile.Profile // sqlite db connection instance DBInstance *sql.DB - profile *profile.Profile } -// NewDB returns a new instance of DB associated with the given datasource name. +// NewDB returns a new instance of DB. func NewDB(profile *profile.Profile) *DB { db := &DB{ profile: profile, @@ -50,64 +50,70 @@ func (db *DB) Open(ctx context.Context) (err error) { db.DBInstance = sqliteDB if db.profile.Mode == "prod" { - // If db file not exists, we should migrate the database. - if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) { - if err := db.applyLatestSchema(ctx); err != nil { - return fmt.Errorf("failed to apply latest schema: %w", err) - } - } - - currentVersion := version.GetCurrentVersion(db.profile.Mode) - migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{}) + _, err := os.Stat(db.profile.DSN) if err != nil { - return fmt.Errorf("failed to find migration history, err: %w", err) - } - if len(migrationHistoryList) == 0 { - _, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ - Version: currentVersion, - }) + // If db file not exists, we should apply the latest schema. + if errors.Is(err, os.ErrNotExist) { + if err := db.applyLatestSchema(ctx); err != nil { + return fmt.Errorf("failed to apply latest schema: %w", err) + } + } else { + return fmt.Errorf("failed to check database file: %w", err) + } + } else { + // If db file exists, we should check the migration history and apply the migration if needed. + currentVersion := version.GetCurrentVersion(db.profile.Mode) + migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{}) if err != nil { - return fmt.Errorf("failed to upsert migration history, err: %w", err) + return fmt.Errorf("failed to find migration history, err: %w", err) } - return nil - } - - migrationHistoryVersionList := []string{} - for _, migrationHistory := range migrationHistoryList { - migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) - } - sort.Sort(version.SortVersion(migrationHistoryVersionList)) - latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] - - if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { - minorVersionList := getMinorVersionList() - - // backup the raw database file before migration - rawBytes, err := os.ReadFile(db.profile.DSN) - if err != nil { - return fmt.Errorf("failed to read raw database file, err: %w", err) + if len(migrationHistoryList) == 0 { + _, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ + Version: currentVersion, + }) + if err != nil { + return fmt.Errorf("failed to upsert migration history, err: %w", err) + } + return nil } - backupDBFilePath := fmt.Sprintf("%s/shortify_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix()) - if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { - return fmt.Errorf("failed to write raw database file, err: %w", err) - } - println("succeed to copy a backup database file") - println("start migrate") - for _, minorVersion := range minorVersionList { - normalizedVersion := minorVersion + ".0" - if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { - println("applying migration for", normalizedVersion) - if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { - return fmt.Errorf("failed to apply minor version migration: %w", err) + migrationHistoryVersionList := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) + } + sort.Sort(version.SortVersion(migrationHistoryVersionList)) + latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] + + if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { + minorVersionList := getMinorVersionList() + + // backup the raw database file before migration + rawBytes, err := os.ReadFile(db.profile.DSN) + if err != nil { + return fmt.Errorf("failed to read raw database file, err: %w", err) + } + backupDBFilePath := fmt.Sprintf("%s/shortify_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix()) + if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { + return fmt.Errorf("failed to write raw database file, err: %w", err) + } + println("succeed to copy a backup database file") + + println("start migrate") + for _, minorVersion := range minorVersionList { + normalizedVersion := minorVersion + ".0" + if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { + println("applying migration for", normalizedVersion) + if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return fmt.Errorf("failed to apply minor version migration: %w", err) + } } } - } - println("end migrate") + println("end migrate") - // remove the created backup db file after migrate succeed - if err := os.Remove(backupDBFilePath); err != nil { - println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) + // remove the created backup db file after migrate succeed + if err := os.Remove(backupDBFilePath); err != nil { + println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) + } } } } else { diff --git a/store/db/migration/dev/LATEST__SCHEMA.sql b/store/db/migration/dev/LATEST__SCHEMA.sql index 56afd95..da0e1c5 100644 --- a/store/db/migration/dev/LATEST__SCHEMA.sql +++ b/store/db/migration/dev/LATEST__SCHEMA.sql @@ -4,23 +4,47 @@ CREATE TABLE migration_history ( created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) ); +-- workspace +CREATE TABLE workspace ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + resource_id TEXT NOT NULL UNIQUE, + title TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '' +); + +INSERT INTO + sqlite_sequence (name, seq) +VALUES + ('workspace', 1); + +-- workspace_setting +CREATE TABLE workspace_setting ( + workspace_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + UNIQUE(workspace_id, key) +); + -- user CREATE TABLE user ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', - email TEXT NOT NULL UNIQUE, - display_name TEXT NOT NULL, + username TEXT NOT NULL UNIQUE, + nickname TEXT NOT NULL, + email TEXT NOT NULL, password_hash TEXT NOT NULL, - open_id TEXT NOT NULL UNIQUE, role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' ); INSERT INTO sqlite_sequence (name, seq) VALUES - ('user', 100); + ('user', 10); -- user_setting CREATE TABLE user_setting ( @@ -30,31 +54,6 @@ CREATE TABLE user_setting ( UNIQUE(user_id, key) ); --- workspace -CREATE TABLE workspace ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - creator_id INTEGER NOT NULL, - created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), - updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), - row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', - name TEXT NOT NULL UNIQUE, - title TEXT NOT NULL, - description TEXT NOT NULL DEFAULT '' -); - -INSERT INTO - sqlite_sequence (name, seq) -VALUES - ('workspace', 10); - --- workspace_setting -CREATE TABLE workspace_setting ( - workspace_id INTEGER NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - UNIQUE(workspace_id, key) -); - -- workspace_user CREATE TABLE workspace_user ( workspace_id INTEGER NOT NULL, @@ -80,4 +79,4 @@ CREATE TABLE shortcut ( INSERT INTO sqlite_sequence (name, seq) VALUES - ('shortcut', 1000); \ No newline at end of file + ('shortcut', 100); diff --git a/store/db/migration/prod/LATEST__SCHEMA.sql b/store/db/migration/prod/LATEST__SCHEMA.sql index 56afd95..da0e1c5 100644 --- a/store/db/migration/prod/LATEST__SCHEMA.sql +++ b/store/db/migration/prod/LATEST__SCHEMA.sql @@ -4,23 +4,47 @@ CREATE TABLE migration_history ( created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) ); +-- workspace +CREATE TABLE workspace ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), + row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', + resource_id TEXT NOT NULL UNIQUE, + title TEXT NOT NULL, + description TEXT NOT NULL DEFAULT '' +); + +INSERT INTO + sqlite_sequence (name, seq) +VALUES + ('workspace', 1); + +-- workspace_setting +CREATE TABLE workspace_setting ( + workspace_id INTEGER NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + UNIQUE(workspace_id, key) +); + -- user CREATE TABLE user ( id INTEGER PRIMARY KEY AUTOINCREMENT, created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', - email TEXT NOT NULL UNIQUE, - display_name TEXT NOT NULL, + username TEXT NOT NULL UNIQUE, + nickname TEXT NOT NULL, + email TEXT NOT NULL, password_hash TEXT NOT NULL, - open_id TEXT NOT NULL UNIQUE, role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' ); INSERT INTO sqlite_sequence (name, seq) VALUES - ('user', 100); + ('user', 10); -- user_setting CREATE TABLE user_setting ( @@ -30,31 +54,6 @@ CREATE TABLE user_setting ( UNIQUE(user_id, key) ); --- workspace -CREATE TABLE workspace ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - creator_id INTEGER NOT NULL, - created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), - updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), - row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', - name TEXT NOT NULL UNIQUE, - title TEXT NOT NULL, - description TEXT NOT NULL DEFAULT '' -); - -INSERT INTO - sqlite_sequence (name, seq) -VALUES - ('workspace', 10); - --- workspace_setting -CREATE TABLE workspace_setting ( - workspace_id INTEGER NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - UNIQUE(workspace_id, key) -); - -- workspace_user CREATE TABLE workspace_user ( workspace_id INTEGER NOT NULL, @@ -80,4 +79,4 @@ CREATE TABLE shortcut ( INSERT INTO sqlite_sequence (name, seq) VALUES - ('shortcut', 1000); \ No newline at end of file + ('shortcut', 100); diff --git a/store/db/migration_history.go b/store/db/migration_history.go index 472bfb5..cbda344 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -67,7 +67,7 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi FROM migration_history WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY version DESC + ORDER BY created_ts DESC ` rows, err := tx.QueryContext(ctx, query, args...) if err != nil { diff --git a/store/db/seed/10000__reset.sql b/store/db/seed/10000__reset.sql index 66c04fc..f99cf0d 100644 --- a/store/db/seed/10000__reset.sql +++ b/store/db/seed/10000__reset.sql @@ -1,17 +1,11 @@ -DELETE FROM - shortcut; +DELETE FROM shortcut; -DELETE FROM - workspace_user; +DELETE FROM workspace_user; -DELETE FROM - user_setting; +DELETE FROM user_setting; -DELETE FROM - user; +DELETE FROM user; -DELETE FROM - workspace_setting; +DELETE FROM workspace_setting; -DELETE FROM - workspace; \ No newline at end of file +DELETE FROM workspace; diff --git a/store/db/seed/10001__user.sql b/store/db/seed/10001__user.sql index 3c09db7..ec93eb9 100644 --- a/store/db/seed/10001__user.sql +++ b/store/db/seed/10001__user.sql @@ -1,35 +1,35 @@ INSERT INTO user ( `id`, + `username`, + `nickname`, `email`, - `display_name`, - `password_hash`, - `open_id` + `password_hash` ) VALUES ( - 101, - 'frank@shortify.demo', + 11, + 'frank', 'Frank', + 'frank@shortify.demo', -- raw password: secret - '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK', - 'frank_open_id' + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' ); INSERT INTO user ( `id`, + `username`, + `nickname`, `email`, - `display_name`, - `password_hash`, - `open_id` + `password_hash` ) VALUES ( - 102, - 'bob@shortify.demo', + 12, + 'bob', 'Bob', + 'bob@shortify.demo', -- raw password: secret - '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK', - 'bob_open_id' - ); \ No newline at end of file + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ); diff --git a/store/db/seed/10002__workspace.sql b/store/db/seed/10002__workspace.sql index 3583f0f..57147c9 100644 --- a/store/db/seed/10002__workspace.sql +++ b/store/db/seed/10002__workspace.sql @@ -1,33 +1,14 @@ INSERT INTO workspace ( `id`, - `creator_id`, `name`, `title`, `description` ) VALUES ( - 11, - 101, + 1, 'minecraft', 'minecraft', '' ); - -INSERT INTO - workspace ( - `id`, - `creator_id`, - `name`, - `title`, - `description` - ) -VALUES - ( - 12, - 102, - 'bob', - 'bob-room', - '' - ); \ No newline at end of file diff --git a/store/db/seed/10003__workspace_user.sql b/store/db/seed/10003__workspace_user.sql index 728ce23..2ee1a8e 100644 --- a/store/db/seed/10003__workspace_user.sql +++ b/store/db/seed/10003__workspace_user.sql @@ -6,8 +6,8 @@ INSERT INTO ) VALUES ( - 11, - 101, + 1, + 11, 'ADMIN' ); @@ -19,33 +19,7 @@ INSERT INTO ) VALUES ( - 11, - 102, - 'USER' - ); - -INSERT INTO - workspace_user ( - `workspace_id`, - `user_id`, - `role` - ) -VALUES - ( - 12, - 102, - 'ADMIN' - ); - -INSERT INTO - workspace_user ( - `workspace_id`, - `user_id`, - `role` - ) -VALUES - ( - 12, - 101, + 1, + 12, 'USER' ); diff --git a/store/db/seed/10004__shortcut.sql b/store/db/seed/10004__shortcut.sql index e08dfa1..45b422e 100644 --- a/store/db/seed/10004__shortcut.sql +++ b/store/db/seed/10004__shortcut.sql @@ -9,8 +9,8 @@ INSERT INTO ) VALUES ( - 101, - 11, + 11, + 1, 'baidu', 'https://baidu.com', '百度搜索', @@ -28,8 +28,8 @@ INSERT INTO ) VALUES ( - 102, - 11, + 12, + 1, 'bl', 'https://bilibili.com', 'B站', @@ -47,48 +47,10 @@ INSERT INTO ) VALUES ( - 101, - 11, + 11, + 1, 'ph', 'https://producthunt.com', 'PH', 'PRIVATE' ); - -INSERT INTO - shortcut ( - `creator_id`, - `workspace_id`, - `name`, - `link`, - `description`, - `visibility` - ) -VALUES - ( - 101, - 12, - 'github', - 'https://producthunt.com', - 'GitHub', - 'PRIVATE' - ); - -INSERT INTO - shortcut ( - `creator_id`, - `workspace_id`, - `name`, - `link`, - `description`, - `visibility` - ) -VALUES - ( - 102, - 12, - 'go', - 'https://google.com', - 'google', - 'WORKSPACE' - ); diff --git a/store/error.go b/store/error.go deleted file mode 100644 index 72440ea..0000000 --- a/store/error.go +++ /dev/null @@ -1 +0,0 @@ -package store diff --git a/store/shortcut.go b/store/shortcut.go index 7199fad..f8ceb5c 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -7,9 +7,306 @@ import ( "strings" "github.com/boojack/shortify/api" - "github.com/boojack/shortify/common" + "github.com/boojack/shortify/internal/errorutil" ) +// Visibility is the type of a visibility. +type Visibility string + +const ( + // VisibilityPublic is the PUBLIC visibility. + VisibilityPublic Visibility = "PUBLIC" + // VisibilityWorkspace is the WORKSPACE visibility. + VisibilityWorkspace Visibility = "WORKSPACE" + // VisibilityPrivite is the PRIVATE visibility. + VisibilityPrivite Visibility = "PRIVATE" +) + +func (e Visibility) String() string { + switch e { + case VisibilityPublic: + return "PUBLIC" + case VisibilityWorkspace: + return "WORKSPACE" + case VisibilityPrivite: + return "PRIVATE" + } + return "PRIVATE" +} + +type Shortcut struct { + ID int + + // Standard fields + CreatorID int + CreatedTs int64 + UpdatedTs int64 + RowStatus RowStatus + + // Domain specific fields + WorkspaceID int + Name string + Link string + Description string + Visibility Visibility +} + +type UpdateShortcut struct { + ID int + + RowStatus *RowStatus + Name *string + Link *string + Description *string + Visibility *Visibility +} + +type FindShortcut struct { + ID *int + CreatorID *int + RowStatus *RowStatus + WorkspaceID *int + Name *string + VisibilityList []Visibility +} + +type DeleteShortcut struct { + ID int +} + +func (s *Store) CreateShortcutV1(ctx context.Context, create *Shortcut) (*Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + set := []string{"creator_id", "workspace_id", "name", "link", "description", "visibility"} + args := []any{create.CreatorID, create.WorkspaceID, create.Name, create.Link, create.Description, create.Visibility} + placeholder := []string{"?", "?", "?", "?", "?", "?"} + + query := ` + INSERT INTO shortcut ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + RETURNING id, created_ts, updated_ts, row_status + ` + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return create, nil +} + +func (s *Store) UpdateShortcutV1(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + set, args := []string{}, []any{} + + if update.RowStatus != nil { + set, args = append(set, "row_status = ?"), append(args, *update.RowStatus) + } + if update.Name != nil { + set, args = append(set, "name = ?"), append(args, *update.Name) + } + if update.Link != nil { + set, args = append(set, "link = ?"), append(args, *update.Link) + } + if update.Description != nil { + set, args = append(set, "description = ?"), append(args, *update.Description) + } + if update.Visibility != nil { + set, args = append(set, "visibility = ?"), append(args, *update.Visibility) + } + if len(set) == 0 { + return nil, fmt.Errorf("no update specified") + } + args = append(args, update.ID) + + query := ` + UPDATE shortcut + SET + ` + strings.Join(set, ", ") + ` + WHERE + id = ? + RETURNING id, creator_id, created_ts, updated_ts, workspace_id, row_status, name, link, description, visibility + ` + var shortcut Shortcut + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &shortcut.ID, + &shortcut.CreatorID, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &shortcut.WorkspaceID, + &shortcut.RowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Description, + &shortcut.Visibility, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return &shortcut, nil +} + +func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + shortcuts, err := listShortcuts(ctx, tx, find) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return shortcuts, nil +} + +func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + shortcuts, err := listShortcuts(ctx, tx, find) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + if len(shortcuts) == 0 { + return nil, nil + } + return shortcuts[0], nil +} + +func (s *Store) DeleteShortcutV1(ctx context.Context, delete *DeleteShortcut) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM shortcut WHERE id = ? + `, delete.ID); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + // do nothing here to prevent linter warning. + return err + } + + return nil +} + +func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.CreatorID; v != nil { + where, args = append(where, "creator_id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, *v) + } + if v := find.WorkspaceID; v != nil { + where, args = append(where, "workspace_id = ?"), append(args, *v) + } + if v := find.Name; v != nil { + where, args = append(where, "name = ?"), append(args, *v) + } + if v := find.VisibilityList; len(v) != 0 { + list := []string{} + for _, visibility := range v { + list = append(list, fmt.Sprintf("$%d", len(args)+1)) + args = append(args, visibility) + } + where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ","))) + } + + rows, err := tx.QueryContext(ctx, ` + SELECT + id, + creator_id, + created_ts, + updated_ts, + row_status, + workspace_id, + name, + link, + description, + visibility + FROM shortcut + WHERE `+strings.Join(where, " AND ")+` + ORDER BY created_ts DESC`, + args..., + ) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*Shortcut, 0) + for rows.Next() { + var shortcut Shortcut + if err := rows.Scan( + &shortcut.ID, + &shortcut.CreatorID, + &shortcut.CreatedTs, + &shortcut.UpdatedTs, + &shortcut.WorkspaceID, + &shortcut.RowStatus, + &shortcut.Name, + &shortcut.Link, + &shortcut.Description, + &shortcut.Visibility, + ); err != nil { + return nil, err + } + + list = append(list, &shortcut) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + // shortcutRaw is the store model for an Shortcut. // Fields have exactly the same meanings as Shortcut. type shortcutRaw struct { @@ -54,7 +351,6 @@ func (s *Store) ComposeShortcut(ctx context.Context, shortcut *api.Shortcut) err return err } user.OpenID = "" - user.UserSettingList = nil shortcut.Creator = user return nil @@ -142,7 +438,7 @@ func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api. } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found")} } shortcutRaw := list[0] @@ -353,7 +649,7 @@ func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete) rows, _ := result.RowsAffected() if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut ID not found: %d", delete.ID)} + return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found")} } return nil diff --git a/store/user.go b/store/user.go index 1df2a05..bd7a4a0 100644 --- a/store/user.go +++ b/store/user.go @@ -7,9 +7,274 @@ import ( "strings" "github.com/boojack/shortify/api" - "github.com/boojack/shortify/common" + "github.com/boojack/shortify/internal/errorutil" ) +type User struct { + ID int + + // Standard fields + CreatedTs int64 + UpdatedTs int64 + RowStatus RowStatus + + // Domain specific fields + Username string + Nickname string + Email string + PasswordHash string + Role Role +} + +type UpdateUser struct { + ID int + + RowStatus *RowStatus + Username *string + Nickname *string + Email *string + PasswordHash *string + Role *Role +} + +type FindUser struct { + ID *int + RowStatus *RowStatus + Username *string + Nickname *string + Email *string + Role *Role +} + +type DeleteUser struct { + ID int +} + +func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + query := ` + INSERT INTO user ( + username, + nickname, + email, + password_hash, + role + ) + VALUES (?, ?, ?, ?, ?) + RETURNING id, created_ts, updated_ts, row_status + ` + if err := tx.QueryRowContext(ctx, query, + create.Username, + create.Nickname, + create.Email, + create.PasswordHash, + create.Role, + ).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + user := create + return user, nil +} + +func (s *Store) UpdateUserV1(ctx context.Context, update *UpdateUser) (*User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + set, args := []string{}, []any{} + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.Username; v != nil { + set, args = append(set, "username = ?"), append(args, *v) + } + if v := update.Nickname; v != nil { + set, args = append(set, "nickname = ?"), append(args, *v) + } + if v := update.Email; v != nil { + set, args = append(set, "email = ?"), append(args, *v) + } + if v := update.PasswordHash; v != nil { + set, args = append(set, "password_hash = ?"), append(args, *v) + } + if v := update.Role; v != nil { + set, args = append(set, "role = ?"), append(args, *v) + } + + if len(set) == 0 { + return nil, fmt.Errorf("no fields to update") + } + + query := ` + UPDATE user + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, created_ts, updated_ts, row_status, username, nickname, email, password_hash, role + ` + args = append(args, update.ID) + user := &User{} + if err := tx.QueryRowContext(ctx, query, args...).Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Username, + &user.Nickname, + &user.Email, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return user, nil +} + +func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listUsers(ctx, tx, find) + if err != nil { + return nil, err + } + + return list, nil +} + +func (s *Store) GetUserV1(ctx context.Context, find *FindUser) (*User, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listUsers(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + return list[0], nil +} + +func (s *Store) DeleteUserV1(ctx context.Context, delete *DeleteUser) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM user WHERE id = ? + `, delete.ID); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + // do nothing here to prevent linter warning. + return err + } + + return nil +} + +func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, v.String()) + } + if v := find.Username; v != nil { + where, args = append(where, "username = ?"), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = ?"), append(args, *v) + } + if v := find.Email; v != nil { + where, args = append(where, "email = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + + query := ` + SELECT + id, + created_ts, + updated_ts, + row_status, + username, + nickname, + email, + password_hash, + role + FROM user + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY updated_ts DESC, created_ts DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*User, 0) + for rows.Next() { + user := User{} + if err := rows.Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Username, + &user.Nickname, + &user.Email, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + list = append(list, &user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + // userRaw is the store model for an User. // Fields have exactly the same meanings as User. type userRaw struct { @@ -126,9 +391,9 @@ func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, er } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)} } else if len(list) > 1 { - return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d users with filter %+v, expect 1", len(list), find)} + return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d users with filter %+v, expect 1", len(list), find)} } userRaw := list[0] @@ -249,7 +514,7 @@ func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw, return &userRaw, nil } - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)} } func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) { @@ -325,7 +590,7 @@ func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error { rows, _ := result.RowsAffected() if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)} + return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)} } return nil diff --git a/store/user_setting.go b/store/user_setting.go index efbd222..72440ea 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -1,151 +1 @@ package store - -import ( - "context" - "database/sql" - "strings" - - "github.com/boojack/shortify/api" -) - -type userSettingRaw struct { - UserID int - Key api.UserSettingKey - Value string -} - -func (raw *userSettingRaw) toUserSetting() *api.UserSetting { - return &api.UserSetting{ - UserID: raw.UserID, - Key: raw.Key, - Value: raw.Value, - } -} - -func (s *Store) UpsertUserSetting(ctx context.Context, upsert *api.UserSettingUpsert) (*api.UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - userSettingRaw, err := upsertUserSetting(ctx, tx, upsert) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - userSetting := userSettingRaw.toUserSetting() - - return userSetting, nil -} - -func (s *Store) FindUserSettingList(ctx context.Context, find *api.UserSettingFind) ([]*api.UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - userSettingRawList, err := findUserSettingList(ctx, tx, find) - if err != nil { - return nil, err - } - - list := []*api.UserSetting{} - for _, raw := range userSettingRawList { - list = append(list, raw.toUserSetting()) - } - - return list, nil -} - -func (s *Store) FindUserSetting(ctx context.Context, find *api.UserSettingFind) (*api.UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := findUserSettingList(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - userSetting := list[0].toUserSetting() - - return userSetting, nil -} - -func upsertUserSetting(ctx context.Context, tx *sql.Tx, upsert *api.UserSettingUpsert) (*userSettingRaw, error) { - query := ` - INSERT INTO user_setting ( - user_id, key, value - ) - VALUES (?, ?, ?) - ON CONFLICT(user_id, key) DO UPDATE - SET - value = EXCLUDED.value - RETURNING user_id, key, value - ` - var userSettingRaw userSettingRaw - if err := tx.QueryRowContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value).Scan( - &userSettingRaw.UserID, - &userSettingRaw.Key, - &userSettingRaw.Value, - ); err != nil { - return nil, err - } - - return &userSettingRaw, nil -} - -func findUserSettingList(ctx context.Context, tx *sql.Tx, find *api.UserSettingFind) ([]*userSettingRaw, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.Key; v != nil { - where, args = append(where, "key = ?"), append(args, v.String()) - } - - where, args = append(where, "user_id = ?"), append(args, find.UserID) - - query := ` - SELECT - user_id, - key, - value - FROM user_setting - WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - userSettingRawList := make([]*userSettingRaw, 0) - for rows.Next() { - var userSettingRaw userSettingRaw - if err := rows.Scan( - &userSettingRaw.UserID, - &userSettingRaw.Key, - &userSettingRaw.Value, - ); err != nil { - return nil, err - } - - userSettingRawList = append(userSettingRawList, &userSettingRaw) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return userSettingRawList, nil -} diff --git a/store/workspace.go b/store/workspace.go index 7c2e651..f3bd36d 100644 --- a/store/workspace.go +++ b/store/workspace.go @@ -3,13 +3,259 @@ package store import ( "context" "database/sql" + "errors" "fmt" "strings" "github.com/boojack/shortify/api" - "github.com/boojack/shortify/common" + "github.com/boojack/shortify/internal/errorutil" ) +type Workspace struct { + ID int + + // Standard fields + CreatedTs int64 + UpdatedTs int64 + RowStatus RowStatus + + // Domain specific fields + ResourceID string + Title string + Description string +} + +type UpdateWorkspace struct { + ID int + + // Standard fields + RowStatus *RowStatus + + // Domain specific fields + ResourceID *string + Title *string + Description *string +} + +type FindWorkspace struct { + ID *int + RowStatus *RowStatus + ResourceID *string +} + +type DeleteWorkspace struct { + ID int +} + +func (s *Store) CreateWorkspaceV1(ctx context.Context, create *Workspace) (*Workspace, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + query := ` + INSERT INTO workspace ( + resource_id, + title, + description + ) + VALUES (?, ?, ?) + RETURNING id, created_ts, updated_ts, row_status + ` + if err := tx.QueryRowContext(ctx, query, + create.ResourceID, + create.Title, + create.Description, + ).Scan( + &create.ID, + &create.CreatedTs, + &create.UpdatedTs, + &create.RowStatus, + ); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + workspace := create + return workspace, nil +} + +func (s *Store) UpdateWorkspace(ctx context.Context, update *UpdateWorkspace) (*Workspace, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + set, args := []string{}, []any{} + + if v := update.RowStatus; v != nil { + set, args = append(set, "row_status = ?"), append(args, *v) + } + if v := update.ResourceID; v != nil { + set, args = append(set, "resource_id = ?"), append(args, *v) + } + if v := update.Title; v != nil { + set, args = append(set, "title = ?"), append(args, *v) + } + if v := update.Description; v != nil { + set, args = append(set, "description = ?"), append(args, *v) + } + args = append(args, update.ID) + + query := ` + UPDATE workspace + SET ` + strings.Join(set, ", ") + ` + WHERE id = ? + RETURNING id, created_ts, updated_ts, row_status, resource_id, title, description + ` + row, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer row.Close() + + if !row.Next() { + return nil, errors.New(fmt.Sprint("workspace ID not found: ", update.ID)) + } + workspace := &Workspace{} + if err := row.Scan( + &workspace.ID, + &workspace.CreatedTs, + &workspace.UpdatedTs, + &workspace.RowStatus, + &workspace.ResourceID, + &workspace.Title, + &workspace.Description, + ); err != nil { + return nil, err + } + + if err := row.Err(); err != nil { + return nil, err + } + if err := tx.Commit(); err != nil { + return nil, err + } + + return workspace, nil +} + +func (s *Store) ListWorkspaces(ctx context.Context, find *FindWorkspace) ([]*Workspace, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listWorkspaces(ctx, tx, find) + if err != nil { + return nil, err + } + + return list, nil +} + +func (s *Store) GetWorkspace(ctx context.Context, find *FindWorkspace) (*Workspace, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listWorkspaces(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + return list[0], nil +} + +func (s *Store) DeleteWorkspaceV1(ctx context.Context, delete *DeleteWorkspace) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM workspace WHERE id = ? + `, delete.ID); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + // do nothing here to prevent linter warning. + return err + } + + return nil +} + +func listWorkspaces(ctx context.Context, tx *sql.Tx, find *FindWorkspace) ([]*Workspace, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, *v) + } + if v := find.ResourceID; v != nil { + where, args = append(where, "resource_id = ?"), append(args, *v) + } + + query := ` + SELECT + id, + created_ts, + updated_ts, + row_status, + resource_id, + title, + description + FROM workspace + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY created_ts DESC, row_status DESC + ` + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*Workspace, 0) + for rows.Next() { + var workspace Workspace + if err := rows.Scan( + &workspace.ID, + &workspace.CreatedTs, + &workspace.UpdatedTs, + &workspace.RowStatus, + &workspace.ResourceID, + &workspace.Title, + &workspace.Description, + ); err != nil { + return nil, err + } + + list = append(list, &workspace) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + // workspaceRaw is the store model for Workspace. type workspaceRaw struct { ID int @@ -124,9 +370,9 @@ func (s *Store) FindWorkspace(ctx context.Context, find *api.WorkspaceFind) (*ap } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found workspace with filter %+v", find)} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found workspace with filter %+v", find)} } else if len(list) > 1 { - return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d workspaces with filter %+v, expect 1", len(list), find)} + return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d workspaces with filter %+v, expect 1", len(list), find)} } workspaceRaw := list[0] @@ -240,7 +486,7 @@ func patchWorkspace(ctx context.Context, tx *sql.Tx, patch *api.WorkspacePatch) return &workspaceRaw, nil } - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace ID not found: %d", patch.ID)} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace ID not found: %d", patch.ID)} } func findWorkspaceList(ctx context.Context, tx *sql.Tx, find *api.WorkspaceFind) ([]*workspaceRaw, error) { @@ -315,7 +561,7 @@ func deleteWorkspace(ctx context.Context, tx *sql.Tx, delete *api.WorkspaceDelet rows, _ := result.RowsAffected() if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace ID not found: %d", delete.ID)} + return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace ID not found: %d", delete.ID)} } return nil diff --git a/store/workspace_user.go b/store/workspace_user.go index 5f1b4b0..b84435c 100644 --- a/store/workspace_user.go +++ b/store/workspace_user.go @@ -7,9 +7,174 @@ import ( "strings" "github.com/boojack/shortify/api" - "github.com/boojack/shortify/common" + "github.com/boojack/shortify/internal/errorutil" ) +// Role is the type of a role. +type Role string + +const ( + // RoleAdmin is the ADMIN role. + RoleAdmin Role = "ADMIN" + // RoleUser is the USER role. + RoleUser Role = "USER" +) + +type WorkspaceUser struct { + WorkspaceID int + UserID int + Role Role +} + +type FindWorkspaceUser struct { + WorkspaceID *int + UserID *int + Role *Role +} + +type DeleteWorkspaceUser struct { + WorkspaceID int + UserID int +} + +func (s *Store) UpsertWorkspaceUserV1(ctx context.Context, upsert *WorkspaceUser) (*WorkspaceUser, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + set := []string{"workspace_id", "user_id", "role"} + args := []any{upsert.WorkspaceID, upsert.UserID, upsert.Role} + placeholder := []string{"?", "?", "?"} + + query := ` + INSERT INTO workspace_user ( + ` + strings.Join(set, ", ") + ` + ) + VALUES (` + strings.Join(placeholder, ",") + `) + ON CONFLICT(workspace_id, user_id) DO UPDATE + SET + role = EXCLUDED.role + ` + if _, err := tx.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + workspaceUser := upsert + return workspaceUser, nil +} + +func (s *Store) ListWorkspaceUsers(ctx context.Context, find *FindWorkspaceUser) ([]*WorkspaceUser, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listWorkspaceUsers(ctx, tx, find) + if err != nil { + return nil, err + } + + if err := tx.Commit(); err != nil { + return nil, err + } + + return list, nil +} + +func (s *Store) GetWorkspaceUser(ctx context.Context, find *FindWorkspaceUser) (*WorkspaceUser, error) { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return nil, err + } + defer tx.Rollback() + + list, err := listWorkspaceUsers(ctx, tx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + workspaceUser := list[0] + return workspaceUser, nil +} + +func (s *Store) DeleteWorkspaceUserV1(ctx context.Context, delete *DeleteWorkspaceUser) error { + tx, err := s.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if _, err := tx.ExecContext(ctx, ` + DELETE FROM workspace_user WHERE workspace_id = ? AND user_id = ? + `, delete.WorkspaceID, delete.UserID); err != nil { + return err + } + + if err := tx.Commit(); err != nil { + // do nothing here to prevent linter warning. + return err + } + + return nil +} + +func listWorkspaceUsers(ctx context.Context, tx *sql.Tx, find *FindWorkspaceUser) ([]*WorkspaceUser, error) { + where, args := []string{"1 = 1"}, []any{} + + if v := find.WorkspaceID; v != nil { + where, args = append(where, "workspace_id = ?"), append(args, *v) + } + if v := find.UserID; v != nil { + where, args = append(where, "user_id = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + + query := ` + SELECT + workspace_id, + user_id, + role + FROM workspace_user + WHERE ` + strings.Join(where, " AND ") + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*WorkspaceUser, 0) + for rows.Next() { + var workspaceUser WorkspaceUser + if err := rows.Scan( + &workspaceUser.WorkspaceID, + &workspaceUser.UserID, + &workspaceUser.Role, + ); err != nil { + return nil, err + } + + list = append(list, &workspaceUser) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + // workspaceUserRaw is the store model for WorkspaceUser. type workspaceUserRaw struct { WorkspaceID int @@ -111,9 +276,9 @@ func (s *Store) FindWordspaceUser(ctx context.Context, find *api.WorkspaceUserFi } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found workspace user with filter %+v", find)} + return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found workspace user with filter %+v", find)} } else if len(list) > 1 { - return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d workspaces user with filter %+v, expect 1", len(list), find)} + return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d workspaces user with filter %+v, expect 1", len(list), find)} } workspaceUser := list[0].toWorkspaceUser() @@ -221,7 +386,7 @@ func deleteWorkspaceUser(ctx context.Context, tx *sql.Tx, delete *api.WorkspaceU rows, _ := result.RowsAffected() if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace user not found")} + return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace user not found")} } return nil