mirror of
https://github.com/aykhans/slash-e.git
synced 2025-04-20 14:01:24 +00:00
chore: fix postgres driver
This commit is contained in:
parent
a7d48e8059
commit
4c66edc170
@ -2,7 +2,6 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
@ -38,10 +37,10 @@ func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store
|
||||
func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if find.Type != "" {
|
||||
where, args = append(where, "type = $"+fmt.Sprint(len(args)+1)), append(args, find.Type.String())
|
||||
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type.String())
|
||||
}
|
||||
if find.Level != "" {
|
||||
where, args = append(where, "level = $"+fmt.Sprint(len(args)+1)), append(args, find.Level.String())
|
||||
where, args = append(where, "level = "+placeholder(len(args)+1)), append(args, find.Level.String())
|
||||
}
|
||||
if find.Where != nil {
|
||||
where = append(where, find.Where...)
|
||||
|
@ -6,23 +6,20 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/yourselfhosted/slash/internal/util"
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (*storepb.Collection, error) {
|
||||
set := []string{"creator_id", "name", "title", "description", "shortcut_ids", "visibility"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Description, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(create.ShortcutIds)), ","), "[]"), create.Visibility.String()}
|
||||
placeholder := []string{"$1", "$2", "$3", "$4", "$5", "$6"}
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Description, pq.Array(create.ShortcutIds), create.Visibility.String()}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO collection (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
VALUES (` + strings.Join(placeholder, ",") + `)
|
||||
INSERT INTO collection (` + strings.Join(set, ", ") + `)
|
||||
VALUES (` + placeholders(len(args)) + `)
|
||||
RETURNING id, created_ts, updated_ts
|
||||
`
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
@ -39,35 +36,34 @@ func (d *DB) CreateCollection(ctx context.Context, create *storepb.Collection) (
|
||||
func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollection) (*storepb.Collection, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.Name != nil {
|
||||
set, args = append(set, "name = $1"), append(args, *update.Name)
|
||||
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set, args = append(set, "title = $2"), append(args, *update.Title)
|
||||
set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
|
||||
}
|
||||
if update.Description != nil {
|
||||
set, args = append(set, "description = $3"), append(args, *update.Description)
|
||||
set, args = append(set, "description = "+placeholder(len(args)+1)), append(args, *update.Description)
|
||||
}
|
||||
if update.ShortcutIDs != nil {
|
||||
set, args = append(set, "shortcut_ids = $4"), append(args, strings.Trim(strings.Join(strings.Fields(fmt.Sprint(update.ShortcutIDs)), ","), "[]"))
|
||||
set, args = append(set, "shortcut_ids = "+placeholder(len(args)+1)), append(args, pq.Array(update.ShortcutIDs))
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set, args = append(set, "visibility = $5"), append(args, update.Visibility.String())
|
||||
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE collection
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = $6
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, creator_id, created_ts, updated_ts, name, title, description, shortcut_ids, visibility
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
var shortcutIDs []sql.NullInt32
|
||||
var visibility string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
@ -76,20 +72,16 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
pq.Array(&shortcutIDs),
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
for _, id := range shortcutIDs {
|
||||
if id.Valid {
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
|
||||
}
|
||||
}
|
||||
collection.Visibility = convertVisibilityStringToStorepb(visibility)
|
||||
@ -99,19 +91,18 @@ func (d *DB) UpdateCollection(ctx context.Context, update *store.UpdateCollectio
|
||||
func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([]*storepb.Collection, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = $1"), append(args, *v)
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = $2"), append(args, *v)
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = $3"), append(args, *v)
|
||||
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for i, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
|
||||
args = append(args, visibility)
|
||||
for _, visibility := range v {
|
||||
list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
@ -140,7 +131,8 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
|
||||
list := make([]*storepb.Collection, 0)
|
||||
for rows.Next() {
|
||||
collection := &storepb.Collection{}
|
||||
var shortcutIDs, visibility string
|
||||
var shortcutIDs []sql.NullInt32
|
||||
var visibility string
|
||||
if err := rows.Scan(
|
||||
&collection.Id,
|
||||
&collection.CreatorId,
|
||||
@ -149,20 +141,16 @@ func (d *DB) ListCollections(ctx context.Context, find *store.FindCollection) ([
|
||||
&collection.Name,
|
||||
&collection.Title,
|
||||
&collection.Description,
|
||||
&shortcutIDs,
|
||||
pq.Array(&shortcutIDs),
|
||||
&visibility,
|
||||
); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
collection.ShortcutIds = []int32{}
|
||||
if shortcutIDs != "" {
|
||||
for _, idStr := range strings.Split(shortcutIDs, ",") {
|
||||
shortcutID, err := util.ConvertStringToInt32(idStr)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert shortcut id")
|
||||
}
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, shortcutID)
|
||||
for _, id := range shortcutIDs {
|
||||
if id.Valid {
|
||||
collection.ShortcutIds = append(collection.ShortcutIds, id.Int32)
|
||||
}
|
||||
}
|
||||
collection.Visibility = storepb.Visibility(storepb.Visibility_value[visibility])
|
||||
@ -182,13 +170,3 @@ func (d *DB) DeleteCollection(ctx context.Context, delete *store.DeleteCollectio
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumCollection(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `DELETE FROM collection WHERE creator_id NOT IN (SELECT id FROM user)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -2,7 +2,6 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -17,9 +16,7 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
|
||||
args := []any{create.CreatorId, create.Name, create.Title, create.Content, create.Visibility.String(), strings.Join(create.Tags, " ")}
|
||||
|
||||
stmt := `
|
||||
INSERT INTO memo (
|
||||
` + strings.Join(set, ", ") + `
|
||||
)
|
||||
INSERT INTO memo (` + strings.Join(set, ", ") + `)
|
||||
VALUES (` + placeholders(len(args)) + `)
|
||||
RETURNING id, created_ts, updated_ts, row_status
|
||||
`
|
||||
@ -41,43 +38,34 @@ func (d *DB) CreateMemo(ctx context.Context, create *storepb.Memo) (*storepb.Mem
|
||||
func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb.Memo, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if update.RowStatus != nil {
|
||||
set = append(set, fmt.Sprintf("row_status = $%d", len(set)+1))
|
||||
args = append(args, update.RowStatus.String())
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, update.RowStatus.String())
|
||||
}
|
||||
if update.Name != nil {
|
||||
set = append(set, fmt.Sprintf("name = $%d", len(set)+1))
|
||||
args = append(args, *update.Name)
|
||||
set, args = append(set, "name = "+placeholder(len(args)+1)), append(args, *update.Name)
|
||||
}
|
||||
if update.Title != nil {
|
||||
set = append(set, fmt.Sprintf("title = $%d", len(set)+1))
|
||||
args = append(args, *update.Title)
|
||||
set, args = append(set, "title = "+placeholder(len(args)+1)), append(args, *update.Title)
|
||||
}
|
||||
if update.Content != nil {
|
||||
set = append(set, fmt.Sprintf("content = $%d", len(set)+1))
|
||||
args = append(args, *update.Content)
|
||||
set, args = append(set, "content = "+placeholder(len(args)+1)), append(args, *update.Content)
|
||||
}
|
||||
if update.Visibility != nil {
|
||||
set = append(set, fmt.Sprintf("visibility = $%d", len(set)+1))
|
||||
args = append(args, update.Visibility.String())
|
||||
set, args = append(set, "visibility = "+placeholder(len(args)+1)), append(args, update.Visibility.String())
|
||||
}
|
||||
if update.Tag != nil {
|
||||
set = append(set, fmt.Sprintf("tag = $%d", len(set)+1))
|
||||
args = append(args, *update.Tag)
|
||||
set, args = append(set, "tag = "+placeholder(len(args)+1)), append(args, *update.Tag)
|
||||
}
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no update specified")
|
||||
}
|
||||
args = append(args, update.ID)
|
||||
|
||||
stmt := `
|
||||
UPDATE memo
|
||||
SET
|
||||
` + strings.Join(set, ", ") + `
|
||||
WHERE
|
||||
id = $` + fmt.Sprint(len(set)+1) + `
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, creator_id, created_ts, updated_ts, row_status, name, title, content, visibility, tag
|
||||
`
|
||||
|
||||
args = append(args, update.ID)
|
||||
memo := &storepb.Memo{}
|
||||
var rowStatus, visibility, tags string
|
||||
if err := d.db.QueryRowContext(ctx, stmt, args...).Scan(
|
||||
@ -103,27 +91,26 @@ func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) (*storepb
|
||||
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*storepb.Memo, error) {
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = $1"), append(args, *v)
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.CreatorID; v != nil {
|
||||
where, args = append(where, "creator_id = $2"), append(args, *v)
|
||||
where, args = append(where, "creator_id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = $3"), append(args, *v)
|
||||
where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Name; v != nil {
|
||||
where, args = append(where, "name = $4"), append(args, *v)
|
||||
where, args = append(where, "name = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.VisibilityList; len(v) != 0 {
|
||||
list := []string{}
|
||||
for i, visibility := range v {
|
||||
list = append(list, fmt.Sprintf("$%d", len(args)+i+1))
|
||||
args = append(args, visibility)
|
||||
for _, visibility := range v {
|
||||
list, args = append(list, placeholder(len(args)+1)), append(args, visibility)
|
||||
}
|
||||
where = append(where, fmt.Sprintf("visibility IN (%s)", strings.Join(list, ",")))
|
||||
}
|
||||
if v := find.Tag; v != nil {
|
||||
where, args = append(where, "tag LIKE $"+fmt.Sprint(len(args)+1)), append(args, "%"+*v+"%")
|
||||
where, args = append(where, "tag LIKE "+placeholder(len(args)+1)), append(args, "%"+*v+"%")
|
||||
}
|
||||
|
||||
rows, err := d.db.QueryContext(ctx, `
|
||||
@ -185,24 +172,10 @@ func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func vacuumMemo(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `DELETE FROM memo WHERE creator_id NOT IN (SELECT id FROM user)`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholders(n int) string {
|
||||
placeholder := ""
|
||||
list := []string{}
|
||||
for i := 0; i < n; i++ {
|
||||
if i == 0 {
|
||||
placeholder = fmt.Sprintf("$%d", i+1)
|
||||
} else {
|
||||
placeholder = fmt.Sprintf("%s, $%d", placeholder, i+1)
|
||||
list = append(list, fmt.Sprintf("$%d", i+1))
|
||||
}
|
||||
}
|
||||
return placeholder
|
||||
return strings.Join(list, ", ")
|
||||
}
|
||||
|
@ -1,3 +1,13 @@
|
||||
-- drop all tables first (PostgreSQL style)
|
||||
DROP TABLE IF EXISTS migration_history CASCADE;
|
||||
DROP TABLE IF EXISTS workspace_setting CASCADE;
|
||||
DROP TABLE IF EXISTS "user" CASCADE;
|
||||
DROP TABLE IF EXISTS user_setting CASCADE;
|
||||
DROP TABLE IF EXISTS shortcut CASCADE;
|
||||
DROP TABLE IF EXISTS activity CASCADE;
|
||||
DROP TABLE IF EXISTS collection CASCADE;
|
||||
DROP TABLE IF EXISTS memo CASCADE;
|
||||
|
||||
-- migration_history
|
||||
CREATE TABLE migration_history (
|
||||
version TEXT NOT NULL PRIMARY KEY,
|
||||
@ -11,7 +21,7 @@ CREATE TABLE workspace_setting (
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE user (
|
||||
CREATE TABLE "user" (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
@ -22,11 +32,11 @@ CREATE TABLE user (
|
||||
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_email ON user(email);
|
||||
CREATE INDEX idx_user_email ON "user"(email);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
user_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key)
|
||||
@ -35,7 +45,7 @@ CREATE TABLE user_setting (
|
||||
-- shortcut
|
||||
CREATE TABLE shortcut (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
||||
@ -53,7 +63,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
|
||||
@ -63,7 +73,7 @@ CREATE TABLE activity (
|
||||
-- collection
|
||||
CREATE TABLE collection (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
@ -78,7 +88,7 @@ CREATE INDEX idx_collection_name ON collection(name);
|
||||
-- memo
|
||||
CREATE TABLE memo (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
@ -11,7 +11,7 @@ CREATE TABLE workspace_setting (
|
||||
);
|
||||
|
||||
-- user
|
||||
CREATE TABLE user (
|
||||
CREATE TABLE "user" (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
@ -22,11 +22,11 @@ CREATE TABLE user (
|
||||
role TEXT NOT NULL CHECK (role IN ('ADMIN', 'USER')) DEFAULT 'USER'
|
||||
);
|
||||
|
||||
CREATE INDEX idx_user_email ON user(email);
|
||||
CREATE INDEX idx_user_email ON "user"(email);
|
||||
|
||||
-- user_setting
|
||||
CREATE TABLE user_setting (
|
||||
user_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
user_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
key TEXT NOT NULL,
|
||||
value TEXT NOT NULL,
|
||||
PRIMARY KEY (user_id, key)
|
||||
@ -35,7 +35,7 @@ CREATE TABLE user_setting (
|
||||
-- shortcut
|
||||
CREATE TABLE shortcut (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
||||
@ -53,7 +53,7 @@ CREATE INDEX idx_shortcut_name ON shortcut(name);
|
||||
-- activity
|
||||
CREATE TABLE activity (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
type TEXT NOT NULL DEFAULT '',
|
||||
level TEXT NOT NULL CHECK (level IN ('INFO', 'WARN', 'ERROR')) DEFAULT 'INFO',
|
||||
@ -63,7 +63,7 @@ CREATE TABLE activity (
|
||||
-- collection
|
||||
CREATE TABLE collection (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
name TEXT NOT NULL UNIQUE,
|
||||
@ -78,7 +78,7 @@ CREATE INDEX idx_collection_name ON collection(name);
|
||||
-- memo
|
||||
CREATE TABLE memo (
|
||||
id SERIAL PRIMARY KEY,
|
||||
creator_id INTEGER REFERENCES user(id) NOT NULL,
|
||||
creator_id INTEGER REFERENCES "user"(id) NOT NULL,
|
||||
created_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
updated_ts BIGINT NOT NULL DEFAULT EXTRACT(EPOCH FROM NOW()),
|
||||
row_status TEXT NOT NULL CHECK (row_status IN ('NORMAL', 'ARCHIVED')) DEFAULT 'NORMAL',
|
@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
|
||||
}
|
||||
// In demo mode, we should seed the database.
|
||||
if d.profile.Mode == "demo" {
|
||||
if err := d.seed(ctx); err != nil {
|
||||
if err := d.Seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
|
||||
}
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
latestSchemaFileName = "LATEST.sql"
|
||||
)
|
||||
|
||||
func (d *DB) applyLatestSchema(ctx context.Context) error {
|
||||
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) seed(ctx context.Context) error {
|
||||
func (d *DB) Seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
|
@ -1,9 +0,0 @@
|
||||
DELETE FROM activity;
|
||||
|
||||
DELETE FROM shortcut;
|
||||
|
||||
DELETE FROM user_setting;
|
||||
|
||||
DELETE FROM user;
|
||||
|
||||
DELETE FROM workspace_setting;
|
@ -2,7 +2,6 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@ -207,12 +206,6 @@ func (d *DB) DeleteShortcut(ctx context.Context, delete *store.DeleteShortcut) e
|
||||
return err
|
||||
}
|
||||
|
||||
func vacuumShortcut(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `DELETE FROM shortcut WHERE creator_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
return err
|
||||
}
|
||||
|
||||
func filterTags(tags []string) []string {
|
||||
result := []string{}
|
||||
for _, tag := range tags {
|
||||
|
@ -41,21 +41,20 @@ func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, e
|
||||
func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) {
|
||||
set, args := []string{}, []any{}
|
||||
if v := update.RowStatus; v != nil {
|
||||
set, args = append(set, "row_status = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
set, args = append(set, "row_status = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Email; v != nil {
|
||||
set, args = append(set, "email = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
set, args = append(set, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Nickname; v != nil {
|
||||
set, args = append(set, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
set, args = append(set, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.PasswordHash; v != nil {
|
||||
set, args = append(set, "password_hash = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
set, args = append(set, "password_hash = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := update.Role; v != nil {
|
||||
set, args = append(set, "role = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
set, args = append(set, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
if len(set) == 0 {
|
||||
return nil, errors.New("no fields to update")
|
||||
}
|
||||
@ -63,7 +62,7 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
|
||||
stmt := `
|
||||
UPDATE "user"
|
||||
SET ` + strings.Join(set, ", ") + `
|
||||
WHERE id = $` + placeholder(len(args)+1) + `
|
||||
WHERE id = ` + placeholder(len(args)+1) + `
|
||||
RETURNING id, created_ts, updated_ts, row_status, email, nickname, password_hash, role
|
||||
`
|
||||
args = append(args, update.ID)
|
||||
@ -88,19 +87,19 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.ID; v != nil {
|
||||
where, args = append(where, "id = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.RowStatus; v != nil {
|
||||
where, args = append(where, "row_status = $"+placeholder(len(args)+1)), append(args, v.String())
|
||||
where, args = append(where, "row_status = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
if v := find.Email; v != nil {
|
||||
where, args = append(where, "email = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "email = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Nickname; v != nil {
|
||||
where, args = append(where, "nickname = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "nickname = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
if v := find.Role; v != nil {
|
||||
where, args = append(where, "role = $"+placeholder(len(args)+1)), append(args, *v)
|
||||
where, args = append(where, "role = "+placeholder(len(args)+1)), append(args, *v)
|
||||
}
|
||||
|
||||
query := `
|
||||
@ -149,32 +148,10 @@ func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User
|
||||
}
|
||||
|
||||
func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error {
|
||||
tx, err := d.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
if _, err := d.db.ExecContext(ctx, `DELETE FROM "user" WHERE id = $1`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `
|
||||
DELETE FROM "user" WHERE id = $1
|
||||
`, delete.ID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := vacuumUserSetting(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumShortcut(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumMemo(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := vacuumCollection(ctx, tx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit()
|
||||
return nil
|
||||
}
|
||||
|
||||
func placeholder(n int) string {
|
||||
|
@ -2,9 +2,7 @@ package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
@ -51,10 +49,10 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
||||
where, args := []string{"1 = 1"}, []any{}
|
||||
|
||||
if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, fmt.Sprintf("key = $%d", len(args)+1)), append(args, v.String())
|
||||
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, v.String())
|
||||
}
|
||||
if v := find.UserID; v != nil {
|
||||
where, args = append(where, fmt.Sprintf("user_id = $%d", len(args)+1)), append(args, *find.UserID)
|
||||
where, args = append(where, "user_id = "+placeholder(len(args)+1)), append(args, *find.UserID)
|
||||
}
|
||||
|
||||
query := `
|
||||
@ -110,13 +108,3 @@ func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting)
|
||||
|
||||
return userSettingList, nil
|
||||
}
|
||||
|
||||
func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error {
|
||||
stmt := `DELETE FROM user_setting WHERE user_id NOT IN (SELECT id FROM "user")`
|
||||
_, err := tx.ExecContext(ctx, stmt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -55,7 +55,7 @@ func (d *DB) ListWorkspaceSettings(ctx context.Context, find *store.FindWorkspac
|
||||
where, args := []string{"1 = 1"}, []interface{}{}
|
||||
|
||||
if find.Key != storepb.WorkspaceSettingKey_WORKSPACE_SETTING_KEY_UNSPECIFIED {
|
||||
where, args = append(where, "key = $"+placeholder(len(args)+1)), append(args, find.Key.String())
|
||||
where, args = append(where, "key = "+placeholder(len(args)+1)), append(args, find.Key.String())
|
||||
}
|
||||
|
||||
query := `
|
||||
|
@ -108,7 +108,7 @@ func (d *DB) Migrate(ctx context.Context) error {
|
||||
}
|
||||
// In demo mode, we should seed the database.
|
||||
if d.profile.Mode == "demo" {
|
||||
if err := d.seed(ctx); err != nil {
|
||||
if err := d.Seed(ctx); err != nil {
|
||||
return errors.Wrap(err, "failed to seed")
|
||||
}
|
||||
}
|
||||
@ -119,7 +119,7 @@ func (d *DB) Migrate(ctx context.Context) error {
|
||||
}
|
||||
|
||||
const (
|
||||
latestSchemaFileName = "LATEST__SCHEMA.sql"
|
||||
latestSchemaFileName = "LATEST.sql"
|
||||
)
|
||||
|
||||
func (d *DB) applyLatestSchema(ctx context.Context) error {
|
||||
@ -172,7 +172,7 @@ func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion str
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *DB) seed(ctx context.Context) error {
|
||||
func (d *DB) Seed(ctx context.Context) error {
|
||||
filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed"))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to read seed files")
|
||||
|
@ -14,6 +14,7 @@ type Driver interface {
|
||||
Close() error
|
||||
|
||||
Migrate(ctx context.Context) error
|
||||
Seed(ctx context.Context) error
|
||||
|
||||
// MigrationHistory model related methods.
|
||||
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)
|
||||
|
@ -12,11 +12,13 @@ import (
|
||||
func TestActivityStore(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
ts := NewTestingStore(ctx, t)
|
||||
user, err := createTestingAdminUser(ctx, ts)
|
||||
require.NoError(t, err)
|
||||
list, err := ts.ListActivities(ctx, &store.FindActivity{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(list))
|
||||
activity, err := ts.CreateActivity(ctx, &store.Activity{
|
||||
CreatorID: -1,
|
||||
CreatorID: user.ID,
|
||||
Type: store.ActivityShortcutCreate,
|
||||
Level: store.ActivityInfo,
|
||||
Payload: "",
|
||||
|
@ -22,6 +22,9 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store {
|
||||
if err := dbDriver.Migrate(ctx); err != nil {
|
||||
fmt.Printf("failed to migrate db, error: %+v\n", err)
|
||||
}
|
||||
if err := dbDriver.Seed(ctx); err != nil {
|
||||
fmt.Printf("failed to seed db, error: %+v\n", err)
|
||||
}
|
||||
|
||||
store := store.New(dbDriver, profile)
|
||||
return store
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
|
||||
storepb "github.com/yourselfhosted/slash/proto/gen/store"
|
||||
"github.com/yourselfhosted/slash/store"
|
||||
)
|
||||
|
||||
@ -27,13 +26,6 @@ func TestUserStore(t *testing.T) {
|
||||
Nickname: &userPatchNickname,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = ts.CreateShortcut(ctx, &storepb.Shortcut{
|
||||
CreatorId: user.ID,
|
||||
Name: "test_shortcut",
|
||||
Link: "https://www.google.com",
|
||||
Visibility: storepb.Visibility_PUBLIC,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, userPatchNickname, user.Nickname)
|
||||
err = ts.DeleteUser(ctx, &store.DeleteUser{
|
||||
ID: user.ID,
|
||||
@ -42,9 +34,6 @@ func TestUserStore(t *testing.T) {
|
||||
users, err = ts.ListUsers(ctx, &store.FindUser{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(users))
|
||||
shortcuts, err := ts.ListShortcuts(ctx, &store.FindShortcut{})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, 0, len(shortcuts))
|
||||
}
|
||||
|
||||
// createTestingAdminUser creates a testing admin user.
|
||||
|
Loading…
x
Reference in New Issue
Block a user