diff --git a/bin/slash/main.go b/bin/slash/main.go index 3cb41f5..2bfe8a9 100644 --- a/bin/slash/main.go +++ b/bin/slash/main.go @@ -47,13 +47,13 @@ var ( slog.Error("failed to create db driver", "error", err) return } - if err := dbDriver.Migrate(ctx); err != nil { + + storeInstance := store.New(dbDriver, serverProfile) + if err := storeInstance.Migrate(ctx); err != nil { cancel() slog.Error("failed to migrate db", "error", err) return } - - storeInstance := store.New(dbDriver, serverProfile) if err := storeInstance.MigrateWorkspaceSettings(ctx); err != nil { cancel() slog.Error("failed to migrate workspace settings", "error", err) diff --git a/store/db/postgres/migrator.go b/store/db/postgres/migrator.go deleted file mode 100644 index 5ff429e..0000000 --- a/store/db/postgres/migrator.go +++ /dev/null @@ -1,171 +0,0 @@ -package postgres - -import ( - "context" - "embed" - "fmt" - "io/fs" - "regexp" - "sort" - "strings" - - "github.com/pkg/errors" - - "github.com/yourselfhosted/slash/server/common" - "github.com/yourselfhosted/slash/store" -) - -const ( - latestSchemaFileName = "LATEST.sql" -) - -//go:embed migration -var migrationFS embed.FS - -func (d *DB) Migrate(ctx context.Context) error { - if d.profile.IsDev() { - return d.nonProdMigrate(ctx) - } - - return d.prodMigrate(ctx) -} - -func (d *DB) nonProdMigrate(ctx context.Context) error { - rows, err := d.db.QueryContext(ctx, "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';") - if err != nil { - return errors.Errorf("failed to query database tables: %s", err) - } - if rows.Err() != nil { - return errors.Errorf("failed to query database tables: %s", err) - } - defer rows.Close() - - var tables []string - for rows.Next() { - var table string - err := rows.Scan(&table) - if err != nil { - return errors.Errorf("failed to scan table name: %s", err) - } - tables = append(tables, table) - } - - if len(tables) != 0 { - return nil - } - - buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName) - if err != nil { - return errors.Errorf("failed to read latest schema file: %s", err) - } - - stmt := string(buf) - if _, err := d.db.ExecContext(ctx, stmt); err != nil { - return errors.Errorf("failed to exec SQL %s: %s", stmt, err) - } - - return nil -} - -func (d *DB) prodMigrate(ctx context.Context) error { - currentVersion := common.GetCurrentVersion(d.profile.Mode) - migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{}) - // If there is no migration history, we should apply the latest schema. - if err != nil || len(migrationHistoryList) == 0 { - latestSchemaBytes, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName) - if err != nil { - return errors.Errorf("failed to read latest schema file: %s", err) - } - - latestSchema := string(latestSchemaBytes) - if _, err := d.db.ExecContext(ctx, latestSchema); err != nil { - return errors.Errorf("failed to exec SQL %s: %s", latestSchema, err) - } - // After applying the latest schema, we should insert the latest version to migration_history. - if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ - Version: currentVersion, - }); err != nil { - return errors.Wrap(err, "failed to upsert migration history") - } - return nil - } - - migrationHistoryVersionList := []string{} - for _, migrationHistory := range migrationHistoryList { - migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) - } - sort.Sort(common.SortVersion(migrationHistoryVersionList)) - latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] - // If the latest migration history version is greater than or equal to the current version, we will not apply any migration. - if !common.IsVersionGreaterThan(common.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { - return nil - } - - println("start migrate") - for _, minorVersion := range getMinorVersionList() { - normalizedVersion := minorVersion + ".0" - if common.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && common.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { - println("applying migration for", normalizedVersion) - if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { - return errors.Wrap(err, "failed to apply minor version migration") - } - } - } - println("end migrate") - return nil -} - -func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { - filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion)) - if err != nil { - return errors.Wrap(err, "failed to read ddl files") - } - - sort.Strings(filenames) - // Loop over all migration files and execute them in order. - for _, filename := range filenames { - buf, err := migrationFS.ReadFile(filename) - if err != nil { - return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) - } - for _, stmt := range strings.Split(string(buf), ";") { - if strings.TrimSpace(stmt) == "" { - continue - } - if _, err := d.db.ExecContext(ctx, stmt); err != nil { - return errors.Wrapf(err, "migrate error: %s", stmt) - } - } - } - - // Upsert the newest version to migration_history. - version := minorVersion + ".0" - if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil { - return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) - } - - return nil -} - -// minorDirRegexp is a regular expression for minor version directory. -var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) - -func getMinorVersionList() []string { - minorVersionList := []string{} - - if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { - if err != nil { - return err - } - if file.IsDir() && minorDirRegexp.MatchString(path) { - minorVersionList = append(minorVersionList, file.Name()) - } - - return nil - }); err != nil { - panic(err) - } - - sort.Sort(common.SortVersion(minorVersionList)) - return minorVersionList -} diff --git a/store/db/sqlite/migrator.go b/store/db/sqlite/migrator.go deleted file mode 100644 index 893a446..0000000 --- a/store/db/sqlite/migrator.go +++ /dev/null @@ -1,233 +0,0 @@ -package sqlite - -import ( - "context" - "embed" - "fmt" - "io/fs" - "os" - "regexp" - "sort" - "time" - - "github.com/pkg/errors" - - "github.com/yourselfhosted/slash/server/common" - "github.com/yourselfhosted/slash/store" -) - -//go:embed migration -var migrationFS embed.FS - -//go:embed seed -var seedFS embed.FS - -// Migrate applies the latest schema to the database. -func (d *DB) Migrate(ctx context.Context) error { - currentVersion := common.GetCurrentVersion(d.profile.Mode) - if d.profile.Mode == "prod" { - _, err := os.Stat(d.profile.DSN) - if err != nil { - // If db file not exists, we should create a new one with latest schema. - if errors.Is(err, os.ErrNotExist) { - if err := d.applyLatestSchema(ctx); err != nil { - return errors.Wrap(err, "failed to apply latest schema") - } - // Upsert the newest version to migration_history. - if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ - Version: currentVersion, - }); err != nil { - return errors.Wrap(err, "failed to upsert migration history") - } - } else { - return errors.Wrap(err, "failed to get db file stat") - } - } else { - // If db file exists, we should check if we need to migrate the database. - migrationHistoryList, err := d.ListMigrationHistories(ctx, &store.FindMigrationHistory{}) - if err != nil { - return errors.Wrap(err, "failed to find migration history") - } - // If no migration history, we should apply the latest version migration and upsert the migration history. - if len(migrationHistoryList) == 0 { - minorVersion := common.GetMinorVersion(currentVersion) - if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { - return errors.Wrapf(err, "failed to apply version %s migration", minorVersion) - } - _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ - Version: currentVersion, - }) - if err != nil { - return errors.Wrap(err, "failed to upsert migration history") - } - return nil - } - - migrationHistoryVersionList := []string{} - for _, migrationHistory := range migrationHistoryList { - migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) - } - sort.Sort(common.SortVersion(migrationHistoryVersionList)) - latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] - - if common.IsVersionGreaterThan(common.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { - minorVersionList := getMinorVersionList() - // backup the raw database file before migration - rawBytes, err := os.ReadFile(d.profile.DSN) - if err != nil { - return errors.Wrap(err, "failed to read raw database file") - } - backupDBFilePath := fmt.Sprintf("%s/slash_%s_%d_backup.db", d.profile.Data, d.profile.Version, time.Now().Unix()) - if err := os.WriteFile(backupDBFilePath, rawBytes, 0644); err != nil { - return errors.Wrap(err, "failed to write raw database file") - } - println("succeed to copy a backup database file") - println("start migrate") - for _, minorVersion := range minorVersionList { - normalizedVersion := minorVersion + ".0" - if common.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && common.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { - println("applying migration for", normalizedVersion) - if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { - return errors.Wrap(err, "failed to apply minor version migration") - } - } - } - println("end migrate") - - // remove the created backup db file after migrate succeed - if err := os.Remove(backupDBFilePath); err != nil { - println(fmt.Sprintf("Failed to remove temp database file, err %v", err)) - } - } - } - } else { - // In non-prod mode, we should always migrate the database. - if _, err := os.Stat(d.profile.DSN); errors.Is(err, os.ErrNotExist) { - if err := d.applyLatestSchema(ctx); err != nil { - return errors.Wrap(err, "failed to apply latest schema") - } - // In demo mode, we should seed the database. - if d.profile.Mode == "demo" { - if err := d.Seed(ctx); err != nil { - return errors.Wrap(err, "failed to seed") - } - } - } - } - - return nil -} - -const ( - latestSchemaFileName = "LATEST.sql" -) - -func (d *DB) applyLatestSchema(ctx context.Context) error { - schemaMode := "dev" - if d.profile.Mode == "prod" { - schemaMode = "prod" - } - latestSchemaPath := fmt.Sprintf("migration/%s/%s", schemaMode, latestSchemaFileName) - buf, err := migrationFS.ReadFile(latestSchemaPath) - if err != nil { - return errors.Wrapf(err, "failed to read latest schema %q", latestSchemaPath) - } - stmt := string(buf) - if err := d.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "migrate error: %s", stmt) - } - return nil -} - -func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { - filenames, err := fs.Glob(migrationFS, fmt.Sprintf("%s/%s/*.sql", "migration/prod", minorVersion)) - if err != nil { - return errors.Wrap(err, "failed to read ddl files") - } - - sort.Strings(filenames) - migrationStmt := "" - - // Loop over all migration files and execute them in order. - for _, filename := range filenames { - buf, err := migrationFS.ReadFile(filename) - if err != nil { - return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) - } - stmt := string(buf) - migrationStmt += stmt - if err := d.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "migrate error: %s", stmt) - } - } - - // Upsert the newest version to migration_history. - version := minorVersion + ".0" - if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ - Version: version, - }); err != nil { - return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) - } - - return nil -} - -func (d *DB) Seed(ctx context.Context) error { - filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s/*.sql", "seed")) - if err != nil { - return errors.Wrap(err, "failed to read seed files") - } - - sort.Strings(filenames) - - // Loop over all seed files and execute them in order. - for _, filename := range filenames { - buf, err := seedFS.ReadFile(filename) - if err != nil { - return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) - } - stmt := string(buf) - if err := d.execute(ctx, stmt); err != nil { - return errors.Wrapf(err, "seed error: %s", stmt) - } - } - return nil -} - -// execute runs a single SQL statement within a transaction. -func (d *DB) execute(ctx context.Context, stmt string) error { - tx, err := d.db.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - if _, err := tx.ExecContext(ctx, stmt); err != nil { - return errors.Wrap(err, "failed to execute statement") - } - - return tx.Commit() -} - -// minorDirRegexp is a regular expression for minor version directory. -var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) - -func getMinorVersionList() []string { - minorVersionList := []string{} - - if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { - if err != nil { - return err - } - if file.IsDir() && minorDirRegexp.MatchString(path) { - minorVersionList = append(minorVersionList, file.Name()) - } - - return nil - }); err != nil { - panic(err) - } - - sort.Sort(common.SortVersion(minorVersionList)) - return minorVersionList -} diff --git a/store/driver.go b/store/driver.go index 97ea360..ebaeae5 100644 --- a/store/driver.go +++ b/store/driver.go @@ -13,8 +13,6 @@ type Driver interface { GetDB() *sql.DB Close() error - Migrate(ctx context.Context) error - // MigrationHistory model related methods. UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) ListMigrationHistories(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error) diff --git a/store/db/postgres/migration/dev/LATEST.sql b/store/migration/postgres/dev/LATEST_SCHEMA.sql similarity index 100% rename from store/db/postgres/migration/dev/LATEST.sql rename to store/migration/postgres/dev/LATEST_SCHEMA.sql diff --git a/store/db/postgres/migration/prod/LATEST.sql b/store/migration/postgres/prod/LATEST_SCHEMA.sql similarity index 100% rename from store/db/postgres/migration/prod/LATEST.sql rename to store/migration/postgres/prod/LATEST_SCHEMA.sql diff --git a/store/db/sqlite/migration/dev/LATEST.sql b/store/migration/sqlite/dev/LATEST_SCHEMA.sql similarity index 100% rename from store/db/sqlite/migration/dev/LATEST.sql rename to store/migration/sqlite/dev/LATEST_SCHEMA.sql diff --git a/store/db/sqlite/migration/prod/0.2/00__create_index.sql b/store/migration/sqlite/prod/0.2/00__create_index.sql similarity index 100% rename from store/db/sqlite/migration/prod/0.2/00__create_index.sql rename to store/migration/sqlite/prod/0.2/00__create_index.sql diff --git a/store/db/sqlite/migration/prod/0.3/00__add_og_metadata.sql b/store/migration/sqlite/prod/0.3/00__add_og_metadata.sql similarity index 100% rename from store/db/sqlite/migration/prod/0.3/00__add_og_metadata.sql rename to store/migration/sqlite/prod/0.3/00__add_og_metadata.sql diff --git a/store/db/sqlite/migration/prod/0.4/00__add_shortcut_title.sql b/store/migration/sqlite/prod/0.4/00__add_shortcut_title.sql similarity index 100% rename from store/db/sqlite/migration/prod/0.4/00__add_shortcut_title.sql rename to store/migration/sqlite/prod/0.4/00__add_shortcut_title.sql diff --git a/store/db/sqlite/migration/prod/0.5/00__drop_idp.sql b/store/migration/sqlite/prod/0.5/00__drop_idp.sql similarity index 100% rename from store/db/sqlite/migration/prod/0.5/00__drop_idp.sql rename to store/migration/sqlite/prod/0.5/00__drop_idp.sql diff --git a/store/db/sqlite/migration/prod/0.5/01__collection.sql b/store/migration/sqlite/prod/0.5/01__collection.sql similarity index 100% rename from store/db/sqlite/migration/prod/0.5/01__collection.sql rename to store/migration/sqlite/prod/0.5/01__collection.sql diff --git a/store/db/sqlite/migration/prod/LATEST.sql b/store/migration/sqlite/prod/LATEST_SCHEMA.sql similarity index 100% rename from store/db/sqlite/migration/prod/LATEST.sql rename to store/migration/sqlite/prod/LATEST_SCHEMA.sql diff --git a/store/migrator.go b/store/migrator.go index f7be6a5..b5698ae 100644 --- a/store/migrator.go +++ b/store/migrator.go @@ -3,9 +3,294 @@ package store import ( "context" + "database/sql" + "embed" + "fmt" + "io/fs" + "log/slog" + "path/filepath" + "sort" + "strconv" + "strings" + + "github.com/pkg/errors" storepb "github.com/yourselfhosted/slash/proto/gen/store" + "github.com/yourselfhosted/slash/server/common" ) +//go:embed migration +var migrationFS embed.FS + +//go:embed seed +var seedFS embed.FS + +const ( + // MigrateFileNameSplit is the split character between the patch version and the description in the migration file name. + // For example, "1__create_table.sql". + MigrateFileNameSplit = "__" + // LatestSchemaFileName is the name of the latest schema file. + // This file is used to apply the latest schema when no migration history is found. + LatestSchemaFileName = "LATEST_SCHEMA.sql" +) + +// Migrate applies the latest schema to the database. +func (s *Store) Migrate(ctx context.Context) error { + if err := s.preMigrate(ctx); err != nil { + return errors.Wrap(err, "failed to pre-migrate") + } + + if s.profile.Mode == "prod" { + migrationHistoryList, err := s.driver.ListMigrationHistories(ctx, &FindMigrationHistory{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + if len(migrationHistoryList) == 0 { + return errors.Errorf("no migration history found") + } + + migrationHistoryVersions := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version) + } + sort.Sort(common.SortVersion(migrationHistoryVersions)) + latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1] + schemaVersion, err := s.GetCurrentSchemaVersion() + if err != nil { + return errors.Wrap(err, "failed to get current schema version") + } + + if common.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) { + filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s/*/*.sql", s.getMigrationBasePath())) + if err != nil { + return errors.Wrap(err, "failed to read migration files") + } + sort.Strings(filePaths) + + // Start a transaction to apply the latest schema. + tx, err := s.driver.GetDB().Begin() + if err != nil { + return errors.Wrap(err, "failed to start transaction") + } + defer tx.Rollback() + + slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion)) + for _, filePath := range filePaths { + fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) + if err != nil { + return errors.Wrap(err, "failed to get schema version of migrate script") + } + if common.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && common.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) { + bytes, err := migrationFS.ReadFile(filePath) + if err != nil { + return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath) + } + stmt := string(bytes) + if err := s.execute(ctx, tx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + } + } + + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "failed to commit transaction") + } + slog.Info("end migrate") + + // Upsert the current schema version to migration_history. + if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ + Version: schemaVersion, + }); err != nil { + return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion) + } + } + } else if s.profile.Mode == "demo" { + // In demo mode, we should seed the database. + if err := s.seed(ctx); err != nil { + return errors.Wrap(err, "failed to seed") + } + } + return nil +} + +func (s *Store) preMigrate(ctx context.Context) error { + migrationHistoryList, err := s.driver.ListMigrationHistories(ctx, &FindMigrationHistory{}) + // If any error occurs or no migration history found, apply the latest schema. + if err != nil || len(migrationHistoryList) == 0 { + if err != nil { + slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error())) + } + filePath := s.getMigrationBasePath() + LatestSchemaFileName + bytes, err := migrationFS.ReadFile(filePath) + if err != nil { + return errors.Errorf("failed to read latest schema file: %s", err) + } + schemaVersion, err := s.GetCurrentSchemaVersion() + if err != nil { + return errors.Wrap(err, "failed to get current schema version") + } + + // Start a transaction to apply the latest schema. + tx, err := s.driver.GetDB().Begin() + if err != nil { + return errors.Wrap(err, "failed to start transaction") + } + defer tx.Rollback() + if err := s.execute(ctx, tx, string(bytes)); err != nil { + return errors.Errorf("failed to execute SQL file %s, err %s", filePath, err) + } + if err := tx.Commit(); err != nil { + return errors.Wrap(err, "failed to commit transaction") + } + + if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ + Version: schemaVersion, + }); err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + } + if s.profile.Mode == "prod" { + if err := s.normalizedMigrationHistoryList(ctx); err != nil { + return errors.Wrap(err, "failed to normalize migration history list") + } + } + return nil +} + +func (s *Store) getMigrationBasePath() string { + mode := "dev" + if s.profile.Mode == "prod" { + mode = "prod" + } + return fmt.Sprintf("migration/%s/%s/", s.profile.Driver, mode) +} + +func (s *Store) getSeedBasePath() string { + return fmt.Sprintf("seed/%s/", s.profile.Driver) +} + +func (s *Store) seed(ctx context.Context) error { + // Only seed for SQLite. + if s.profile.Driver != "sqlite" { + slog.Warn("seed is only supported for SQLite") + return nil + } + + filenames, err := fs.Glob(seedFS, fmt.Sprintf("%s*.sql", s.getSeedBasePath())) + if err != nil { + return errors.Wrap(err, "failed to read seed files") + } + + // Sort seed files by name. This is important to ensure that seed files are applied in order. + sort.Strings(filenames) + // Start a transaction to apply the seed files. + tx, err := s.driver.GetDB().Begin() + if err != nil { + return errors.Wrap(err, "failed to start transaction") + } + defer tx.Rollback() + // Loop over all seed files and execute them in order. + for _, filename := range filenames { + bytes, err := seedFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) + } + if err := s.execute(ctx, tx, string(bytes)); err != nil { + return errors.Wrapf(err, "seed error: %s", filename) + } + } + return tx.Commit() +} + +func (s *Store) GetCurrentSchemaVersion() (string, error) { + currentVersion := common.GetCurrentVersion(s.profile.Mode) + minorVersion := common.GetMinorVersion(currentVersion) + filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s%s/*.sql", s.getMigrationBasePath(), minorVersion)) + if err != nil { + return "", errors.Wrap(err, "failed to read migration files") + } + + sort.Strings(filePaths) + if len(filePaths) == 0 { + return fmt.Sprintf("%s.0", minorVersion), nil + } + return s.getSchemaVersionOfMigrateScript(filePaths[len(filePaths)-1]) +} + +func (s *Store) getSchemaVersionOfMigrateScript(filePath string) (string, error) { + // If the file is the latest schema file, return the current schema common. + if strings.HasSuffix(filePath, LatestSchemaFileName) { + return s.GetCurrentSchemaVersion() + } + + normalizedPath := filepath.ToSlash(filePath) + elements := strings.Split(normalizedPath, "/") + if len(elements) < 2 { + return "", errors.Errorf("invalid file path: %s", filePath) + } + minorVersion := elements[len(elements)-2] + rawPatchVersion := strings.Split(elements[len(elements)-1], MigrateFileNameSplit)[0] + patchVersion, err := strconv.Atoi(rawPatchVersion) + if err != nil { + return "", errors.Wrapf(err, "failed to convert patch version to int: %s", rawPatchVersion) + } + return fmt.Sprintf("%s.%d", minorVersion, patchVersion+1), nil +} + +// execute runs a single SQL statement within a transaction. +func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error { + if _, err := tx.ExecContext(ctx, stmt); err != nil { + return errors.Wrap(err, "failed to execute statement") + } + return nil +} + +func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { + migrationHistoryList, err := s.driver.ListMigrationHistories(ctx, &FindMigrationHistory{}) + if err != nil { + return errors.Wrap(err, "failed to find migration history") + } + versions := []string{} + for _, migrationHistory := range migrationHistoryList { + versions = append(versions, migrationHistory.Version) + } + sort.Sort(common.SortVersion(versions)) + latestVersion := versions[len(versions)-1] + latestMinorVersion := common.GetMinorVersion(latestVersion) + + schemaVersionMap := map[string]string{} + filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s/*/*.sql", s.getMigrationBasePath())) + if err != nil { + return errors.Wrap(err, "failed to read migration files") + } + sort.Strings(filePaths) + for _, filePath := range filePaths { + fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) + if err != nil { + return errors.Wrap(err, "failed to get schema version of migrate script") + } + schemaVersionMap[common.GetMinorVersion(fileSchemaVersion)] = fileSchemaVersion + } + + latestSchemaVersion := schemaVersionMap[latestMinorVersion] + if latestSchemaVersion == "" { + return errors.Errorf("latest schema version not found") + } + if common.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) { + return nil + } + + // Start a transaction to insert the latest schema version to migration_history. + tx, err := s.driver.GetDB().Begin() + if err != nil { + return errors.Wrap(err, "failed to start transaction") + } + defer tx.Rollback() + if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil { + return errors.Wrap(err, "failed to insert migration history") + } + return tx.Commit() +} + func (s *Store) MigrateWorkspaceSettings(ctx context.Context) error { workspaceSettings, err := s.driver.ListWorkspaceSettings(ctx, &FindWorkspaceSetting{}) if err != nil { diff --git a/store/db/sqlite/seed/10000__user.sql b/store/seed/sqlite/10000__user.sql similarity index 100% rename from store/db/sqlite/seed/10000__user.sql rename to store/seed/sqlite/10000__user.sql diff --git a/store/db/sqlite/seed/10001__shortcut.sql b/store/seed/sqlite/10001__shortcut.sql similarity index 100% rename from store/db/sqlite/seed/10001__shortcut.sql rename to store/seed/sqlite/10001__shortcut.sql diff --git a/store/db/sqlite/seed/10002__collection.sql b/store/seed/sqlite/10002__collection.sql similarity index 100% rename from store/db/sqlite/seed/10002__collection.sql rename to store/seed/sqlite/10002__collection.sql diff --git a/test/store/migrator_test.go b/test/store/migrator_test.go new file mode 100644 index 0000000..046ca6c --- /dev/null +++ b/test/store/migrator_test.go @@ -0,0 +1,17 @@ +package teststore + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestGetCurrentSchemaVersion(t *testing.T) { + ctx := context.Background() + ts := NewTestingStore(ctx, t) + + currentSchemaVersion, err := ts.GetCurrentSchemaVersion() + require.NoError(t, err) + require.Equal(t, "1.0.0", currentSchemaVersion) +} diff --git a/test/store/store.go b/test/store/store.go index 3c5b570..c2ae0cc 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -18,11 +18,10 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { fmt.Printf("failed to create db driver, error: %+v\n", err) } resetTestingDB(ctx, profile, dbDriver) - if err := dbDriver.Migrate(ctx); err != nil { + store := store.New(dbDriver, profile) + if err := store.Migrate(ctx); err != nil { fmt.Printf("failed to migrate db, error: %+v\n", err) } - - store := store.New(dbDriver, profile) return store }