Compare commits
	
		
			2 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 33f1a94fb2 | |||
| 6880fb5af4 | 
| @@ -3,3 +3,5 @@ package mysqlite | ||||
| import "errors" | ||||
|  | ||||
| var ErrNoRows = errors.New("mysqlite: no rows returned") | ||||
| var ErrMissingBind = errors.New("mysqlite: missing bind value") | ||||
| var ErrMissingScan = errors.New("mysqlite: missing scan value") | ||||
|   | ||||
							
								
								
									
										28
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										28
									
								
								query.go
									
									
									
									
									
								
							| @@ -11,8 +11,10 @@ type Query struct { | ||||
| 	stmt *sqlite.Stmt | ||||
| 	// Reference to the database. If set, it is assumed that a lock was taken | ||||
| 	// by the query that should be freed by the query. | ||||
| 	db  *Db | ||||
| 	err error | ||||
| 	db *Db | ||||
| 	// The number of bound arguments | ||||
| 	binds int | ||||
| 	err   error | ||||
| } | ||||
|  | ||||
| func (d *Db) Query(query string) *Query { | ||||
| @@ -46,24 +48,33 @@ func (q *Query) bindInto(into *int, args ...any) *Query { | ||||
| 		*into++ | ||||
| 		if arg == nil { | ||||
| 			q.stmt.BindNull(*into) | ||||
| 			q.binds++ | ||||
| 			continue | ||||
| 		} | ||||
| 		v := reflect.ValueOf(arg) | ||||
| 		if v.Kind() == reflect.Ptr { | ||||
| 			if v.IsNil() { | ||||
| 				q.stmt.BindNull(*into) | ||||
| 				q.binds++ | ||||
| 				continue | ||||
| 			} | ||||
| 			arg = v.Elem().Interface() | ||||
| 		} | ||||
| 		if asString, ok := arg.(string); ok { | ||||
| 			q.stmt.BindText(*into, asString) | ||||
| 			q.binds++ | ||||
| 		} else if asInt, ok := arg.(int); ok { | ||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | ||||
| 		} else if asInt, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(*into, asInt) | ||||
| 			q.binds++ | ||||
| 		} else if asInt64, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(*into, asInt64) | ||||
| 			q.binds++ | ||||
| 		} else if asFloat, ok := arg.(float64); ok { | ||||
| 			q.stmt.BindFloat(*into, asFloat) | ||||
| 			q.binds++ | ||||
| 		} else if asBool, ok := arg.(bool); ok { | ||||
| 			q.stmt.BindBool(*into, asBool) | ||||
| 			q.binds++ | ||||
| 		} else { | ||||
| 			// Check if the argument is a slice or array of any type | ||||
| 			v = reflect.ValueOf(arg) | ||||
| @@ -166,6 +177,10 @@ type Rows struct { | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanMulti() (*Rows, error) { | ||||
| 	if q.binds != q.stmt.BindParamCount() { | ||||
| 		return nil, ErrMissingBind | ||||
| 	} | ||||
|  | ||||
| 	if q.err != nil { | ||||
| 		return nil, q.err | ||||
| 	} | ||||
| @@ -203,6 +218,9 @@ func (r *Rows) MustNext() bool { | ||||
| } | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	if r.query.stmt.ColumnCount() != len(results) { | ||||
| 		return ErrMissingScan | ||||
| 	} | ||||
| 	for i, arg := range results { | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| @@ -219,6 +237,8 @@ func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asInt, ok := arg.(*int64); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt64(i) | ||||
| 	} else if asFloat, ok := arg.(*float64); ok { | ||||
| 		*asFloat = r.query.stmt.ColumnFloat(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
|   | ||||
| @@ -28,6 +28,23 @@ func TestSimpleQueryWithNoResults(t *testing.T) { | ||||
| 	require.True(t, errors.Is(err, ErrNoRows)) | ||||
| } | ||||
|  | ||||
| func TestScanWithMissingValues(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var count int | ||||
| 	err := db.Query("select 1, 2").ScanSingle(&count) | ||||
| 	require.Equal(t, ErrMissingScan, err) | ||||
| } | ||||
|  | ||||
| func TestBindInt64(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value int64 | ||||
| 	var result int64 | ||||
| 	value = 5 | ||||
| 	err := db.Query("select ?").Bind(&value).ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(5), result) | ||||
| } | ||||
|  | ||||
| func TestSimpleQueryWithArgs(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value string | ||||
| @@ -176,6 +193,22 @@ func TestQueryWithInt64Scan(t *testing.T) { | ||||
| 	require.Equal(t, int64(2), result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithFloat64Scan(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select 2.5").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.InDelta(t, 2.5, result, 0.001) | ||||
| } | ||||
|  | ||||
| func TestQueryWithMissingBinds(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select ?").ScanSingle(&result) | ||||
| 	require.ErrorIs(t, err, ErrMissingBind) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=null where key = 'foo'").MustExec() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user