refactor: update stores

This commit is contained in:
Steven 2023-06-20 15:54:56 +08:00
parent 8fb410ec3b
commit 805ab9996e
16 changed files with 1160 additions and 403 deletions

21
store/common.go Normal file
View File

@ -0,0 +1,21 @@
package store
// RowStatus is the status for a row.
type RowStatus string
const (
// Normal is the status for a normal row.
Normal RowStatus = "NORMAL"
// Archived is the status for an archived row.
Archived RowStatus = "ARCHIVED"
)
func (e RowStatus) String() string {
switch e {
case Normal:
return "NORMAL"
case Archived:
return "ARCHIVED"
}
return ""
}

View File

@ -23,12 +23,12 @@ var migrationFS embed.FS
var seedFS embed.FS var seedFS embed.FS
type DB struct { type DB struct {
profile *profile.Profile
// sqlite db connection instance // sqlite db connection instance
DBInstance *sql.DB DBInstance *sql.DB
profile *profile.Profile
} }
// NewDB returns a new instance of DB associated with the given datasource name. // NewDB returns a new instance of DB.
func NewDB(profile *profile.Profile) *DB { func NewDB(profile *profile.Profile) *DB {
db := &DB{ db := &DB{
profile: profile, profile: profile,
@ -50,64 +50,70 @@ func (db *DB) Open(ctx context.Context) (err error) {
db.DBInstance = sqliteDB db.DBInstance = sqliteDB
if db.profile.Mode == "prod" { if db.profile.Mode == "prod" {
// If db file not exists, we should migrate the database. _, err := os.Stat(db.profile.DSN)
if _, err := os.Stat(db.profile.DSN); errors.Is(err, os.ErrNotExist) {
if err := db.applyLatestSchema(ctx); err != nil {
return fmt.Errorf("failed to apply latest schema: %w", err)
}
}
currentVersion := version.GetCurrentVersion(db.profile.Mode)
migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
if err != nil { if err != nil {
return fmt.Errorf("failed to find migration history, err: %w", err) // If db file not exists, we should apply the latest schema.
} if errors.Is(err, os.ErrNotExist) {
if len(migrationHistoryList) == 0 { if err := db.applyLatestSchema(ctx); err != nil {
_, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{ return fmt.Errorf("failed to apply latest schema: %w", err)
Version: currentVersion, }
}) } else {
return fmt.Errorf("failed to check database file: %w", err)
}
} else {
// If db file exists, we should check the migration history and apply the migration if needed.
currentVersion := version.GetCurrentVersion(db.profile.Mode)
migrationHistoryList, err := db.FindMigrationHistoryList(ctx, &MigrationHistoryFind{})
if err != nil { if err != nil {
return fmt.Errorf("failed to upsert migration history, err: %w", err) return fmt.Errorf("failed to find migration history, err: %w", err)
} }
return nil if len(migrationHistoryList) == 0 {
} _, err := db.UpsertMigrationHistory(ctx, &MigrationHistoryUpsert{
Version: currentVersion,
migrationHistoryVersionList := []string{} })
for _, migrationHistory := range migrationHistoryList { if err != nil {
migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) return fmt.Errorf("failed to upsert migration history, err: %w", err)
} }
sort.Sort(version.SortVersion(migrationHistoryVersionList)) return nil
latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
minorVersionList := getMinorVersionList()
// backup the raw database file before migration
rawBytes, err := os.ReadFile(db.profile.DSN)
if err != nil {
return fmt.Errorf("failed to read raw database file, err: %w", err)
} }
backupDBFilePath := fmt.Sprintf("%s/shortify_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix())
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
return fmt.Errorf("failed to write raw database file, err: %w", err)
}
println("succeed to copy a backup database file")
println("start migrate") migrationHistoryVersionList := []string{}
for _, minorVersion := range minorVersionList { for _, migrationHistory := range migrationHistoryList {
normalizedVersion := minorVersion + ".0" migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version)
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { }
println("applying migration for", normalizedVersion) sort.Sort(version.SortVersion(migrationHistoryVersionList))
if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1]
return fmt.Errorf("failed to apply minor version migration: %w", err)
if version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) {
minorVersionList := getMinorVersionList()
// backup the raw database file before migration
rawBytes, err := os.ReadFile(db.profile.DSN)
if err != nil {
return fmt.Errorf("failed to read raw database file, err: %w", err)
}
backupDBFilePath := fmt.Sprintf("%s/shortify_%s_%d_backup.db", db.profile.Data, db.profile.Version, time.Now().Unix())
if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil {
return fmt.Errorf("failed to write raw database file, err: %w", err)
}
println("succeed to copy a backup database file")
println("start migrate")
for _, minorVersion := range minorVersionList {
normalizedVersion := minorVersion + ".0"
if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) {
println("applying migration for", normalizedVersion)
if err := db.applyMigrationForMinorVersion(ctx, minorVersion); err != nil {
return fmt.Errorf("failed to apply minor version migration: %w", err)
}
} }
} }
} println("end migrate")
println("end migrate")
// remove the created backup db file after migrate succeed // remove the created backup db file after migrate succeed
if err := os.Remove(backupDBFilePath); err != nil { if err := os.Remove(backupDBFilePath); err != nil {
println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) println(fmt.Sprintf("Failed to remove temp database file, err %v", err))
}
} }
} }
} else { } else {

View File

@ -4,23 +4,47 @@ CREATE TABLE migration_history (
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now'))
); );
-- workspace
CREATE TABLE workspace (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
resource_id TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
description TEXT NOT NULL DEFAULT ''
);
INSERT INTO
sqlite_sequence (name, seq)
VALUES
('workspace', 1);
-- workspace_setting
CREATE TABLE workspace_setting (
workspace_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(workspace_id, key)
);
-- user -- user
CREATE TABLE user ( CREATE TABLE user (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
email TEXT NOT NULL UNIQUE, username TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL, nickname TEXT NOT NULL,
email TEXT NOT NULL,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
open_id TEXT NOT NULL UNIQUE,
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
); );
INSERT INTO INSERT INTO
sqlite_sequence (name, seq) sqlite_sequence (name, seq)
VALUES VALUES
('user', 100); ('user', 10);
-- user_setting -- user_setting
CREATE TABLE user_setting ( CREATE TABLE user_setting (
@ -30,31 +54,6 @@ CREATE TABLE user_setting (
UNIQUE(user_id, key) UNIQUE(user_id, key)
); );
-- workspace
CREATE TABLE workspace (
id INTEGER PRIMARY KEY AUTOINCREMENT,
creator_id INTEGER NOT NULL,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
description TEXT NOT NULL DEFAULT ''
);
INSERT INTO
sqlite_sequence (name, seq)
VALUES
('workspace', 10);
-- workspace_setting
CREATE TABLE workspace_setting (
workspace_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(workspace_id, key)
);
-- workspace_user -- workspace_user
CREATE TABLE workspace_user ( CREATE TABLE workspace_user (
workspace_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL,
@ -80,4 +79,4 @@ CREATE TABLE shortcut (
INSERT INTO INSERT INTO
sqlite_sequence (name, seq) sqlite_sequence (name, seq)
VALUES VALUES
('shortcut', 1000); ('shortcut', 100);

View File

@ -4,23 +4,47 @@ CREATE TABLE migration_history (
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now'))
); );
-- workspace
CREATE TABLE workspace (
id INTEGER PRIMARY KEY AUTOINCREMENT,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
resource_id TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
description TEXT NOT NULL DEFAULT ''
);
INSERT INTO
sqlite_sequence (name, seq)
VALUES
('workspace', 1);
-- workspace_setting
CREATE TABLE workspace_setting (
workspace_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(workspace_id, key)
);
-- user -- user
CREATE TABLE user ( CREATE TABLE user (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')), updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL', row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
email TEXT NOT NULL UNIQUE, username TEXT NOT NULL UNIQUE,
display_name TEXT NOT NULL, nickname TEXT NOT NULL,
email TEXT NOT NULL,
password_hash TEXT NOT NULL, password_hash TEXT NOT NULL,
open_id TEXT NOT NULL UNIQUE,
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER' role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
); );
INSERT INTO INSERT INTO
sqlite_sequence (name, seq) sqlite_sequence (name, seq)
VALUES VALUES
('user', 100); ('user', 10);
-- user_setting -- user_setting
CREATE TABLE user_setting ( CREATE TABLE user_setting (
@ -30,31 +54,6 @@ CREATE TABLE user_setting (
UNIQUE(user_id, key) UNIQUE(user_id, key)
); );
-- workspace
CREATE TABLE workspace (
id INTEGER PRIMARY KEY AUTOINCREMENT,
creator_id INTEGER NOT NULL,
created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
updated_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')),
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
name TEXT NOT NULL UNIQUE,
title TEXT NOT NULL,
description TEXT NOT NULL DEFAULT ''
);
INSERT INTO
sqlite_sequence (name, seq)
VALUES
('workspace', 10);
-- workspace_setting
CREATE TABLE workspace_setting (
workspace_id INTEGER NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
UNIQUE(workspace_id, key)
);
-- workspace_user -- workspace_user
CREATE TABLE workspace_user ( CREATE TABLE workspace_user (
workspace_id INTEGER NOT NULL, workspace_id INTEGER NOT NULL,
@ -80,4 +79,4 @@ CREATE TABLE shortcut (
INSERT INTO INSERT INTO
sqlite_sequence (name, seq) sqlite_sequence (name, seq)
VALUES VALUES
('shortcut', 1000); ('shortcut', 100);

View File

@ -67,7 +67,7 @@ func findMigrationHistoryList(ctx context.Context, tx *sql.Tx, find *MigrationHi
FROM FROM
migration_history migration_history
WHERE ` + strings.Join(where, " AND ") + ` WHERE ` + strings.Join(where, " AND ") + `
ORDER BY version DESC ORDER BY created_ts DESC
` `
rows, err := tx.QueryContext(ctx, query, args...) rows, err := tx.QueryContext(ctx, query, args...)
if err != nil { if err != nil {

View File

@ -1,17 +1,11 @@
DELETE FROM DELETE FROM shortcut;
shortcut;
DELETE FROM DELETE FROM workspace_user;
workspace_user;
DELETE FROM DELETE FROM user_setting;
user_setting;
DELETE FROM DELETE FROM user;
user;
DELETE FROM DELETE FROM workspace_setting;
workspace_setting;
DELETE FROM DELETE FROM workspace;
workspace;

View File

@ -1,35 +1,35 @@
INSERT INTO INSERT INTO
user ( user (
`id`, `id`,
`username`,
`nickname`,
`email`, `email`,
`display_name`, `password_hash`
`password_hash`,
`open_id`
) )
VALUES VALUES
( (
101, 11,
'frank@shortify.demo', 'frank',
'Frank', 'Frank',
'frank@shortify.demo',
-- raw password: secret -- raw password: secret
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK', '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
'frank_open_id'
); );
INSERT INTO INSERT INTO
user ( user (
`id`, `id`,
`username`,
`nickname`,
`email`, `email`,
`display_name`, `password_hash`
`password_hash`,
`open_id`
) )
VALUES VALUES
( (
102, 12,
'bob@shortify.demo', 'bob',
'Bob', 'Bob',
'bob@shortify.demo',
-- raw password: secret -- raw password: secret
'$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK', '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK'
'bob_open_id' );
);

View File

@ -1,33 +1,14 @@
INSERT INTO INSERT INTO
workspace ( workspace (
`id`, `id`,
`creator_id`,
`name`, `name`,
`title`, `title`,
`description` `description`
) )
VALUES VALUES
( (
11, 1,
101,
'minecraft', 'minecraft',
'minecraft', 'minecraft',
'' ''
); );
INSERT INTO
workspace (
`id`,
`creator_id`,
`name`,
`title`,
`description`
)
VALUES
(
12,
102,
'bob',
'bob-room',
''
);

View File

@ -6,8 +6,8 @@ INSERT INTO
) )
VALUES VALUES
( (
11, 1,
101, 11,
'ADMIN' 'ADMIN'
); );
@ -19,33 +19,7 @@ INSERT INTO
) )
VALUES VALUES
( (
11, 1,
102, 12,
'USER'
);
INSERT INTO
workspace_user (
`workspace_id`,
`user_id`,
`role`
)
VALUES
(
12,
102,
'ADMIN'
);
INSERT INTO
workspace_user (
`workspace_id`,
`user_id`,
`role`
)
VALUES
(
12,
101,
'USER' 'USER'
); );

View File

@ -9,8 +9,8 @@ INSERT INTO
) )
VALUES VALUES
( (
101, 11,
11, 1,
'baidu', 'baidu',
'https://baidu.com', 'https://baidu.com',
'百度搜索', '百度搜索',
@ -28,8 +28,8 @@ INSERT INTO
) )
VALUES VALUES
( (
102, 12,
11, 1,
'bl', 'bl',
'https://bilibili.com', 'https://bilibili.com',
'B站', 'B站',
@ -47,48 +47,10 @@ INSERT INTO
) )
VALUES VALUES
( (
101, 11,
11, 1,
'ph', 'ph',
'https://producthunt.com', 'https://producthunt.com',
'PH', 'PH',
'PRIVATE' 'PRIVATE'
); );
INSERT INTO
shortcut (
`creator_id`,
`workspace_id`,
`name`,
`link`,
`description`,
`visibility`
)
VALUES
(
101,
12,
'github',
'https://producthunt.com',
'GitHub',
'PRIVATE'
);
INSERT INTO
shortcut (
`creator_id`,
`workspace_id`,
`name`,
`link`,
`description`,
`visibility`
)
VALUES
(
102,
12,
'go',
'https://google.com',
'google',
'WORKSPACE'
);

View File

@ -1 +0,0 @@
package store

View File

@ -7,9 +7,306 @@ import (
"strings" "strings"
"github.com/boojack/shortify/api" "github.com/boojack/shortify/api"
"github.com/boojack/shortify/common" "github.com/boojack/shortify/internal/errorutil"
) )
// Visibility is the type of a visibility.
type Visibility string
const (
// VisibilityPublic is the PUBLIC visibility.
VisibilityPublic Visibility = "PUBLIC"
// VisibilityWorkspace is the WORKSPACE visibility.
VisibilityWorkspace Visibility = "WORKSPACE"
// VisibilityPrivite is the PRIVATE visibility.
VisibilityPrivite Visibility = "PRIVATE"
)
func (e Visibility) String() string {
switch e {
case VisibilityPublic:
return "PUBLIC"
case VisibilityWorkspace:
return "WORKSPACE"
case VisibilityPrivite:
return "PRIVATE"
}
return "PRIVATE"
}
type Shortcut struct {
ID int
// Standard fields
CreatorID int
CreatedTs int64
UpdatedTs int64
RowStatus RowStatus
// Domain specific fields
WorkspaceID int
Name string
Link string
Description string
Visibility Visibility
}
type UpdateShortcut struct {
ID int
RowStatus *RowStatus
Name *string
Link *string
Description *string
Visibility *Visibility
}
type FindShortcut struct {
ID *int
CreatorID *int
RowStatus *RowStatus
WorkspaceID *int
Name *string
VisibilityList []Visibility
}
type DeleteShortcut struct {
ID int
}
func (s *Store) CreateShortcutV1(ctx context.Context, create *Shortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set := []string{"creator_id", "workspace_id", "name", "link", "description", "visibility"}
args := []any{create.CreatorID, create.WorkspaceID, create.Name, create.Link, create.Description, create.Visibility}
placeholder := []string{"?", "?", "?", "?", "?", "?"}
query := `
INSERT INTO shortcut (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return create, nil
}
func (s *Store) UpdateShortcutV1(ctx context.Context, update *UpdateShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if update.RowStatus != nil {
set, args = append(set, "row_status = ?"), append(args, *update.RowStatus)
}
if update.Name != nil {
set, args = append(set, "name = ?"), append(args, *update.Name)
}
if update.Link != nil {
set, args = append(set, "link = ?"), append(args, *update.Link)
}
if update.Description != nil {
set, args = append(set, "description = ?"), append(args, *update.Description)
}
if update.Visibility != nil {
set, args = append(set, "visibility = ?"), append(args, *update.Visibility)
}
if len(set) == 0 {
return nil, fmt.Errorf("no update specified")
}
args = append(args, update.ID)
query := `
UPDATE shortcut
SET
` + strings.Join(set, ", ") + `
WHERE
id = ?
RETURNING id, creator_id, created_ts, updated_ts, workspace_id, row_status, name, link, description, visibility
`
var shortcut Shortcut
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&shortcut.ID,
&shortcut.CreatorID,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&shortcut.WorkspaceID,
&shortcut.RowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Description,
&shortcut.Visibility,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return &shortcut, nil
}
func (s *Store) ListShortcuts(ctx context.Context, find *FindShortcut) ([]*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
shortcuts, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return shortcuts, nil
}
func (s *Store) GetShortcut(ctx context.Context, find *FindShortcut) (*Shortcut, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
shortcuts, err := listShortcuts(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
if len(shortcuts) == 0 {
return nil, nil
}
return shortcuts[0], nil
}
func (s *Store) DeleteShortcutV1(ctx context.Context, delete *DeleteShortcut) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM shortcut WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// do nothing here to prevent linter warning.
return err
}
return nil
}
func listShortcuts(ctx context.Context, tx *sql.Tx, find *FindShortcut) ([]*Shortcut, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.CreatorID; v != nil {
where, args = append(where, "creator_id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.WorkspaceID; v != nil {
where, args = append(where, "workspace_id = ?"), append(args, *v)
}
if v := find.Name; v != nil {
where, args = append(where, "name = ?"), append(args, *v)
}
if v := find.VisibilityList; len(v) != 0 {
list := []string{}
for _, visibility := range v {
list = append(list, fmt.Sprintf("$%d", len(args)+1))
args = append(args, visibility)
}
where = append(where, fmt.Sprintf("visibility in (%s)", strings.Join(list, ",")))
}
rows, err := tx.QueryContext(ctx, `
SELECT
id,
creator_id,
created_ts,
updated_ts,
row_status,
workspace_id,
name,
link,
description,
visibility
FROM shortcut
WHERE `+strings.Join(where, " AND ")+`
ORDER BY created_ts DESC`,
args...,
)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*Shortcut, 0)
for rows.Next() {
var shortcut Shortcut
if err := rows.Scan(
&shortcut.ID,
&shortcut.CreatorID,
&shortcut.CreatedTs,
&shortcut.UpdatedTs,
&shortcut.WorkspaceID,
&shortcut.RowStatus,
&shortcut.Name,
&shortcut.Link,
&shortcut.Description,
&shortcut.Visibility,
); err != nil {
return nil, err
}
list = append(list, &shortcut)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
// shortcutRaw is the store model for an Shortcut. // shortcutRaw is the store model for an Shortcut.
// Fields have exactly the same meanings as Shortcut. // Fields have exactly the same meanings as Shortcut.
type shortcutRaw struct { type shortcutRaw struct {
@ -54,7 +351,6 @@ func (s *Store) ComposeShortcut(ctx context.Context, shortcut *api.Shortcut) err
return err return err
} }
user.OpenID = "" user.OpenID = ""
user.UserSettingList = nil
shortcut.Creator = user shortcut.Creator = user
return nil return nil
@ -142,7 +438,7 @@ func (s *Store) FindShortcut(ctx context.Context, find *api.ShortcutFind) (*api.
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found")}
} }
shortcutRaw := list[0] shortcutRaw := list[0]
@ -353,7 +649,7 @@ func deleteShortcut(ctx context.Context, tx *sql.Tx, delete *api.ShortcutDelete)
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("shortcut ID not found: %d", delete.ID)} return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found")}
} }
return nil return nil

View File

@ -7,9 +7,274 @@ import (
"strings" "strings"
"github.com/boojack/shortify/api" "github.com/boojack/shortify/api"
"github.com/boojack/shortify/common" "github.com/boojack/shortify/internal/errorutil"
) )
type User struct {
ID int
// Standard fields
CreatedTs int64
UpdatedTs int64
RowStatus RowStatus
// Domain specific fields
Username string
Nickname string
Email string
PasswordHash string
Role Role
}
type UpdateUser struct {
ID int
RowStatus *RowStatus
Username *string
Nickname *string
Email *string
PasswordHash *string
Role *Role
}
type FindUser struct {
ID *int
RowStatus *RowStatus
Username *string
Nickname *string
Email *string
Role *Role
}
type DeleteUser struct {
ID int
}
func (s *Store) CreateUserV1(ctx context.Context, create *User) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
INSERT INTO user (
username,
nickname,
email,
password_hash,
role
)
VALUES (?, ?, ?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query,
create.Username,
create.Nickname,
create.Email,
create.PasswordHash,
create.Role,
).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
user := create
return user, nil
}
func (s *Store) UpdateUserV1(ctx context.Context, update *UpdateUser) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.Username; v != nil {
set, args = append(set, "username = ?"), append(args, *v)
}
if v := update.Nickname; v != nil {
set, args = append(set, "nickname = ?"), append(args, *v)
}
if v := update.Email; v != nil {
set, args = append(set, "email = ?"), append(args, *v)
}
if v := update.PasswordHash; v != nil {
set, args = append(set, "password_hash = ?"), append(args, *v)
}
if v := update.Role; v != nil {
set, args = append(set, "role = ?"), append(args, *v)
}
if len(set) == 0 {
return nil, fmt.Errorf("no fields to update")
}
query := `
UPDATE user
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, created_ts, updated_ts, row_status, username, nickname, email, password_hash, role
`
args = append(args, update.ID)
user := &User{}
if err := tx.QueryRowContext(ctx, query, args...).Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Username,
&user.Nickname,
&user.Email,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return user, nil
}
func (s *Store) ListUsers(ctx context.Context, find *FindUser) ([]*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetUserV1(ctx context.Context, find *FindUser) (*User, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (s *Store) DeleteUserV1(ctx context.Context, delete *DeleteUser) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM user WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// do nothing here to prevent linter warning.
return err
}
return nil
}
func listUsers(ctx context.Context, tx *sql.Tx, find *FindUser) ([]*User, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, v.String())
}
if v := find.Username; v != nil {
where, args = append(where, "username = ?"), append(args, *v)
}
if v := find.Nickname; v != nil {
where, args = append(where, "nickname = ?"), append(args, *v)
}
if v := find.Email; v != nil {
where, args = append(where, "email = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
query := `
SELECT
id,
created_ts,
updated_ts,
row_status,
username,
nickname,
email,
password_hash,
role
FROM user
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY updated_ts DESC, created_ts DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*User, 0)
for rows.Next() {
user := User{}
if err := rows.Scan(
&user.ID,
&user.CreatedTs,
&user.UpdatedTs,
&user.RowStatus,
&user.Username,
&user.Nickname,
&user.Email,
&user.PasswordHash,
&user.Role,
); err != nil {
return nil, err
}
list = append(list, &user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
// userRaw is the store model for an User. // userRaw is the store model for an User.
// Fields have exactly the same meanings as User. // Fields have exactly the same meanings as User.
type userRaw struct { type userRaw struct {
@ -126,9 +391,9 @@ func (s *Store) FindUser(ctx context.Context, find *api.UserFind) (*api.User, er
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found user with filter %+v", find)}
} else if len(list) > 1 { } else if len(list) > 1 {
return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d users with filter %+v, expect 1", len(list), find)} return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d users with filter %+v, expect 1", len(list), find)}
} }
userRaw := list[0] userRaw := list[0]
@ -249,7 +514,7 @@ func patchUser(ctx context.Context, tx *sql.Tx, patch *api.UserPatch) (*userRaw,
return &userRaw, nil return &userRaw, nil
} }
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("user ID not found: %d", patch.ID)}
} }
func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) { func findUserList(ctx context.Context, tx *sql.Tx, find *api.UserFind) ([]*userRaw, error) {
@ -325,7 +590,7 @@ func deleteUser(ctx context.Context, tx *sql.Tx, delete *api.UserDelete) error {
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)} return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("user ID not found: %d", delete.ID)}
} }
return nil return nil

