4 Commits

Author SHA1 Message Date
0a177e0b46 Allow multi-statement migration scripts
All checks were successful
Build / build (push) Successful in 1m7s
2025-03-12 08:24:27 +01:00
82c7f57078 Allow multi-statement migration scripts
Some checks failed
Build / build (push) Has been cancelled
2025-03-12 08:23:21 +01:00
9d5c0bcbb1 Add support for proper mutexes
All checks were successful
Build / build (push) Successful in 1m41s
2025-03-06 10:25:10 +01:00
258dcc7180 Add reading of multiple rows with optional iterator
All checks were successful
Build / build (push) Successful in 1m18s
2025-02-20 12:02:16 +01:00
7 changed files with 232 additions and 22 deletions

View File

@@ -2,12 +2,14 @@ package mysqlite
import ( import (
"fmt" "fmt"
"sync"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
// Db holds a connection to a SQLite database. // Db holds a connection to a SQLite database.
type Db struct { type Db struct {
Db *sqlite.Conn Db *sqlite.Conn
lock sync.Mutex
} }
// OpenDb opens a new connection to a SQLite database. // OpenDb opens a new connection to a SQLite database.
@@ -35,3 +37,11 @@ func (d *Db) MustClose() {
panic(fmt.Sprintf("error closing db: %v", err)) panic(fmt.Sprintf("error closing db: %v", err))
} }
} }
func (d *Db) Lock() {
d.lock.Lock()
}
func (d *Db) Unlock() {
d.lock.Unlock()
}

View File

@@ -50,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
} }
err = performSingleMigration(err, d, migrationScript, targetVersion) err = performSingleMigration(d, migrationScript, targetVersion)
if err != nil { if err != nil {
return err return err
} }
@@ -61,21 +61,29 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return nil return nil
} }
func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error { func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
script := string(migrationScript)
// Split script based on semicolon
statements := strings.Split(script, ";")
tx, err := d.Begin() tx, err := d.Begin()
if err != nil { if err != nil {
return fmt.Errorf("error beginning transaction: %v", err) return fmt.Errorf("error beginning transaction: %v", err)
} }
defer tx.MustRollback() defer tx.MustRollback()
err = tx.Query(string(migrationScript)).Exec() for _, statement := range statements {
if err != nil { statement = strings.TrimSpace(statement)
return fmt.Errorf("error performing migration: %v", err) err = tx.Query(statement).Exec()
} if err != nil {
return fmt.Errorf("error performing migration: %v", err)
}
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
if err != nil {
return fmt.Errorf("error updating version: %v", err)
}
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
if err != nil {
return fmt.Errorf("error updating version: %v", err)
} }
err = tx.Commit() err = tx.Commit()

View File

@@ -17,4 +17,8 @@ func TestDb_MigrateDb(t *testing.T) {
var count int var count int
db.Query("select count(*) from mydata").MustScanSingle(&count) db.Query("select count(*) from mydata").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database") require.Equal(t, 1, count, "incorrect number of rows in database")
count = 0
db.Query("select count(*) from multiTable").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database")
} }

145
query.go
View File

