refactor: schema migrator

pull/4772/head^2
Johnny 4 months ago
parent d386b83b7b
commit 3fd29f6493

@ -1,6 +1,7 @@
package mysql package mysql
import ( import (
"context"
"database/sql" "database/sql"
"github.com/go-sql-driver/mysql" "github.com/go-sql-driver/mysql"
@ -47,6 +48,15 @@ func (d *DB) Close() error {
return d.db.Close() return d.db.Close()
} }
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
var exists bool
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}
func mergeDSN(baseDSN string) (string, error) { func mergeDSN(baseDSN string) (string, error) {
config, err := mysql.ParseDSN(baseDSN) config, err := mysql.ParseDSN(baseDSN)
if err != nil { if err != nil {

@ -1,6 +1,7 @@
package postgres package postgres
import ( import (
"context"
"database/sql" "database/sql"
"log" "log"
@ -15,7 +16,6 @@ import (
type DB struct { type DB struct {
db *sql.DB db *sql.DB
profile *profile.Profile profile *profile.Profile
// Add any other fields as needed
} }
func NewDB(profile *profile.Profile) (store.Driver, error) { func NewDB(profile *profile.Profile) (store.Driver, error) {
@ -46,3 +46,12 @@ func (d *DB) GetDB() *sql.DB {
func (d *DB) Close() error { func (d *DB) Close() error {
return d.db.Close() return d.db.Close()
} }
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
var exists bool
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}

@ -1,6 +1,7 @@
package sqlite package sqlite
import ( import (
"context"
"database/sql" "database/sql"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -57,3 +58,13 @@ func (d *DB) GetDB() *sql.DB {
func (d *DB) Close() error { func (d *DB) Close() error {
return d.db.Close() return d.db.Close()
} }
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
// Check if the database is initialized by checking if the memo table exists.
var exists bool
err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}

@ -15,6 +15,8 @@ type Driver interface {
GetDB() *sql.DB GetDB() *sql.DB
Close() error Close() error
IsInitialized(ctx context.Context) (bool, error)
// MigrationHistory model related methods. // MigrationHistory model related methods.
FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error) FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error) UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)

@ -40,26 +40,22 @@ func (s *Store) Migrate(ctx context.Context) error {
} }
if s.profile.Mode == "prod" { if s.profile.Mode == "prod" {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to find migration history") return errors.Wrap(err, "failed to get workspace basic setting")
} }
if len(migrationHistoryList) == 0 { currentSchemaVersion, err := s.GetCurrentSchemaVersion()
return errors.Errorf("no migration history found")
}
migrationHistoryVersions := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersions))
latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1]
schemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil { if err != nil {
return errors.Wrap(err, "failed to get current schema version") return errors.Wrap(err, "failed to get current schema version")
} }
if version.IsVersionGreaterThan(workspaceBasicSetting.SchemaVersion, currentSchemaVersion) {
if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) { slog.Error("cannot downgrade schema version",
slog.String("databaseVersion", workspaceBasicSetting.SchemaVersion),
slog.String("currentVersion", currentSchemaVersion),
)
return errors.Errorf("cannot downgrade schema version from %s to %s", workspaceBasicSetting.SchemaVersion, currentSchemaVersion)
}
if version.IsVersionGreaterThan(currentSchemaVersion, workspaceBasicSetting.SchemaVersion) {
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath())) filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath()))
if err != nil { if err != nil {
return errors.Wrap(err, "failed to read migration files") return errors.Wrap(err, "failed to read migration files")
@ -73,13 +69,13 @@ func (s *Store) Migrate(ctx context.Context) error {
} }
defer tx.Rollback() defer tx.Rollback()
slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion)) slog.Info("start migration", slog.String("currentSchemaVersion", workspaceBasicSetting.SchemaVersion), slog.String("targetSchemaVersion", currentSchemaVersion))
for _, filePath := range filePaths { for _, filePath := range filePaths {
fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath) fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to get schema version of migrate script") return errors.Wrap(err, "failed to get schema version of migrate script")
} }
if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) { if version.IsVersionGreaterThan(fileSchemaVersion, workspaceBasicSetting.SchemaVersion) && version.IsVersionGreaterOrEqualThan(currentSchemaVersion, fileSchemaVersion) {
bytes, err := migrationFS.ReadFile(filePath) bytes, err := migrationFS.ReadFile(filePath)
if err != nil { if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath) return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath)
@ -90,20 +86,11 @@ func (s *Store) Migrate(ctx context.Context) error {
} }
} }
} }
if err := tx.Commit(); err != nil { if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction") return errors.Wrap(err, "failed to commit transaction")
} }
slog.Info("end migrate") slog.Info("end migrate")
if err := s.updateCurrentSchemaVersion(ctx, currentSchemaVersion); err != nil {
// Upsert the current schema version to migration_history.
// TODO: retire using migration history later.
if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
Version: schemaVersion,
}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion)
}
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
return errors.Wrap(err, "failed to update current schema version") return errors.Wrap(err, "failed to update current schema version")
} }
} }
@ -117,23 +104,17 @@ func (s *Store) Migrate(ctx context.Context) error {
} }
func (s *Store) preMigrate(ctx context.Context) error { func (s *Store) preMigrate(ctx context.Context) error {
// TODO: using schema version in basic setting instead of migration history. initialized, err := s.driver.IsInitialized(ctx)
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) if err != nil {
// If any error occurs or no migration history found, apply the latest schema. return errors.Wrap(err, "failed to check if database is initialized")
if err != nil || len(migrationHistoryList) == 0 { }
if err != nil {
slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error())) if !initialized {
}
filePath := s.getMigrationBasePath() + LatestSchemaFileName filePath := s.getMigrationBasePath() + LatestSchemaFileName
bytes, err := migrationFS.ReadFile(filePath) bytes, err := migrationFS.ReadFile(filePath)
if err != nil { if err != nil {
return errors.Errorf("failed to read latest schema file: %s", err) return errors.Errorf("failed to read latest schema file: %s", err)
} }
schemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
// Start a transaction to apply the latest schema. // Start a transaction to apply the latest schema.
tx, err := s.driver.GetDB().Begin() tx, err := s.driver.GetDB().Begin()
if err != nil { if err != nil {
@ -147,20 +128,23 @@ func (s *Store) preMigrate(ctx context.Context) error {
return errors.Wrap(err, "failed to commit transaction") return errors.Wrap(err, "failed to commit transaction")
} }
// TODO: using schema version in basic setting instead of migration history. // Upsert current schema version to database.
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{ schemaVersion, err := s.GetCurrentSchemaVersion()
Version: schemaVersion, if err != nil {
}); err != nil { return errors.Wrap(err, "failed to get current schema version")
return errors.Wrap(err, "failed to upsert migration history")
} }
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil { if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
return errors.Wrap(err, "failed to update current schema version") return errors.Wrap(err, "failed to update current schema version")
} }
} }
if s.profile.Mode == "prod" { if s.profile.Mode == "prod" {
if err := s.normalizedMigrationHistoryList(ctx); err != nil { if err := s.normalizeMigrationHistoryList(ctx); err != nil {
return errors.Wrap(err, "failed to normalize migration history list") return errors.Wrap(err, "failed to normalize migration history list")
} }
if err := s.migrateSchemaVersionToSetting(ctx); err != nil {
return errors.Wrap(err, "failed to migrate schema version to setting")
}
} }
return nil return nil
} }
@ -249,7 +233,22 @@ func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error {
return nil return nil
} }
func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error { func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get workspace basic setting")
}
workspaceBasicSetting.SchemaVersion = schemaVersion
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_BASIC,
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
}); err != nil {
return errors.Wrap(err, "failed to upsert workspace setting")
}
return nil
}
func (s *Store) normalizeMigrationHistoryList(ctx context.Context) error {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{}) migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
if err != nil { if err != nil {
return errors.Wrap(err, "failed to find migration history") return errors.Wrap(err, "failed to find migration history")
@ -258,6 +257,9 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
for _, migrationHistory := range migrationHistoryList { for _, migrationHistory := range migrationHistoryList {
versions = append(versions, migrationHistory.Version) versions = append(versions, migrationHistory.Version)
} }
if len(versions) == 0 {
return errors.Errorf("no migration history found")
}
sort.Sort(version.SortVersion(versions)) sort.Sort(version.SortVersion(versions))
latestVersion := versions[len(versions)-1] latestVersion := versions[len(versions)-1]
latestMinorVersion := version.GetMinorVersion(latestVersion) latestMinorVersion := version.GetMinorVersion(latestVersion)
@ -289,30 +291,37 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) { if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) {
return nil return nil
} }
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
Version: latestSchemaVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert latest migration history")
}
return nil
}
// Start a transaction to insert the latest schema version to migration_history. func (s *Store) migrateSchemaVersionToSetting(ctx context.Context) error {
tx, err := s.driver.GetDB().Begin() migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
if err != nil { if err != nil {
return errors.Wrap(err, "failed to start transaction") return errors.Wrap(err, "failed to find migration history")
} }
defer tx.Rollback() versions := []string{}
if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil { for _, migrationHistory := range migrationHistoryList {
return errors.Wrap(err, "failed to insert migration history") versions = append(versions, migrationHistory.Version)
} }
return tx.Commit() if len(versions) == 0 {
} return errors.Errorf("no migration history found")
}
sort.Sort(version.SortVersion(versions))
latestVersion := versions[len(versions)-1]
func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx) workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "failed to get workspace basic setting") return errors.Wrap(err, "failed to get workspace basic setting")
} }
workspaceBasicSetting.SchemaVersion = schemaVersion if version.IsVersionGreaterOrEqualThan(workspaceBasicSetting.SchemaVersion, latestVersion) {
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{ if err := s.updateCurrentSchemaVersion(ctx, latestVersion); err != nil {
Key: storepb.WorkspaceSettingKey_BASIC, return errors.Wrap(err, "failed to update current schema version")
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting}, }
}); err != nil {
return errors.Wrap(err, "failed to upsert workspace setting")
} }
return nil return nil
} }

Loading…
Cancel
Save