From 9d4bf1e6e26604417a6ef9fa95c72778cdc0aad0 Mon Sep 17 00:00:00 2001
From: Sebastiaan de Schaetzen <sebastiaan.de.schaetzen@gmail.com>
Date: Sat, 17 May 2025 17:42:23 +0200
Subject: [PATCH] Add support for float64 and errors when missing binds

---
 errors.go     |  1 +
 query.go      | 22 ++++++++++++++++++----
 query_test.go | 16 ++++++++++++++++
 3 files changed, 35 insertions(+), 4 deletions(-)

diff --git a/errors.go b/errors.go
index d691fe5..fa4ce90 100644
--- a/errors.go
+++ b/errors.go
@@ -3,3 +3,4 @@ package mysqlite
 import "errors"
 
 var ErrNoRows = errors.New("mysqlite: no rows returned")
+var ErrMissingBind = errors.New("mysqlite: missing bind value")
diff --git a/query.go b/query.go
index 6b7d626..e674b4b 100644
--- a/query.go
+++ b/query.go
@@ -11,8 +11,10 @@ type Query struct {
 	stmt *sqlite.Stmt
 	// Reference to the database. If set, it is assumed that a lock was taken
 	// by the query that should be freed by the query.
-	db  *Db
-	err error
+	db *Db
+	// The number of bound arguments
+	binds int
+	err   error
 }
 
 func (d *Db) Query(query string) *Query {
@@ -46,24 +48,30 @@ func (q *Query) bindInto(into *int, args ...any) *Query {
 		*into++
 		if arg == nil {
 			q.stmt.BindNull(*into)
+			q.binds++
 			continue
 		}
 		v := reflect.ValueOf(arg)
 		if v.Kind() == reflect.Ptr {
 			if v.IsNil() {
 				q.stmt.BindNull(*into)
+				q.binds++
 				continue
 			}
 			arg = v.Elem().Interface()
 		}
 		if asString, ok := arg.(string); ok {
 			q.stmt.BindText(*into, asString)
+			q.binds++
 		} else if asInt, ok := arg.(int); ok {
 			q.stmt.BindInt64(*into, int64(asInt))
-		} else if asInt, ok := arg.(int64); ok {
-			q.stmt.BindInt64(*into, asInt)
+			q.binds++
+		} else if asFloat, ok := arg.(float64); ok {
+			q.stmt.BindFloat(*into, asFloat)
+			q.binds++
 		} else if asBool, ok := arg.(bool); ok {
 			q.stmt.BindBool(*into, asBool)
+			q.binds++
 		} else {
 			// Check if the argument is a slice or array of any type
 			v = reflect.ValueOf(arg)
@@ -166,6 +174,10 @@ type Rows struct {
 }
 
 func (q *Query) ScanMulti() (*Rows, error) {
+	if q.binds != q.stmt.BindParamCount() {
+		return nil, ErrMissingBind
+	}
+
 	if q.err != nil {
 		return nil, q.err
 	}
@@ -219,6 +231,8 @@ func (r *Rows) scanArgument(i int, arg any) error {
 		*asInt = r.query.stmt.ColumnInt(i)
 	} else if asInt, ok := arg.(*int64); ok {
 		*asInt = r.query.stmt.ColumnInt64(i)
+	} else if asFloat, ok := arg.(*float64); ok {
+		*asFloat = r.query.stmt.ColumnFloat(i)
 	} else if asBool, ok := arg.(*bool); ok {
 		*asBool = r.query.stmt.ColumnBool(i)
 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
diff --git a/query_test.go b/query_test.go
index 66397fa..4c8eb89 100644
--- a/query_test.go
+++ b/query_test.go
@@ -176,6 +176,22 @@ func TestQueryWithInt64Scan(t *testing.T) {
 	require.Equal(t, int64(2), result)
 }
 
+func TestQueryWithFloat64Scan(t *testing.T) {
+	db := openTestDb(t)
+	var result float64
+	err := db.Query("select 2.5").ScanSingle(&result)
+	require.NoError(t, err)
+	require.NotNil(t, result)
+	require.InDelta(t, 2.5, result, 0.001)
+}
+
+func TestQueryWithMissingBinds(t *testing.T) {
+	db := openTestDb(t)
+	var result float64
+	err := db.Query("select ?").ScanSingle(&result)
+	require.ErrorIs(t, err, ErrMissingBind)
+}
+
 func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
 	db := openTestDb(t)
 	db.Query("update mytable set value=null where key = 'foo'").MustExec()
-- 
2.47.2