Compare commits
	
		
			3 Commits
		
	
	
		
			93844c04b2
			...
			v0.14.0
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 33f1a94fb2 | |||
| 6880fb5af4 | |||
| e9b12dc3f6 | 
| @@ -3,3 +3,5 @@ package mysqlite | |||||||
| import "errors" | import "errors" | ||||||
|  |  | ||||||
| var ErrNoRows = errors.New("mysqlite: no rows returned") | var ErrNoRows = errors.New("mysqlite: no rows returned") | ||||||
|  | var ErrMissingBind = errors.New("mysqlite: missing bind value") | ||||||
|  | var ErrMissingScan = errors.New("mysqlite: missing scan value") | ||||||
|   | |||||||
							
								
								
									
										30
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										30
									
								
								query.go
									
									
									
									
									
								
							| @@ -11,8 +11,10 @@ type Query struct { | |||||||
| 	stmt *sqlite.Stmt | 	stmt *sqlite.Stmt | ||||||
| 	// Reference to the database. If set, it is assumed that a lock was taken | 	// Reference to the database. If set, it is assumed that a lock was taken | ||||||
| 	// by the query that should be freed by the query. | 	// by the query that should be freed by the query. | ||||||
| 	db  *Db | 	db *Db | ||||||
| 	err error | 	// The number of bound arguments | ||||||
|  | 	binds int | ||||||
|  | 	err   error | ||||||
| } | } | ||||||
|  |  | ||||||
| func (d *Db) Query(query string) *Query { | func (d *Db) Query(query string) *Query { | ||||||
| @@ -46,24 +48,33 @@ func (q *Query) bindInto(into *int, args ...any) *Query { | |||||||
| 		*into++ | 		*into++ | ||||||
| 		if arg == nil { | 		if arg == nil { | ||||||
| 			q.stmt.BindNull(*into) | 			q.stmt.BindNull(*into) | ||||||
|  | 			q.binds++ | ||||||
| 			continue | 			continue | ||||||
| 		} | 		} | ||||||
| 		v := reflect.ValueOf(arg) | 		v := reflect.ValueOf(arg) | ||||||
| 		if v.Kind() == reflect.Ptr { | 		if v.Kind() == reflect.Ptr { | ||||||
| 			if v.IsNil() { | 			if v.IsNil() { | ||||||
| 				q.stmt.BindNull(*into) | 				q.stmt.BindNull(*into) | ||||||
|  | 				q.binds++ | ||||||
| 				continue | 				continue | ||||||
| 			} | 			} | ||||||
| 			arg = v.Elem().Interface() | 			arg = v.Elem().Interface() | ||||||
| 		} | 		} | ||||||
| 		if asString, ok := arg.(string); ok { | 		if asString, ok := arg.(string); ok { | ||||||
| 			q.stmt.BindText(*into, asString) | 			q.stmt.BindText(*into, asString) | ||||||
|  | 			q.binds++ | ||||||
| 		} else if asInt, ok := arg.(int); ok { | 		} else if asInt, ok := arg.(int); ok { | ||||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | 			q.stmt.BindInt64(*into, int64(asInt)) | ||||||
| 		} else if asInt, ok := arg.(int64); ok { | 			q.binds++ | ||||||
| 			q.stmt.BindInt64(*into, asInt) | 		} else if asInt64, ok := arg.(int64); ok { | ||||||
|  | 			q.stmt.BindInt64(*into, asInt64) | ||||||
|  | 			q.binds++ | ||||||
|  | 		} else if asFloat, ok := arg.(float64); ok { | ||||||
|  | 			q.stmt.BindFloat(*into, asFloat) | ||||||
|  | 			q.binds++ | ||||||
| 		} else if asBool, ok := arg.(bool); ok { | 		} else if asBool, ok := arg.(bool); ok { | ||||||
| 			q.stmt.BindBool(*into, asBool) | 			q.stmt.BindBool(*into, asBool) | ||||||
|  | 			q.binds++ | ||||||
| 		} else { | 		} else { | ||||||
| 			// Check if the argument is a slice or array of any type | 			// Check if the argument is a slice or array of any type | ||||||
| 			v = reflect.ValueOf(arg) | 			v = reflect.ValueOf(arg) | ||||||
| @@ -166,6 +177,10 @@ type Rows struct { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (q *Query) ScanMulti() (*Rows, error) { | func (q *Query) ScanMulti() (*Rows, error) { | ||||||
|  | 	if q.binds != q.stmt.BindParamCount() { | ||||||
|  | 		return nil, ErrMissingBind | ||||||
|  | 	} | ||||||
|  |  | ||||||
| 	if q.err != nil { | 	if q.err != nil { | ||||||
| 		return nil, q.err | 		return nil, q.err | ||||||
| 	} | 	} | ||||||
| @@ -203,6 +218,9 @@ func (r *Rows) MustNext() bool { | |||||||
| } | } | ||||||
|  |  | ||||||
| func (r *Rows) Scan(results ...any) error { | func (r *Rows) Scan(results ...any) error { | ||||||
|  | 	if r.query.stmt.ColumnCount() != len(results) { | ||||||
|  | 		return ErrMissingScan | ||||||
|  | 	} | ||||||
| 	for i, arg := range results { | 	for i, arg := range results { | ||||||
| 		err := r.scanArgument(i, arg) | 		err := r.scanArgument(i, arg) | ||||||
| 		if err != nil { | 		if err != nil { | ||||||
| @@ -217,6 +235,10 @@ func (r *Rows) scanArgument(i int, arg any) error { | |||||||
| 		*asString = r.query.stmt.ColumnText(i) | 		*asString = r.query.stmt.ColumnText(i) | ||||||
| 	} else if asInt, ok := arg.(*int); ok { | 	} else if asInt, ok := arg.(*int); ok { | ||||||
| 		*asInt = r.query.stmt.ColumnInt(i) | 		*asInt = r.query.stmt.ColumnInt(i) | ||||||
|  | 	} else if asInt, ok := arg.(*int64); ok { | ||||||
|  | 		*asInt = r.query.stmt.ColumnInt64(i) | ||||||
|  | 	} else if asFloat, ok := arg.(*float64); ok { | ||||||
|  | 		*asFloat = r.query.stmt.ColumnFloat(i) | ||||||
| 	} else if asBool, ok := arg.(*bool); ok { | 	} else if asBool, ok := arg.(*bool); ok { | ||||||
| 		*asBool = r.query.stmt.ColumnBool(i) | 		*asBool = r.query.stmt.ColumnBool(i) | ||||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||||
|   | |||||||
| @@ -28,6 +28,23 @@ func TestSimpleQueryWithNoResults(t *testing.T) { | |||||||
| 	require.True(t, errors.Is(err, ErrNoRows)) | 	require.True(t, errors.Is(err, ErrNoRows)) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestScanWithMissingValues(t *testing.T) { | ||||||
|  | 	db := openTestDb(t) | ||||||
|  | 	var count int | ||||||
|  | 	err := db.Query("select 1, 2").ScanSingle(&count) | ||||||
|  | 	require.Equal(t, ErrMissingScan, err) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestBindInt64(t *testing.T) { | ||||||
|  | 	db := openTestDb(t) | ||||||
|  | 	var value int64 | ||||||
|  | 	var result int64 | ||||||
|  | 	value = 5 | ||||||
|  | 	err := db.Query("select ?").Bind(&value).ScanSingle(&result) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	require.Equal(t, int64(5), result) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestSimpleQueryWithArgs(t *testing.T) { | func TestSimpleQueryWithArgs(t *testing.T) { | ||||||
| 	db := openTestDb(t) | 	db := openTestDb(t) | ||||||
| 	var value string | 	var value string | ||||||
| @@ -167,6 +184,31 @@ func TestQueryWithPointerStringArguments(t *testing.T) { | |||||||
| 	require.Equal(t, "bar", *result) | 	require.Equal(t, "bar", *result) | ||||||
| } | } | ||||||
|  |  | ||||||
|  | func TestQueryWithInt64Scan(t *testing.T) { | ||||||
|  | 	db := openTestDb(t) | ||||||
|  | 	var result int64 | ||||||
|  | 	err := db.Query("select 2").ScanSingle(&result) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	require.NotNil(t, result) | ||||||
|  | 	require.Equal(t, int64(2), result) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestQueryWithFloat64Scan(t *testing.T) { | ||||||
|  | 	db := openTestDb(t) | ||||||
|  | 	var result float64 | ||||||
|  | 	err := db.Query("select 2.5").ScanSingle(&result) | ||||||
|  | 	require.NoError(t, err) | ||||||
|  | 	require.NotNil(t, result) | ||||||
|  | 	require.InDelta(t, 2.5, result, 0.001) | ||||||
|  | } | ||||||
|  |  | ||||||
|  | func TestQueryWithMissingBinds(t *testing.T) { | ||||||
|  | 	db := openTestDb(t) | ||||||
|  | 	var result float64 | ||||||
|  | 	err := db.Query("select ?").ScanSingle(&result) | ||||||
|  | 	require.ErrorIs(t, err, ErrMissingBind) | ||||||
|  | } | ||||||
|  |  | ||||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||||
| 	db := openTestDb(t) | 	db := openTestDb(t) | ||||||
| 	db.Query("update mytable set value=null where key = 'foo'").MustExec() | 	db.Query("update mytable set value=null where key = 'foo'").MustExec() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user