diff --git a/store/db/db.go b/store/db/db.go index 482d41d0..69bf3b6d 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -142,26 +142,23 @@ func (db *DB) compareMigrationHistory() error { return err } if table == nil { - createTable(db, ` + if err := createTable(db, ` CREATE TABLE migration_history ( version TEXT NOT NULL PRIMARY KEY, created_ts BIGINT NOT NULL DEFAULT (strftime('%s', 'now')) ); - `) + `); err != nil { + return err + } } - migrationHistoryList, err := findMigrationHistoryList(db) + currentVersion := common.Version + migrationHistory, err := upsertMigrationHistory(db.Db, currentVersion) if err != nil { return err } - - if len(migrationHistoryList) == 0 { - createMigrationHistory(db, common.Version) - } else { - migrationHistory := migrationHistoryList[0] - if migrationHistory.Version != common.Version { - createMigrationHistory(db, common.Version) - } + if migrationHistory == nil { + return fmt.Errorf("failed to upsert migration history") } return nil diff --git a/store/db/migration_history.go b/store/db/migration_history.go index b2c875fe..ed820c81 100644 --- a/store/db/migration_history.go +++ b/store/db/migration_history.go @@ -1,7 +1,7 @@ package db import ( - "fmt" + "database/sql" ) type MigrationHistory struct { @@ -9,53 +9,32 @@ type MigrationHistory struct { Version string } -func findMigrationHistoryList(db *DB) ([]*MigrationHistory, error) { - rows, err := db.Db.Query(` - SELECT - version, - created_ts - FROM - migration_history - ORDER BY created_ts DESC - `) - if err != nil { - return nil, err - } - defer rows.Close() - - migrationHistoryList := make([]*MigrationHistory, 0) - for rows.Next() { - var migrationHistory MigrationHistory - if err := rows.Scan( - &migrationHistory.Version, - &migrationHistory.CreatedTs, - ); err != nil { - return nil, err - } - - migrationHistoryList = append(migrationHistoryList, &migrationHistory) - } - - return migrationHistoryList, nil -} - -func createMigrationHistory(db *DB, version string) error { - result, err := db.Db.Exec(` +func upsertMigrationHistory(db *sql.DB, version string) (*MigrationHistory, error) { + row, err := db.Query(` INSERT INTO migration_history ( version ) VALUES (?) + ON CONFLICT(version) DO UPDATE + SET + version=EXCLUDED.version + RETURNING version, created_ts `, version, ) if err != nil { - return err + return nil, err } - - rows, _ := result.RowsAffected() - if rows == 0 { - return fmt.Errorf("failed to create migration history with %s", version) + defer row.Close() + + row.Next() + migrationHistory := MigrationHistory{} + if err := row.Scan( + &migrationHistory.Version, + &migrationHistory.CreatedTs, + ); err != nil { + return nil, err } - return nil + return &migrationHistory, nil }