This commit is contained in:
parent
43d6265c4c
commit
96b27ff99d
169
migrator.go
169
migrator.go
@ -1,87 +1,86 @@
|
|||||||
package mysqlite
|
package mysqlite
|
||||||
|
|
||||||
//import (
|
import (
|
||||||
// "database/sql"
|
"fmt"
|
||||||
// "embed"
|
"io/fs"
|
||||||
// "fmt"
|
"log"
|
||||||
// "io/fs"
|
"path"
|
||||||
// "log"
|
"strconv"
|
||||||
// "strconv"
|
"strings"
|
||||||
// "strings"
|
)
|
||||||
// "zombiezen.com/go/sqlite"
|
|
||||||
//)
|
type ReadDirFileFS interface {
|
||||||
//
|
fs.ReadDirFS
|
||||||
//type ReadDirFileFS interface {
|
fs.ReadFileFS
|
||||||
// fs.ReadDirFS
|
}
|
||||||
// fs.ReadFileFS
|
|
||||||
//}
|
func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
||||||
//
|
// Read all migrations
|
||||||
//func (db *Db) MigrateDb(migrations ReadDirFileFS) error {
|
migrationFiles, err := filesystem.ReadDir(directory)
|
||||||
// // Read all migrations
|
if err != nil {
|
||||||
// migrationFiles, err := migrations.ReadDir("")
|
return fmt.Errorf("error reading migration files: %v", err)
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("error reading migration files: %v", err)
|
var migrationsByVersion = make(map[int]string)
|
||||||
// }
|
latestVersion := 0
|
||||||
// var migrationsByVersion = make(map[int]string)
|
for _, f := range migrationFiles {
|
||||||
// latestVersion := 0
|
versionStr := f.Name()
|
||||||
// for _, f := range migrationFiles {
|
version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
|
||||||
// versionStr := f.Name()
|
if err != nil {
|
||||||
// version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
|
return fmt.Errorf("invalid version number for migration script: %v", err)
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("invalid version number for migration script: %v", err)
|
migrationsByVersion[version] = versionStr
|
||||||
// }
|
latestVersion = max(latestVersion, version)
|
||||||
// migrationsByVersion[version] = versionStr
|
}
|
||||||
// latestVersion = max(latestVersion, version)
|
|
||||||
// }
|
// Get current migration version from user_version
|
||||||
//
|
var currentVersion int
|
||||||
// // Get current migration version from user_version
|
err = d.Query("PRAGMA user_version").ScanSingle(¤tVersion)
|
||||||
// var currentVersion int
|
if err != nil {
|
||||||
// err = d.QuerySingle("PRAGMA user_version", ¤tVersion)
|
return fmt.Errorf("error getting current version: %v", err)
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("error getting current version: %v", err)
|
log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
|
||||||
// }
|
|
||||||
// log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion)
|
// If we are no up-to-date, bring the db up-to-date
|
||||||
//
|
for currentVersion != latestVersion {
|
||||||
// // If we are no up-to-date, bring the db up-to-date
|
targetVersion := currentVersion + 1
|
||||||
// for currentVersion != latestVersion {
|
migrationFile := migrationsByVersion[targetVersion]
|
||||||
// targetVersion := currentVersion + 1
|
log.Printf("migrating to version %s", migrationFile)
|
||||||
// migrationFile := migrationsByVersion[targetVersion]
|
migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile))
|
||||||
// log.Printf("migration to version %s", migrationFile)
|
if err != nil {
|
||||||
// migrationScript, err := migrations.ReadFile(migrationFile)
|
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("error opening migration script %s: %v", migrationScript, err)
|
|
||||||
// }
|
err = performSingleMigration(err, d, migrationScript, targetVersion)
|
||||||
//
|
if err != nil {
|
||||||
// tx, err := db.Begin()
|
return err
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("error beginning transaction: %v", err)
|
currentVersion = targetVersion
|
||||||
// }
|
}
|
||||||
// defer tx.MustRollback()
|
|
||||||
//
|
log.Println("Database is up-to-date")
|
||||||
// err = tx.QuerySingle(string(migrationScript))
|
return nil
|
||||||
// if err != nil {
|
}
|
||||||
// log.Fatalf("error performing migration: %v", err)
|
|
||||||
// }
|
func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
|
||||||
//
|
tx, err := d.Begin()
|
||||||
// err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion))
|
if err != nil {
|
||||||
// if err != nil {
|
return fmt.Errorf("error beginning transaction: %v", err)
|
||||||
// log.Fatalf("error updating version: %v", err)
|
}
|
||||||
// }
|
defer tx.MustRollback()
|
||||||
//
|
|
||||||
// err = tx.Commit()
|
err = tx.Query(string(migrationScript)).Exec()
|
||||||
// if err != nil {
|
if err != nil {
|
||||||
// log.Fatalf("error commiting transaction: %v", err)
|
return fmt.Errorf("error performing migration: %v", err)
|
||||||
// }
|
}
|
||||||
// currentVersion = targetVersion
|
|
||||||
// }
|
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
|
||||||
//
|
if err != nil {
|
||||||
// log.Println("All migrations applied")
|
return fmt.Errorf("error updating version: %v", err)
|
||||||
// return nil
|
}
|
||||||
//}
|
|
||||||
//
|
err = tx.Commit()
|
||||||
//func rollbackIgnoringErrors(tx *sql.Tx) {
|
if err != nil {
|
||||||
// err := tx.Rollback()
|
return fmt.Errorf("error commiting transaction: %v", err)
|
||||||
// if err != nil {
|
}
|
||||||
// log.Printf("error rolling back: %v", err)
|
return nil
|
||||||
// }
|
}
|
||||||
//}
|
|
||||||
|
20
migrator_test.go
Normal file
20
migrator_test.go
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
package mysqlite
|
||||||
|
|
||||||
|
import (
|
||||||
|
"embed"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
//go:embed testMigrations/*.sql
|
||||||
|
var migrations embed.FS
|
||||||
|
|
||||||
|
func TestDb_MigrateDb(t *testing.T) {
|
||||||
|
db := openEmptyTestDb(t)
|
||||||
|
err := db.MigrateDb(migrations, "testMigrations")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var count int
|
||||||
|
db.Query("select count(*) from mydata").MustScanSingle(&count)
|
||||||
|
require.Equal(t, 1, count, "incorrect number of rows in database")
|
||||||
|
}
|
3
testMigrations/1_initial.sql
Normal file
3
testMigrations/1_initial.sql
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
create table mydata (
|
||||||
|
value text
|
||||||
|
)
|
2
testMigrations/2_addRow.sql
Normal file
2
testMigrations/2_addRow.sql
Normal file
@ -0,0 +1,2 @@
|
|||||||
|
insert into mydata (value)
|
||||||
|
values ('hello')
|
Loading…
x
Reference in New Issue
Block a user