Compare commits
4 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0a177e0b46 | |||
| 82c7f57078 | |||
| 9d5c0bcbb1 | |||
| 258dcc7180 |
10
database.go
10
database.go
@@ -2,12 +2,14 @@ package mysqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
// Db holds a connection to a SQLite database.
|
||||
type Db struct {
|
||||
Db *sqlite.Conn
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// OpenDb opens a new connection to a SQLite database.
|
||||
@@ -35,3 +37,11 @@ func (d *Db) MustClose() {
|
||||
panic(fmt.Sprintf("error closing db: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (d *Db) Lock() {
|
||||
d.lock.Lock()
|
||||
}
|
||||
|
||||
func (d *Db) Unlock() {
|
||||
d.lock.Unlock()
|
||||
}
|
||||
|
||||
14
migrator.go
14
migrator.go
@@ -50,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
||||
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
|
||||
}
|
||||
|
||||
err = performSingleMigration(err, d, migrationScript, targetVersion)
|
||||
err = performSingleMigration(d, migrationScript, targetVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -61,14 +61,20 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
|
||||
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
|
||||
script := string(migrationScript)
|
||||
// Split script based on semicolon
|
||||
statements := strings.Split(script, ";")
|
||||
|
||||
tx, err := d.Begin()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error beginning transaction: %v", err)
|
||||
}
|
||||
defer tx.MustRollback()
|
||||
|
||||
err = tx.Query(string(migrationScript)).Exec()
|
||||
for _, statement := range statements {
|
||||
statement = strings.TrimSpace(statement)
|
||||
err = tx.Query(statement).Exec()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error performing migration: %v", err)
|
||||
}
|
||||
@@ -78,6 +84,8 @@ func performSingleMigration(err error, d *Db, migrationScript []byte, targetVers
|
||||
return fmt.Errorf("error updating version: %v", err)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error commiting transaction: %v", err)
|
||||
|
||||
@@ -17,4 +17,8 @@ func TestDb_MigrateDb(t *testing.T) {
|
||||
var count int
|
||||
db.Query("select count(*) from mydata").MustScanSingle(&count)
|
||||
require.Equal(t, 1, count, "incorrect number of rows in database")
|
||||
|
||||
count = 0
|
||||
db.Query("select count(*) from multiTable").MustScanSingle(&count)
|
||||
require.Equal(t, 1, count, "incorrect number of rows in database")
|
||||
}
|
||||
|
||||
143
query.go
143
query.go
@@ -2,16 +2,27 @@ package mysqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"iter"
|
||||
"reflect"
|
||||
"zombiezen.com/go/sqlite"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
stmt *sqlite.Stmt
|
||||
// Reference to the database. If set, it is assumed that a lock was taken
|
||||
// by the query that should be freed by the query.
|
||||
db *Db
|
||||
err error
|
||||
}
|
||||
|
||||
func (d *Db) Query(query string) *Query {
|
||||
d.Lock()
|
||||
q := d.query(query)
|
||||
q.db = d
|
||||
return q
|
||||
}
|
||||
|
||||
func (d *Db) query(query string) *Query {
|
||||
stmt, remaining, err := d.Db.PrepareTransient(query)
|
||||
if err != nil {
|
||||
return &Query{err: err}
|
||||
@@ -44,6 +55,8 @@ func (q *Query) Bind(args ...any) *Query {
|
||||
}
|
||||
|
||||
func (q *Query) Exec() (rerr error) {
|
||||
defer q.unlock()
|
||||
|
||||
if q.stmt != nil {
|
||||
defer func() { rerr = q.stmt.Finalize() }()
|
||||
}
|
||||
@@ -68,27 +81,104 @@ func (q *Query) MustExec() {
|
||||
}
|
||||
|
||||
func (q *Query) ScanSingle(results ...any) (rerr error) {
|
||||
defer q.unlock()
|
||||
// Scan rows
|
||||
if q.stmt != nil {
|
||||
defer func() { rerr = q.stmt.Finalize() }()
|
||||
}
|
||||
if q.err != nil {
|
||||
return q.err
|
||||
}
|
||||
rowReturned, err := q.stmt.Step()
|
||||
rows, err := q.ScanMulti()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !rowReturned {
|
||||
|
||||
// Fetch the first row
|
||||
hasResult, err := rows.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !hasResult {
|
||||
return fmt.Errorf("did not return any rows")
|
||||
}
|
||||
|
||||
// Scan its columns
|
||||
err = rows.Scan(results...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Ensure there are no more rows
|
||||
hasResult, err = rows.Next()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if hasResult {
|
||||
return fmt.Errorf("returned more than one row")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Query) MustScanSingle(results ...any) {
|
||||
err := q.ScanSingle(results...)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error getting results: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) unlock() {
|
||||
if q.db != nil {
|
||||
q.db.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
type Rows struct {
|
||||
query *Query
|
||||
}
|
||||
|
||||
func (q *Query) ScanMulti() (*Rows, error) {
|
||||
return &Rows{
|
||||
query: q,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *Rows) Finish() error {
|
||||
defer r.query.unlock()
|
||||
return r.query.stmt.Finalize()
|
||||
}
|
||||
|
||||
func (r *Rows) MustFinish() {
|
||||
err := r.Finish()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *Rows) Next() (bool, error) {
|
||||
gotRow, err := r.query.stmt.Step()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return gotRow, nil
|
||||
}
|
||||
|
||||
func (r *Rows) MustNext() bool {
|
||||
gotRow, err := r.Next()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return gotRow
|
||||
}
|
||||
|
||||
func (r *Rows) Scan(results ...any) error {
|
||||
for i, arg := range results {
|
||||
if asString, ok := arg.(*string); ok {
|
||||
*asString = q.stmt.ColumnText(i)
|
||||
*asString = r.query.stmt.ColumnText(i)
|
||||
} else if asInt, ok := arg.(*int); ok {
|
||||
*asInt = q.stmt.ColumnInt(i)
|
||||
*asInt = r.query.stmt.ColumnInt(i)
|
||||
} else if asBool, ok := arg.(*bool); ok {
|
||||
*asBool = q.stmt.ColumnBool(i)
|
||||
*asBool = r.query.stmt.ColumnBool(i)
|
||||
} else {
|
||||
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
|
||||
@@ -100,9 +190,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *Query) MustScanSingle(results ...any) {
|
||||
err := q.ScanSingle(results...)
|
||||
func (r *Rows) MustScan(results ...any) {
|
||||
err := r.Scan(results...)
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("error getting results: %v", err))
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Range(err *error) iter.Seq[*Rows] {
|
||||
return func(yield func(*Rows) bool) {
|
||||
// Start the scan
|
||||
rows, terr := q.ScanMulti()
|
||||
if terr != nil {
|
||||
*err = terr
|
||||
return
|
||||
}
|
||||
|
||||
// Ensure we close and return any errors
|
||||
defer func() {
|
||||
terr := rows.Finish()
|
||||
if terr != nil {
|
||||
*err = terr
|
||||
}
|
||||
}()
|
||||
|
||||
// Loop over each record
|
||||
for {
|
||||
// Get the record
|
||||
hasRow, terr := rows.Next()
|
||||
if terr != nil {
|
||||
*err = terr
|
||||
return
|
||||
}
|
||||
if !hasRow {
|
||||
return
|
||||
}
|
||||
|
||||
// Pass it to the range body
|
||||
if !yield(&Rows{query: q}) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) {
|
||||
db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
|
||||
require.Equal(t, "bar", value, "bad value returned")
|
||||
}
|
||||
|
||||
func TestQueryWithTwoRows(t *testing.T) {
|
||||
db := openTestDb(t)
|
||||
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
|
||||
|
||||
rows, err := db.Query("select value from mytable").ScanMulti()
|
||||
require.NoError(t, err)
|
||||
defer rows.MustFinish()
|
||||
|
||||
require.True(t, rows.MustNext(), "expected first row")
|
||||
var value string
|
||||
rows.MustScan(&value)
|
||||
require.Equal(t, "bar", value, "bad value returned")
|
||||
|
||||
require.True(t, rows.MustNext(), "expected second row")
|
||||
rows.MustScan(&value)
|
||||
require.Equal(t, "ipsum", value, "bad value returned")
|
||||
|
||||
require.False(t, rows.MustNext(), "expected no more rows")
|
||||
}
|
||||
|
||||
func TestQueryWithRange(t *testing.T) {
|
||||
db := openTestDb(t)
|
||||
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
|
||||
|
||||
var err error
|
||||
index := 0
|
||||
for row := range db.Query("select value from mytable").Range(&err) {
|
||||
var value string
|
||||
row.MustScan(&value)
|
||||
if index == 0 {
|
||||
require.Equal(t, "bar", value)
|
||||
} else if index == 1 {
|
||||
require.Equal(t, "ipsum", value)
|
||||
} else {
|
||||
require.FailNow(t, "more rows than expected")
|
||||
}
|
||||
index++
|
||||
}
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
3
testMigrations/3_multicomment.sql
Normal file
3
testMigrations/3_multicomment.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
create table multiTable(value text);
|
||||
|
||||
insert into multiTable(value) values ('testValue');
|
||||
@@ -7,7 +7,8 @@ type Tx struct {
|
||||
}
|
||||
|
||||
func (d *Db) Begin() (*Tx, error) {
|
||||
err := d.Query("BEGIN").Exec()
|
||||
d.Lock()
|
||||
err := d.query("BEGIN").Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -15,10 +16,16 @@ func (d *Db) Begin() (*Tx, error) {
|
||||
}
|
||||
|
||||
func (tx *Tx) Commit() error {
|
||||
defer tx.unlock()
|
||||
return tx.Query("COMMIT").Exec()
|
||||
}
|
||||
|
||||
func (tx *Tx) Rollback() error {
|
||||
if tx.db == nil {
|
||||
// The transaction was already commited
|
||||
return nil
|
||||
}
|
||||
defer tx.unlock()
|
||||
return tx.Query("ROLLBACK").Exec()
|
||||
}
|
||||
|
||||
@@ -29,6 +36,16 @@ func (tx *Tx) MustRollback() {
|
||||
}
|
||||
}
|
||||
|
||||
func (tx *Tx) Query(query string) *Query {
|
||||
return tx.db.Query(query)
|
||||
func (tx *Tx) unlock() {
|
||||
if tx.db != nil {
|
||||
tx.db.Unlock()
|
||||
tx.db = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (tx *Tx) Query(query string) *Query {
|
||||
if tx.db == nil {
|
||||
panic("query was performed on a transaction after Commit or Rollback")
|
||||
}
|
||||
return tx.db.query(query)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user