package mysqlite import ( "database/sql" "embed" "fmt" "io/fs" "log" "strconv" "strings" "zombiezen.com/go/sqlite" ) type ReadDirFileFS interface { fs.ReadDirFS fs.ReadFileFS } func (db *Db) MigrateDb(migrations ReadDirFileFS) error { // Read all migrations migrationFiles, err := migrations.ReadDir("") if err != nil { log.Fatalf("error reading migration files: %v", err) } var migrationsByVersion = make(map[int]string) latestVersion := 0 for _, f := range migrationFiles { versionStr := f.Name() version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) if err != nil { log.Fatalf("invalid version number for migration script: %v", err) } migrationsByVersion[version] = versionStr latestVersion = max(latestVersion, version) } // Get current migration version from user_version var currentVersion int err = d.QuerySingle("PRAGMA user_version", ¤tVersion) if err != nil { log.Fatalf("error getting current version: %v", err) } log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) // If we are no up-to-date, bring the db up-to-date for currentVersion != latestVersion { targetVersion := currentVersion + 1 migrationFile := migrationsByVersion[targetVersion] log.Printf("migration to version %s", migrationFile) migrationScript, err := migrations.ReadFile(migrationFile) if err != nil { log.Fatalf("error opening migration script %s: %v", migrationScript, err) } tx, err := db.Begin() if err != nil { log.Fatalf("error beginning transaction: %v", err) } defer tx.MustRollback() err = tx.QuerySingle(string(migrationScript)) if err != nil { log.Fatalf("error performing migration: %v", err) } err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) if err != nil { log.Fatalf("error updating version: %v", err) } err = tx.Commit() if err != nil { log.Fatalf("error commiting transaction: %v", err) } currentVersion = targetVersion } log.Println("All migrations applied") return nil } func rollbackIgnoringErrors(tx *sql.Tx) { err := tx.Rollback() if err != nil { log.Printf("error rolling back: %v", err) } }