Add support for proper mutexes
All checks were successful
Build / build (push) Successful in 1m41s

This commit is contained in:
Sebastiaan de Schaetzen 2025-03-06 10:25:10 +01:00
parent 258dcc7180
commit 9d5c0bcbb1
3 changed files with 52 additions and 5 deletions

View File

@ -2,12 +2,14 @@ package mysqlite
import ( import (
"fmt" "fmt"
"sync"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
// Db holds a connection to a SQLite database. // Db holds a connection to a SQLite database.
type Db struct { type Db struct {
Db *sqlite.Conn Db *sqlite.Conn
lock sync.Mutex
} }
// OpenDb opens a new connection to a SQLite database. // OpenDb opens a new connection to a SQLite database.
@ -35,3 +37,11 @@ func (d *Db) MustClose() {
panic(fmt.Sprintf("error closing db: %v", err)) panic(fmt.Sprintf("error closing db: %v", err))
} }
} }
func (d *Db) Lock() {
d.lock.Lock()
}
func (d *Db) Unlock() {
d.lock.Unlock()
}

View File

@ -9,10 +9,20 @@ import (
type Query struct { type Query struct {
stmt *sqlite.Stmt stmt *sqlite.Stmt
err error // Reference to the database. If set, it is assumed that a lock was taken
// by the query that should be freed by the query.
db *Db
err error
} }
func (d *Db) Query(query string) *Query { func (d *Db) Query(query string) *Query {
d.Lock()
q := d.query(query)
q.db = d
return q
}
func (d *Db) query(query string) *Query {
stmt, remaining, err := d.Db.PrepareTransient(query) stmt, remaining, err := d.Db.PrepareTransient(query)
if err != nil { if err != nil {
return &Query{err: err} return &Query{err: err}
@ -45,6 +55,8 @@ func (q *Query) Bind(args ...any) *Query {
} }
func (q *Query) Exec() (rerr error) { func (q *Query) Exec() (rerr error) {
defer q.unlock()
if q.stmt != nil { if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }() defer func() { rerr = q.stmt.Finalize() }()
} }
@ -69,6 +81,7 @@ func (q *Query) MustExec() {
} }
func (q *Query) ScanSingle(results ...any) (rerr error) { func (q *Query) ScanSingle(results ...any) (rerr error) {
defer q.unlock()
// Scan rows // Scan rows
if q.stmt != nil { if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }() defer func() { rerr = q.stmt.Finalize() }()
@ -114,6 +127,12 @@ func (q *Query) MustScanSingle(results ...any) {
} }
} }
func (q *Query) unlock() {
if q.db != nil {
q.db.Unlock()
}
}
type Rows struct { type Rows struct {
query *Query query *Query
} }
@ -125,6 +144,7 @@ func (q *Query) ScanMulti() (*Rows, error) {
} }
func (r *Rows) Finish() error { func (r *Rows) Finish() error {
defer r.query.unlock()
return r.query.stmt.Finalize() return r.query.stmt.Finalize()
} }

View File

@ -7,7 +7,8 @@ type Tx struct {
} }
func (d *Db) Begin() (*Tx, error) { func (d *Db) Begin() (*Tx, error) {
err := d.Query("BEGIN").Exec() d.Lock()
err := d.query("BEGIN").Exec()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -15,10 +16,16 @@ func (d *Db) Begin() (*Tx, error) {
} }
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
defer tx.unlock()
return tx.Query("COMMIT").Exec() return tx.Query("COMMIT").Exec()
} }
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
if tx.db == nil {
// The transaction was already commited
return nil
}
defer tx.unlock()
return tx.Query("ROLLBACK").Exec() return tx.Query("ROLLBACK").Exec()
} }
@ -29,6 +36,16 @@ func (tx *Tx) MustRollback() {
} }
} }
func (tx *Tx) Query(query string) *Query { func (tx *Tx) unlock() {
return tx.db.Query(query) if tx.db != nil {
tx.db.Unlock()
tx.db = nil
}
}
func (tx *Tx) Query(query string) *Query {
if tx.db == nil {
panic("query was performed on a transaction after Commit or Rollback")
}
return tx.db.query(query)
} }