mirror of https://github.com/synctv-org/synctv
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
175 lines
5.0 KiB
Go
175 lines
5.0 KiB
Go
package bootstrap
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/synctv-org/synctv/cmd/flags"
|
|
"github.com/synctv-org/synctv/internal/conf"
|
|
"github.com/synctv-org/synctv/internal/db"
|
|
sysnotify "github.com/synctv-org/synctv/internal/sysnotify"
|
|
"github.com/synctv-org/synctv/internal/version"
|
|
"github.com/synctv-org/synctv/utils"
|
|
"gorm.io/driver/mysql"
|
|
"gorm.io/driver/postgres"
|
|
"gorm.io/gorm"
|
|
"gorm.io/gorm/logger"
|
|
)
|
|
|
|
func InitDatabase(ctx context.Context) (err error) {
|
|
dialector, err := createDialector(conf.Conf.Database)
|
|
if err != nil {
|
|
log.Fatalf("failed to create dialector: %s", err.Error())
|
|
}
|
|
|
|
var opts []gorm.Option
|
|
opts = append(opts, &gorm.Config{
|
|
TranslateError: true,
|
|
Logger: newDBLogger(),
|
|
PrepareStmt: true,
|
|
DisableForeignKeyConstraintWhenMigrating: false,
|
|
IgnoreRelationshipsWhenMigrating: false,
|
|
})
|
|
d, err := gorm.Open(dialector, opts...)
|
|
if err != nil {
|
|
log.Fatalf("failed to connect database: %s", err.Error())
|
|
}
|
|
sqlDB, err := d.DB()
|
|
if err != nil {
|
|
log.Fatalf("failed to get sqlDB: %s", err.Error())
|
|
}
|
|
err = sysnotify.RegisterSysNotifyTask(0, sysnotify.NewSysNotifyTask("database", sysnotify.NotifyTypeEXIT, func() error {
|
|
return sqlDB.Close()
|
|
}))
|
|
if err != nil {
|
|
log.Fatalf("failed to register sysnotify task: %s", err.Error())
|
|
}
|
|
if conf.Conf.Database.Type != conf.DatabaseTypeSqlite3 {
|
|
initRawDB(sqlDB)
|
|
}
|
|
return db.Init(d, conf.Conf.Database.Type)
|
|
}
|
|
|
|
func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err error) {
|
|
var dsn string
|
|
switch dbConf.Type {
|
|
case conf.DatabaseTypeMysql:
|
|
if dbConf.CustomDSN != "" {
|
|
dsn = dbConf.CustomDSN
|
|
} else if dbConf.Port == 0 {
|
|
dsn = fmt.Sprintf("%s:%s@unix(%s)/%s?charset=utf8mb4&parseTime=True&loc=Local&interpolateParams=true&tls=%s",
|
|
dbConf.User,
|
|
dbConf.Password,
|
|
dbConf.Host,
|
|
dbConf.Name,
|
|
dbConf.SslMode,
|
|
)
|
|
log.Infof("mysql database: %s", dbConf.Host)
|
|
} else {
|
|
dsn = fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=True&loc=Local&interpolateParams=true&tls=%s",
|
|
dbConf.User,
|
|
dbConf.Password,
|
|
dbConf.Host,
|
|
dbConf.Port,
|
|
dbConf.Name,
|
|
dbConf.SslMode,
|
|
)
|
|
log.Infof("mysql database tcp: %s:%d", dbConf.Host, dbConf.Port)
|
|
}
|
|
dialector = mysql.New(mysql.Config{
|
|
DSN: dsn,
|
|
DefaultStringSize: 256,
|
|
DisableDatetimePrecision: true,
|
|
DontSupportRenameIndex: true,
|
|
DontSupportRenameColumn: true,
|
|
SkipInitializeWithVersion: false,
|
|
})
|
|
case conf.DatabaseTypeSqlite3:
|
|
if dbConf.CustomDSN != "" {
|
|
dsn = dbConf.CustomDSN
|
|
} else if dbConf.Name == "memory" || strings.HasPrefix(dbConf.Name, ":memory:") {
|
|
dsn = "file::memory:?cache=shared&_journal_mode=WAL&_vacuum=incremental&_pragma=foreign_keys(1)"
|
|
log.Infof("sqlite3 database memory")
|
|
} else {
|
|
if !strings.HasSuffix(dbConf.Name, ".db") {
|
|
dbConf.Name += ".db"
|
|
}
|
|
dbConf.Name, err = utils.OptFilePath(dbConf.Name)
|
|
if err != nil {
|
|
log.Fatalf("sqlite3 database file path error: %v", err)
|
|
}
|
|
dsn = dbConf.Name + "?_journal_mode=WAL&_vacuum=incremental&_pragma=foreign_keys(1)"
|
|
log.Infof("sqlite3 database file: %s", dbConf.Name)
|
|
}
|
|
dialector = openSqlite(dsn)
|
|
case conf.DatabaseTypePostgres:
|
|
if dbConf.CustomDSN != "" {
|
|
dsn = dbConf.CustomDSN
|
|
} else if dbConf.Port == 0 {
|
|
dsn = fmt.Sprintf("host=%s user=%s password=%s dbname=%s sslmode=%s",
|
|
dbConf.Host,
|
|
dbConf.User,
|
|
dbConf.Password,
|
|
dbConf.Name,
|
|
dbConf.SslMode,
|
|
)
|
|
log.Infof("postgres database: %s", dbConf.Host)
|
|
} else {
|
|
dsn = fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=%s",
|
|
dbConf.Host,
|
|
dbConf.Port,
|
|
dbConf.User,
|
|
dbConf.Password,
|
|
dbConf.Name,
|
|
dbConf.SslMode,
|
|
)
|
|
log.Infof("postgres database tcp: %s:%d", dbConf.Host, dbConf.Port)
|
|
}
|
|
dialector = postgres.New(postgres.Config{
|
|
DSN: dsn,
|
|
PreferSimpleProtocol: true,
|
|
})
|
|
default:
|
|
log.Fatalf("unknown database type: %s", dbConf.Type)
|
|
}
|
|
return
|
|
}
|
|
|
|
func newDBLogger() logger.Interface {
|
|
var logLevel logger.LogLevel
|
|
if flags.Global.Dev {
|
|
logLevel = logger.Info
|
|
} else {
|
|
logLevel = logger.Warn
|
|
}
|
|
return logger.New(
|
|
log.StandardLogger(),
|
|
logger.Config{
|
|
SlowThreshold: time.Second,
|
|
LogLevel: logLevel,
|
|
IgnoreRecordNotFoundError: true,
|
|
ParameterizedQueries: !flags.Global.Dev && version.Version != "dev",
|
|
Colorful: utils.ForceColor(),
|
|
},
|
|
)
|
|
}
|
|
|
|
func initRawDB(db *sql.DB) {
|
|
db.SetMaxOpenConns(conf.Conf.Database.MaxOpenConns)
|
|
db.SetMaxIdleConns(conf.Conf.Database.MaxIdleConns)
|
|
d, err := time.ParseDuration(conf.Conf.Database.ConnMaxLifetime)
|
|
if err != nil {
|
|
log.Fatalf("failed to parse conn_max_lifetime: %s", err.Error())
|
|
}
|
|
db.SetConnMaxLifetime(d)
|
|
d, err = time.ParseDuration(conf.Conf.Database.ConnMaxIdleTime)
|
|
if err != nil {
|
|
log.Fatalf("failed to parse conn_max_idle_time: %s", err.Error())
|
|
}
|
|
db.SetConnMaxIdleTime(d)
|
|
}
|