From 2eacf6fbc446620fecb30d58cac637d7826afcf8 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Sun, 16 Mar 2025 11:38:31 +0100 Subject: [PATCH] Improve migration handling and error forwarding in deferred statements --- migrator.go | 3 +++ query.go | 4 ++-- query_test.go | 13 +++++++++++-- util.go | 7 +++++++ 4 files changed, 23 insertions(+), 4 deletions(-) create mode 100644 util.go diff --git a/migrator.go b/migrator.go index ce4af77..68bc011 100644 --- a/migrator.go +++ b/migrator.go @@ -74,6 +74,9 @@ func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) er for _, statement := range statements { statement = strings.TrimSpace(statement) + if statement == "" { + continue + } err = tx.Query(statement).Exec() if err != nil { return fmt.Errorf("error performing migration: %v", err) diff --git a/query.go b/query.go index 829b530..7e7aad4 100644 --- a/query.go +++ b/query.go @@ -58,7 +58,7 @@ func (q *Query) Exec() (rerr error) { defer q.unlock() if q.stmt != nil { - defer func() { rerr = q.stmt.Finalize() }() + defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err @@ -84,7 +84,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) { defer q.unlock() // Scan rows if q.stmt != nil { - defer func() { rerr = q.stmt.Finalize() }() + defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err diff --git a/query_test.go b/query_test.go index 355125e..0ea4ba0 100644 --- a/query_test.go +++ b/query_test.go @@ -73,11 +73,20 @@ func TestUpdateQuery(t *testing.T) { tx := db.MustBegin() defer tx.MustRollback() tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() - tx.Query("update mytable set value = 'ipsum' where key = 'lorem'").MustExec() + value := "ipsum" + key := "lorem" + tx.Query("update mytable set value = ? where key = ?").Bind(value, key).MustExec() tx.MustCommit() }() var value string - db.Query("select value from mytable where value = 'ipsum'").MustScanSingle(&value) + db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) require.Equal(t, "ipsum", value) } + +func TestUpdateQueryWithWrongArguments(t *testing.T) { + db := openTestDb(t) + value := "ipsum" + err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() + require.Error(t, err) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..148c965 --- /dev/null +++ b/util.go @@ -0,0 +1,7 @@ +package mysqlite + +func forwardError(from error, to *error) { + if from != nil { + *to = from + } +}