From 9d5c0bcbb1533379f9ca0a94d366c859adac8991 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Thu, 6 Mar 2025 10:25:10 +0100 Subject: [PATCH] Add support for proper mutexes --- database.go | 12 +++++++++++- query.go | 22 +++++++++++++++++++++- transaction.go | 23 ++++++++++++++++++++--- 3 files changed, 52 insertions(+), 5 deletions(-) diff --git a/database.go b/database.go index 7ba2646..a075054 100644 --- a/database.go +++ b/database.go @@ -2,12 +2,14 @@ package mysqlite import ( "fmt" + "sync" "zombiezen.com/go/sqlite" ) // Db holds a connection to a SQLite database. type Db struct { - Db *sqlite.Conn + Db *sqlite.Conn + lock sync.Mutex } // OpenDb opens a new connection to a SQLite database. @@ -35,3 +37,11 @@ func (d *Db) MustClose() { panic(fmt.Sprintf("error closing db: %v", err)) } } + +func (d *Db) Lock() { + d.lock.Lock() +} + +func (d *Db) Unlock() { + d.lock.Unlock() +} diff --git a/query.go b/query.go index 5a3f95a..829b530 100644 --- a/query.go +++ b/query.go @@ -9,10 +9,20 @@ import ( type Query struct { stmt *sqlite.Stmt - err error + // Reference to the database. If set, it is assumed that a lock was taken + // by the query that should be freed by the query. + db *Db + err error } func (d *Db) Query(query string) *Query { + d.Lock() + q := d.query(query) + q.db = d + return q +} + +func (d *Db) query(query string) *Query { stmt, remaining, err := d.Db.PrepareTransient(query) if err != nil { return &Query{err: err} @@ -45,6 +55,8 @@ func (q *Query) Bind(args ...any) *Query { } func (q *Query) Exec() (rerr error) { + defer q.unlock() + if q.stmt != nil { defer func() { rerr = q.stmt.Finalize() }() } @@ -69,6 +81,7 @@ func (q *Query) MustExec() { } func (q *Query) ScanSingle(results ...any) (rerr error) { + defer q.unlock() // Scan rows if q.stmt != nil { defer func() { rerr = q.stmt.Finalize() }() @@ -114,6 +127,12 @@ func (q *Query) MustScanSingle(results ...any) { } } +func (q *Query) unlock() { + if q.db != nil { + q.db.Unlock() + } +} + type Rows struct { query *Query } @@ -125,6 +144,7 @@ func (q *Query) ScanMulti() (*Rows, error) { } func (r *Rows) Finish() error { + defer r.query.unlock() return r.query.stmt.Finalize() } diff --git a/transaction.go b/transaction.go index bafd180..7b409c0 100644 --- a/transaction.go +++ b/transaction.go @@ -7,7 +7,8 @@ type Tx struct { } func (d *Db) Begin() (*Tx, error) { - err := d.Query("BEGIN").Exec() + d.Lock() + err := d.query("BEGIN").Exec() if err != nil { return nil, err } @@ -15,10 +16,16 @@ func (d *Db) Begin() (*Tx, error) { } func (tx *Tx) Commit() error { + defer tx.unlock() return tx.Query("COMMIT").Exec() } func (tx *Tx) Rollback() error { + if tx.db == nil { + // The transaction was already commited + return nil + } + defer tx.unlock() return tx.Query("ROLLBACK").Exec() } @@ -29,6 +36,16 @@ func (tx *Tx) MustRollback() { } } -func (tx *Tx) Query(query string) *Query { - return tx.db.Query(query) +func (tx *Tx) unlock() { + if tx.db != nil { + tx.db.Unlock() + tx.db = nil + } +} + +func (tx *Tx) Query(query string) *Query { + if tx.db == nil { + panic("query was performed on a transaction after Commit or Rollback") + } + return tx.db.query(query) }