From a1b633e4db9aa0639e030a31cd6c2a488c6d583b Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 19 Jul 2023 22:33:30 +0800 Subject: [PATCH] chore: update store statement execution --- store/activity.go | 67 ++++---------- store/db/db.go | 20 +---- store/db/migration_history.go | 44 +--------- store/shortcut.go | 130 +++++++++------------------ store/user.go | 159 +++++++++++++--------------------- store/user_setting.go | 99 ++++++++------------- store/workspace_setting.go | 89 +++++++------------ 7 files changed, 185 insertions(+), 423 deletions(-) diff --git a/store/activity.go b/store/activity.go index 9406ccc..c9af102 100644 --- a/store/activity.go +++ b/store/activity.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -64,13 +63,7 @@ type FindActivity struct { } func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO activity ( creator_id, type, @@ -80,7 +73,7 @@ func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity VALUES (?, ?, ?, ?) RETURNING id, created_ts ` - if err := tx.QueryRowContext(ctx, query, + if err := s.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type.String(), create.Level.String(), @@ -92,50 +85,11 @@ func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - activity := create return activity, nil } func (s *Store) ListActivities(ctx context.Context, find *FindActivity) ([]*Activity, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listActivities(ctx, tx, find) - if err != nil { - return nil, err - } - - return list, nil -} - -func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listActivities(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - activity := list[0] - return activity, nil -} - -func listActivities(ctx context.Context, tx *sql.Tx, find *FindActivity) ([]*Activity, error) { where, args := []string{"1 = 1"}, []any{} if find.Type != "" { where, args = append(where, "type = ?"), append(args, find.Type.String()) @@ -157,11 +111,10 @@ func listActivities(ctx context.Context, tx *sql.Tx, find *FindActivity) ([]*Act payload FROM activity WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer rows.Close() list := []*Activity{} @@ -187,3 +140,17 @@ func listActivities(ctx context.Context, tx *sql.Tx, find *FindActivity) ([]*Act return list, nil } + +func (s *Store) GetActivity(ctx context.Context, find *FindActivity) (*Activity, error) { + list, err := s.ListActivities(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + activity := list[0] + return activity, nil +} diff --git a/store/db/db.go b/store/db/db.go index 9803dec..14d2353 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -177,21 +177,15 @@ func (db *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion st } } - tx, err := db.DBInstance.Begin() - if err != nil { - return err - } - defer tx.Rollback() - // upsert the newest version to migration_history version := minorVersion + ".0" - if _, err = upsertMigrationHistory(ctx, tx, &MigrationHistoryUpsert{ + if _, err = db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ Version: version, }); err != nil { return fmt.Errorf("failed to upsert migration history with version: %s, err: %w", version, err) } - return tx.Commit() + return nil } func (db *DB) seed(ctx context.Context) error { @@ -218,17 +212,11 @@ func (db *DB) seed(ctx context.Context) error { // execute runs a single SQL statement within a transaction. func (db *DB) execute(ctx context.Context, stmt string) error { - tx, err := db.DBInstance.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, stmt); err != nil { + if _, err := db.DBInstance.ExecContext(ctx, stmt); err != nil { return fmt.Errorf("failed to execute statement, err: %w", err) } - return tx.Commit() + return nil } // minorDirRegexp is a regular expression for minor version directory. diff --git a/store/db/migration_history.go b/store/db/migration_history.go index 0c0e4db..3aa83b3 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -2,7 +2,6 @@ package db import ( "context" - "database/sql" "strings" ) @@ -20,47 +19,13 @@ type MigrationHistoryFind struct { } func (db *DB) FindMigrationHistoryList(ctx context.Context, find *MigrationHistoryFind) ([]*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := findMigrationHistoryList(ctx, tx, find) - if err != nil { - return nil, err - } - - return list, nil -} - -func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { - tx, err := db.DBInstance.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - migrationHistory, err := upsertMigrationHistory(ctx, tx, upsert) - if err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { - return nil, err - } - - return migrationHistory, nil -} - -func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHistoryFind) ([]*MigrationHistory, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Version; v != nil { where, args = append(where, "version = ?"), append(args, *v) } - query := ` + stmt := ` SELECT version, created_ts @@ -69,7 +34,7 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi WHERE ` + strings.Join(where, " AND ") + ` ORDER BY created_ts DESC ` - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := db.DBInstance.QueryContext(ctx, stmt, args...) if err != nil { return nil, err } @@ -84,7 +49,6 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi ); err != nil { return nil, err } - migrationHistoryList = append(migrationHistoryList, &migrationHistory) } @@ -95,7 +59,7 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi return migrationHistoryList, nil } -func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { +func (db *DB) UpsertMigrationHistory(ctx context.Context, upsert *MigrationHistoryUpsert) (*MigrationHistory, error) { query := ` INSERT INTO migration_history ( version @@ -107,7 +71,7 @@ func upsertMigrationHistory(ctx context.Context, tx *sql.Tx, upsert *MigrationHi RETURNING version, created_ts ` migrationHistory := &MigrationHistory{} - if err := tx.QueryRowContext(ctx, query, upsert.Version).Scan( + if err := db.DBInstance.QueryRowContext(ctx, query, upsert.Version).Scan( &migrationHistory.Version, &migrationHistory.CreatedTs, ); err != nil { diff --git a/store/shortcut.go b/store/shortcut.go index b8e8914..333adfb 100644 --- a/store/shortcut.go +++ b/store/shortcut.go @@ -73,24 +73,18 @@ type DeleteShortcut struct { } func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set := []string{"creator_id", "name", "link", "description", "visibility", "tag"} args := []any{create.CreatorID, create.Name, create.Link, create.Description, create.Visibility, create.Tag} placeholder := []string{"?", "?", "?", "?", "?", "?"} - query := ` + stmt := ` INSERT INTO shortcut ( ` + strings.Join(set, ", ") + ` ) VALUES (` + strings.Join(placeholder, ",") + `) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &create.ID, &create.CreatedTs, &create.UpdatedTs, @@ -99,20 +93,10 @@ func (s *Store) CreateShortcut(ctx context.Context, create *Shortcut) (*Shortcut return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - return create, nil } func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if update.RowStatus != nil { set, args = append(set, "row_status = ?"), append(args, update.RowStatus.String()) @@ -137,7 +121,7 @@ func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Sh } args = append(args, update.ID) - query := ` + stmt := ` UPDATE shortcut SET ` + strings.Join(set, ", ") + ` @@ -146,7 +130,7 @@ func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Sh RETURNING id, creator_id, created_ts, updated_ts, row_status, name, link, description, visibility, tag ` shortcut := &Shortcut{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &shortcut.ID, &shortcut.CreatorID, &shortcut.CreatedTs, @@ -161,82 +145,12 @@ func (s *Store) UpdateShortcut(ctx context.Context, update *UpdateShortcut) (*Sh return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - s.shortcutCache.Store(shortcut.ID, shortcut) return shortcut, nil } func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listShortcuts(ctx, tx, find) - if err != nil { - return nil, err - } - - for _, shortcut := range list { - s.shortcutCache.Store(shortcut.ID, shortcut) - } - return list, nil -} - -func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { - if find.ID != nil { - if cache, ok := s.shortcutCache.Load(*find.ID); ok { - return cache.(*Shortcut), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - shortcuts, err := listShortcuts(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(shortcuts) == 0 { - return nil, nil - } - - shortcut := shortcuts[0] - s.shortcutCache.Store(shortcut.ID, shortcut) - return shortcut, nil -} - -func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil { - return err - } - - if err := tx.Commit(); err != nil { - // do nothing here to prevent linter warning. - return err - } - - s.shortcutCache.Delete(delete.ID) - return nil -} - -func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) { where, args := []string{"1 = 1"}, []any{} - if v := find.ID; v != nil { where, args = append(where, "id = ?"), append(args, *v) } @@ -261,7 +175,7 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor where, args = append(where, "tag LIKE ?"), append(args, "%"+*v+"%") } - rows, err := tx.QueryContext(ctx, ` + rows, err := s.db.QueryContext(ctx, ` SELECT id, creator_id, @@ -307,9 +221,43 @@ func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shor return nil, err } + for _, shortcut := range list { + s.shortcutCache.Store(shortcut.ID, shortcut) + } return list, nil } +func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) { + if find.ID != nil { + if cache, ok := s.shortcutCache.Load(*find.ID); ok { + return cache.(*Shortcut), nil + } + } + + shortcuts, err := s.ListShortcuts(ctx, find) + if err != nil { + return nil, err + } + + if len(shortcuts) == 0 { + return nil, nil + } + + shortcut := shortcuts[0] + s.shortcutCache.Store(shortcut.ID, shortcut) + return shortcut, nil +} + +func (s *Store) DeleteShortcut(ctx context.Context, delete *DeleteShortcut) error { + if _, err := s.db.ExecContext(ctx, `DELETE FROM shortcut WHERE id = ?`, delete.ID); err != nil { + return err + } + + s.shortcutCache.Delete(delete.ID) + + return nil +} + func vacuumShortcut(ctx context.Context, tx *sql.Tx) error { stmt := ` DELETE FROM diff --git a/store/user.go b/store/user.go index 49b8406..02c0f65 100644 --- a/store/user.go +++ b/store/user.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "fmt" "strings" ) @@ -55,13 +54,7 @@ type DeleteUser struct { } func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user ( email, nickname, @@ -71,7 +64,7 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { VALUES (?, ?, ?, ?) RETURNING id, created_ts, updated_ts, row_status ` - if err := tx.QueryRowContext(ctx, query, + if err := s.db.QueryRowContext(ctx, stmt, create.Email, create.Nickname, create.PasswordHash, @@ -85,22 +78,12 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - user := create s.userCache.Store(user.ID, user) return user, nil } func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - set, args := []string{}, []any{} if v := update.RowStatus; v != nil { set, args = append(set, "row_status = ?"), append(args, *v) @@ -122,7 +105,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro return nil, fmt.Errorf("no fields to update") } - query := ` + stmt := ` UPDATE user SET ` + strings.Join(set, ", ") + ` WHERE id = ? @@ -130,7 +113,7 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro ` args = append(args, update.ID) user := &User{} - if err := tx.QueryRowContext(ctx, query, args...).Scan( + if err := s.db.QueryRowContext(ctx, stmt, args...).Scan( &user.ID, &user.CreatedTs, &user.UpdatedTs, @@ -143,23 +126,68 @@ func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, erro return nil, err } - if err := tx.Commit(); err != nil { - return nil, err - } - s.userCache.Store(user.ID, user) return user, nil } func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) { - tx, err := s.db.BeginTx(ctx, nil) + where, args := []string{"1 = 1"}, []any{} + + if v := find.ID; v != nil { + where, args = append(where, "id = ?"), append(args, *v) + } + if v := find.RowStatus; v != nil { + where, args = append(where, "row_status = ?"), append(args, v.String()) + } + if v := find.Email; v != nil { + where, args = append(where, "email = ?"), append(args, *v) + } + if v := find.Nickname; v != nil { + where, args = append(where, "nickname = ?"), append(args, *v) + } + if v := find.Role; v != nil { + where, args = append(where, "role = ?"), append(args, *v) + } + + query := ` + SELECT + id, + created_ts, + updated_ts, + row_status, + email, + nickname, + password_hash, + role + FROM user + WHERE ` + strings.Join(where, " AND ") + ` + ORDER BY updated_ts DESC, created_ts DESC + ` + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } - defer tx.Rollback() + defer rows.Close() - list, err := listUsers(ctx, tx, find) - if err != nil { + list := make([]*User, 0) + for rows.Next() { + user := &User{} + if err := rows.Scan( + &user.ID, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.Role, + ); err != nil { + return nil, err + } + list = append(list, user) + } + + if err := rows.Err(); err != nil { return nil, err } @@ -177,13 +205,7 @@ func (s *Store) GetUser(ctx context.Context, find *FindUser) (*User, error) { } } - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUsers(ctx, tx, find) + list, err := s.ListUsers(ctx, find) if err != nil { return nil, err } @@ -217,7 +239,6 @@ func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { } if err := tx.Commit(); err != nil { - // do nothing here to prevent linter warning. return err } @@ -225,67 +246,3 @@ func (s *Store) DeleteUser(ctx context.Context, delete *DeleteUser) error { return nil } - -func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) { - where, args := []string{"1 = 1"}, []any{} - - if v := find.ID; v != nil { - where, args = append(where, "id = ?"), append(args, *v) - } - if v := find.RowStatus; v != nil { - where, args = append(where, "row_status = ?"), append(args, v.String()) - } - if v := find.Email; v != nil { - where, args = append(where, "email = ?"), append(args, *v) - } - if v := find.Nickname; v != nil { - where, args = append(where, "nickname = ?"), append(args, *v) - } - if v := find.Role; v != nil { - where, args = append(where, "role = ?"), append(args, *v) - } - - query := ` - SELECT - id, - created_ts, - updated_ts, - row_status, - email, - nickname, - password_hash, - role - FROM user - WHERE ` + strings.Join(where, " AND ") + ` - ORDER BY updated_ts DESC, created_ts DESC - ` - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - list := make([]*User, 0) - for rows.Next() { - user := &User{} - if err := rows.Scan( - &user.ID, - &user.CreatedTs, - &user.UpdatedTs, - &user.RowStatus, - &user.Email, - &user.Nickname, - &user.PasswordHash, - &user.Role, - ); err != nil { - return nil, err - } - list = append(list, user) - } - - if err := rows.Err(); err != nil { - return nil, err - } - - return list, nil -} diff --git a/store/user_setting.go b/store/user_setting.go index 58a6fd9..2324042 100644 --- a/store/user_setting.go +++ b/store/user_setting.go @@ -18,13 +18,7 @@ type FindUserSetting struct { } func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO user_setting ( user_id, key, value ) @@ -32,11 +26,7 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us ON CONFLICT(user_id, key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := tx.ExecContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.UserID, upsert.Key, upsert.Value); err != nil { return nil, err } @@ -46,51 +36,6 @@ func (s *Store) UpsertUserSetting(ctx context.Context, upsert *UserSetting) (*Us } func (s *Store) ListUserSettings(ctx context.Context, find *FindUserSetting) ([]*UserSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - userSettingList, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - for _, userSetting := range userSettingList { - s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) - } - return userSettingList, nil -} - -func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { - if find.UserID != nil && find.Key != "" { - if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { - return cache.(*UserSetting), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listUserSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - userSettingMessage := list[0] - s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserID, userSettingMessage.Key), userSettingMessage) - return userSettingMessage, nil -} - -func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([]*UserSetting, error) { where, args := []string{"1 = 1"}, []any{} if v := find.Key; v != "" { @@ -107,30 +52,54 @@ func listUserSettings(ctx context.Context, tx *sql.Tx, find *FindUserSetting) ([ value FROM user_setting WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } defer rows.Close() - userSettingMessageList := make([]*UserSetting, 0) + userSettingList := make([]*UserSetting, 0) for rows.Next() { - userSettingMessage := &UserSetting{} + userSetting := &UserSetting{} if err := rows.Scan( - &userSettingMessage.UserID, - &userSettingMessage.Key, - &userSettingMessage.Value, + &userSetting.UserID, + &userSetting.Key, + &userSetting.Value, ); err != nil { return nil, err } - userSettingMessageList = append(userSettingMessageList, userSettingMessage) + userSettingList = append(userSettingList, userSetting) } if err := rows.Err(); err != nil { return nil, err } - return userSettingMessageList, nil + for _, userSetting := range userSettingList { + s.userSettingCache.Store(getUserSettingCacheKey(userSetting.UserID, userSetting.Key), userSetting) + } + return userSettingList, nil +} + +func (s *Store) GetUserSetting(ctx context.Context, find *FindUserSetting) (*UserSetting, error) { + if find.UserID != nil && find.Key != "" { + if cache, ok := s.userSettingCache.Load(getUserSettingCacheKey(*find.UserID, find.Key)); ok { + return cache.(*UserSetting), nil + } + } + + list, err := s.ListUserSettings(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + userSettingMessage := list[0] + s.userSettingCache.Store(getUserSettingCacheKey(userSettingMessage.UserID, userSettingMessage.Key), userSettingMessage) + return userSettingMessage, nil } func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { diff --git a/store/workspace_setting.go b/store/workspace_setting.go index a61c08a..dddcec1 100644 --- a/store/workspace_setting.go +++ b/store/workspace_setting.go @@ -2,7 +2,6 @@ package store import ( "context" - "database/sql" "strings" ) @@ -30,13 +29,7 @@ type FindWorkspaceSetting struct { } func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *WorkspaceSetting) (*WorkspaceSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - query := ` + stmt := ` INSERT INTO workspace_setting ( key, value @@ -45,11 +38,7 @@ func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *WorkspaceSet ON CONFLICT(key) DO UPDATE SET value = EXCLUDED.value ` - if _, err := tx.ExecContext(ctx, query, upsert.Key, upsert.Value); err != nil { - return nil, err - } - - if err := tx.Commit(); err != nil { + if _, err := s.db.ExecContext(ctx, stmt, upsert.Key, upsert.Value); err != nil { return nil, err } @@ -59,53 +48,8 @@ func (s *Store) UpsertWorkspaceSetting(ctx context.Context, upsert *WorkspaceSet } func (s *Store) ListWorkspaceSettings(ctx context.Context, find *FindWorkspaceSetting) ([]*WorkspaceSetting, error) { - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listWorkspaceSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - for _, workspaceSetting := range list { - s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) - } - - return list, nil -} - -func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*WorkspaceSetting, error) { - if find.Key != "" { - if cache, ok := s.workspaceSettingCache.Load(find.Key); ok { - return cache.(*WorkspaceSetting), nil - } - } - - tx, err := s.db.BeginTx(ctx, nil) - if err != nil { - return nil, err - } - defer tx.Rollback() - - list, err := listWorkspaceSettings(ctx, tx, find) - if err != nil { - return nil, err - } - - if len(list) == 0 { - return nil, nil - } - - workspaceSetting := list[0] - s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) - return workspaceSetting, nil -} - -func listWorkspaceSettings(ctx context.Context, tx *sql.Tx, find *FindWorkspaceSetting) ([]*WorkspaceSetting, error) { where, args := []string{"1 = 1"}, []any{} + if find.Key != "" { where, args = append(where, "key = ?"), append(args, find.Key) } @@ -116,7 +60,7 @@ func listWorkspaceSettings(ctx context.Context, tx *sql.Tx, find *FindWorkspaceS value FROM workspace_setting WHERE ` + strings.Join(where, " AND ") - rows, err := tx.QueryContext(ctx, query, args...) + rows, err := s.db.QueryContext(ctx, query, args...) if err != nil { return nil, err } @@ -140,5 +84,30 @@ func listWorkspaceSettings(ctx context.Context, tx *sql.Tx, find *FindWorkspaceS return nil, err } + for _, workspaceSetting := range list { + s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) + } + return list, nil } + +func (s *Store) GetWorkspaceSetting(ctx context.Context, find *FindWorkspaceSetting) (*WorkspaceSetting, error) { + if find.Key != "" { + if cache, ok := s.workspaceSettingCache.Load(find.Key); ok { + return cache.(*WorkspaceSetting), nil + } + } + + list, err := s.ListWorkspaceSettings(ctx, find) + if err != nil { + return nil, err + } + + if len(list) == 0 { + return nil, nil + } + + workspaceSetting := list[0] + s.workspaceSettingCache.Store(workspaceSetting.Key, workspaceSetting) + return workspaceSetting, nil +}