Compare commits

..

No commits in common. "master" and "v0.3.0" have entirely different histories.

11 changed files with 35 additions and 466 deletions

View File

@ -7,10 +7,5 @@ jobs:
- name: Checkout
uses: actions/checkout@v4
- name: Setup Go
uses: actions/setup-go@v5
with:
go-version: '>=1.24'
- name: Test
run: go test . -v

View File

@ -1,92 +0,0 @@
# MySQLite
A Go library that provides a convenient wrapper around SQLite with additional functionality for database management, migrations, and transactions.
## Features
- Simple and intuitive SQLite database connection management
- Thread-safe database operations with built-in locking mechanism
- Support for database migrations
- Transaction management
- Built on top of [zombiezen.com/go/sqlite](https://pkg.go.dev/zombiezen.com/go/sqlite)
## Installation
```bash
go get gitea.seeseepuff.be/seeseemelk/mysqlite
```
## Usage
### Opening a Database Connection
```go
import "gitea.seeseepuff.be/seeseemelk/mysqlite"
// Open an in-memory database
db, err := mysqlite.OpenDb(":memory:")
if err != nil {
// Handle error
}
defer db.Close()
// Open a file-based database
db, err := mysqlite.OpenDb("path/to/database.db")
if err != nil {
// Handle error
}
defer db.Close()
```
### Executing Queries
The library provides methods for executing SQL queries and managing transactions:
```go
// Execute a simple query
err := db.Query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)").Exec()
// Use transactions
tx, err := db.BeginTransaction()
if err != nil {
// Handle error
}
// Perform operations within transaction
// ...
// Commit or rollback
err = tx.Commit() // or tx.Rollback()
```
### Database Migrations
The library includes support for SQL-based migrations. Migrations are SQL files stored in a directory and are executed in order based on their filename prefix:
1. Create a directory for your migrations (e.g., `migrations/`)
2. Add numbered SQL migration files:
```
migrations/
├── 1_initial.sql
├── 2_add_users.sql
├── 3_add_posts.sql
```
3. Embed the migrations in your Go code:
```go
import "embed"
//go:embed migrations/*.sql
var migrations embed.FS
// Apply migrations
err := db.MigrateDb(migrations, "migrations")
if err != nil {
// Handle error
}
```
Each migration file should contain valid SQL statements. The migrations are executed in order and are tracked internally to ensure they only run once.
## Requirements
- Go 1.24 or higher

View File

@ -2,15 +2,12 @@ package mysqlite
import (
"fmt"
"sync"
"zombiezen.com/go/sqlite"
)
// Db holds a connection to a SQLite database.
type Db struct {
Db *sqlite.Conn
source string
lock sync.Mutex
}
// OpenDb opens a new connection to a SQLite database.
@ -22,7 +19,7 @@ func OpenDb(databaseSource string) (*Db, error) {
if err != nil {
return nil, err
}
return &Db{Db: conn, source: databaseSource}, nil
return &Db{Db: conn}, nil
}
// Close closes the database.
@ -38,11 +35,3 @@ func (d *Db) MustClose() {
panic(fmt.Sprintf("error closing db: %v", err))
}
}
func (d *Db) Lock() {
d.lock.Lock()
}
func (d *Db) Unlock() {
d.lock.Unlock()
}

View File

@ -1,5 +0,0 @@
package mysqlite
import "errors"
var ErrNoRows = errors.New("mysqlite: no rows returned")

View File

