diff --git a/errors.go b/errors.go index d691fe5..fa4ce90 100644 --- a/errors.go +++ b/errors.go @@ -3,3 +3,4 @@ package mysqlite import "errors" var ErrNoRows = errors.New("mysqlite: no rows returned") +var ErrMissingBind = errors.New("mysqlite: missing bind value") diff --git a/query.go b/query.go index 6b7d626..e674b4b 100644 --- a/query.go +++ b/query.go @@ -11,8 +11,10 @@ type Query struct { stmt *sqlite.Stmt // Reference to the database. If set, it is assumed that a lock was taken // by the query that should be freed by the query. - db *Db - err error + db *Db + // The number of bound arguments + binds int + err error } func (d *Db) Query(query string) *Query { @@ -46,24 +48,30 @@ func (q *Query) bindInto(into *int, args ...any) *Query { *into++ if arg == nil { q.stmt.BindNull(*into) + q.binds++ continue } v := reflect.ValueOf(arg) if v.Kind() == reflect.Ptr { if v.IsNil() { q.stmt.BindNull(*into) + q.binds++ continue } arg = v.Elem().Interface() } if asString, ok := arg.(string); ok { q.stmt.BindText(*into, asString) + q.binds++ } else if asInt, ok := arg.(int); ok { q.stmt.BindInt64(*into, int64(asInt)) - } else if asInt, ok := arg.(int64); ok { - q.stmt.BindInt64(*into, asInt) + q.binds++ + } else if asFloat, ok := arg.(float64); ok { + q.stmt.BindFloat(*into, asFloat) + q.binds++ } else if asBool, ok := arg.(bool); ok { q.stmt.BindBool(*into, asBool) + q.binds++ } else { // Check if the argument is a slice or array of any type v = reflect.ValueOf(arg) @@ -166,6 +174,10 @@ type Rows struct { } func (q *Query) ScanMulti() (*Rows, error) { + if q.binds != q.stmt.BindParamCount() { + return nil, ErrMissingBind + } + if q.err != nil { return nil, q.err } @@ -219,6 +231,8 @@ func (r *Rows) scanArgument(i int, arg any) error { *asInt = r.query.stmt.ColumnInt(i) } else if asInt, ok := arg.(*int64); ok { *asInt = r.query.stmt.ColumnInt64(i) + } else if asFloat, ok := arg.(*float64); ok { + *asFloat = r.query.stmt.ColumnFloat(i) } else if asBool, ok := arg.(*bool); ok { *asBool = r.query.stmt.ColumnBool(i) } else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { diff --git a/query_test.go b/query_test.go index 66397fa..4c8eb89 100644 --- a/query_test.go +++ b/query_test.go @@ -176,6 +176,22 @@ func TestQueryWithInt64Scan(t *testing.T) { require.Equal(t, int64(2), result) } +func TestQueryWithFloat64Scan(t *testing.T) { + db := openTestDb(t) + var result float64 + err := db.Query("select 2.5").ScanSingle(&result) + require.NoError(t, err) + require.NotNil(t, result) + require.InDelta(t, 2.5, result, 0.001) +} + +func TestQueryWithMissingBinds(t *testing.T) { + db := openTestDb(t) + var result float64 + err := db.Query("select ?").ScanSingle(&result) + require.ErrorIs(t, err, ErrMissingBind) +} + func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { db := openTestDb(t) db.Query("update mytable set value=null where key = 'foo'").MustExec()