mysqlite/migrator.go
2025-03-16 11:38:40 +01:00

98 lines
2.5 KiB
Go

package mysqlite
import (
"fmt"
"io/fs"
"log"
"path"
"strconv"
"strings"
)
type ReadDirFileFS interface {
fs.ReadDirFS
fs.ReadFileFS
}
func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
// Read all migrations
migrationFiles, err := filesystem.ReadDir(directory)
if err != nil {
return fmt.Errorf("error reading migration files: %v", err)
}
var migrationsByVersion = make(map[int]string)
latestVersion := 0
for _, f := range migrationFiles {
versionStr := f.Name()
version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
if err != nil {
return fmt.Errorf("invalid version number for migration script: %v", err)
}
migrationsByVersion[version] = versionStr
latestVersion = max(latestVersion, version)
}
// Get current migration version from user_version
var currentVersion int
err = d.Query("PRAGMA user_version").ScanSingle(&currentVersion)
if err != nil {
return fmt.Errorf("error getting current version: %v", err)
}
log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
// If we are no up-to-date, bring the db up-to-date
for currentVersion != latestVersion {
targetVersion := currentVersion + 1
migrationFile := migrationsByVersion[targetVersion]
log.Printf("migrating to version %s", migrationFile)
migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile))
if err != nil {
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
}
err = performSingleMigration(d, migrationScript, targetVersion)
if err != nil {
return err
}
currentVersion = targetVersion
}
log.Println("Database is up-to-date")
return nil
}
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
script := string(migrationScript)
// Split script based on semicolon
statements := strings.Split(script, ";")
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("error beginning transaction: %v", err)
}
defer tx.MustRollback()
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
err = tx.Query(statement).Exec()
if err != nil {
return fmt.Errorf("error performing migration: %v", err)
}
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
if err != nil {
return fmt.Errorf("error updating version: %v", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("error commiting transaction: %v", err)
}
return nil
}