diff --git a/store/db/mysql/mysql.go b/store/db/mysql/mysql.go index 02a6523bd..989d33e89 100644 --- a/store/db/mysql/mysql.go +++ b/store/db/mysql/mysql.go @@ -1,6 +1,7 @@ package mysql import ( + "context" "database/sql" "github.com/go-sql-driver/mysql" @@ -47,6 +48,15 @@ func (d *DB) Close() error { return d.db.Close() } +func (d *DB) IsInitialized(ctx context.Context) (bool, error) { + var exists bool + err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "failed to check if database is initialized") + } + return exists, nil +} + func mergeDSN(baseDSN string) (string, error) { config, err := mysql.ParseDSN(baseDSN) if err != nil { diff --git a/store/db/postgres/postgres.go b/store/db/postgres/postgres.go index 34495e730..4c8b907b2 100644 --- a/store/db/postgres/postgres.go +++ b/store/db/postgres/postgres.go @@ -1,6 +1,7 @@ package postgres import ( + "context" "database/sql" "log" @@ -15,7 +16,6 @@ import ( type DB struct { db *sql.DB profile *profile.Profile - // Add any other fields as needed } func NewDB(profile *profile.Profile) (store.Driver, error) { @@ -46,3 +46,12 @@ func (d *DB) GetDB() *sql.DB { func (d *DB) Close() error { return d.db.Close() } + +func (d *DB) IsInitialized(ctx context.Context) (bool, error) { + var exists bool + err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "failed to check if database is initialized") + } + return exists, nil +} diff --git a/store/db/sqlite/sqlite.go b/store/db/sqlite/sqlite.go index 345a4783b..3b4a30f8d 100644 --- a/store/db/sqlite/sqlite.go +++ b/store/db/sqlite/sqlite.go @@ -1,6 +1,7 @@ package sqlite import ( + "context" "database/sql" "github.com/pkg/errors" @@ -57,3 +58,13 @@ func (d *DB) GetDB() *sql.DB { func (d *DB) Close() error { return d.db.Close() } + +func (d *DB) IsInitialized(ctx context.Context) (bool, error) { + // Check if the database is initialized by checking if the memo table exists. + var exists bool + err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists) + if err != nil { + return false, errors.Wrap(err, "failed to check if database is initialized") + } + return exists, nil +} diff --git a/store/driver.go b/store/driver.go index 603ab1d42..27f110034 100644 --- a/store/driver.go +++ b/store/driver.go @@ -15,6 +15,8 @@ type Driver interface { GetDB() *sql.DB Close() error + IsInitialized(ctx context.Context) (bool, error) + // MigrationHistory model related methods. FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error) UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) diff --git a/store/migrator.go b/store/migrator.go index 61eaed019..52ef597bf 100644 --- a/store/migrator.go +++ b/store/migrator.go @@ -40,26 +40,22 @@ func (s *Store) Migrate(ctx context.Context) error { } if s.profile.Mode == "prod" { - migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) + workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx) if err != nil { - return errors.Wrap(err, "failed to find migration history") + return errors.Wrap(err, "failed to get workspace basic setting") } - if len(migrationHistoryList) == 0 { - return errors.Errorf("no migration history found") - } - - migrationHistoryVersions := []string{} - for _, migrationHistory := range migrationHistoryList { - migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version) - } - sort.Sort(version.SortVersion(migrationHistoryVersions)) - latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1] - schemaVersion, err := s.GetCurrentSchemaVersion() + currentSchemaVersion, err := s.GetCurrentSchemaVersion() if err != nil { return errors.Wrap(err, "failed to get current schema version") } - - if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) { + if version.IsVersionGreaterThan(workspaceBasicSetting.SchemaVersion, currentSchemaVersion) { + slog.Error("cannot downgrade schema version", + slog.String("databaseVersion", workspaceBasicSetting.SchemaVersion), + slog.String("currentVersion", currentSchemaVersion), + ) + return errors.Errorf("cannot downgrade schema version from %s to %s", workspaceBasicSetting.SchemaVersion, currentSchemaVersion) + } + if version.IsVersionGreaterThan(currentSchemaVersion, workspaceBasicSetting.SchemaVersion) { filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath())) if err != nil { return errors.Wrap(err, "failed to read migration files") @@ -73,13 +69,13 @@ func (s *Store) Migrate(ctx context.Context) error { } defer tx.Rollback() - slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion)) + slog.Info("start migration", slog.String("currentSchemaVersion", workspaceBasicSetting.SchemaVersion), slog.String("targetSchemaVersion", currentSchemaVersion)) for _, filePath := range filePaths { fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) if err != nil { return errors.Wrap(err, "failed to get schema version of migrate script") } - if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) { + if version.IsVersionGreaterThan(fileSchemaVersion, workspaceBasicSetting.SchemaVersion) && version.IsVersionGreaterOrEqualThan(currentSchemaVersion, fileSchemaVersion) { bytes, err := migrationFS.ReadFile(filePath) if err != nil { return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath) @@ -90,20 +86,11 @@ func (s *Store) Migrate(ctx context.Context) error { } } } - if err := tx.Commit(); err != nil { return errors.Wrap(err, "failed to commit transaction") } slog.Info("end migrate") - - // Upsert the current schema version to migration_history. - // TODO: retire using migration history later. - if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ - Version: schemaVersion, - }); err != nil { - return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion) - } - if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil { + if err := s.updateCurrentSchemaVersion(ctx, currentSchemaVersion); err != nil { return errors.Wrap(err, "failed to update current schema version") } } @@ -117,23 +104,17 @@ func (s *Store) Migrate(ctx context.Context) error { } func (s *Store) preMigrate(ctx context.Context) error { - // TODO: using schema version in basic setting instead of migration history. - migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) - // If any error occurs or no migration history found, apply the latest schema. - if err != nil || len(migrationHistoryList) == 0 { - if err != nil { - slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error())) - } + initialized, err := s.driver.IsInitialized(ctx) + if err != nil { + return errors.Wrap(err, "failed to check if database is initialized") + } + + if !initialized { filePath := s.getMigrationBasePath() + LatestSchemaFileName bytes, err := migrationFS.ReadFile(filePath) if err != nil { return errors.Errorf("failed to read latest schema file: %s", err) } - schemaVersion, err := s.GetCurrentSchemaVersion() - if err != nil { - return errors.Wrap(err, "failed to get current schema version") - } - // Start a transaction to apply the latest schema. tx, err := s.driver.GetDB().Begin() if err != nil { @@ -147,20 +128,23 @@ func (s *Store) preMigrate(ctx context.Context) error { return errors.Wrap(err, "failed to commit transaction") } - // TODO: using schema version in basic setting instead of migration history. - if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ - Version: schemaVersion, - }); err != nil { - return errors.Wrap(err, "failed to upsert migration history") + // Upsert current schema version to database. + schemaVersion, err := s.GetCurrentSchemaVersion() + if err != nil { + return errors.Wrap(err, "failed to get current schema version") } if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil { return errors.Wrap(err, "failed to update current schema version") } } + if s.profile.Mode == "prod" { - if err := s.normalizedMigrationHistoryList(ctx); err != nil { + if err := s.normalizeMigrationHistoryList(ctx); err != nil { return errors.Wrap(err, "failed to normalize migration history list") } + if err := s.migrateSchemaVersionToSetting(ctx); err != nil { + return errors.Wrap(err, "failed to migrate schema version to setting") + } } return nil } @@ -249,7 +233,22 @@ func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error { return nil } -func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { +func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error { + workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx) + if err != nil { + return errors.Wrap(err, "failed to get workspace basic setting") + } + workspaceBasicSetting.SchemaVersion = schemaVersion + if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{ + Key: storepb.WorkspaceSettingKey_BASIC, + Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting}, + }); err != nil { + return errors.Wrap(err, "failed to upsert workspace setting") + } + return nil +} + +func (s *Store) normalizeMigrationHistoryList(ctx context.Context) error { migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) if err != nil { return errors.Wrap(err, "failed to find migration history") @@ -258,6 +257,9 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { for _, migrationHistory := range migrationHistoryList { versions = append(versions, migrationHistory.Version) } + if len(versions) == 0 { + return errors.Errorf("no migration history found") + } sort.Sort(version.SortVersion(versions)) latestVersion := versions[len(versions)-1] latestMinorVersion := version.GetMinorVersion(latestVersion) @@ -289,30 +291,37 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) { return nil } + if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ + Version: latestSchemaVersion, + }); err != nil { + return errors.Wrap(err, "failed to upsert latest migration history") + } + return nil +} - // Start a transaction to insert the latest schema version to migration_history. - tx, err := s.driver.GetDB().Begin() +func (s *Store) migrateSchemaVersionToSetting(ctx context.Context) error { + migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) if err != nil { - return errors.Wrap(err, "failed to start transaction") + return errors.Wrap(err, "failed to find migration history") } - defer tx.Rollback() - if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil { - return errors.Wrap(err, "failed to insert migration history") + versions := []string{} + for _, migrationHistory := range migrationHistoryList { + versions = append(versions, migrationHistory.Version) } - return tx.Commit() -} + if len(versions) == 0 { + return errors.Errorf("no migration history found") + } + sort.Sort(version.SortVersion(versions)) + latestVersion := versions[len(versions)-1] -func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error { workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx) if err != nil { return errors.Wrap(err, "failed to get workspace basic setting") } - workspaceBasicSetting.SchemaVersion = schemaVersion - if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{ - Key: storepb.WorkspaceSettingKey_BASIC, - Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting}, - }); err != nil { - return errors.Wrap(err, "failed to upsert workspace setting") + if version.IsVersionGreaterOrEqualThan(workspaceBasicSetting.SchemaVersion, latestVersion) { + if err := s.updateCurrentSchemaVersion(ctx, latestVersion); err != nil { + return errors.Wrap(err, "failed to update current schema version") + } } return nil }