Compare commits
3 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 0a177e0b46 | |||
| 82c7f57078 | |||
| 9d5c0bcbb1 |
12
database.go
12
database.go
@@ -2,12 +2,14 @@ package mysqlite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"zombiezen.com/go/sqlite"
|
"zombiezen.com/go/sqlite"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Db holds a connection to a SQLite database.
|
// Db holds a connection to a SQLite database.
|
||||||
type Db struct {
|
type Db struct {
|
||||||
Db *sqlite.Conn
|
Db *sqlite.Conn
|
||||||
|
lock sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenDb opens a new connection to a SQLite database.
|
// OpenDb opens a new connection to a SQLite database.
|
||||||
@@ -35,3 +37,11 @@ func (d *Db) MustClose() {
|
|||||||
panic(fmt.Sprintf("error closing db: %v", err))
|
panic(fmt.Sprintf("error closing db: %v", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Db) Lock() {
|
||||||
|
d.lock.Lock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) Unlock() {
|
||||||
|
d.lock.Unlock()
|
||||||
|
}
|
||||||
|
|||||||
26
migrator.go
26
migrator.go
@@ -50,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
|||||||
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
|
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = performSingleMigration(err, d, migrationScript, targetVersion)
|
err = performSingleMigration(d, migrationScript, targetVersion)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -61,21 +61,29 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
|
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
|
||||||
|
script := string(migrationScript)
|
||||||
|
// Split script based on semicolon
|
||||||
|
statements := strings.Split(script, ";")
|
||||||
|
|
||||||
tx, err := d.Begin()
|
tx, err := d.Begin()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("error beginning transaction: %v", err)
|
return fmt.Errorf("error beginning transaction: %v", err)
|
||||||
}
|
}
|
||||||
defer tx.MustRollback()
|
defer tx.MustRollback()
|
||||||
|
|
||||||
err = tx.Query(string(migrationScript)).Exec()
|
for _, statement := range statements {
|
||||||
if err != nil {
|
statement = strings.TrimSpace(statement)
|
||||||
return fmt.Errorf("error performing migration: %v", err)
|
err = tx.Query(statement).Exec()
|
||||||
}
|
if err != nil {
|
||||||
|
return fmt.Errorf("error performing migration: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("error updating version: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("error updating version: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = tx.Commit()
|
err = tx.Commit()
|
||||||
|
|||||||
@@ -17,4 +17,8 @@ func TestDb_MigrateDb(t *testing.T) {
|
|||||||
var count int
|
var count int
|
||||||
db.Query("select count(*) from mydata").MustScanSingle(&count)
|
db.Query("select count(*) from mydata").MustScanSingle(&count)
|
||||||
require.Equal(t, 1, count, "incorrect number of rows in database")
|
require.Equal(t, 1, count, "incorrect number of rows in database")
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
db.Query("select count(*) from multiTable").MustScanSingle(&count)
|
||||||
|
require.Equal(t, 1, count, "incorrect number of rows in database")
|
||||||
}
|
}
|
||||||
|
|||||||
22
query.go
22
query.go
@@ -9,10 +9,20 @@ import (
|
|||||||
|
|
||||||
type Query struct {
|
type Query struct {
|
||||||
stmt *sqlite.Stmt
|
stmt *sqlite.Stmt
|
||||||
err error
|
// Reference to the database. If set, it is assumed that a lock was taken
|
||||||
|
// by the query that should be freed by the query.
|
||||||
|
db *Db
|
||||||
|
err error
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Db) Query(query string) *Query {
|
func (d *Db) Query(query string) *Query {
|
||||||
|
d.Lock()
|
||||||
|
q := d.query(query)
|
||||||
|
q.db = d
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Db) query(query string) *Query {
|
||||||
stmt, remaining, err := d.Db.PrepareTransient(query)
|
stmt, remaining, err := d.Db.PrepareTransient(query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return &Query{err: err}
|
return &Query{err: err}
|
||||||
@@ -45,6 +55,8 @@ func (q *Query) Bind(args ...any) *Query {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) Exec() (rerr error) {
|
func (q *Query) Exec() (rerr error) {
|
||||||
|
defer q.unlock()
|
||||||
|
|
||||||
if q.stmt != nil {
|
if q.stmt != nil {
|
||||||
defer func() { rerr = q.stmt.Finalize() }()
|
defer func() { rerr = q.stmt.Finalize() }()
|
||||||
}
|
}
|
||||||
@@ -69,6 +81,7 @@ func (q *Query) MustExec() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) ScanSingle(results ...any) (rerr error) {
|
func (q *Query) ScanSingle(results ...any) (rerr error) {
|
||||||
|
defer q.unlock()
|
||||||
// Scan rows
|
// Scan rows
|
||||||
if q.stmt != nil {
|
if q.stmt != nil {
|
||||||
defer func() { rerr = q.stmt.Finalize() }()
|
defer func() { rerr = q.stmt.Finalize() }()
|
||||||
@@ -114,6 +127,12 @@ func (q *Query) MustScanSingle(results ...any) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (q *Query) unlock() {
|
||||||
|
if q.db != nil {
|
||||||
|
q.db.Unlock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
type Rows struct {
|
type Rows struct {
|
||||||
query *Query
|
query *Query
|
||||||
}
|
}
|
||||||
@@ -125,6 +144,7 @@ func (q *Query) ScanMulti() (*Rows, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Rows) Finish() error {
|
func (r *Rows) Finish() error {
|
||||||
|
defer r.query.unlock()
|
||||||
return r.query.stmt.Finalize()
|
return r.query.stmt.Finalize()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
3
testMigrations/3_multicomment.sql
Normal file
3
testMigrations/3_multicomment.sql
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
create table multiTable(value text);
|
||||||
|
|
||||||
|
insert into multiTable(value) values ('testValue');
|
||||||
@@ -7,7 +7,8 @@ type Tx struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *Db) Begin() (*Tx, error) {
|
func (d *Db) Begin() (*Tx, error) {
|
||||||
err := d.Query("BEGIN").Exec()
|
d.Lock()
|
||||||
|
err := d.query("BEGIN").Exec()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -15,10 +16,16 @@ func (d *Db) Begin() (*Tx, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Commit() error {
|
func (tx *Tx) Commit() error {
|
||||||
|
defer tx.unlock()
|
||||||
return tx.Query("COMMIT").Exec()
|
return tx.Query("COMMIT").Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Rollback() error {
|
func (tx *Tx) Rollback() error {
|
||||||
|
if tx.db == nil {
|
||||||
|
// The transaction was already commited
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
defer tx.unlock()
|
||||||
return tx.Query("ROLLBACK").Exec()
|
return tx.Query("ROLLBACK").Exec()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -29,6 +36,16 @@ func (tx *Tx) MustRollback() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (tx *Tx) Query(query string) *Query {
|
func (tx *Tx) unlock() {
|
||||||
return tx.db.Query(query)
|
if tx.db != nil {
|
||||||
|
tx.db.Unlock()
|
||||||
|
tx.db = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Tx) Query(query string) *Query {
|
||||||
|
if tx.db == nil {
|
||||||
|
panic("query was performed on a transaction after Commit or Rollback")
|
||||||
|
}
|
||||||
|
return tx.db.query(query)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user