View File

@ -1,151 +1 @@
package store package store
import (
"context"
"database/sql"
"strings"
"github.com/boojack/shortify/api"
)
type userSettingRaw struct {
UserID int
Key api.UserSettingKey
Value string
}
func (raw *userSettingRaw) toUserSetting() *api.UserSetting {
return &api.UserSetting{
UserID: raw.UserID,
Key: raw.Key,
Value: raw.Value,
}
}
func (s *Store) UpsertUserSetting(ctx context.Context, upsert *api.UserSettingUpsert) (*api.UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
userSettingRaw, err := upsertUserSetting(ctx, tx, upsert)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
userSetting := userSettingRaw.toUserSetting()
return userSetting, nil
}
func (s *Store) FindUserSettingList(ctx context.Context, find *api.UserSettingFind) ([]*api.UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
userSettingRawList, err := findUserSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
list := []*api.UserSetting{}
for _, raw := range userSettingRawList {
list = append(list, raw.toUserSetting())
}
return list, nil
}
func (s *Store) FindUserSetting(ctx context.Context, find *api.UserSettingFind) (*api.UserSetting, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := findUserSettingList(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
userSetting := list[0].toUserSetting()
return userSetting, nil
}
func upsertUserSetting(ctx context.Context, tx *sql.Tx, upsert *api.UserSettingUpsert) (*userSettingRaw, error) {
query := `
INSERT INTO user_setting (
user_id, key, value
)
VALUES (?, ?, ?)
ON CONFLICT(user_id, key) DO UPDATE
SET
value = EXCLUDED.value
RETURNING user_id, key, value
`
var userSettingRaw userSettingRaw
if err := tx.QueryRowContext(ctx, query, upsert.UserID, upsert.Key, upsert.Value).Scan(
&userSettingRaw.UserID,
&userSettingRaw.Key,
&userSettingRaw.Value,
); err != nil {
return nil, err
}
return &userSettingRaw, nil
}
func findUserSettingList(ctx context.Context, tx *sql.Tx, find *api.UserSettingFind) ([]*userSettingRaw, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.Key; v != nil {
where, args = append(where, "key = ?"), append(args, v.String())
}
where, args = append(where, "user_id = ?"), append(args, find.UserID)
query := `
SELECT
user_id,
key,
value
FROM user_setting
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
userSettingRawList := make([]*userSettingRaw, 0)
for rows.Next() {
var userSettingRaw userSettingRaw
if err := rows.Scan(
&userSettingRaw.UserID,
&userSettingRaw.Key,
&userSettingRaw.Value,
); err != nil {
return nil, err
}
userSettingRawList = append(userSettingRawList, &userSettingRaw)
}
if err := rows.Err(); err != nil {
return nil, err
}
return userSettingRawList, nil
}

View File

@ -3,13 +3,259 @@ package store
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"strings" "strings"
"github.com/boojack/shortify/api" "github.com/boojack/shortify/api"
"github.com/boojack/shortify/common" "github.com/boojack/shortify/internal/errorutil"
) )
type Workspace struct {
ID int
// Standard fields
CreatedTs int64
UpdatedTs int64
RowStatus RowStatus
// Domain specific fields
ResourceID string
Title string
Description string
}
type UpdateWorkspace struct {
ID int
// Standard fields
RowStatus *RowStatus
// Domain specific fields
ResourceID *string
Title *string
Description *string
}
type FindWorkspace struct {
ID *int
RowStatus *RowStatus
ResourceID *string
}
type DeleteWorkspace struct {
ID int
}
func (s *Store) CreateWorkspaceV1(ctx context.Context, create *Workspace) (*Workspace, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
query := `
INSERT INTO workspace (
resource_id,
title,
description
)
VALUES (?, ?, ?)
RETURNING id, created_ts, updated_ts, row_status
`
if err := tx.QueryRowContext(ctx, query,
create.ResourceID,
create.Title,
create.Description,
).Scan(
&create.ID,
&create.CreatedTs,
&create.UpdatedTs,
&create.RowStatus,
); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
workspace := create
return workspace, nil
}
func (s *Store) UpdateWorkspace(ctx context.Context, update *UpdateWorkspace) (*Workspace, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set, args := []string{}, []any{}
if v := update.RowStatus; v != nil {
set, args = append(set, "row_status = ?"), append(args, *v)
}
if v := update.ResourceID; v != nil {
set, args = append(set, "resource_id = ?"), append(args, *v)
}
if v := update.Title; v != nil {
set, args = append(set, "title = ?"), append(args, *v)
}
if v := update.Description; v != nil {
set, args = append(set, "description = ?"), append(args, *v)
}
args = append(args, update.ID)
query := `
UPDATE workspace
SET ` + strings.Join(set, ", ") + `
WHERE id = ?
RETURNING id, created_ts, updated_ts, row_status, resource_id, title, description
`
row, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer row.Close()
if !row.Next() {
return nil, errors.New(fmt.Sprint("workspace ID not found: ", update.ID))
}
workspace := &Workspace{}
if err := row.Scan(
&workspace.ID,
&workspace.CreatedTs,
&workspace.UpdatedTs,
&workspace.RowStatus,
&workspace.ResourceID,
&workspace.Title,
&workspace.Description,
); err != nil {
return nil, err
}
if err := row.Err(); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return workspace, nil
}
func (s *Store) ListWorkspaces(ctx context.Context, find *FindWorkspace) ([]*Workspace, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listWorkspaces(ctx, tx, find)
if err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetWorkspace(ctx context.Context, find *FindWorkspace) (*Workspace, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listWorkspaces(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
return list[0], nil
}
func (s *Store) DeleteWorkspaceV1(ctx context.Context, delete *DeleteWorkspace) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM workspace WHERE id = ?
`, delete.ID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// do nothing here to prevent linter warning.
return err
}
return nil
}
func listWorkspaces(ctx context.Context, tx *sql.Tx, find *FindWorkspace) ([]*Workspace, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.ID; v != nil {
where, args = append(where, "id = ?"), append(args, *v)
}
if v := find.RowStatus; v != nil {
where, args = append(where, "row_status = ?"), append(args, *v)
}
if v := find.ResourceID; v != nil {
where, args = append(where, "resource_id = ?"), append(args, *v)
}
query := `
SELECT
id,
created_ts,
updated_ts,
row_status,
resource_id,
title,
description
FROM workspace
WHERE ` + strings.Join(where, " AND ") + `
ORDER BY created_ts DESC, row_status DESC
`
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*Workspace, 0)
for rows.Next() {
var workspace Workspace
if err := rows.Scan(
&workspace.ID,
&workspace.CreatedTs,
&workspace.UpdatedTs,
&workspace.RowStatus,
&workspace.ResourceID,
&workspace.Title,
&workspace.Description,
); err != nil {
return nil, err
}
list = append(list, &workspace)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
// workspaceRaw is the store model for Workspace. // workspaceRaw is the store model for Workspace.
type workspaceRaw struct { type workspaceRaw struct {
ID int ID int
@ -124,9 +370,9 @@ func (s *Store) FindWorkspace(ctx context.Context, find *api.WorkspaceFind) (*ap
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found workspace with filter %+v", find)} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found workspace with filter %+v", find)}
} else if len(list) > 1 { } else if len(list) > 1 {
return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d workspaces with filter %+v, expect 1", len(list), find)} return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d workspaces with filter %+v, expect 1", len(list), find)}
} }
workspaceRaw := list[0] workspaceRaw := list[0]
@ -240,7 +486,7 @@ func patchWorkspace(ctx context.Context, tx *sql.Tx, patch *api.WorkspacePatch)
return &workspaceRaw, nil return &workspaceRaw, nil
} }
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace ID not found: %d", patch.ID)} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace ID not found: %d", patch.ID)}
} }
func findWorkspaceList(ctx context.Context, tx *sql.Tx, find *api.WorkspaceFind) ([]*workspaceRaw, error) { func findWorkspaceList(ctx context.Context, tx *sql.Tx, find *api.WorkspaceFind) ([]*workspaceRaw, error) {
@ -315,7 +561,7 @@ func deleteWorkspace(ctx context.Context, tx *sql.Tx, delete *api.WorkspaceDelet
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace ID not found: %d", delete.ID)} return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace ID not found: %d", delete.ID)}
} }
return nil return nil

