Compare commits

..

No commits in common. "master" and "v0.3.0" have entirely different histories.

11 changed files with 35 additions and 466 deletions

View File

@ -7,10 +7,5 @@ jobs:
- name: Checkout - name: Checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '>=1.24'
- name: Test - name: Test
run: go test . -v run: go test . -v

View File

@ -1,92 +0,0 @@
# MySQLite
A Go library that provides a convenient wrapper around SQLite with additional functionality for database management, migrations, and transactions.
## Features
- Simple and intuitive SQLite database connection management
- Thread-safe database operations with built-in locking mechanism
- Support for database migrations
- Transaction management
- Built on top of [zombiezen.com/go/sqlite](https://pkg.go.dev/zombiezen.com/go/sqlite)
## Installation
```bash
go get gitea.seeseepuff.be/seeseemelk/mysqlite
```
## Usage
### Opening a Database Connection
```go
import "gitea.seeseepuff.be/seeseemelk/mysqlite"
// Open an in-memory database
db, err := mysqlite.OpenDb(":memory:")
if err != nil {
// Handle error
}
defer db.Close()
// Open a file-based database
db, err := mysqlite.OpenDb("path/to/database.db")
if err != nil {
// Handle error
}
defer db.Close()
```
### Executing Queries
The library provides methods for executing SQL queries and managing transactions:
```go
// Execute a simple query
err := db.Query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)").Exec()
// Use transactions
tx, err := db.BeginTransaction()
if err != nil {
// Handle error
}
// Perform operations within transaction
// ...
// Commit or rollback
err = tx.Commit() // or tx.Rollback()
```
### Database Migrations
The library includes support for SQL-based migrations. Migrations are SQL files stored in a directory and are executed in order based on their filename prefix:
1. Create a directory for your migrations (e.g., `migrations/`)
2. Add numbered SQL migration files:
```
migrations/
├── 1_initial.sql
├── 2_add_users.sql
├── 3_add_posts.sql
```
3. Embed the migrations in your Go code:
```go
import "embed"
//go:embed migrations/*.sql
var migrations embed.FS
// Apply migrations
err := db.MigrateDb(migrations, "migrations")
if err != nil {
// Handle error
}
```
Each migration file should contain valid SQL statements. The migrations are executed in order and are tracked internally to ensure they only run once.
## Requirements
- Go 1.24 or higher

View File

@ -2,15 +2,12 @@ package mysqlite
import ( import (
"fmt" "fmt"
"sync"
"zombiezen.com/go/sqlite" "zombiezen.com/go/sqlite"
) )
// Db holds a connection to a SQLite database. // Db holds a connection to a SQLite database.
type Db struct { type Db struct {
Db *sqlite.Conn Db *sqlite.Conn
source string
lock sync.Mutex
} }
// OpenDb opens a new connection to a SQLite database. // OpenDb opens a new connection to a SQLite database.
@ -22,7 +19,7 @@ func OpenDb(databaseSource string) (*Db, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Db{Db: conn, source: databaseSource}, nil return &Db{Db: conn}, nil
} }
// Close closes the database. // Close closes the database.
@ -38,11 +35,3 @@ func (d *Db) MustClose() {
panic(fmt.Sprintf("error closing db: %v", err)) panic(fmt.Sprintf("error closing db: %v", err))
} }
} }
func (d *Db) Lock() {
d.lock.Lock()
}
func (d *Db) Unlock() {
d.lock.Unlock()
}

View File

@ -1,5 +0,0 @@
package mysqlite
import "errors"
var ErrNoRows = errors.New("mysqlite: no rows returned")

View File

