refactor: schema migrator

pull/4772/head^2
Johnny 4 months ago
parent d386b83b7b
commit 3fd29f6493

@ -1,6 +1,7 @@
package mysql
import (
"context"
"database/sql"
"github.com/go-sql-driver/mysql"
@ -47,6 +48,15 @@ func (d *DB) Close() error {
return d.db.Close()
}
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
var exists bool
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE TABLE_NAME = 'memo' AND TABLE_TYPE = 'BASE TABLE')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}
func mergeDSN(baseDSN string) (string, error) {
config, err := mysql.ParseDSN(baseDSN)
if err != nil {

@ -1,6 +1,7 @@
package postgres
import (
"context"
"database/sql"
"log"
@ -15,7 +16,6 @@ import (
type DB struct {
db *sql.DB
profile *profile.Profile
// Add any other fields as needed
}
func NewDB(profile *profile.Profile) (store.Driver, error) {
@ -46,3 +46,12 @@ func (d *DB) GetDB() *sql.DB {
func (d *DB) Close() error {
return d.db.Close()
}
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
var exists bool
err := d.db.QueryRow("SELECT EXISTS (SELECT 1 FROM information_schema.tables WHERE table_name = 'memo' AND table_type = 'BASE TABLE')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}

@ -1,6 +1,7 @@
package sqlite
import (
"context"
"database/sql"
"github.com/pkg/errors"
@ -57,3 +58,13 @@ func (d *DB) GetDB() *sql.DB {
func (d *DB) Close() error {
return d.db.Close()
}
func (d *DB) IsInitialized(ctx context.Context) (bool, error) {
// Check if the database is initialized by checking if the memo table exists.
var exists bool
err := d.db.QueryRowContext(ctx, "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND name='memo')").Scan(&exists)
if err != nil {
return false, errors.Wrap(err, "failed to check if database is initialized")
}
return exists, nil
}

@ -15,6 +15,8 @@ type Driver interface {
GetDB() *sql.DB
Close() error
IsInitialized(ctx context.Context) (bool, error)
// MigrationHistory model related methods.
FindMigrationHistoryList(ctx context.Context, find *FindMigrationHistory) ([]*MigrationHistory, error)
UpsertMigrationHistory(ctx context.Context, upsert *UpsertMigrationHistory) (*MigrationHistory, error)

@ -40,26 +40,22 @@ func (s *Store) Migrate(ctx context.Context) error {
}
if s.profile.Mode == "prod" {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to find migration history")
return errors.Wrap(err, "failed to get workspace basic setting")
}
if len(migrationHistoryList) == 0 {
return errors.Errorf("no migration history found")
}
migrationHistoryVersions := []string{}
for _, migrationHistory := range migrationHistoryList {
migrationHistoryVersions = append(migrationHistoryVersions, migrationHistory.Version)
}
sort.Sort(version.SortVersion(migrationHistoryVersions))
latestMigrationHistoryVersion := migrationHistoryVersions[len(migrationHistoryVersions)-1]
schemaVersion, err := s.GetCurrentSchemaVersion()
currentSchemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
if version.IsVersionGreaterThan(schemaVersion, latestMigrationHistoryVersion) {
if version.IsVersionGreaterThan(workspaceBasicSetting.SchemaVersion, currentSchemaVersion) {
slog.Error("cannot downgrade schema version",
slog.String("databaseVersion", workspaceBasicSetting.SchemaVersion),
slog.String("currentVersion", currentSchemaVersion),
)
return errors.Errorf("cannot downgrade schema version from %s to %s", workspaceBasicSetting.SchemaVersion, currentSchemaVersion)
}
if version.IsVersionGreaterThan(currentSchemaVersion, workspaceBasicSetting.SchemaVersion) {
filePaths, err := fs.Glob(migrationFS, fmt.Sprintf("%s*/*.sql", s.getMigrationBasePath()))
if err != nil {
return errors.Wrap(err, "failed to read migration files")
@ -73,13 +69,13 @@ func (s *Store) Migrate(ctx context.Context) error {
}
defer tx.Rollback()
slog.Info("start migration", slog.String("currentSchemaVersion", latestMigrationHistoryVersion), slog.String("targetSchemaVersion", schemaVersion))
slog.Info("start migration", slog.String("currentSchemaVersion", workspaceBasicSetting.SchemaVersion), slog.String("targetSchemaVersion", currentSchemaVersion))
for _, filePath := range filePaths {
fileSchemaVersion, err := s.getSchemaVersionOfMigrateScript(filePath)
if err != nil {
return errors.Wrap(err, "failed to get schema version of migrate script")
}
if version.IsVersionGreaterThan(fileSchemaVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(schemaVersion, fileSchemaVersion) {
if version.IsVersionGreaterThan(fileSchemaVersion, workspaceBasicSetting.SchemaVersion) && version.IsVersionGreaterOrEqualThan(currentSchemaVersion, fileSchemaVersion) {
bytes, err := migrationFS.ReadFile(filePath)
if err != nil {
return errors.Wrapf(err, "failed to read minor version migration file: %s", filePath)
@ -90,20 +86,11 @@ func (s *Store) Migrate(ctx context.Context) error {
}
}
}
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "failed to commit transaction")
}
slog.Info("end migrate")
// Upsert the current schema version to migration_history.
// TODO: retire using migration history later.
if _, err = s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
Version: schemaVersion,
}); err != nil {
return errors.Wrapf(err, "failed to upsert migration history with version: %s", schemaVersion)
}
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
if err := s.updateCurrentSchemaVersion(ctx, currentSchemaVersion); err != nil {
return errors.Wrap(err, "failed to update current schema version")
}
}
@ -117,23 +104,17 @@ func (s *Store) Migrate(ctx context.Context) error {
}
func (s *Store) preMigrate(ctx context.Context) error {
// TODO: using schema version in basic setting instead of migration history.
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
// If any error occurs or no migration history found, apply the latest schema.
if err != nil || len(migrationHistoryList) == 0 {
if err != nil {
slog.Warn("failed to find migration history in pre-migrate", slog.String("error", err.Error()))
}
initialized, err := s.driver.IsInitialized(ctx)
if err != nil {
return errors.Wrap(err, "failed to check if database is initialized")
}
if !initialized {
filePath := s.getMigrationBasePath() + LatestSchemaFileName
bytes, err := migrationFS.ReadFile(filePath)
if err != nil {
return errors.Errorf("failed to read latest schema file: %s", err)
}
schemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
// Start a transaction to apply the latest schema.
tx, err := s.driver.GetDB().Begin()
if err != nil {
@ -147,20 +128,23 @@ func (s *Store) preMigrate(ctx context.Context) error {
return errors.Wrap(err, "failed to commit transaction")
}
// TODO: using schema version in basic setting instead of migration history.
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
Version: schemaVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert migration history")
// Upsert current schema version to database.
schemaVersion, err := s.GetCurrentSchemaVersion()
if err != nil {
return errors.Wrap(err, "failed to get current schema version")
}
if err := s.updateCurrentSchemaVersion(ctx, schemaVersion); err != nil {
return errors.Wrap(err, "failed to update current schema version")
}
}
if s.profile.Mode == "prod" {
if err := s.normalizedMigrationHistoryList(ctx); err != nil {
if err := s.normalizeMigrationHistoryList(ctx); err != nil {
return errors.Wrap(err, "failed to normalize migration history list")
}
if err := s.migrateSchemaVersionToSetting(ctx); err != nil {
return errors.Wrap(err, "failed to migrate schema version to setting")
}
}
return nil
}
@ -249,7 +233,22 @@ func (*Store) execute(ctx context.Context, tx *sql.Tx, stmt string) error {
return nil
}
func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get workspace basic setting")
}
workspaceBasicSetting.SchemaVersion = schemaVersion
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_BASIC,
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
}); err != nil {
return errors.Wrap(err, "failed to upsert workspace setting")
}
return nil
}
func (s *Store) normalizeMigrationHistoryList(ctx context.Context) error {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
if err != nil {
return errors.Wrap(err, "failed to find migration history")
@ -258,6 +257,9 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
for _, migrationHistory := range migrationHistoryList {
versions = append(versions, migrationHistory.Version)
}
if len(versions) == 0 {
return errors.Errorf("no migration history found")
}
sort.Sort(version.SortVersion(versions))
latestVersion := versions[len(versions)-1]
latestMinorVersion := version.GetMinorVersion(latestVersion)
@ -289,30 +291,37 @@ func (s *Store) normalizedMigrationHistoryList(ctx context.Context) error {
if version.IsVersionGreaterOrEqualThan(latestVersion, latestSchemaVersion) {
return nil
}
if _, err := s.driver.UpsertMigrationHistory(ctx, &UpsertMigrationHistory{
Version: latestSchemaVersion,
}); err != nil {
return errors.Wrap(err, "failed to upsert latest migration history")
}
return nil
}
// Start a transaction to insert the latest schema version to migration_history.
tx, err := s.driver.GetDB().Begin()
func (s *Store) migrateSchemaVersionToSetting(ctx context.Context) error {
migrationHistoryList, err := s.driver.FindMigrationHistoryList(ctx, &FindMigrationHistory{})
if err != nil {
return errors.Wrap(err, "failed to start transaction")
return errors.Wrap(err, "failed to find migration history")
}
defer tx.Rollback()
if err := s.execute(ctx, tx, fmt.Sprintf("INSERT INTO migration_history (version) VALUES ('%s')", latestSchemaVersion)); err != nil {
return errors.Wrap(err, "failed to insert migration history")
versions := []string{}
for _, migrationHistory := range migrationHistoryList {
versions = append(versions, migrationHistory.Version)
}
return tx.Commit()
}
if len(versions) == 0 {
return errors.Errorf("no migration history found")
}
sort.Sort(version.SortVersion(versions))
latestVersion := versions[len(versions)-1]
func (s *Store) updateCurrentSchemaVersion(ctx context.Context, schemaVersion string) error {
workspaceBasicSetting, err := s.GetWorkspaceBasicSetting(ctx)
if err != nil {
return errors.Wrap(err, "failed to get workspace basic setting")
}
workspaceBasicSetting.SchemaVersion = schemaVersion
if _, err := s.UpsertWorkspaceSetting(ctx, &storepb.WorkspaceSetting{
Key: storepb.WorkspaceSettingKey_BASIC,
Value: &storepb.WorkspaceSetting_BasicSetting{BasicSetting: workspaceBasicSetting},
}); err != nil {
return errors.Wrap(err, "failed to upsert workspace setting")
if version.IsVersionGreaterOrEqualThan(workspaceBasicSetting.SchemaVersion, latestVersion) {
if err := s.updateCurrentSchemaVersion(ctx, latestVersion); err != nil {
return errors.Wrap(err, "failed to update current schema version")
}
}
return nil
}

Loading…
Cancel
Save