View File

@ -7,9 +7,174 @@ import (
"strings" "strings"
"github.com/boojack/shortify/api" "github.com/boojack/shortify/api"
"github.com/boojack/shortify/common" "github.com/boojack/shortify/internal/errorutil"
) )
// Role is the type of a role.
type Role string
const (
// RoleAdmin is the ADMIN role.
RoleAdmin Role = "ADMIN"
// RoleUser is the USER role.
RoleUser Role = "USER"
)
type WorkspaceUser struct {
WorkspaceID int
UserID int
Role Role
}
type FindWorkspaceUser struct {
WorkspaceID *int
UserID *int
Role *Role
}
type DeleteWorkspaceUser struct {
WorkspaceID int
UserID int
}
func (s *Store) UpsertWorkspaceUserV1(ctx context.Context, upsert *WorkspaceUser) (*WorkspaceUser, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
set := []string{"workspace_id", "user_id", "role"}
args := []any{upsert.WorkspaceID, upsert.UserID, upsert.Role}
placeholder := []string{"?", "?", "?"}
query := `
INSERT INTO workspace_user (
` + strings.Join(set, ", ") + `
)
VALUES (` + strings.Join(placeholder, ",") + `)
ON CONFLICT(workspace_id, user_id) DO UPDATE
SET
role = EXCLUDED.role
`
if _, err := tx.ExecContext(ctx, query, args...); err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
workspaceUser := upsert
return workspaceUser, nil
}
func (s *Store) ListWorkspaceUsers(ctx context.Context, find *FindWorkspaceUser) ([]*WorkspaceUser, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listWorkspaceUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, err
}
return list, nil
}
func (s *Store) GetWorkspaceUser(ctx context.Context, find *FindWorkspaceUser) (*WorkspaceUser, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
defer tx.Rollback()
list, err := listWorkspaceUsers(ctx, tx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
workspaceUser := list[0]
return workspaceUser, nil
}
func (s *Store) DeleteWorkspaceUserV1(ctx context.Context, delete *DeleteWorkspaceUser) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback()
if _, err := tx.ExecContext(ctx, `
DELETE FROM workspace_user WHERE workspace_id = ? AND user_id = ?
`, delete.WorkspaceID, delete.UserID); err != nil {
return err
}
if err := tx.Commit(); err != nil {
// do nothing here to prevent linter warning.
return err
}
return nil
}
func listWorkspaceUsers(ctx context.Context, tx *sql.Tx, find *FindWorkspaceUser) ([]*WorkspaceUser, error) {
where, args := []string{"1 = 1"}, []any{}
if v := find.WorkspaceID; v != nil {
where, args = append(where, "workspace_id = ?"), append(args, *v)
}
if v := find.UserID; v != nil {
where, args = append(where, "user_id = ?"), append(args, *v)
}
if v := find.Role; v != nil {
where, args = append(where, "role = ?"), append(args, *v)
}
query := `
SELECT
workspace_id,
user_id,
role
FROM workspace_user
WHERE ` + strings.Join(where, " AND ")
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
list := make([]*WorkspaceUser, 0)
for rows.Next() {
var workspaceUser WorkspaceUser
if err := rows.Scan(
&workspaceUser.WorkspaceID,
&workspaceUser.UserID,
&workspaceUser.Role,
); err != nil {
return nil, err
}
list = append(list, &workspaceUser)
}
if err := rows.Err(); err != nil {
return nil, err
}
return list, nil
}
// workspaceUserRaw is the store model for WorkspaceUser. // workspaceUserRaw is the store model for WorkspaceUser.
type workspaceUserRaw struct { type workspaceUserRaw struct {
WorkspaceID int WorkspaceID int
@ -111,9 +276,9 @@ func (s *Store) FindWordspaceUser(ctx context.Context, find *api.WorkspaceUserFi
} }
if len(list) == 0 { if len(list) == 0 {
return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found workspace user with filter %+v", find)} return nil, &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("not found workspace user with filter %+v", find)}
} else if len(list) > 1 { } else if len(list) > 1 {
return nil, &common.Error{Code: common.Conflict, Err: fmt.Errorf("found %d workspaces user with filter %+v, expect 1", len(list), find)} return nil, &errorutil.Error{Code: errorutil.Conflict, Err: fmt.Errorf("found %d workspaces user with filter %+v, expect 1", len(list), find)}
} }
workspaceUser := list[0].toWorkspaceUser() workspaceUser := list[0].toWorkspaceUser()
@ -221,7 +386,7 @@ func deleteWorkspaceUser(ctx context.Context, tx *sql.Tx, delete *api.WorkspaceU
rows, _ := result.RowsAffected() rows, _ := result.RowsAffected()
if rows == 0 { if rows == 0 {
return &common.Error{Code: common.NotFound, Err: fmt.Errorf("workspace user not found")} return &errorutil.Error{Code: errorutil.NotFound, Err: fmt.Errorf("workspace user not found")}
} }
return nil return nil