Compare commits
	
		
			3 Commits
		
	
	
		
			v0.6.0
			...
			850e4a27d8
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 850e4a27d8 | |||
| 3e7455ef31 | |||
| 94ae36305a | 
@@ -9,6 +9,7 @@ import (
 | 
				
			|||||||
// Db holds a connection to a SQLite database.
 | 
					// Db holds a connection to a SQLite database.
 | 
				
			||||||
type Db struct {
 | 
					type Db struct {
 | 
				
			||||||
	Db     *sqlite.Conn
 | 
						Db     *sqlite.Conn
 | 
				
			||||||
 | 
						source string
 | 
				
			||||||
	lock   sync.Mutex
 | 
						lock   sync.Mutex
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -21,7 +22,7 @@ func OpenDb(databaseSource string) (*Db, error) {
 | 
				
			|||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return nil, err
 | 
							return nil, err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return &Db{Db: conn}, nil
 | 
						return &Db{Db: conn, source: databaseSource}, nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
// Close closes the database.
 | 
					// Close closes the database.
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										15
									
								
								migrator.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								migrator.go
									
									
									
									
									
								
							@@ -4,6 +4,7 @@ import (
 | 
				
			|||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
	"io/fs"
 | 
						"io/fs"
 | 
				
			||||||
	"log"
 | 
						"log"
 | 
				
			||||||
 | 
						"os"
 | 
				
			||||||
	"path"
 | 
						"path"
 | 
				
			||||||
	"strconv"
 | 
						"strconv"
 | 
				
			||||||
	"strings"
 | 
						"strings"
 | 
				
			||||||
@@ -40,6 +41,20 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
 | 
				
			|||||||
	}
 | 
						}
 | 
				
			||||||
	log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
 | 
						log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Create a backup if we're not on the latest version
 | 
				
			||||||
 | 
						if currentVersion != latestVersion && d.source != ":memory:" {
 | 
				
			||||||
 | 
							target := d.source + ".backup." + strconv.Itoa(currentVersion)
 | 
				
			||||||
 | 
							log.Printf("Creating backup of database to %s", target)
 | 
				
			||||||
 | 
							data, err := d.Db.Serialize("main")
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("error serializing database: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							err = os.WriteFile(target, data, 0644)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return fmt.Errorf("error writing backup: %v", err)
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
	// If we are no up-to-date, bring the db up-to-date
 | 
						// If we are no up-to-date, bring the db up-to-date
 | 
				
			||||||
	for currentVersion != latestVersion {
 | 
						for currentVersion != latestVersion {
 | 
				
			||||||
		targetVersion := currentVersion + 1
 | 
							targetVersion := currentVersion + 1
 | 
				
			||||||
 
 | 
				
			|||||||
							
								
								
									
										24
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								query.go
									
									
									
									
									
								
							@@ -173,12 +173,23 @@ func (r *Rows) MustNext() bool {
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
func (r *Rows) Scan(results ...any) error {
 | 
					func (r *Rows) Scan(results ...any) error {
 | 
				
			||||||
	for i, arg := range results {
 | 
						for i, arg := range results {
 | 
				
			||||||
 | 
							err := r.scanArgument(i, arg)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) scanArgument(i int, arg any) error {
 | 
				
			||||||
	if asString, ok := arg.(*string); ok {
 | 
						if asString, ok := arg.(*string); ok {
 | 
				
			||||||
		*asString = r.query.stmt.ColumnText(i)
 | 
							*asString = r.query.stmt.ColumnText(i)
 | 
				
			||||||
	} else if asInt, ok := arg.(*int); ok {
 | 
						} else if asInt, ok := arg.(*int); ok {
 | 
				
			||||||
		*asInt = r.query.stmt.ColumnInt(i)
 | 
							*asInt = r.query.stmt.ColumnInt(i)
 | 
				
			||||||
	} else if asBool, ok := arg.(*bool); ok {
 | 
						} else if asBool, ok := arg.(*bool); ok {
 | 
				
			||||||
		*asBool = r.query.stmt.ColumnBool(i)
 | 
							*asBool = r.query.stmt.ColumnBool(i)
 | 
				
			||||||
 | 
						} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
 | 
				
			||||||
 | 
							return r.handleNullableType(i, arg)
 | 
				
			||||||
	} else {
 | 
						} else {
 | 
				
			||||||
		if reflect.TypeOf(arg).Kind() != reflect.Ptr {
 | 
							if reflect.TypeOf(arg).Kind() != reflect.Ptr {
 | 
				
			||||||
			return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
 | 
								return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
 | 
				
			||||||
@@ -186,6 +197,19 @@ func (r *Rows) Scan(results ...any) error {
 | 
				
			|||||||
		name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
 | 
							name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
 | 
				
			||||||
		return fmt.Errorf("unsupported column type *%s at index %d", name, i)
 | 
							return fmt.Errorf("unsupported column type *%s at index %d", name, i)
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) handleNullableType(i int, asPtr any) error {
 | 
				
			||||||
 | 
						if r.query.stmt.ColumnIsNull(i) {
 | 
				
			||||||
 | 
							reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem()))
 | 
				
			||||||
 | 
						} else {
 | 
				
			||||||
 | 
							value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface()
 | 
				
			||||||
 | 
							err := r.scanArgument(i, value)
 | 
				
			||||||
 | 
							if err != nil {
 | 
				
			||||||
 | 
								return err
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value))
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -90,3 +90,46 @@ func TestUpdateQueryWithWrongArguments(t *testing.T) {
 | 
				
			|||||||
	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec()
 | 
						err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec()
 | 
				
			||||||
	require.Error(t, err)
 | 
						require.Error(t, err)
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithPointerStringArguments(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						var result *string
 | 
				
			||||||
 | 
						err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.NotNil(t, result)
 | 
				
			||||||
 | 
						require.Equal(t, "bar", *result)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						db.Query("update mytable set value=NULL where key = 'foo'").MustExec()
 | 
				
			||||||
 | 
						myString := "some string"
 | 
				
			||||||
 | 
						var result *string
 | 
				
			||||||
 | 
						result = &myString
 | 
				
			||||||
 | 
						err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						require.Nil(t, result)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestDeleteQuery(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						db.Query("delete from mytable where key = 'foo'").MustExec()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var count int
 | 
				
			||||||
 | 
						db.Query("select count(*) from mytable where key = 'foo'").MustScanSingle(&count)
 | 
				
			||||||
 | 
						require.Equal(t, 0, count, "expected row to be deleted")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestTransactionRollback(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						func() {
 | 
				
			||||||
 | 
							tx := db.MustBegin()
 | 
				
			||||||
 | 
							defer tx.MustRollback()
 | 
				
			||||||
 | 
							tx.Query("update mytable set value = 'ipsum' where key = 'foo'").MustExec()
 | 
				
			||||||
 | 
							// Intentionally not committing the transaction
 | 
				
			||||||
 | 
						}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var value string
 | 
				
			||||||
 | 
						db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value)
 | 
				
			||||||
 | 
						require.Equal(t, "bar", value, "expected original value after rollback")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user