Add basic migration test
All checks were successful
Build / build (push) Successful in 2m12s

This commit is contained in:
Sebastiaan de Schaetzen 2025-02-19 05:33:47 +01:00
parent 43d6265c4c
commit 96b27ff99d
4 changed files with 109 additions and 85 deletions

View File

@ -1,87 +1,86 @@
package mysqlite package mysqlite
//import ( import (
// "database/sql" "fmt"
// "embed" "io/fs"
// "fmt" "log"
// "io/fs" "path"
// "log" "strconv"
// "strconv" "strings"
// "strings" )
// "zombiezen.com/go/sqlite"
//) type ReadDirFileFS interface {
// fs.ReadDirFS
//type ReadDirFileFS interface { fs.ReadFileFS
// fs.ReadDirFS }
// fs.ReadFileFS
//} func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
// // Read all migrations
//func (db *Db) MigrateDb(migrations ReadDirFileFS) error { migrationFiles, err := filesystem.ReadDir(directory)
// // Read all migrations if err != nil {
// migrationFiles, err := migrations.ReadDir("") return fmt.Errorf("error reading migration files: %v", err)
// if err != nil { }
// log.Fatalf("error reading migration files: %v", err) var migrationsByVersion = make(map[int]string)
// } latestVersion := 0
// var migrationsByVersion = make(map[int]string) for _, f := range migrationFiles {
// latestVersion := 0 versionStr := f.Name()
// for _, f := range migrationFiles { version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
// versionStr := f.Name() if err != nil {
// version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) return fmt.Errorf("invalid version number for migration script: %v", err)
// if err != nil { }
// log.Fatalf("invalid version number for migration script: %v", err) migrationsByVersion[version] = versionStr
// } latestVersion = max(latestVersion, version)
// migrationsByVersion[version] = versionStr }
// latestVersion = max(latestVersion, version)
// } // Get current migration version from user_version
// var currentVersion int
// // Get current migration version from user_version err = d.Query("PRAGMA user_version").ScanSingle(&currentVersion)
// var currentVersion int if err != nil {
// err = d.QuerySingle("PRAGMA user_version", &currentVersion) return fmt.Errorf("error getting current version: %v", err)
// if err != nil { }
// log.Fatalf("error getting current version: %v", err) log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
// }
// log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) // If we are no up-to-date, bring the db up-to-date
// for currentVersion != latestVersion {
// // If we are no up-to-date, bring the db up-to-date targetVersion := currentVersion + 1
// for currentVersion != latestVersion { migrationFile := migrationsByVersion[targetVersion]
// targetVersion := currentVersion + 1 log.Printf("migrating to version %s", migrationFile)
// migrationFile := migrationsByVersion[targetVersion] migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile))
// log.Printf("migration to version %s", migrationFile) if err != nil {
// migrationScript, err := migrations.ReadFile(migrationFile) return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
// if err != nil { }
// log.Fatalf("error opening migration script %s: %v", migrationScript, err)
// } err = performSingleMigration(err, d, migrationScript, targetVersion)
// if err != nil {
// tx, err := db.Begin() return err
// if err != nil { }
// log.Fatalf("error beginning transaction: %v", err) currentVersion = targetVersion
// } }
// defer tx.MustRollback()
// log.Println("Database is up-to-date")
// err = tx.QuerySingle(string(migrationScript)) return nil
// if err != nil { }
// log.Fatalf("error performing migration: %v", err)
// } func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
// tx, err := d.Begin()
// err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) if err != nil {
// if err != nil { return fmt.Errorf("error beginning transaction: %v", err)
// log.Fatalf("error updating version: %v", err) }
// } defer tx.MustRollback()
//
// err = tx.Commit() err = tx.Query(string(migrationScript)).Exec()
// if err != nil { if err != nil {
// log.Fatalf("error commiting transaction: %v", err) return fmt.Errorf("error performing migration: %v", err)
// } }
// currentVersion = targetVersion
// } err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
// if err != nil {
// log.Println("All migrations applied") return fmt.Errorf("error updating version: %v", err)
// return nil }
//}
// err = tx.Commit()
//func rollbackIgnoringErrors(tx *sql.Tx) { if err != nil {
// err := tx.Rollback() return fmt.Errorf("error commiting transaction: %v", err)
// if err != nil { }
// log.Printf("error rolling back: %v", err) return nil
// } }
//}

20
migrator_test.go Normal file
View File

@ -0,0 +1,20 @@
package mysqlite
import (
"embed"
"github.com/stretchr/testify/require"
"testing"
)
//go:embed testMigrations/*.sql
var migrations embed.FS
func TestDb_MigrateDb(t *testing.T) {
db := openEmptyTestDb(t)
err := db.MigrateDb(migrations, "testMigrations")
require.NoError(t, err)
var count int
db.Query("select count(*) from mydata").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database")
}

View File

@ -0,0 +1,3 @@
create table mydata (
value text
)

View File

@ -0,0 +1,2 @@
insert into mydata (value)
values ('hello')