From 96b27ff99dc37fa62eaa76079ac8ced4d0fb9651 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Wed, 19 Feb 2025 05:33:47 +0100 Subject: [PATCH] Add basic migration test --- migrator.go | 169 +++++++++++++++++------------------ migrator_test.go | 20 +++++ testMigrations/1_initial.sql | 3 + testMigrations/2_addRow.sql | 2 + 4 files changed, 109 insertions(+), 85 deletions(-) create mode 100644 migrator_test.go create mode 100644 testMigrations/1_initial.sql create mode 100644 testMigrations/2_addRow.sql diff --git a/migrator.go b/migrator.go index ebcf314..47b8ff6 100644 --- a/migrator.go +++ b/migrator.go @@ -1,87 +1,86 @@ package mysqlite -//import ( -// "database/sql" -// "embed" -// "fmt" -// "io/fs" -// "log" -// "strconv" -// "strings" -// "zombiezen.com/go/sqlite" -//) -// -//type ReadDirFileFS interface { -// fs.ReadDirFS -// fs.ReadFileFS -//} -// -//func (db *Db) MigrateDb(migrations ReadDirFileFS) error { -// // Read all migrations -// migrationFiles, err := migrations.ReadDir("") -// if err != nil { -// log.Fatalf("error reading migration files: %v", err) -// } -// var migrationsByVersion = make(map[int]string) -// latestVersion := 0 -// for _, f := range migrationFiles { -// versionStr := f.Name() -// version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) -// if err != nil { -// log.Fatalf("invalid version number for migration script: %v", err) -// } -// migrationsByVersion[version] = versionStr -// latestVersion = max(latestVersion, version) -// } -// -// // Get current migration version from user_version -// var currentVersion int -// err = d.QuerySingle("PRAGMA user_version", ¤tVersion) -// if err != nil { -// log.Fatalf("error getting current version: %v", err) -// } -// log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) -// -// // If we are no up-to-date, bring the db up-to-date -// for currentVersion != latestVersion { -// targetVersion := currentVersion + 1 -// migrationFile := migrationsByVersion[targetVersion] -// log.Printf("migration to version %s", migrationFile) -// migrationScript, err := migrations.ReadFile(migrationFile) -// if err != nil { -// log.Fatalf("error opening migration script %s: %v", migrationScript, err) -// } -// -// tx, err := db.Begin() -// if err != nil { -// log.Fatalf("error beginning transaction: %v", err) -// } -// defer tx.MustRollback() -// -// err = tx.QuerySingle(string(migrationScript)) -// if err != nil { -// log.Fatalf("error performing migration: %v", err) -// } -// -// err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) -// if err != nil { -// log.Fatalf("error updating version: %v", err) -// } -// -// err = tx.Commit() -// if err != nil { -// log.Fatalf("error commiting transaction: %v", err) -// } -// currentVersion = targetVersion -// } -// -// log.Println("All migrations applied") -// return nil -//} -// -//func rollbackIgnoringErrors(tx *sql.Tx) { -// err := tx.Rollback() -// if err != nil { -// log.Printf("error rolling back: %v", err) -// } -//} +import ( + "fmt" + "io/fs" + "log" + "path" + "strconv" + "strings" +) + +type ReadDirFileFS interface { + fs.ReadDirFS + fs.ReadFileFS +} + +func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { + // Read all migrations + migrationFiles, err := filesystem.ReadDir(directory) + if err != nil { + return fmt.Errorf("error reading migration files: %v", err) + } + var migrationsByVersion = make(map[int]string) + latestVersion := 0 + for _, f := range migrationFiles { + versionStr := f.Name() + version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) + if err != nil { + return fmt.Errorf("invalid version number for migration script: %v", err) + } + migrationsByVersion[version] = versionStr + latestVersion = max(latestVersion, version) + } + + // Get current migration version from user_version + var currentVersion int + err = d.Query("PRAGMA user_version").ScanSingle(¤tVersion) + if err != nil { + return fmt.Errorf("error getting current version: %v", err) + } + log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) + + // If we are no up-to-date, bring the db up-to-date + for currentVersion != latestVersion { + targetVersion := currentVersion + 1 + migrationFile := migrationsByVersion[targetVersion] + log.Printf("migrating to version %s", migrationFile) + migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile)) + if err != nil { + return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) + } + + err = performSingleMigration(err, d, migrationScript, targetVersion) + if err != nil { + return err + } + currentVersion = targetVersion + } + + log.Println("Database is up-to-date") + return nil +} + +func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error { + tx, err := d.Begin() + if err != nil { + return fmt.Errorf("error beginning transaction: %v", err) + } + defer tx.MustRollback() + + err = tx.Query(string(migrationScript)).Exec() + if err != nil { + return fmt.Errorf("error performing migration: %v", err) + } + + err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() + if err != nil { + return fmt.Errorf("error updating version: %v", err) + } + + err = tx.Commit() + if err != nil { + return fmt.Errorf("error commiting transaction: %v", err) + } + return nil +} diff --git a/migrator_test.go b/migrator_test.go new file mode 100644 index 0000000..e8ce92d --- /dev/null +++ b/migrator_test.go @@ -0,0 +1,20 @@ +package mysqlite + +import ( + "embed" + "github.com/stretchr/testify/require" + "testing" +) + +//go:embed testMigrations/*.sql +var migrations embed.FS + +func TestDb_MigrateDb(t *testing.T) { + db := openEmptyTestDb(t) + err := db.MigrateDb(migrations, "testMigrations") + require.NoError(t, err) + + var count int + db.Query("select count(*) from mydata").MustScanSingle(&count) + require.Equal(t, 1, count, "incorrect number of rows in database") +} diff --git a/testMigrations/1_initial.sql b/testMigrations/1_initial.sql new file mode 100644 index 0000000..b314f2d --- /dev/null +++ b/testMigrations/1_initial.sql @@ -0,0 +1,3 @@ +create table mydata ( + value text +) diff --git a/testMigrations/2_addRow.sql b/testMigrations/2_addRow.sql new file mode 100644 index 0000000..581ebc7 --- /dev/null +++ b/testMigrations/2_addRow.sql @@ -0,0 +1,2 @@ +insert into mydata (value) +values ('hello')