From 94ae36305aa7b9d0e6519e2c659d978d82d9a42d Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Sun, 16 Mar 2025 18:07:58 +0100 Subject: [PATCH] Add support for pointers to pointers of arguments --- query.go | 48 ++++++++++++++++++++++++++++++++++++------------ query_test.go | 20 ++++++++++++++++++++ 2 files changed, 56 insertions(+), 12 deletions(-) diff --git a/query.go b/query.go index 7e7aad4..2c351b9 100644 --- a/query.go +++ b/query.go @@ -173,23 +173,47 @@ func (r *Rows) MustNext() bool { func (r *Rows) Scan(results ...any) error { for i, arg := range results { - if asString, ok := arg.(*string); ok { - *asString = r.query.stmt.ColumnText(i) - } else if asInt, ok := arg.(*int); ok { - *asInt = r.query.stmt.ColumnInt(i) - } else if asBool, ok := arg.(*bool); ok { - *asBool = r.query.stmt.ColumnBool(i) - } else { - if reflect.TypeOf(arg).Kind() != reflect.Ptr { - return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) - } - name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() - return fmt.Errorf("unsupported column type *%s at index %d", name, i) + err := r.scanArgument(i, arg) + if err != nil { + return err } } return nil } +func (r *Rows) scanArgument(i int, arg any) error { + if asString, ok := arg.(*string); ok { + *asString = r.query.stmt.ColumnText(i) + } else if asInt, ok := arg.(*int); ok { + *asInt = r.query.stmt.ColumnInt(i) + } else if asBool, ok := arg.(*bool); ok { + *asBool = r.query.stmt.ColumnBool(i) + } else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { + return r.handleNullableType(i, arg) + } else { + if reflect.TypeOf(arg).Kind() != reflect.Ptr { + return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) + } + name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() + return fmt.Errorf("unsupported column type *%s at index %d", name, i) + } + return nil +} + +func (r *Rows) handleNullableType(i int, asPtr any) error { + if r.query.stmt.ColumnIsNull(i) { + reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem())) + } else { + value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface() + err := r.scanArgument(i, value) + if err != nil { + return err + } + reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value)) + } + return nil +} + func (r *Rows) MustScan(results ...any) { err := r.Scan(results...) if err != nil { diff --git a/query_test.go b/query_test.go index 0ea4ba0..fe070f8 100644 --- a/query_test.go +++ b/query_test.go @@ -90,3 +90,23 @@ func TestUpdateQueryWithWrongArguments(t *testing.T) { err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() require.Error(t, err) } + +func TestQueryWithPointerStringArguments(t *testing.T) { + db := openTestDb(t) + var result *string + err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) + require.NoError(t, err) + require.NotNil(t, result) + require.Equal(t, "bar", *result) +} + +func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { + db := openTestDb(t) + db.Query("update mytable set value=NULL where key = 'foo'").MustExec() + myString := "some string" + var result *string + result = &myString + err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) + require.NoError(t, err) + require.Nil(t, result) +}