From 278d7ed49723837d20b4c279f39810584f46d59f Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Thu, 8 May 2025 19:42:57 +0200 Subject: [PATCH] Add support for pointer arguments --- query.go | 10 +++++++++- query_test.go | 42 ++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/query.go b/query.go index 362d0c9..6d763a6 100644 --- a/query.go +++ b/query.go @@ -44,6 +44,14 @@ func (q *Query) bindInto(into *int, args ...any) *Query { } for i, arg := range args { *into++ + if arg == nil { + q.stmt.BindNull(*into) + continue + } + v := reflect.ValueOf(arg) + if v.Kind() == reflect.Ptr { + arg = v.Elem().Interface() + } if asString, ok := arg.(string); ok { q.stmt.BindText(*into, asString) } else if asInt, ok := arg.(int); ok { @@ -54,7 +62,7 @@ func (q *Query) bindInto(into *int, args ...any) *Query { q.stmt.BindBool(*into, asBool) } else { // Check if the argument is a slice or array of any type - v := reflect.ValueOf(arg) + v = reflect.ValueOf(arg) if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { *into-- for i := 0; i < v.Len(); i++ { diff --git a/query_test.go b/query_test.go index 14144f9..71cd9b8 100644 --- a/query_test.go +++ b/query_test.go @@ -94,12 +94,50 @@ func TestUpdateQuery(t *testing.T) { } func TestUpdateQueryWithWrongArguments(t *testing.T) { + type S struct { + Field string + } db := openTestDb(t) - value := "ipsum" - err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() + abc := S{ + Field: "ipsum", + } + err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec() require.Error(t, err) } +func TestUpdateQueryWithPointerValue(t *testing.T) { + db := openTestDb(t) + func() { + tx := db.MustBegin() + defer tx.MustRollback() + tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() + value := "ipsum" + key := "lorem" + tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec() + tx.MustCommit() + }() + + var value string + db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) + require.Equal(t, "ipsum", value) +} + +func TestUpdateQueryWithNullValue(t *testing.T) { + db := openTestDb(t) + func() { + tx := db.MustBegin() + defer tx.MustRollback() + tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() + key := "lorem" + tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec() + tx.MustCommit() + }() + + var value *string + db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) + require.Nil(t, value) +} + func TestQueryWithPointerStringArguments(t *testing.T) { db := openTestDb(t) var result *string -- 2.47.2