Compare commits
	
		
			5 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 6880fb5af4 | |||
| e9b12dc3f6 | |||
| 7daf1915a5 | |||
| 029cf6ce01 | |||
| 278d7ed497 | 
@@ -3,3 +3,4 @@ package mysqlite
 | 
				
			|||||||
import "errors"
 | 
					import "errors"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
var ErrNoRows = errors.New("mysqlite: no rows returned")
 | 
					var ErrNoRows = errors.New("mysqlite: no rows returned")
 | 
				
			||||||
 | 
					var ErrMissingBind = errors.New("mysqlite: missing bind value")
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										34
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										34
									
								
								query.go
									
									
									
									
									
								
							@@ -12,6 +12,8 @@ type Query struct {
 | 
				
			|||||||
	// Reference to the database. If set, it is assumed that a lock was taken
 | 
						// Reference to the database. If set, it is assumed that a lock was taken
 | 
				
			||||||
	// by the query that should be freed by the query.
 | 
						// by the query that should be freed by the query.
 | 
				
			||||||
	db *Db
 | 
						db *Db
 | 
				
			||||||
 | 
						// The number of bound arguments
 | 
				
			||||||
 | 
						binds int
 | 
				
			||||||
	err   error
 | 
						err   error
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -44,17 +46,35 @@ func (q *Query) bindInto(into *int, args ...any) *Query {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	for i, arg := range args {
 | 
						for i, arg := range args {
 | 
				
			||||||
		*into++
 | 
							*into++
 | 
				
			||||||
 | 
							if arg == nil {
 | 
				
			||||||
 | 
								q.stmt.BindNull(*into)
 | 
				
			||||||
 | 
								q.binds++
 | 
				
			||||||
 | 
								continue
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							v := reflect.ValueOf(arg)
 | 
				
			||||||
 | 
							if v.Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
								if v.IsNil() {
 | 
				
			||||||
 | 
									q.stmt.BindNull(*into)
 | 
				
			||||||
 | 
									q.binds++
 | 
				
			||||||
 | 
									continue
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								arg = v.Elem().Interface()
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
		if asString, ok := arg.(string); ok {
 | 
							if asString, ok := arg.(string); ok {
 | 
				
			||||||
			q.stmt.BindText(*into, asString)
 | 
								q.stmt.BindText(*into, asString)
 | 
				
			||||||
 | 
								q.binds++
 | 
				
			||||||
		} else if asInt, ok := arg.(int); ok {
 | 
							} else if asInt, ok := arg.(int); ok {
 | 
				
			||||||
			q.stmt.BindInt64(*into, int64(asInt))
 | 
								q.stmt.BindInt64(*into, int64(asInt))
 | 
				
			||||||
		} else if asInt, ok := arg.(int64); ok {
 | 
								q.binds++
 | 
				
			||||||
			q.stmt.BindInt64(*into, asInt)
 | 
							} else if asFloat, ok := arg.(float64); ok {
 | 
				
			||||||
 | 
								q.stmt.BindFloat(*into, asFloat)
 | 
				
			||||||
 | 
								q.binds++
 | 
				
			||||||
		} else if asBool, ok := arg.(bool); ok {
 | 
							} else if asBool, ok := arg.(bool); ok {
 | 
				
			||||||
			q.stmt.BindBool(*into, asBool)
 | 
								q.stmt.BindBool(*into, asBool)
 | 
				
			||||||
 | 
								q.binds++
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			// Check if the argument is a slice or array of any type
 | 
								// Check if the argument is a slice or array of any type
 | 
				
			||||||
			v := reflect.ValueOf(arg)
 | 
								v = reflect.ValueOf(arg)
 | 
				
			||||||
			if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
 | 
								if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
 | 
				
			||||||
				*into--
 | 
									*into--
 | 
				
			||||||
				for i := 0; i < v.Len(); i++ {
 | 
									for i := 0; i < v.Len(); i++ {
 | 
				
			||||||
@@ -154,6 +174,10 @@ type Rows struct {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (q *Query) ScanMulti() (*Rows, error) {
 | 
					func (q *Query) ScanMulti() (*Rows, error) {
 | 
				
			||||||
 | 
						if q.binds != q.stmt.BindParamCount() {
 | 
				
			||||||
 | 
							return nil, ErrMissingBind
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	if q.err != nil {
 | 
						if q.err != nil {
 | 
				
			||||||
		return nil, q.err
 | 
							return nil, q.err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
@@ -205,6 +229,10 @@ func (r *Rows) scanArgument(i int, arg any) error {
 | 
				
			|||||||
		*asString = r.query.stmt.ColumnText(i)
 | 
							*asString = r.query.stmt.ColumnText(i)
 | 
				
			||||||
	} else if asInt, ok := arg.(*int); ok {
 | 
						} else if asInt, ok := arg.(*int); ok {
 | 
				
			||||||
		*asInt = r.query.stmt.ColumnInt(i)
 | 
							*asInt = r.query.stmt.ColumnInt(i)
 | 
				
			||||||
 | 
						} else if asInt, ok := arg.(*int64); ok {
 | 
				
			||||||
 | 
							*asInt = r.query.stmt.ColumnInt64(i)
 | 
				
			||||||
 | 
						} else if asFloat, ok := arg.(*float64); ok {
 | 
				
			||||||
 | 
							*asFloat = r.query.stmt.ColumnFloat(i)
 | 
				
			||||||
	} else if asBool, ok := arg.(*bool); ok {
 | 
						} else if asBool, ok := arg.(*bool); ok {
 | 
				
			||||||
		*asBool = r.query.stmt.ColumnBool(i)
 | 
							*asBool = r.query.stmt.ColumnBool(i)
 | 
				
			||||||
	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
 | 
						} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -94,12 +94,70 @@ func TestUpdateQuery(t *testing.T) {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestUpdateQueryWithWrongArguments(t *testing.T) {
 | 
					func TestUpdateQueryWithWrongArguments(t *testing.T) {
 | 
				
			||||||
 | 
						type S struct {
 | 
				
			||||||
 | 
							Field string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
	db := openTestDb(t)
 | 
						db := openTestDb(t)
 | 
				
			||||||
	value := "ipsum"
 | 
						abc := S{
 | 
				
			||||||
	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec()
 | 
							Field: "ipsum",
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec()
 | 
				
			||||||
	require.Error(t, err)
 | 
						require.Error(t, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUpdateQueryWithPointerValue(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						func() {
 | 
				
			||||||
 | 
							tx := db.MustBegin()
 | 
				
			||||||
 | 
							defer tx.MustRollback()
 | 
				
			||||||
 | 
							tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
 | 
				
			||||||
 | 
							value := "ipsum"
 | 
				
			||||||
 | 
							key := "lorem"
 | 
				
			||||||
 | 
							tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec()
 | 
				
			||||||
 | 
							tx.MustCommit()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var value string
 | 
				
			||||||
 | 
						db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
 | 
				
			||||||
 | 
						require.Equal(t, "ipsum", value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUpdateQueryWithSetPointerValue(t *testing.T) {
 | 
				
			||||||
 | 
						type S struct {
 | 
				
			||||||
 | 
							value *string
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						func() {
 | 
				
			||||||
 | 
							tx := db.MustBegin()
 | 
				
			||||||
 | 
							defer tx.MustRollback()
 | 
				
			||||||
 | 
							tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
 | 
				
			||||||
 | 
							s := S{nil}
 | 
				
			||||||
 | 
							key := "lorem"
 | 
				
			||||||
 | 
							tx.Query("update mytable set value = ? where key = ?").Bind(s.value, key).MustExec()
 | 
				
			||||||
 | 
							tx.MustCommit()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var value *string
 | 
				
			||||||
 | 
						db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
 | 
				
			||||||
 | 
						require.Equal(t, (*string)(nil), value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestUpdateQueryWithNullValue(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						func() {
 | 
				
			||||||
 | 
							tx := db.MustBegin()
 | 
				
			||||||
 | 
							defer tx.MustRollback()
 | 
				
			||||||
 | 
							tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
 | 
				
			||||||
 | 
							key := "lorem"
 | 
				
			||||||
 | 
							tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec()
 | 
				
			||||||
 | 
							tx.MustCommit()
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var value *string
 | 
				
			||||||
 | 
						db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
 | 
				
			||||||
 | 
						require.Nil(t, value)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestQueryWithPointerStringArguments(t *testing.T) {
 | 
					func TestQueryWithPointerStringArguments(t *testing.T) {
 | 
				
			||||||
	db := openTestDb(t)
 | 
						db := openTestDb(t)
 | 
				
			||||||
	var result *string
 | 
						var result *string
 | 
				
			||||||
@@ -109,6 +167,31 @@ func TestQueryWithPointerStringArguments(t *testing.T) {
 | 
				
			|||||||
	require.Equal(t, "bar", *result)
 | 
						require.Equal(t, "bar", *result)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithInt64Scan(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						var result int64
 | 
				
			||||||
 | 
						err := db.Query("select 2").ScanSingle(&result)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.NotNil(t, result)
 | 
				
			||||||
 | 
						require.Equal(t, int64(2), result)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithFloat64Scan(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						var result float64
 | 
				
			||||||
 | 
						err := db.Query("select 2.5").ScanSingle(&result)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.NotNil(t, result)
 | 
				
			||||||
 | 
						require.InDelta(t, 2.5, result, 0.001)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithMissingBinds(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						var result float64
 | 
				
			||||||
 | 
						err := db.Query("select ?").ScanSingle(&result)
 | 
				
			||||||
 | 
						require.ErrorIs(t, err, ErrMissingBind)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
 | 
					func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
 | 
				
			||||||
	db := openTestDb(t)
 | 
						db := openTestDb(t)
 | 
				
			||||||
	db.Query("update mytable set value=null where key = 'foo'").MustExec()
 | 
						db.Query("update mytable set value=null where key = 'foo'").MustExec()
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user