@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"io/fs" "io/fs"
"log" "log"
"os"
"path" "path"
"strconv" "strconv"
"strings" "strings"
@ -41,20 +40,6 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
} }
log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
// Create a backup if we're not on the latest version
if currentVersion != 0 && currentVersion != latestVersion && d.source != ":memory:" {
target := d.source + ".backup." + strconv.Itoa(currentVersion)
log.Printf("Creating backup of database to %s", target)
data, err := d.Db.Serialize("main")
if err != nil {
return fmt.Errorf("error serializing database: %v", err)
}
err = os.WriteFile(target, data, 0644)
if err != nil {
return fmt.Errorf("error writing backup: %v", err)
}
}
// If we are no up-to-date, bring the db up-to-date // If we are no up-to-date, bring the db up-to-date
for currentVersion != latestVersion { for currentVersion != latestVersion {
targetVersion := currentVersion + 1 targetVersion := currentVersion + 1
@ -65,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
} }
err = performSingleMigration(d, migrationScript, targetVersion) err = performSingleMigration(err, d, migrationScript, targetVersion)
if err != nil { if err != nil {
return err return err
} }
@ -76,32 +61,21 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return nil return nil
} }
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error { func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
script := string(migrationScript)
// Split script based on semicolon
statements := strings.Split(script, ";")
tx, err := d.Begin() tx, err := d.Begin()
if err != nil { if err != nil {
return fmt.Errorf("error beginning transaction: %v", err) return fmt.Errorf("error beginning transaction: %v", err)
} }
defer tx.MustRollback() defer tx.MustRollback()
for _, statement := range statements { err = tx.Query(string(migrationScript)).Exec()
statement = strings.TrimSpace(statement) if err != nil {
if statement == "" { return fmt.Errorf("error performing migration: %v", err)
continue }
}
err = tx.Query(statement).Exec()
if err != nil {
return fmt.Errorf("error performing migration: %v", err)
}
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
if err != nil {
return fmt.Errorf("error updating version: %v", err)
}
err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
if err != nil {
return fmt.Errorf("error updating version: %v", err)
} }
err = tx.Commit() err = tx.Commit()

View File

@ -17,8 +17,4 @@ func TestDb_MigrateDb(t *testing.T) {
var count int var count int
db.Query("select count(*) from mydata").MustScanSingle(&count) db.Query("select count(*) from mydata").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database") require.Equal(t, 1, count, "incorrect number of rows in database")
count = 0
db.Query("select count(*) from multiTable").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database")
} }

121
query.go
View File

