diff --git a/migrator.go b/migrator.go index 47b8ff6..ce4af77 100644 --- a/migrator.go +++ b/migrator.go @@ -50,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) } - err = performSingleMigration(err, d, migrationScript, targetVersion) + err = performSingleMigration(d, migrationScript, targetVersion) if err != nil { return err } @@ -61,21 +61,29 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { return nil } -func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error { +func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error { + script := string(migrationScript) + // Split script based on semicolon + statements := strings.Split(script, ";") + tx, err := d.Begin() if err != nil { return fmt.Errorf("error beginning transaction: %v", err) } defer tx.MustRollback() - err = tx.Query(string(migrationScript)).Exec() - if err != nil { - return fmt.Errorf("error performing migration: %v", err) - } + for _, statement := range statements { + statement = strings.TrimSpace(statement) + err = tx.Query(statement).Exec() + if err != nil { + return fmt.Errorf("error performing migration: %v", err) + } + + err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() + if err != nil { + return fmt.Errorf("error updating version: %v", err) + } - err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() - if err != nil { - return fmt.Errorf("error updating version: %v", err) } err = tx.Commit() diff --git a/testMigrations/3_multicomment.sql b/testMigrations/3_multicomment.sql new file mode 100644 index 0000000..e3ad9c2 --- /dev/null +++ b/testMigrations/3_multicomment.sql @@ -0,0 +1,3 @@ +create table testTable(value text); + +insert into testTable(value) values ('testValue');