From d48a828d2a44dfb03abeb8b33294feb0eccd4639 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Tue, 18 Feb 2025 12:53:51 +0100 Subject: [PATCH] Possible first version --- .gitea/workflows/build.yml | 11 +++ database.go | 33 +------ database_test.go | 42 +++++++++ go.mod | 9 +- go.sum | 34 ++++++++ migrator.go | 170 ++++++++++++++++++------------------- query.go | 46 ++++++++-- query_test.go | 27 ++++++ transaction.go | 4 +- 9 files changed, 253 insertions(+), 123 deletions(-) create mode 100644 .gitea/workflows/build.yml create mode 100644 database_test.go create mode 100644 query_test.go diff --git a/.gitea/workflows/build.yml b/.gitea/workflows/build.yml new file mode 100644 index 0000000..d370aa1 --- /dev/null +++ b/.gitea/workflows/build.yml @@ -0,0 +1,11 @@ +name: Build +on: [push] +jobs: + build: + runs-on: standard-latest + steps: + - name: Checkout + uses: actions/checkout@v4 + + - name: Test + run: go test . -v diff --git a/database.go b/database.go index 2fdf969..5629287 100644 --- a/database.go +++ b/database.go @@ -2,7 +2,6 @@ package mysqlite import ( "fmt" - "reflect" "zombiezen.com/go/sqlite" ) @@ -22,35 +21,9 @@ func (d *Db) Close() error { return d.Db.Close() } -func (d *Db) QuerySingle(query string, args ...any) error { - stmt, remaining, err := d.Db.PrepareTransient(query) +func (d *Db) MustClose() { + err := d.Close() if err != nil { - return err + panic(fmt.Sprintf("error closing db: %v", err)) } - defer stmt.Finalize() - if remaining != 0 { - return fmt.Errorf("remaining bytes: %s", remaining) - } - rowReturned, err := stmt.Step() - if err != nil { - return err - } - if !rowReturned { - return fmt.Errorf("did not return any rows") - } - if stmt.ColumnCount() != 1 { - return fmt.Errorf("query returned %d rows while only one was expected", stmt.ColumnCount()) - } - for i, arg := range args { - if asString, ok := arg.(*string); ok { - *asString = stmt.ColumnText(i) - } else if asInt, ok := arg.(*int); ok { - *asInt = stmt.ColumnInt(i) - } else if asBool, ok := arg.(*bool); ok { - *asBool = stmt.ColumnBool(i) - } else { - return fmt.Errorf("unsupported column type at index %d", i) - } - } - return nil } diff --git a/database_test.go b/database_test.go new file mode 100644 index 0000000..27088a2 --- /dev/null +++ b/database_test.go @@ -0,0 +1,42 @@ +package mysqlite + +import ( + "fmt" + "github.com/stretchr/testify/require" + "testing" +) + +func openEmptyTestDb(t *testing.T) *Db { + db, err := OpenDb(":memory:") + require.NoError(t, err, "error opening db") + require.NotNil(t, db, "db was nil") + return db +} + +func TestDb_Begin(t *testing.T) { + db := openEmptyTestDb(t) + err := db.Close() + require.NoError(t, err, "error closing db") +} + +func TestDb_GetUserVersion(t *testing.T) { + db := openEmptyTestDb(t) + defer db.MustClose() + userVersion := -1 + err := db.Query("pragma user_version").ScanSingle(&userVersion) + require.NoError(t, err) + require.Equal(t, 0, userVersion) +} + +func TestDb_SetUserVersion(t *testing.T) { + db := openEmptyTestDb(t) + defer db.MustClose() + + err := db.Query(fmt.Sprintf("pragma user_version = %d", 123)).Exec() + require.NoError(t, err) + + userVersion := -1 + err = db.Query("pragma user_version").ScanSingle(&userVersion) + require.NoError(t, err) + require.Equal(t, 123, userVersion) +} diff --git a/go.mod b/go.mod index 65b8d4d..4069648 100644 --- a/go.mod +++ b/go.mod @@ -3,15 +3,22 @@ module mysqlite go 1.24 require ( + github.com/stretchr/testify v1.10.0 + zombiezen.com/go/sqlite v1.4.0 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dustin/go-humanize v1.0.1 // indirect github.com/google/uuid v1.6.0 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/ncruces/go-strftime v0.1.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect golang.org/x/sys v0.22.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect modernc.org/libc v1.55.3 // indirect modernc.org/mathutil v1.6.0 // indirect modernc.org/memory v1.8.0 // indirect modernc.org/sqlite v1.33.1 // indirect - zombiezen.com/go/sqlite v1.4.0 // indirect ) diff --git a/go.sum b/go.sum index 11433c6..1e382d6 100644 --- a/go.sum +++ b/go.sum @@ -1,23 +1,57 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dustin/go-humanize v1.0.1 h1:GzkhY7T5VNhEkwH0PVJgjz+fX1rhBrR7pRT3mDkpeCY= github.com/dustin/go-humanize v1.0.1/go.mod h1:Mu1zIs6XwVuF/gI1OepvI0qD18qycQx+mFykh5fBlto= +github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= +github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/mattn/go-isatty v0.0.20 h1:xfD0iDuEKnDkl03q4limB+vH+GxLEtL/jb4xVJSWWEY= github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/ncruces/go-strftime v0.1.9 h1:bY0MQC28UADQmHmaF5dgpLmImcShSi2kHU9XLdhx/f4= github.com/ncruces/go-strftime v0.1.9/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec h1:W09IVJc94icq4NjY3clb7Lk8O1qJ8BdBEF8z0ibU0rE= github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= +github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= +github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= +golang.org/x/mod v0.16.0 h1:QX4fJ0Rr5cPQCF7O9lh9Se4pmwfwskqZfq5moyldzic= +golang.org/x/mod v0.16.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.22.0 h1:RI27ohtqKCnwULzJLqkv897zojh5/DwS/ENaMzUOaWI= golang.org/x/sys v0.22.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= +golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= +golang.org/x/tools v0.19.0 h1:tfGCXNR1OsFG+sVdLAitlpjAvD/I6dHDKnYrpEZUHkw= +golang.org/x/tools v0.19.0/go.mod h1:qoJWxmGSIBmAeriMx19ogtrEPrGtDbPK634QFIcLAhc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +modernc.org/cc/v4 v4.21.4 h1:3Be/Rdo1fpr8GrQ7IVw9OHtplU4gWbb+wNgeoBMmGLQ= +modernc.org/cc/v4 v4.21.4/go.mod h1:HM7VJTZbUCR3rV8EYBi9wxnJ0ZBRiGE5OeGXNA0IsLQ= +modernc.org/ccgo/v4 v4.19.2 h1:lwQZgvboKD0jBwdaeVCTouxhxAyN6iawF3STraAal8Y= +modernc.org/ccgo/v4 v4.19.2/go.mod h1:ysS3mxiMV38XGRTTcgo0DQTeTmAO4oCmJl1nX9VFI3s= +modernc.org/fileutil v1.3.0 h1:gQ5SIzK3H9kdfai/5x41oQiKValumqNTDXMvKo62HvE= +modernc.org/fileutil v1.3.0/go.mod h1:XatxS8fZi3pS8/hKG2GH/ArUogfxjpEKs3Ku3aK4JyQ= +modernc.org/gc/v2 v2.4.1 h1:9cNzOqPyMJBvrUipmynX0ZohMhcxPtMccYgGOJdOiBw= +modernc.org/gc/v2 v2.4.1/go.mod h1:wzN5dK1AzVGoH6XOzc3YZ+ey/jPgYHLuVckd62P0GYU= modernc.org/libc v1.55.3 h1:AzcW1mhlPNrRtjS5sS+eW2ISCgSOLLNyFzRh/V3Qj/U= modernc.org/libc v1.55.3/go.mod h1:qFXepLhz+JjFThQ4kzwzOjA/y/artDeg+pcYnY+Q83w= modernc.org/mathutil v1.6.0 h1:fRe9+AmYlaej+64JsEEhoWuAYBkOtQiMEU7n/XgfYi4= modernc.org/mathutil v1.6.0/go.mod h1:Ui5Q9q1TR2gFm0AQRqQUaBWFLAhQpCwNcuhBOSedWPo= modernc.org/memory v1.8.0 h1:IqGTL6eFMaDZZhEWwcREgeMXYwmW83LYW8cROZYkg+E= modernc.org/memory v1.8.0/go.mod h1:XPZ936zp5OMKGWPqbD3JShgd/ZoQ7899TUuQqxY+peU= +modernc.org/opt v0.1.3 h1:3XOZf2yznlhC+ibLltsDGzABUGVx8J6pnFMS3E4dcq4= +modernc.org/opt v0.1.3/go.mod h1:WdSiB5evDcignE70guQKxYUl14mgWtbClRi5wmkkTX0= +modernc.org/sortutil v1.2.0 h1:jQiD3PfS2REGJNzNCMMaLSp/wdMNieTbKX920Cqdgqc= +modernc.org/sortutil v1.2.0/go.mod h1:TKU2s7kJMf1AE84OoiGppNHJwvB753OYfNl2WRb++Ss= modernc.org/sqlite v1.33.1 h1:trb6Z3YYoeM9eDL1O8do81kP+0ejv+YzgyFo+Gwy0nM= modernc.org/sqlite v1.33.1/go.mod h1:pXV2xHxhzXZsgT/RtTFAPY6JJDEvOTcTdwADQCCWD4k= +modernc.org/strutil v1.2.0 h1:agBi9dp1I+eOnxXeiZawM8F4LawKv4NzGWSaLfyeNZA= +modernc.org/strutil v1.2.0/go.mod h1:/mdcBmfOibveCTBxUl5B5l6W+TTH1FXPLHZE6bTosX0= +modernc.org/token v1.1.0 h1:Xl7Ap9dKaEs5kLoOQeQmPWevfnk/DM5qcLcYlA8ys6Y= +modernc.org/token v1.1.0/go.mod h1:UGzOrNV1mAFSEB63lOFHIpNRUVMvYTc6yu1SMY/XTDM= zombiezen.com/go/sqlite v1.4.0 h1:N1s3RIljwtp4541Y8rM880qgGIgq3fTD2yks1xftnKU= zombiezen.com/go/sqlite v1.4.0/go.mod h1:0w9F1DN9IZj9AcLS9YDKMboubCACkwYCGkzoy3eG5ik= diff --git a/migrator.go b/migrator.go index b4e62d8..ebcf314 100644 --- a/migrator.go +++ b/migrator.go @@ -1,87 +1,87 @@ package mysqlite -import ( - "database/sql" - "embed" - "fmt" - "io/fs" - "log" - "strconv" - "strings" - "zombiezen.com/go/sqlite" -) - -type ReadDirFileFS interface { - fs.ReadDirFS - fs.ReadFileFS -} - -func (db *Db) MigrateDb(migrations ReadDirFileFS) error { - // Read all migrations - migrationFiles, err := migrations.ReadDir("") - if err != nil { - log.Fatalf("error reading migration files: %v", err) - } - var migrationsByVersion = make(map[int]string) - latestVersion := 0 - for _, f := range migrationFiles { - versionStr := f.Name() - version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) - if err != nil { - log.Fatalf("invalid version number for migration script: %v", err) - } - migrationsByVersion[version] = versionStr - latestVersion = max(latestVersion, version) - } - - // Get current migration version from user_version - var currentVersion int - err = d.QuerySingle("PRAGMA user_version", ¤tVersion) - if err != nil { - log.Fatalf("error getting current version: %v", err) - } - log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) - - // If we are no up-to-date, bring the db up-to-date - for currentVersion != latestVersion { - targetVersion := currentVersion + 1 - migrationFile := migrationsByVersion[targetVersion] - log.Printf("migration to version %s", migrationFile) - migrationScript, err := migrations.ReadFile(migrationFile) - if err != nil { - log.Fatalf("error opening migration script %s: %v", migrationScript, err) - } - - tx, err := db.Begin() - if err != nil { - log.Fatalf("error beginning transaction: %v", err) - } - defer tx.MustRollback() - - err = tx.QuerySingle(string(migrationScript)) - if err != nil { - log.Fatalf("error performing migration: %v", err) - } - - err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) - if err != nil { - log.Fatalf("error updating version: %v", err) - } - - err = tx.Commit() - if err != nil { - log.Fatalf("error commiting transaction: %v", err) - } - currentVersion = targetVersion - } - - log.Println("All migrations applied") - return nil -} - -func rollbackIgnoringErrors(tx *sql.Tx) { - err := tx.Rollback() - if err != nil { - log.Printf("error rolling back: %v", err) - } -} +//import ( +// "database/sql" +// "embed" +// "fmt" +// "io/fs" +// "log" +// "strconv" +// "strings" +// "zombiezen.com/go/sqlite" +//) +// +//type ReadDirFileFS interface { +// fs.ReadDirFS +// fs.ReadFileFS +//} +// +//func (db *Db) MigrateDb(migrations ReadDirFileFS) error { +// // Read all migrations +// migrationFiles, err := migrations.ReadDir("") +// if err != nil { +// log.Fatalf("error reading migration files: %v", err) +// } +// var migrationsByVersion = make(map[int]string) +// latestVersion := 0 +// for _, f := range migrationFiles { +// versionStr := f.Name() +// version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) +// if err != nil { +// log.Fatalf("invalid version number for migration script: %v", err) +// } +// migrationsByVersion[version] = versionStr +// latestVersion = max(latestVersion, version) +// } +// +// // Get current migration version from user_version +// var currentVersion int +// err = d.QuerySingle("PRAGMA user_version", ¤tVersion) +// if err != nil { +// log.Fatalf("error getting current version: %v", err) +// } +// log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) +// +// // If we are no up-to-date, bring the db up-to-date +// for currentVersion != latestVersion { +// targetVersion := currentVersion + 1 +// migrationFile := migrationsByVersion[targetVersion] +// log.Printf("migration to version %s", migrationFile) +// migrationScript, err := migrations.ReadFile(migrationFile) +// if err != nil { +// log.Fatalf("error opening migration script %s: %v", migrationScript, err) +// } +// +// tx, err := db.Begin() +// if err != nil { +// log.Fatalf("error beginning transaction: %v", err) +// } +// defer tx.MustRollback() +// +// err = tx.QuerySingle(string(migrationScript)) +// if err != nil { +// log.Fatalf("error performing migration: %v", err) +// } +// +// err = tx.QuerySingle(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) +// if err != nil { +// log.Fatalf("error updating version: %v", err) +// } +// +// err = tx.Commit() +// if err != nil { +// log.Fatalf("error commiting transaction: %v", err) +// } +// currentVersion = targetVersion +// } +// +// log.Println("All migrations applied") +// return nil +//} +// +//func rollbackIgnoringErrors(tx *sql.Tx) { +// err := tx.Rollback() +// if err != nil { +// log.Printf("error rolling back: %v", err) +// } +//} diff --git a/query.go b/query.go index 8e114bc..7b4e575 100644 --- a/query.go +++ b/query.go @@ -2,6 +2,7 @@ package mysqlite import ( "fmt" + "reflect" "zombiezen.com/go/sqlite" ) @@ -16,14 +17,31 @@ func (d *Db) Query(query string) *Query { return &Query{err: err} } if remaining != 0 { - return &Query{err: fmt.Errorf("remaining bytes: %s", remaining)} + return &Query{err: fmt.Errorf("remaining bytes: %d", remaining)} } return &Query{stmt: stmt} } -//func (q *Query) Args(args ...any) *Query { -// return q -//} +func (q *Query) Bind(args ...any) *Query { + if q.err != nil || q.stmt == nil { + return q + } + for i, arg := range args { + if asString, ok := arg.(string); ok { + q.stmt.BindText(i+1, asString) + } else if asInt, ok := arg.(int); ok { + q.stmt.BindInt64(i+1, int64(asInt)) + } else if asInt, ok := arg.(int64); ok { + q.stmt.BindInt64(i+1, asInt) + } else if asBool, ok := arg.(bool); ok { + q.stmt.BindBool(i+1, asBool) + } else { + q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) + return q + } + } + return q +} func (q *Query) Exec() error { if q.stmt != nil { @@ -42,6 +60,13 @@ func (q *Query) Exec() error { return err } +func (q *Query) MustExec() { + err := q.Exec() + if err != nil { + panic(err) + } +} + func (q *Query) ScanSingle(results ...any) error { if q.stmt != nil { defer q.stmt.Finalize() @@ -65,8 +90,19 @@ func (q *Query) ScanSingle(results ...any) error { } else if asBool, ok := arg.(*bool); ok { *asBool = q.stmt.ColumnBool(i) } else { - return fmt.Errorf("unsupported column type at index %d", i) + if reflect.TypeOf(arg).Kind() != reflect.Ptr { + return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) + } + name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() + return fmt.Errorf("unsupported column type *%s at index %d", name, i) } } return nil } + +func (q *Query) MustScanSingle(results ...any) { + err := q.ScanSingle(results...) + if err != nil { + panic(fmt.Sprintf("error getting results: %v", err)) + } +} diff --git a/query_test.go b/query_test.go new file mode 100644 index 0000000..c2071cf --- /dev/null +++ b/query_test.go @@ -0,0 +1,27 @@ +package mysqlite + +import ( + "github.com/stretchr/testify/require" + "testing" +) + +func openTestDb(t *testing.T) *Db { + db := openEmptyTestDb(t) + db.Query("create table mytable (key text, value text)").MustExec() + db.Query("insert into mytable(key, value) values ('foo', 'bar')").MustExec() + return db +} + +func TestSimpleQuery(t *testing.T) { + db := openTestDb(t) + var count int + db.Query("select count(*) from mytable").MustScanSingle(&count) + require.Equal(t, 1, count, "expected empty count") +} + +func TestSimpleQueryWithArgs(t *testing.T) { + db := openTestDb(t) + var value string + db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value) + require.Equal(t, "bar", value, "bad value returned") +} diff --git a/transaction.go b/transaction.go index 9bebf23..bafd180 100644 --- a/transaction.go +++ b/transaction.go @@ -7,7 +7,7 @@ type Tx struct { } func (d *Db) Begin() (*Tx, error) { - err := d.QuerySingle("BEGIN") + err := d.Query("BEGIN").Exec() if err != nil { return nil, err } @@ -25,7 +25,7 @@ func (tx *Tx) Rollback() error { func (tx *Tx) MustRollback() { err := tx.Rollback() if err != nil { - log.Panicf("error doing rollback: %w", err) + log.Panicf("error doing rollback: %v", err) } }