Compare commits
	
		
			11 Commits
		
	
	
		
			850e4a27d8
			...
			master
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 7e44426452 | |||
| 33f1a94fb2 | |||
| 6880fb5af4 | |||
| e9b12dc3f6 | |||
| 7daf1915a5 | |||
| 029cf6ce01 | |||
| 278d7ed497 | |||
| dd6be6b9b6 | |||
| 12a87a8762 | |||
| 87f10c73d6 | |||
| 58d63b6cf3 | 
							
								
								
									
										7
									
								
								errors.go
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										7
									
								
								errors.go
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,7 @@ | ||||
| package mysqlite | ||||
|  | ||||
| import "errors" | ||||
|  | ||||
| var ErrNoRows = errors.New("mysqlite: no rows returned") | ||||
| var ErrMissingBind = errors.New("mysqlite: missing bind value") | ||||
| var ErrMissingScan = errors.New("mysqlite: missing scan value") | ||||
| @@ -42,7 +42,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { | ||||
| 	log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) | ||||
|  | ||||
| 	// Create a backup if we're not on the latest version | ||||
| 	if currentVersion != latestVersion && d.source != ":memory:" { | ||||
| 	if currentVersion != 0 && currentVersion != latestVersion && d.source != ":memory:" { | ||||
| 		target := d.source + ".backup." + strconv.Itoa(currentVersion) | ||||
| 		log.Printf("Creating backup of database to %s", target) | ||||
| 		data, err := d.Db.Serialize("main") | ||||
|   | ||||
							
								
								
									
										90
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										90
									
								
								query.go
									
									
									
									
									
								
							| @@ -11,8 +11,10 @@ type Query struct { | ||||
| 	stmt *sqlite.Stmt | ||||
| 	// Reference to the database. If set, it is assumed that a lock was taken | ||||
| 	// by the query that should be freed by the query. | ||||
| 	db  *Db | ||||
| 	err error | ||||
| 	db *Db | ||||
| 	// The number of bound arguments | ||||
| 	binds int | ||||
| 	err   error | ||||
| } | ||||
|  | ||||
| func (d *Db) Query(query string) *Query { | ||||
| @@ -34,21 +36,58 @@ func (d *Db) query(query string) *Query { | ||||
| } | ||||
|  | ||||
| func (q *Query) Bind(args ...any) *Query { | ||||
| 	into := 0 | ||||
| 	return q.bindInto(&into, args...) | ||||
| } | ||||
|  | ||||
| func (q *Query) bindInto(into *int, args ...any) *Query { | ||||
| 	if q.err != nil || q.stmt == nil { | ||||
| 		return q | ||||
| 	} | ||||
| 	for i, arg := range args { | ||||
| 		*into++ | ||||
| 		if arg == nil { | ||||
| 			q.stmt.BindNull(*into) | ||||
| 			q.binds++ | ||||
| 			continue | ||||
| 		} | ||||
| 		v := reflect.ValueOf(arg) | ||||
| 		if v.Kind() == reflect.Ptr { | ||||
| 			if v.IsNil() { | ||||
| 				q.stmt.BindNull(*into) | ||||
| 				q.binds++ | ||||
| 				continue | ||||
| 			} | ||||
| 			arg = v.Elem().Interface() | ||||
| 		} | ||||
| 		if asString, ok := arg.(string); ok { | ||||
| 			q.stmt.BindText(i+1, asString) | ||||
| 			q.stmt.BindText(*into, asString) | ||||
| 			q.binds++ | ||||
| 		} else if asInt, ok := arg.(int); ok { | ||||
| 			q.stmt.BindInt64(i+1, int64(asInt)) | ||||
| 		} else if asInt, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(i+1, asInt) | ||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | ||||
| 			q.binds++ | ||||
| 		} else if asInt64, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(*into, asInt64) | ||||
| 			q.binds++ | ||||
| 		} else if asFloat, ok := arg.(float64); ok { | ||||
| 			q.stmt.BindFloat(*into, asFloat) | ||||
| 			q.binds++ | ||||
| 		} else if asBool, ok := arg.(bool); ok { | ||||
| 			q.stmt.BindBool(i+1, asBool) | ||||
| 			q.stmt.BindBool(*into, asBool) | ||||
| 			q.binds++ | ||||
| 		} else { | ||||
| 			q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 			return q | ||||
| 			// Check if the argument is a slice or array of any type | ||||
| 			v = reflect.ValueOf(arg) | ||||
| 			if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { | ||||
| 				*into-- | ||||
| 				for i := 0; i < v.Len(); i++ { | ||||
| 					q.bindInto(into, v.Index(i).Interface()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				*into-- | ||||
| 				q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 				return q | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return q | ||||
| @@ -80,6 +119,23 @@ func (q *Query) MustExec() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanColumns(results *[]string) *Query { | ||||
| 	if q.err != nil { | ||||
| 		return q | ||||
| 	} | ||||
|  | ||||
| 	// Ensure the number of results matches the number of columns | ||||
| 	if q.stmt.ColumnCount() != len(*results) { | ||||
| 		*results = make([]string, q.stmt.ColumnCount()) | ||||
| 	} | ||||
|  | ||||
| 	// Fetch the column names | ||||
| 	for i := 0; i < q.stmt.ColumnCount(); i++ { | ||||
| 		(*results)[i] = q.stmt.ColumnName(i) | ||||
| 	} | ||||
| 	return q | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanSingle(results ...any) (rerr error) { | ||||
| 	defer q.unlock() | ||||
| 	// Scan rows | ||||
| @@ -100,7 +156,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !hasResult { | ||||
| 		return fmt.Errorf("did not return any rows") | ||||
| 		return ErrNoRows | ||||
| 	} | ||||
|  | ||||
| 	// Scan its columns | ||||
| @@ -138,6 +194,13 @@ type Rows struct { | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanMulti() (*Rows, error) { | ||||
| 	if q.binds != q.stmt.BindParamCount() { | ||||
| 		return nil, ErrMissingBind | ||||
| 	} | ||||
|  | ||||
| 	if q.err != nil { | ||||
| 		return nil, q.err | ||||
| 	} | ||||
| 	return &Rows{ | ||||
| 		query: q, | ||||
| 	}, nil | ||||
| @@ -172,6 +235,9 @@ func (r *Rows) MustNext() bool { | ||||
| } | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	if r.query.stmt.ColumnCount() != len(results) { | ||||
| 		return ErrMissingScan | ||||
| 	} | ||||
| 	for i, arg := range results { | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| @@ -186,6 +252,10 @@ func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 		*asString = r.query.stmt.ColumnText(i) | ||||
| 	} else if asInt, ok := arg.(*int); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asInt, ok := arg.(*int64); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt64(i) | ||||
| 	} else if asFloat, ok := arg.(*float64); ok { | ||||
| 		*asFloat = r.query.stmt.ColumnFloat(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
|   | ||||
							
								
								
									
										154
									
								
								query_test.go
									
									
									
									
									
								
							
							
						
						
									
										154
									
								
								query_test.go
									
									
									
									
									
								
							| @@ -1,6 +1,7 @@ | ||||
| package mysqlite | ||||
|  | ||||
| import ( | ||||
| 	"errors" | ||||
| 	"github.com/stretchr/testify/require" | ||||
| 	"testing" | ||||
| ) | ||||
| @@ -19,6 +20,44 @@ func TestSimpleQuery(t *testing.T) { | ||||
| 	require.Equal(t, 1, count, "expected empty count") | ||||
| } | ||||
|  | ||||
| func TestSimpleQueryWithNoResults(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var count int | ||||
| 	err := db.Query("select 1 from mytable where key=999").ScanSingle(&count) | ||||
| 	require.Equal(t, ErrNoRows, err) | ||||
| 	require.True(t, errors.Is(err, ErrNoRows)) | ||||
| } | ||||
|  | ||||
| func TestScanWithMissingValues(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var count int | ||||
| 	err := db.Query("select 1, 2").ScanSingle(&count) | ||||
| 	require.Equal(t, ErrMissingScan, err) | ||||
| } | ||||
|  | ||||
| func TestScanColumns(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var columns []string | ||||
| 	var key string | ||||
| 	err := db.Query("select `key` from mytable"). | ||||
| 		ScanColumns(&columns). | ||||
| 		ScanSingle(&key) | ||||
| 	require.NoError(t, err, "expected no error scanning columns") | ||||
| 	require.Equal(t, 1, len(columns), "expected one column") | ||||
| 	require.Equal(t, "key", columns[0], "expected column name 'key'") | ||||
| 	require.Equal(t, "foo", key, "expected key to be 'foo'") | ||||
| } | ||||
|  | ||||
| func TestBindInt64(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value int64 | ||||
| 	var result int64 | ||||
| 	value = 5 | ||||
| 	err := db.Query("select ?").Bind(&value).ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Equal(t, int64(5), result) | ||||
| } | ||||
|  | ||||
| func TestSimpleQueryWithArgs(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var value string | ||||
| @@ -85,12 +124,70 @@ func TestUpdateQuery(t *testing.T) { | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithWrongArguments(t *testing.T) { | ||||
| 	type S struct { | ||||
| 		Field string | ||||
| 	} | ||||
| 	db := openTestDb(t) | ||||
| 	value := "ipsum" | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() | ||||
| 	abc := S{ | ||||
| 		Field: "ipsum", | ||||
| 	} | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec() | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithPointerValue(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		value := "ipsum" | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "ipsum", value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithSetPointerValue(t *testing.T) { | ||||
| 	type S struct { | ||||
| 		value *string | ||||
| 	} | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		s := S{nil} | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(s.value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value *string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, (*string)(nil), value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithNullValue(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value *string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Nil(t, value) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result *string | ||||
| @@ -100,9 +197,34 @@ func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	require.Equal(t, "bar", *result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithInt64Scan(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result int64 | ||||
| 	err := db.Query("select 2").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.Equal(t, int64(2), result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithFloat64Scan(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select 2.5").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.InDelta(t, 2.5, result, 0.001) | ||||
| } | ||||
|  | ||||
| func TestQueryWithMissingBinds(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result float64 | ||||
| 	err := db.Query("select ?").ScanSingle(&result) | ||||
| 	require.ErrorIs(t, err, ErrMissingBind) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=NULL where key = 'foo'").MustExec() | ||||
| 	db.Query("update mytable set value=null where key = 'foo'").MustExec() | ||||
| 	myString := "some string" | ||||
| 	var result *string | ||||
| 	result = &myString | ||||
| @@ -133,3 +255,29 @@ func TestTransactionRollback(t *testing.T) { | ||||
| 	db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "bar", value, "expected original value after rollback") | ||||
| } | ||||
|  | ||||
| func TestQueryWithInClause(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	// Insert additional test rows | ||||
| 	db.Query("insert into mytable(key, value) values ('key1', 'value1')").MustExec() | ||||
| 	db.Query("insert into mytable(key, value) values ('key2', 'value2')").MustExec() | ||||
|  | ||||
| 	// Execute query with IN clause | ||||
| 	args := []string{"foo", "key2"} | ||||
| 	rows, err := db.Query("select key, value from mytable where key in (?, ?)").Bind(args).ScanMulti() | ||||
| 	require.NoError(t, err) | ||||
| 	defer rows.MustFinish() | ||||
|  | ||||
| 	// Check results | ||||
| 	results := make(map[string]string) | ||||
| 	for rows.MustNext() { | ||||
| 		var key, value string | ||||
| 		rows.MustScan(&key, &value) | ||||
| 		results[key] = value | ||||
| 	} | ||||
|  | ||||
| 	// Verify we got exactly the expected results | ||||
| 	require.Equal(t, 2, len(results), "expected 2 matching rows") | ||||
| 	require.Equal(t, "bar", results["foo"], "unexpected value for key 'foo'") | ||||
| 	require.Equal(t, "value2", results["key2"], "unexpected value for key 'key2'") | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user