@ -9,20 +9,10 @@ import (
type Query struct { type Query struct {
stmt *sqlite.Stmt stmt *sqlite.Stmt
// Reference to the database. If set, it is assumed that a lock was taken err error
// by the query that should be freed by the query.
db *Db
err error
} }
func (d *Db) Query(query string) *Query { func (d *Db) Query(query string) *Query {
d.Lock()
q := d.query(query)
q.db = d
return q
}
func (d *Db) query(query string) *Query {
stmt, remaining, err := d.Db.PrepareTransient(query) stmt, remaining, err := d.Db.PrepareTransient(query)
if err != nil { if err != nil {
return &Query{err: err} return &Query{err: err}
@ -34,59 +24,29 @@ func (d *Db) query(query string) *Query {
} }
func (q *Query) Bind(args ...any) *Query { func (q *Query) Bind(args ...any) *Query {
into := 0
return q.bindInto(&into, args...)
}
func (q *Query) bindInto(into *int, args ...any) *Query {
if q.err != nil || q.stmt == nil { if q.err != nil || q.stmt == nil {
return q return q
} }
for i, arg := range args { for i, arg := range args {
*into++
if arg == nil {
q.stmt.BindNull(*into)
continue
}
v := reflect.ValueOf(arg)
if v.Kind() == reflect.Ptr {
if v.IsNil() {
q.stmt.BindNull(*into)
continue
}
arg = v.Elem().Interface()
}
if asString, ok := arg.(string); ok { if asString, ok := arg.(string); ok {
q.stmt.BindText(*into, asString) q.stmt.BindText(i+1, asString)
} else if asInt, ok := arg.(int); ok { } else if asInt, ok := arg.(int); ok {
q.stmt.BindInt64(*into, int64(asInt)) q.stmt.BindInt64(i+1, int64(asInt))
} else if asInt, ok := arg.(int64); ok { } else if asInt, ok := arg.(int64); ok {
q.stmt.BindInt64(*into, asInt) q.stmt.BindInt64(i+1, asInt)
} else if asBool, ok := arg.(bool); ok { } else if asBool, ok := arg.(bool); ok {
q.stmt.BindBool(*into, asBool) q.stmt.BindBool(i+1, asBool)
} else { } else {
// Check if the argument is a slice or array of any type q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i)
v = reflect.ValueOf(arg) return q
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
*into--
for i := 0; i < v.Len(); i++ {
q.bindInto(into, v.Index(i).Interface())
}
} else {
*into--
q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i)
return q
}
} }
} }
return q return q
} }
func (q *Query) Exec() (rerr error) { func (q *Query) Exec() (rerr error) {
defer q.unlock()
if q.stmt != nil { if q.stmt != nil {
defer func() { forwardError(q.stmt.Finalize(), &rerr) }() defer func() { rerr = q.stmt.Finalize() }()
} }
if q.err != nil { if q.err != nil {
return q.err return q.err
@ -109,10 +69,9 @@ func (q *Query) MustExec() {
} }
func (q *Query) ScanSingle(results ...any) (rerr error) { func (q *Query) ScanSingle(results ...any) (rerr error) {
defer q.unlock()
// Scan rows // Scan rows
if q.stmt != nil { if q.stmt != nil {
defer func() { forwardError(q.stmt.Finalize(), &rerr) }() defer func() { rerr = q.stmt.Finalize() }()
} }
if q.err != nil { if q.err != nil {
return q.err return q.err
@ -128,7 +87,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
return err return err
} }
if !hasResult { if !hasResult {
return ErrNoRows return fmt.Errorf("did not return any rows")
} }
// Scan its columns // Scan its columns
@ -155,27 +114,17 @@ func (q *Query) MustScanSingle(results ...any) {
} }
} }
func (q *Query) unlock() {
if q.db != nil {
q.db.Unlock()
}
}
type Rows struct { type Rows struct {
query *Query query *Query
} }
func (q *Query) ScanMulti() (*Rows, error) { func (q *Query) ScanMulti() (*Rows, error) {
if q.err != nil {
return nil, q.err
}
return &Rows{ return &Rows{
query: q, query: q,
}, nil }, nil
} }
func (r *Rows) Finish() error { func (r *Rows) Finish() error {
defer r.query.unlock()
return r.query.stmt.Finalize() return r.query.stmt.Finalize()
} }
@ -204,49 +153,23 @@ func (r *Rows) MustNext() bool {
func (r *Rows) Scan(results ...any) error { func (r *Rows) Scan(results ...any) error {
for i, arg := range results { for i, arg := range results {
err := r.scanArgument(i, arg) if asString, ok := arg.(*string); ok {
if err != nil { *asString = r.query.stmt.ColumnText(i)
return err } else if asInt, ok := arg.(*int); ok {
*asInt = r.query.stmt.ColumnInt(i)
} else if asBool, ok := arg.(*bool); ok {
*asBool = r.query.stmt.ColumnBool(i)
} else {
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
}
name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
return fmt.Errorf("unsupported column type *%s at index %d", name, i)
} }
} }
return nil return nil
} }
func (r *Rows) scanArgument(i int, arg any) error {
if asString, ok := arg.(*string); ok {
*asString = r.query.stmt.ColumnText(i)
} else if asInt, ok := arg.(*int); ok {
*asInt = r.query.stmt.ColumnInt(i)
} else if asInt, ok := arg.(*int64); ok {
*asInt = r.query.stmt.ColumnInt64(i)
} else if asBool, ok := arg.(*bool); ok {
*asBool = r.query.stmt.ColumnBool(i)
} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
return r.handleNullableType(i, arg)
} else {
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
}
name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
return fmt.Errorf("unsupported column type *%s at index %d", name, i)
}
return nil
}
func (r *Rows) handleNullableType(i int, asPtr any) error {
if r.query.stmt.ColumnIsNull(i) {
reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem()))
} else {
value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface()
err := r.scanArgument(i, value)
if err != nil {
return err
}
reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value))
}
return nil
}
func (r *Rows) MustScan(results ...any) { func (r *Rows) MustScan(results ...any) {
err := r.Scan(results...) err := r.Scan(results...)
if err != nil { if err != nil {

View File

@ -1,7 +1,6 @@
package mysqlite package mysqlite
import ( import (
"errors"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"testing" "testing"
) )
@ -20,14 +19,6 @@ func TestSimpleQuery(t *testing.T) {
require.Equal(t, 1, count, "expected empty count") require.Equal(t, 1, count, "expected empty count")
} }
func TestSimpleQueryWithNoResults(t *testing.T) {
db := openTestDb(t)
var count int
err := db.Query("select 1 from mytable where key=999").ScanSingle(&count)
require.Equal(t, ErrNoRows, err)
require.True(t, errors.Is(err, ErrNoRows))
}
func TestSimpleQueryWithArgs(t *testing.T) { func TestSimpleQueryWithArgs(t *testing.T) {
db := openTestDb(t) db := openTestDb(t)
var value string var value string
@ -75,163 +66,3 @@ func TestQueryWithRange(t *testing.T) {
} }
require.NoError(t, err) require.NoError(t, err)
} }
func TestUpdateQuery(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
value := "ipsum"
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(value, key).MustExec()
tx.MustCommit()
}()
var value string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, "ipsum", value)
}
func TestUpdateQueryWithWrongArguments(t *testing.T) {
type S struct {
Field string
}
db := openTestDb(t)
abc := S{
Field: "ipsum",
}
err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec()
require.Error(t, err)
}
func TestUpdateQueryWithPointerValue(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
value := "ipsum"
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec()
tx.MustCommit()
}()
var value string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, "ipsum", value)
}
func TestUpdateQueryWithSetPointerValue(t *testing.T) {
type S struct {
value *string
}
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
s := S{nil}
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(s.value, key).MustExec()
tx.MustCommit()
}()
var value *string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, (*string)(nil), value)
}
func TestUpdateQueryWithNullValue(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec()
tx.MustCommit()
}()
var value *string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Nil(t, value)
}
func TestQueryWithPointerStringArguments(t *testing.T) {
db := openTestDb(t)
var result *string
err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "bar", *result)
}
func TestQueryWithInt64Scan(t *testing.T) {
db := openTestDb(t)
var result int64
err := db.Query("select 2").ScanSingle(&result)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int64(2), result)
}
func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
db := openTestDb(t)
db.Query("update mytable set value=null where key = 'foo'").MustExec()
myString := "some string"
var result *string
result = &myString
err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
require.NoError(t, err)
require.Nil(t, result)
}
func TestDeleteQuery(t *testing.T) {
db := openTestDb(t)
db.Query("delete from mytable where key = 'foo'").MustExec()
var count int
db.Query("select count(*) from mytable where key = 'foo'").MustScanSingle(&count)
require.Equal(t, 0, count, "expected row to be deleted")
}
func TestTransactionRollback(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("update mytable set value = 'ipsum' where key = 'foo'").MustExec()
// Intentionally not committing the transaction
}()
var value string
db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value)
require.Equal(t, "bar", value, "expected original value after rollback")
}
func TestQueryWithInClause(t *testing.T) {
db := openTestDb(t)
// Insert additional test rows
db.Query("insert into mytable(key, value) values ('key1', 'value1')").MustExec()
db.Query("insert into mytable(key, value) values ('key2', 'value2')").MustExec()
// Execute query with IN clause
args := []string{"foo", "key2"}
rows, err := db.Query("select key, value from mytable where key in (?, ?)").Bind(args).ScanMulti()
require.NoError(t, err)
defer rows.MustFinish()
// Check results
results := make(map[string]string)
for rows.MustNext() {
var key, value string
rows.MustScan(&key, &value)
results[key] = value
}
// Verify we got exactly the expected results
require.Equal(t, 2, len(results), "expected 2 matching rows")
require.Equal(t, "bar", results["foo"], "unexpected value for key 'foo'")
require.Equal(t, "value2", results["key2"], "unexpected value for key 'key2'")
}

