Compare commits
	
		
			5 Commits
		
	
	
		
			v0.6.0
			...
			87f10c73d6
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 87f10c73d6 | |||
| 58d63b6cf3 | |||
| 850e4a27d8 | |||
| 3e7455ef31 | |||
| 94ae36305a | 
| @@ -8,8 +8,9 @@ import ( | ||||
|  | ||||
| // Db holds a connection to a SQLite database. | ||||
| type Db struct { | ||||
| 	Db   *sqlite.Conn | ||||
| 	lock sync.Mutex | ||||
| 	Db     *sqlite.Conn | ||||
| 	source string | ||||
| 	lock   sync.Mutex | ||||
| } | ||||
|  | ||||
| // OpenDb opens a new connection to a SQLite database. | ||||
| @@ -21,7 +22,7 @@ func OpenDb(databaseSource string) (*Db, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Db{Db: conn}, nil | ||||
| 	return &Db{Db: conn, source: databaseSource}, nil | ||||
| } | ||||
|  | ||||
| // Close closes the database. | ||||
|   | ||||
							
								
								
									
										15
									
								
								migrator.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								migrator.go
									
									
									
									
									
								
							| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"io/fs" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @@ -40,6 +41,20 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { | ||||
| 	} | ||||
| 	log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) | ||||
|  | ||||
| 	// Create a backup if we're not on the latest version | ||||
| 	if currentVersion != latestVersion && d.source != ":memory:" { | ||||
| 		target := d.source + ".backup." + strconv.Itoa(currentVersion) | ||||
| 		log.Printf("Creating backup of database to %s", target) | ||||
| 		data, err := d.Db.Serialize("main") | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error serializing database: %v", err) | ||||
| 		} | ||||
| 		err = os.WriteFile(target, data, 0644) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error writing backup: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// If we are no up-to-date, bring the db up-to-date | ||||
| 	for currentVersion != latestVersion { | ||||
| 		targetVersion := currentVersion + 1 | ||||
|   | ||||
							
								
								
									
										79
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										79
									
								
								query.go
									
									
									
									
									
								
							| @@ -34,21 +34,37 @@ func (d *Db) query(query string) *Query { | ||||
| } | ||||
|  | ||||
| func (q *Query) Bind(args ...any) *Query { | ||||
| 	into := 0 | ||||
| 	return q.bindInto(&into, args...) | ||||
| } | ||||
|  | ||||
| func (q *Query) bindInto(into *int, args ...any) *Query { | ||||
| 	if q.err != nil || q.stmt == nil { | ||||
| 		return q | ||||
| 	} | ||||
| 	for i, arg := range args { | ||||
| 		*into++ | ||||
| 		if asString, ok := arg.(string); ok { | ||||
| 			q.stmt.BindText(i+1, asString) | ||||
| 			q.stmt.BindText(*into, asString) | ||||
| 		} else if asInt, ok := arg.(int); ok { | ||||
| 			q.stmt.BindInt64(i+1, int64(asInt)) | ||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | ||||
| 		} else if asInt, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(i+1, asInt) | ||||
| 			q.stmt.BindInt64(*into, asInt) | ||||
| 		} else if asBool, ok := arg.(bool); ok { | ||||
| 			q.stmt.BindBool(i+1, asBool) | ||||
| 			q.stmt.BindBool(*into, asBool) | ||||
| 		} else { | ||||
| 			q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 			return q | ||||
| 			// Check if the argument is a slice or array of any type | ||||
| 			v := reflect.ValueOf(arg) | ||||
| 			if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { | ||||
| 				*into-- | ||||
| 				for i := 0; i < v.Len(); i++ { | ||||
| 					q.bindInto(into, v.Index(i).Interface()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				*into-- | ||||
| 				q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 				return q | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return q | ||||
| @@ -138,6 +154,9 @@ type Rows struct { | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanMulti() (*Rows, error) { | ||||
| 	if q.err != nil { | ||||
| 		return nil, q.err | ||||
| 	} | ||||
| 	return &Rows{ | ||||
| 		query: q, | ||||
| 	}, nil | ||||
| @@ -173,23 +192,47 @@ func (r *Rows) MustNext() bool { | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	for i, arg := range results { | ||||
| 		if asString, ok := arg.(*string); ok { | ||||
| 			*asString = r.query.stmt.ColumnText(i) | ||||
| 		} else if asInt, ok := arg.(*int); ok { | ||||
| 			*asInt = r.query.stmt.ColumnInt(i) | ||||
| 		} else if asBool, ok := arg.(*bool); ok { | ||||
| 			*asBool = r.query.stmt.ColumnBool(i) | ||||
| 		} else { | ||||
| 			if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 				return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 			} | ||||
| 			name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 			return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 	if asString, ok := arg.(*string); ok { | ||||
| 		*asString = r.query.stmt.ColumnText(i) | ||||
| 	} else if asInt, ok := arg.(*int); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
| 		return r.handleNullableType(i, arg) | ||||
| 	} else { | ||||
| 		if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 			return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 		} | ||||
| 		name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 		return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) handleNullableType(i int, asPtr any) error { | ||||
| 	if r.query.stmt.ColumnIsNull(i) { | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem())) | ||||
| 	} else { | ||||
| 		value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface() | ||||
| 		err := r.scanArgument(i, value) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) MustScan(results ...any) { | ||||
| 	err := r.Scan(results...) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -90,3 +90,72 @@ func TestUpdateQueryWithWrongArguments(t *testing.T) { | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result *string | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.Equal(t, "bar", *result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=NULL where key = 'foo'").MustExec() | ||||
| 	myString := "some string" | ||||
| 	var result *string | ||||
| 	result = &myString | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Nil(t, result) | ||||
| } | ||||
|  | ||||
| func TestDeleteQuery(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("delete from mytable where key = 'foo'").MustExec() | ||||
|  | ||||
| 	var count int | ||||
| 	db.Query("select count(*) from mytable where key = 'foo'").MustScanSingle(&count) | ||||
| 	require.Equal(t, 0, count, "expected row to be deleted") | ||||
| } | ||||
|  | ||||
| func TestTransactionRollback(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("update mytable set value = 'ipsum' where key = 'foo'").MustExec() | ||||
| 		// Intentionally not committing the transaction | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "bar", value, "expected original value after rollback") | ||||
| } | ||||
|  | ||||
| func TestQueryWithInClause(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	// Insert additional test rows | ||||
| 	db.Query("insert into mytable(key, value) values ('key1', 'value1')").MustExec() | ||||
| 	db.Query("insert into mytable(key, value) values ('key2', 'value2')").MustExec() | ||||
|  | ||||
| 	// Execute query with IN clause | ||||
| 	args := []string{"foo", "key2"} | ||||
| 	rows, err := db.Query("select key, value from mytable where key in (?, ?)").Bind(args).ScanMulti() | ||||
| 	require.NoError(t, err) | ||||
| 	defer rows.MustFinish() | ||||
|  | ||||
| 	// Check results | ||||
| 	results := make(map[string]string) | ||||
| 	for rows.MustNext() { | ||||
| 		var key, value string | ||||
| 		rows.MustScan(&key, &value) | ||||
| 		results[key] = value | ||||
| 	} | ||||
|  | ||||
| 	// Verify we got exactly the expected results | ||||
| 	require.Equal(t, 2, len(results), "expected 2 matching rows") | ||||
| 	require.Equal(t, "bar", results["foo"], "unexpected value for key 'foo'") | ||||
| 	require.Equal(t, "value2", results["key2"], "unexpected value for key 'key2'") | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user