package mysqlite import ( "fmt" "io/fs" "log" "path" "strconv" "strings" ) type ReadDirFileFS interface { fs.ReadDirFS fs.ReadFileFS } func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { // Read all migrations migrationFiles, err := filesystem.ReadDir(directory) if err != nil { return fmt.Errorf("error reading migration files: %v", err) } var migrationsByVersion = make(map[int]string) latestVersion := 0 for _, f := range migrationFiles { versionStr := f.Name() version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) if err != nil { return fmt.Errorf("invalid version number for migration script: %v", err) } migrationsByVersion[version] = versionStr latestVersion = max(latestVersion, version) } // Get current migration version from user_version var currentVersion int err = d.Query("PRAGMA user_version").ScanSingle(¤tVersion) if err != nil { return fmt.Errorf("error getting current version: %v", err) } log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) // If we are no up-to-date, bring the db up-to-date for currentVersion != latestVersion { targetVersion := currentVersion + 1 migrationFile := migrationsByVersion[targetVersion] log.Printf("migrating to version %s", migrationFile) migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile)) if err != nil { return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) } err = performSingleMigration(d, migrationScript, targetVersion) if err != nil { return err } currentVersion = targetVersion } log.Println("Database is up-to-date") return nil } func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error { script := string(migrationScript) // Split script based on semicolon statements := strings.Split(script, ";") tx, err := d.Begin() if err != nil { return fmt.Errorf("error beginning transaction: %v", err) } defer tx.MustRollback() for _, statement := range statements { statement = strings.TrimSpace(statement) err = tx.Query(statement).Exec() if err != nil { return fmt.Errorf("error performing migration: %v", err) } err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() if err != nil { return fmt.Errorf("error updating version: %v", err) } } err = tx.Commit() if err != nil { return fmt.Errorf("error commiting transaction: %v", err) } return nil }