View File

@ -1,3 +0,0 @@
create table multiTable(value text);
insert into multiTable(value) values ('testValue');

View File

@ -7,40 +7,18 @@ type Tx struct {
} }
func (d *Db) Begin() (*Tx, error) { func (d *Db) Begin() (*Tx, error) {
d.Lock() err := d.Query("BEGIN").Exec()
err := d.query("BEGIN").Exec()
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &Tx{db: d}, nil return &Tx{db: d}, nil
} }
func (d *Db) MustBegin() *Tx {
tx, err := d.Begin()
if err != nil {
panic(err)
}
return tx
}
func (tx *Tx) Commit() error { func (tx *Tx) Commit() error {
defer tx.unlock()
return tx.Query("COMMIT").Exec() return tx.Query("COMMIT").Exec()
} }
func (tx *Tx) MustCommit() {
err := tx.Commit()
if err != nil {
panic(err)
}
}
func (tx *Tx) Rollback() error { func (tx *Tx) Rollback() error {
if tx.db == nil {
// The transaction was already commited
return nil
}
defer tx.unlock()
return tx.Query("ROLLBACK").Exec() return tx.Query("ROLLBACK").Exec()
} }
@ -51,16 +29,6 @@ func (tx *Tx) MustRollback() {
} }
} }
func (tx *Tx) unlock() {
if tx.db != nil {
tx.db.Unlock()
tx.db = nil
}
}
func (tx *Tx) Query(query string) *Query { func (tx *Tx) Query(query string) *Query {
if tx.db == nil { return tx.db.Query(query)
panic("query was performed on a transaction after Commit or Rollback")
}
return tx.db.query(query)
} }

View File

@ -1,7 +0,0 @@
package mysqlite
func forwardError(from error, to *error) {
if from != nil {
*to = from
}
}