All checks were successful
Build / build (push) Successful in 1m35s
113 lines
3.0 KiB
Go
113 lines
3.0 KiB
Go
package mysqlite
|
|
|
|
import (
|
|
"fmt"
|
|
"io/fs"
|
|
"log"
|
|
"os"
|
|
"path"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
type ReadDirFileFS interface {
|
|
fs.ReadDirFS
|
|
fs.ReadFileFS
|
|
}
|
|
|
|
func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
|
// Read all migrations
|
|
migrationFiles, err := filesystem.ReadDir(directory)
|
|
if err != nil {
|
|
return fmt.Errorf("error reading migration files: %v", err)
|
|
}
|
|
var migrationsByVersion = make(map[int]string)
|
|
latestVersion := 0
|
|
for _, f := range migrationFiles {
|
|
versionStr := f.Name()
|
|
version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
|
|
if err != nil {
|
|
return fmt.Errorf("invalid version number for migration script: %v", err)
|
|
}
|
|
migrationsByVersion[version] = versionStr
|
|
latestVersion = max(latestVersion, version)
|
|
}
|
|
|
|
// Get current migration version from user_version
|
|
var currentVersion int
|
|
err = d.Query("PRAGMA user_version").ScanSingle(¤tVersion)
|
|
if err != nil {
|
|
return fmt.Errorf("error getting current version: %v", err)
|
|
}
|
|
log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
|
|
|
|
// Create a backup if we're not on the latest version
|
|
if currentVersion != latestVersion && d.source != ":memory:" {
|
|
target := d.source + ".backup." + strconv.Itoa(currentVersion)
|
|
log.Printf("Creating backup of database to %s", target)
|
|
data, err := d.Db.Serialize("main")
|
|
if err != nil {
|
|
return fmt.Errorf("error serializing database: %v", err)
|
|
}
|
|
err = os.WriteFile(target, data, 0644)
|
|
if err != nil {
|
|
return fmt.Errorf("error writing backup: %v", err)
|
|
}
|
|
}
|
|
|
|
// If we are no up-to-date, bring the db up-to-date
|
|
for currentVersion != latestVersion {
|
|
targetVersion := currentVersion + 1
|
|
migrationFile := migrationsByVersion[targetVersion]
|
|
log.Printf("migrating to version %s", migrationFile)
|
|
migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile))
|
|
if err != nil {
|
|
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
|
|
}
|
|
|
|
err = performSingleMigration(d, migrationScript, targetVersion)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
currentVersion = targetVersion
|
|
}
|
|
|
|
log.Println("Database is up-to-date")
|
|
return nil
|
|
}
|
|
|
|
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
|
|
script := string(migrationScript)
|
|
// Split script based on semicolon
|
|
statements := strings.Split(script, ";")
|
|
|
|
tx, err := d.Begin()
|
|
if err != nil {
|
|
return fmt.Errorf("error beginning transaction: %v", err)
|
|
}
|
|
defer tx.MustRollback()
|
|
|
|
for _, statement := range statements {
|
|
statement = strings.TrimSpace(statement)
|
|
if statement == "" {
|
|
continue
|
|
}
|
|
err = tx.Query(statement).Exec()
|
|
if err != nil {
|
|
return fmt.Errorf("error performing migration: %v", err)
|
|
}
|
|
|
|
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
|
|
if err != nil {
|
|
return fmt.Errorf("error updating version: %v", err)
|
|
}
|
|
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return fmt.Errorf("error commiting transaction: %v", err)
|
|
}
|
|
return nil
|
|
}
|