@ -4,7 +4,6 @@ import (
"fmt"
"io/fs"
"log"
"os"
"path"
"strconv"
"strings"
@ -41,20 +40,6 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
}
log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)
// Create a backup if we're not on the latest version
if currentVersion != 0 && currentVersion != latestVersion && d.source != ":memory:" {
target := d.source + ".backup." + strconv.Itoa(currentVersion)
log.Printf("Creating backup of database to %s", target)
data, err := d.Db.Serialize("main")
if err != nil {
return fmt.Errorf("error serializing database: %v", err)
}
err = os.WriteFile(target, data, 0644)
if err != nil {
return fmt.Errorf("error writing backup: %v", err)
}
}
// If we are no up-to-date, bring the db up-to-date
for currentVersion != latestVersion {
targetVersion := currentVersion + 1
@ -65,7 +50,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
}
err = performSingleMigration(d, migrationScript, targetVersion)
err = performSingleMigration(err, d, migrationScript, targetVersion)
if err != nil {
return err
}
@ -76,23 +61,14 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
return nil
}
func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
script := string(migrationScript)
// Split script based on semicolon
statements := strings.Split(script, ";")
func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error {
tx, err := d.Begin()
if err != nil {
return fmt.Errorf("error beginning transaction: %v", err)
}
defer tx.MustRollback()
for _, statement := range statements {
statement = strings.TrimSpace(statement)
if statement == "" {
continue
}
err = tx.Query(statement).Exec()
err = tx.Query(string(migrationScript)).Exec()
if err != nil {
return fmt.Errorf("error performing migration: %v", err)
}
@ -102,8 +78,6 @@ func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) er
return fmt.Errorf("error updating version: %v", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("error commiting transaction: %v", err)

View File

@ -17,8 +17,4 @@ func TestDb_MigrateDb(t *testing.T) {
var count int
db.Query("select count(*) from mydata").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database")
count = 0
db.Query("select count(*) from multiTable").MustScanSingle(&count)
require.Equal(t, 1, count, "incorrect number of rows in database")
}

View File

@ -9,20 +9,10 @@ import (
type Query struct {
stmt *sqlite.Stmt
// Reference to the database. If set, it is assumed that a lock was taken
// by the query that should be freed by the query.
db *Db
err error
}
func (d *Db) Query(query string) *Query {
d.Lock()
q := d.query(query)
q.db = d
return q
}
func (d *Db) query(query string) *Query {
stmt, remaining, err := d.Db.PrepareTransient(query)
if err != nil {
return &Query{err: err}
@ -34,59 +24,29 @@ func (d *Db) query(query string) *Query {
}
func (q *Query) Bind(args ...any) *Query {
into := 0
return q.bindInto(&into, args...)
}
func (q *Query) bindInto(into *int, args ...any) *Query {
if q.err != nil || q.stmt == nil {
return q
}
for i, arg := range args {
*into++
if arg == nil {
q.stmt.BindNull(*into)
continue
}
v := reflect.ValueOf(arg)
if v.Kind() == reflect.Ptr {
if v.IsNil() {
q.stmt.BindNull(*into)
continue
}
arg = v.Elem().Interface()
}
if asString, ok := arg.(string); ok {
q.stmt.BindText(*into, asString)
q.stmt.BindText(i+1, asString)
} else if asInt, ok := arg.(int); ok {
q.stmt.BindInt64(*into, int64(asInt))
q.stmt.BindInt64(i+1, int64(asInt))
} else if asInt, ok := arg.(int64); ok {
q.stmt.BindInt64(*into, asInt)
q.stmt.BindInt64(i+1, asInt)
} else if asBool, ok := arg.(bool); ok {
q.stmt.BindBool(*into, asBool)
q.stmt.BindBool(i+1, asBool)
} else {
// Check if the argument is a slice or array of any type
v = reflect.ValueOf(arg)
if v.Kind() == reflect.Slice || v.Kind() == reflect.Array {
*into--
for i := 0; i < v.Len(); i++ {
q.bindInto(into, v.Index(i).Interface())
}
} else {
*into--
q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i)
return q
}
}
}
return q
}
func (q *Query) Exec() (rerr error) {
defer q.unlock()
if q.stmt != nil {
defer func() { forwardError(q.stmt.Finalize(), &rerr) }()
defer func() { rerr = q.stmt.Finalize() }()
}
if q.err != nil {
return q.err
@ -109,10 +69,9 @@ func (q *Query) MustExec() {
}
func (q *Query) ScanSingle(results ...any) (rerr error) {
defer q.unlock()
// Scan rows
if q.stmt != nil {
defer func() { forwardError(q.stmt.Finalize(), &rerr) }()
defer func() { rerr = q.stmt.Finalize() }()
}
if q.err != nil {
return q.err
@ -128,7 +87,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
return err
}
if !hasResult {
return ErrNoRows
return fmt.Errorf("did not return any rows")
}
// Scan its columns
@ -155,27 +114,17 @@ func (q *Query) MustScanSingle(results ...any) {
}
}
func (q *Query) unlock() {
if q.db != nil {
q.db.Unlock()
}
}
type Rows struct {
query *Query
}
func (q *Query) ScanMulti() (*Rows, error) {
if q.err != nil {
return nil, q.err
}
return &Rows{
query: q,
}, nil
}
func (r *Rows) Finish() error {
defer r.query.unlock()
return r.query.stmt.Finalize()
}
@ -204,25 +153,12 @@ func (r *Rows) MustNext() bool {
func (r *Rows) Scan(results ...any) error {
for i, arg := range results {
err := r.scanArgument(i, arg)
if err != nil {
return err
}
}
return nil
}
func (r *Rows) scanArgument(i int, arg any) error {
if asString, ok := arg.(*string); ok {
*asString = r.query.stmt.ColumnText(i)
} else if asInt, ok := arg.(*int); ok {
*asInt = r.query.stmt.ColumnInt(i)
} else if asInt, ok := arg.(*int64); ok {
*asInt = r.query.stmt.ColumnInt64(i)
} else if asBool, ok := arg.(*bool); ok {
*asBool = r.query.stmt.ColumnBool(i)
} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr {
return r.handleNullableType(i, arg)
} else {
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
@ -230,19 +166,6 @@ func (r *Rows) scanArgument(i int, arg any) error {
name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
return fmt.Errorf("unsupported column type *%s at index %d", name, i)
}
return nil
}
func (r *Rows) handleNullableType(i int, asPtr any) error {
if r.query.stmt.ColumnIsNull(i) {
reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem()))
} else {
value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface()
err := r.scanArgument(i, value)
if err != nil {
return err
}
reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value))
}
return nil
}

View File

@ -1,7 +1,6 @@
package mysqlite
import (
"errors"
"github.com/stretchr/testify/require"
"testing"
)
@ -20,14 +19,6 @@ func TestSimpleQuery(t *testing.T) {
require.Equal(t, 1, count, "expected empty count")
}
func TestSimpleQueryWithNoResults(t *testing.T) {
db := openTestDb(t)
var count int
err := db.Query("select 1 from mytable where key=999").ScanSingle(&count)
require.Equal(t, ErrNoRows, err)
require.True(t, errors.Is(err, ErrNoRows))
}
func TestSimpleQueryWithArgs(t *testing.T) {
db := openTestDb(t)
var value string
@ -75,163 +66,3 @@ func TestQueryWithRange(t *testing.T) {
}
require.NoError(t, err)
}
func TestUpdateQuery(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
value := "ipsum"
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(value, key).MustExec()
tx.MustCommit()
}()
var value string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, "ipsum", value)
}
func TestUpdateQueryWithWrongArguments(t *testing.T) {
type S struct {
Field string
}
db := openTestDb(t)
abc := S{
Field: "ipsum",
}
err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(abc).Exec()
require.Error(t, err)
}
func TestUpdateQueryWithPointerValue(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
value := "ipsum"
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(&value, key).MustExec()
tx.MustCommit()
}()
var value string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, "ipsum", value)
}
func TestUpdateQueryWithSetPointerValue(t *testing.T) {
type S struct {
value *string
}
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
s := S{nil}
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(s.value, key).MustExec()
tx.MustCommit()
}()
var value *string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Equal(t, (*string)(nil), value)
}
func TestUpdateQueryWithNullValue(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec()
key := "lorem"
tx.Query("update mytable set value = ? where key = ?").Bind(nil, key).MustExec()
tx.MustCommit()
}()
var value *string
db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value)
require.Nil(t, value)
}
func TestQueryWithPointerStringArguments(t *testing.T) {
db := openTestDb(t)
var result *string
err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, "bar", *result)
}
func TestQueryWithInt64Scan(t *testing.T) {
db := openTestDb(t)
var result int64
err := db.Query("select 2").ScanSingle(&result)
require.NoError(t, err)
require.NotNil(t, result)
require.Equal(t, int64(2), result)
}
func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) {
db := openTestDb(t)
db.Query("update mytable set value=null where key = 'foo'").MustExec()
myString := "some string"
var result *string
result = &myString
err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result)
require.NoError(t, err)
require.Nil(t, result)
}
func TestDeleteQuery(t *testing.T) {
db := openTestDb(t)
db.Query("delete from mytable where key = 'foo'").MustExec()
var count int
db.Query("select count(*) from mytable where key = 'foo'").MustScanSingle(&count)
require.Equal(t, 0, count, "expected row to be deleted")
}
func TestTransactionRollback(t *testing.T) {
db := openTestDb(t)
func() {
tx := db.MustBegin()
defer tx.MustRollback()
tx.Query("update mytable set value = 'ipsum' where key = 'foo'").MustExec()
// Intentionally not committing the transaction
}()
var value string
db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value)
require.Equal(t, "bar", value, "expected original value after rollback")
}
func TestQueryWithInClause(t *testing.T) {
db := openTestDb(t)
// Insert additional test rows
db.Query("insert into mytable(key, value) values ('key1', 'value1')").MustExec()
db.Query("insert into mytable(key, value) values ('key2', 'value2')").MustExec()
// Execute query with IN clause
args := []string{"foo", "key2"}
rows, err := db.Query("select key, value from mytable where key in (?, ?)").Bind(args).ScanMulti()
require.NoError(t, err)
defer rows.MustFinish()
// Check results
results := make(map[string]string)
for rows.MustNext() {
var key, value string
rows.MustScan(&key, &value)
results[key] = value
}
// Verify we got exactly the expected results
require.Equal(t, 2, len(results), "expected 2 matching rows")
require.Equal(t, "bar", results["foo"], "unexpected value for key 'foo'")
require.Equal(t, "value2", results["key2"], "unexpected value for key 'key2'")
}