@@ -2,16 +2,27 @@ package mysqlite
import ( import (
"fmt" "fmt"
"iter"
"reflect" "reflect"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
type Query struct { type Query struct {
stmt *sqlite.Stmt stmt *sqlite.Stmt
err error // Reference to the database. If set, it is assumed that a lock was taken
// by the query that should be freed by the query.
db *Db
err error
} }
func (d *Db) Query(query string) *Query { func (d *Db) Query(query string) *Query {
d.Lock()
q := d.query(query)
q.db = d
return q
}
func (d *Db) query(query string) *Query {
stmt, remaining, err := d.Db.PrepareTransient(query) stmt, remaining, err := d.Db.PrepareTransient(query)
if err != nil { if err != nil {
return &Query{err: err} return &Query{err: err}
@@ -44,6 +55,8 @@ func (q *Query) Bind(args ...any) *Query {
} }
func (q *Query) Exec() (rerr error) { func (q *Query) Exec() (rerr error) {
defer q.unlock()
if q.stmt != nil { if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }() defer func() { rerr = q.stmt.Finalize() }()
} }
@@ -68,27 +81,104 @@ func (q *Query) MustExec() {
} }
func (q *Query) ScanSingle(results ...any) (rerr error) { func (q *Query) ScanSingle(results ...any) (rerr error) {
defer q.unlock()
// Scan rows
if q.stmt != nil { if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }() defer func() { rerr = q.stmt.Finalize() }()
} }
if q.err != nil { if q.err != nil {
return q.err return q.err
} }
rowReturned, err := q.stmt.Step() rows, err := q.ScanMulti()
if err != nil { if err != nil {
return err return err
} }
if !rowReturned {
// Fetch the first row
hasResult, err := rows.Next()
if err != nil {
return err
}
if !hasResult {
return fmt.Errorf("did not return any rows") return fmt.Errorf("did not return any rows")
} }
// Scan its columns
err = rows.Scan(results...)
if err != nil {
return err
}
// Ensure there are no more rows
hasResult, err = rows.Next()
if err != nil {
return err
}
if hasResult {
return fmt.Errorf("returned more than one row")
}
return nil
}
func (q *Query) MustScanSingle(results ...any) {
err := q.ScanSingle(results...)
if err != nil {
panic(fmt.Sprintf("error getting results: %v", err))
}
}
func (q *Query) unlock() {
if q.db != nil {
q.db.Unlock()
}
}
type Rows struct {
query *Query
}
func (q *Query) ScanMulti() (*Rows, error) {
return &Rows{
query: q,
}, nil
}
func (r *Rows) Finish() error {
defer r.query.unlock()
return r.query.stmt.Finalize()
}
func (r *Rows) MustFinish() {
err := r.Finish()
if err != nil {
panic(err)
}
}
func (r *Rows) Next() (bool, error) {
gotRow, err := r.query.stmt.Step()
if err != nil {
return false, err
}
return gotRow, nil
}
func (r *Rows) MustNext() bool {
gotRow, err := r.Next()
if err != nil {
panic(err)
}
return gotRow
}
func (r *Rows) Scan(results ...any) error {
for i, arg := range results { for i, arg := range results {
if asString, ok := arg.(*string); ok { if asString, ok := arg.(*string); ok {
*asString = q.stmt.ColumnText(i) *asString = r.query.stmt.ColumnText(i)
} else if asInt, ok := arg.(*int); ok { } else if asInt, ok := arg.(*int); ok {
*asInt = q.stmt.ColumnInt(i) *asInt = r.query.stmt.ColumnInt(i)
} else if asBool, ok := arg.(*bool); ok { } else if asBool, ok := arg.(*bool); ok {
*asBool = q.stmt.ColumnBool(i) *asBool = r.query.stmt.ColumnBool(i)
} else { } else {
if reflect.TypeOf(arg).Kind() != reflect.Ptr { if reflect.TypeOf(arg).Kind() != reflect.Ptr {
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
@@ -100,9 +190,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
return nil return nil
} }
func (q *Query) MustScanSingle(results ...any) { func (r *Rows) MustScan(results ...any) {
err := q.ScanSingle(results...) err := r.Scan(results...)
if err != nil { if err != nil {
panic(fmt.Sprintf("error getting results: %v", err)) panic(err)
}
}
func (q *Query) Range(err *error) iter.Seq[*Rows] {
return func(yield func(*Rows) bool) {
// Start the scan
rows, terr := q.ScanMulti()
if terr != nil {
*err = terr
return
}
// Ensure we close and return any errors
defer func() {
terr := rows.Finish()
if terr != nil {
*err = terr
}
}()
// Loop over each record
for {
// Get the record
hasRow, terr := rows.Next()
if terr != nil {
*err = terr
return
}
if !hasRow {
return
}
// Pass it to the range body
if !yield(&Rows{query: q}) {
return
}
}
} }
} }

View File

@@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) {
db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value) db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
require.Equal(t, "bar", value, "bad value returned") require.Equal(t, "bar", value, "bad value returned")
} }
func TestQueryWithTwoRows(t *testing.T) {
db := openTestDb(t)
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
rows, err := db.Query("select value from mytable").ScanMulti()
require.NoError(t, err)
defer rows.MustFinish()
require.True(t, rows.MustNext(), "expected first row")
var value string
rows.MustScan(&value)
require.Equal(t, "bar", value, "bad value returned")
require.True(t, rows.MustNext(), "expected second row")
rows.MustScan(&value)
require.Equal(t, "ipsum", value, "bad value returned")
require.False(t, rows.MustNext(), "expected no more rows")
}
func TestQueryWithRange(t *testing.T) {
db := openTestDb(t)
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
var err error
index := 0
for row := range db.Query("select value from mytable").Range(&err) {
var value string
row.MustScan(&value)
if index == 0 {
require.Equal(t, "bar", value)
} else if index == 1 {
require.Equal(t, "ipsum", value)
} else {
require.FailNow(t, "more rows than expected")
}
index++
}
require.NoError(t, err)
}

View File

@@ -0,0 +1,3 @@
create table multiTable(value text);
insert into multiTable(value) values ('testValue');

View File

@@ -7,7 +7,8 @@ type Tx struct {
} }
func (d *Db) Begin() (*Tx, error) { func (d *Db) Begin() (*Tx, error) {
err := d.Query("BEGIN").Exec() d.Lock()
err := d.query("BEGIN").Exec()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -15,10 +16,16 @@ func (d *Db) Begin() (*Tx, error) {
} }
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
defer tx.unlock()
return tx.Query("COMMIT").Exec() return tx.Query("COMMIT").Exec()
} }
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
if tx.db == nil {
// The transaction was already commited
return nil
}
defer tx.unlock()
return tx.Query("ROLLBACK").Exec() return tx.Query("ROLLBACK").Exec()
} }
@@ -29,6 +36,16 @@ func (tx *Tx) MustRollback() {
} }
} }
func (tx *Tx) Query(query string) *Query { func (tx *Tx) unlock() {
return tx.db.Query(query) if tx.db != nil {
tx.db.Unlock()
tx.db = nil
}
}
func (tx *Tx) Query(query string) *Query {
if tx.db == nil {
panic("query was performed on a transaction after Commit or Rollback")
}
return tx.db.query(query)
} }