Compare commits
	
		
			7 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 33f1a94fb2 | |||
| 6880fb5af4 | |||
| e9b12dc3f6 | |||
| 7daf1915a5 | |||
| 029cf6ce01 | |||
| 278d7ed497 | |||
| dd6be6b9b6 | 
							
								
								
									
										7
									
								
								errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package mysqlite | ||||
|  | ||||
| import "errors" | ||||
|  | ||||
| var ErrNoRows = errors.New("mysqlite: no rows returned") | ||||
| var ErrMissingBind = errors.New("mysqlite: missing bind value") | ||||
| var ErrMissingScan = errors.New("mysqlite: missing scan value") | ||||
							
								
								
									
										46
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										46
									
								
								query.go
									
									
									
									
									
								
							| @@ -11,8 +11,10 @@ type Query struct { | ||||
| 	stmt *sqlite.Stmt | ||||
| 	// Reference to the database. If set, it is assumed that a lock was taken | ||||
| 	// by the query that should be freed by the query. | ||||
| 	db  *Db | ||||
| 	err error | ||||
| 	db *Db | ||||
| 	// The number of bound arguments | ||||
| 	binds int | ||||
| 	err   error | ||||
| } | ||||
|  | ||||
| func (d *Db) Query(query string) *Query { | ||||
| @@ -44,17 +46,38 @@ func (q *Query) bindInto(into *int, args ...any) *Query { | ||||
| 	} | ||||
| 	for i, arg := range args { | ||||
| 		*into++ | ||||
| 		if arg == nil { | ||||
| 			q.stmt.BindNull(*into) | ||||
| 			q.binds++ | ||||
| 			continue | ||||
| 		} | ||||
| 		v := reflect.ValueOf(arg) | ||||
| 		if v.Kind() == reflect.Ptr { | ||||
| 			if v.IsNil() { | ||||
| 				q.stmt.BindNull(*into) | ||||
| 				q.binds++ | ||||
| 				continue | ||||
| 			} | ||||
| 			arg = v.Elem().Interface() | ||||
| 		} | ||||
| 		if asString, ok := arg.(string); ok { | ||||
| 			q.stmt.BindText(*into, asString) | ||||
| 			q.binds++ | ||||
| 		} else if asInt, ok := arg.(int); ok { | ||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | ||||
| 		} else if asInt, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(*into, asInt) | ||||
| 			q.binds++ | ||||
| 		} else if asInt64, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(*into, asInt64) | ||||
| 			q.binds++ | ||||
| 		} else if asFloat, ok := arg.(float64); ok { | ||||
| 			q.stmt.BindFloat(*into, asFloat) | ||||
| 			q.binds++ | ||||
| 		} else if asBool, ok := arg.(bool); ok { | ||||
| 			q.stmt.BindBool(*into, asBool) | ||||
| 			q.binds++ | ||||
| 		} else { | ||||
| 			// Check if the argument is a slice or array of any type | ||||
| 			v := reflect.ValueOf(arg) | ||||
| 			v = reflect.ValueOf(arg) | ||||
| 			if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { | ||||
| 				*into-- | ||||
| 				for i := 0; i < v.Len(); i++ { | ||||
| @@ -116,7 +139,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !hasResult { | ||||
| 		return fmt.Errorf("did not return any rows") | ||||
| 		return ErrNoRows | ||||
| 	} | ||||
|  | ||||
| 	// Scan its columns | ||||
| @@ -154,6 +177,10 @@ type Rows struct { | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanMulti() (*Rows, error) { | ||||
| 	if q.binds != q.stmt.BindParamCount() { | ||||
| 		return nil, ErrMissingBind | ||||
| 	} | ||||
|  | ||||
| 	if q.err != nil { | ||||
| 		return nil, q.err | ||||
| 	} | ||||
| @@ -191,6 +218,9 @@ func (r *Rows) MustNext() bool { | ||||
| } | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	if r.query.stmt.ColumnCount() != len(results) { | ||||
| 		return ErrMissingScan | ||||
| 	} | ||||
| 	for i, arg := range results { | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| @@ -205,6 +235,10 @@ func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 		*asString = r.query.stmt.ColumnText(i) | ||||
| 	} else if asInt, ok := arg.(*int); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asInt, ok := arg.(*int64); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt64(i) | ||||
| 	} else if asFloat, ok := arg.(*float64); ok { | ||||
| 		*asFloat = r.query.stmt.ColumnFloat(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
|   | ||||
							
								
								
									
										115
									
								
								query_test.go
									
									
									
									
									
								
							
							
						
						
									
										115
									
								
								query_test.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package mysqlite | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"testing" | ||||
| ) | ||||
| @@ -19,6 +20,31 @@ func TestSimpleQuery(t *testing.T) { | ||||
| 	require.Equal(t, 1, count, "expected empty count") | ||||
| } | ||||
|  | ||||
| func TestSimpleQueryWithNoResults(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var count int | ||||
| 	err := db.Query("select 1 from mytable where key=999").ScanSingle(&count) | ||||
| 	require.Equal(t, ErrNoRows, err) | ||||
| 	require.True(t, errors.Is(err, ErrNoRows)) | ||||
| } | ||||
|  | ||||
| func TestScanWithMissingValues(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var count int | ||||
| 	err := db.Query("select 1, 2").ScanSingle(&count) | ||||
| 	require.Equal(t, ErrMissingScan, err) | ||||
| } | ||||
|  | ||||
| func TestBindInt64(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value int64 | ||||
| 	var result int64 | ||||
| 	value = 5 | ||||
| 	err := db.Query("select ?").Bind(&value).ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(5), result) | ||||
| } | ||||
|  | ||||
| func TestSimpleQueryWithArgs(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value string | ||||
| @@ -85,12 +111,70 @@ func TestUpdateQuery(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithWrongArguments(t *testing.T) { | ||||
| 	type S struct { | ||||
| 		Field string | ||||
| 	} | ||||
| 	db := openTestDb(t) | ||||
| 	value := "ipsum" | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() | ||||
| 	abc := S{ | ||||
| 		Field: "ipsum", | ||||
| 	} | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec() | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithPointerValue(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		value := "ipsum" | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "ipsum", value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithSetPointerValue(t *testing.T) { | ||||
| 	type S struct { | ||||
| 		value *string | ||||
| 	} | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		s := S{nil} | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(s.value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value *string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, (*string)(nil), value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithNullValue(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value *string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Nil(t, value) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result *string | ||||
| @@ -100,9 +184,34 @@ func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	require.Equal(t, "bar", *result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithInt64Scan(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result int64 | ||||
| 	err := db.Query("select 2").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.Equal(t, int64(2), result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithFloat64Scan(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select 2.5").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.InDelta(t, 2.5, result, 0.001) | ||||
| } | ||||
|  | ||||
| func TestQueryWithMissingBinds(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select ?").ScanSingle(&result) | ||||
| 	require.ErrorIs(t, err, ErrMissingBind) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=NULL where key = 'foo'").MustExec() | ||||
| 	db.Query("update mytable set value=null where key = 'foo'").MustExec() | ||||
| 	myString := "some string" | ||||
| 	var result *string | ||||
| 	result = &myString | ||||
|   | ||||
		Reference in New Issue
	
	Block a user