From 637797116dfc4254ba16b99f33d486e6552cfb72 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Thu, 8 May 2025 19:42:57 +0200 Subject: [PATCH] Add support for pointer arguments --- query.go | 10 +++++++++- query_test.go | 33 +++++++++++++++++++++++++++++++++ 2 files changed, 42 insertions(+), 1 deletion(-) diff --git a/query.go b/query.go index 362d0c9..992f935 100644 --- a/query.go +++ b/query.go @@ -44,6 +44,14 @@ func (q *Query) bindInto(into *int, args ...any) *Query { } for i, arg := range args { *into++ + v := reflect.ValueOf(arg) + if v.Kind() == reflect.Ptr { + if v.IsNil() { + q.stmt.BindNull(*into) + continue + } + arg = v.Elem().Interface() // Dereference the pointer + } if asString, ok := arg.(string); ok { q.stmt.BindText(*into, asString) } else if asInt, ok := arg.(int); ok { @@ -54,7 +62,7 @@ func (q *Query) bindInto(into *int, args ...any) *Query { q.stmt.BindBool(*into, asBool) } else { // Check if the argument is a slice or array of any type - v := reflect.ValueOf(arg) + v = reflect.ValueOf(arg) if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { *into-- for i := 0; i < v.Len(); i++ { diff --git a/query_test.go b/query_test.go index 14144f9..4c07ecb 100644 --- a/query_test.go +++ b/query_test.go @@ -100,6 +100,39 @@ func TestUpdateQueryWithWrongArguments(t *testing.T) { require.Error(t, err) } +func TestUpdateQueryWithPointerValue(t *testing.T) { + db := openTestDb(t) + func() { + tx := db.MustBegin() + defer tx.MustRollback() + tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() + value := "ipsum" + key := "lorem" + tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec() + tx.MustCommit() + }() + + var value string + db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) + require.Equal(t, "ipsum", value) +} + +func TestUpdateQueryWithNullValue(t *testing.T) { + db := openTestDb(t) + func() { + tx := db.MustBegin() + defer tx.MustRollback() + tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() + key := "lorem" + tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec() + tx.MustCommit() + }() + + var value string + db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) + require.Nil(t, value) +} + func TestQueryWithPointerStringArguments(t *testing.T) { db := openTestDb(t) var result *string