View File

@ -1,3 +0,0 @@
create table multiTable(value text);
insert into multiTable(value) values ('testValue');

View File

@ -7,40 +7,18 @@ type Tx struct {
}
func (d *Db) Begin() (*Tx, error) {
d.Lock()
err := d.query("BEGIN").Exec()
err := d.Query("BEGIN").Exec()
if err != nil {
return nil, err
}
return &Tx{db: d}, nil
}
func (d *Db) MustBegin() *Tx {
tx, err := d.Begin()
if err != nil {
panic(err)
}
return tx
}
func (tx *Tx) Commit() error {
defer tx.unlock()
return tx.Query("COMMIT").Exec()
}
func (tx *Tx) MustCommit() {
err := tx.Commit()
if err != nil {
panic(err)
}
}
func (tx *Tx) Rollback() error {
if tx.db == nil {
// The transaction was already commited
return nil
}
defer tx.unlock()
return tx.Query("ROLLBACK").Exec()
}
@ -51,16 +29,6 @@ func (tx *Tx) MustRollback() {
}
}
func (tx *Tx) unlock() {
if tx.db != nil {
tx.db.Unlock()
tx.db = nil
}
}
func (tx *Tx) Query(query string) *Query {
if tx.db == nil {
panic("query was performed on a transaction after Commit or Rollback")
}
return tx.db.query(query)
return tx.db.Query(query)
}

View File

@ -1,7 +0,0 @@
package mysqlite
func forwardError(from error, to *error) {
if from != nil {
*to = from
}
}