Add reading of multiple rows with optional iterator
All checks were successful
Build / build (push) Successful in 1m18s
All checks were successful
Build / build (push) Successful in 1m18s
This commit is contained in:
parent
96b27ff99d
commit
258dcc7180
123
query.go
123
query.go
@ -2,6 +2,7 @@ package mysqlite
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"iter"
|
||||||
"reflect"
|
"reflect"
|
||||||
"zombiezen.com/go/sqlite"
|
"zombiezen.com/go/sqlite"
|
||||||
)
|
)
|
||||||
@ -68,27 +69,96 @@ func (q *Query) MustExec() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) ScanSingle(results ...any) (rerr error) {
|
func (q *Query) ScanSingle(results ...any) (rerr error) {
|
||||||
|
// Scan rows
|
||||||
if q.stmt != nil {
|
if q.stmt != nil {
|
||||||
defer func() { rerr = q.stmt.Finalize() }()
|
defer func() { rerr = q.stmt.Finalize() }()
|
||||||
}
|
}
|
||||||
if q.err != nil {
|
if q.err != nil {
|
||||||
return q.err
|
return q.err
|
||||||
}
|
}
|
||||||
rowReturned, err := q.stmt.Step()
|
rows, err := q.ScanMulti()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if !rowReturned {
|
|
||||||
|
// Fetch the first row
|
||||||
|
hasResult, err := rows.Next()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if !hasResult {
|
||||||
return fmt.Errorf("did not return any rows")
|
return fmt.Errorf("did not return any rows")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Scan its columns
|
||||||
|
err = rows.Scan(results...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure there are no more rows
|
||||||
|
hasResult, err = rows.Next()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if hasResult {
|
||||||
|
return fmt.Errorf("returned more than one row")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) MustScanSingle(results ...any) {
|
||||||
|
err := q.ScanSingle(results...)
|
||||||
|
if err != nil {
|
||||||
|
panic(fmt.Sprintf("error getting results: %v", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Rows struct {
|
||||||
|
query *Query
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) ScanMulti() (*Rows, error) {
|
||||||
|
return &Rows{
|
||||||
|
query: q,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Finish() error {
|
||||||
|
return r.query.stmt.Finalize()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) MustFinish() {
|
||||||
|
err := r.Finish()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Next() (bool, error) {
|
||||||
|
gotRow, err := r.query.stmt.Step()
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return gotRow, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) MustNext() bool {
|
||||||
|
gotRow, err := r.Next()
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return gotRow
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *Rows) Scan(results ...any) error {
|
||||||
for i, arg := range results {
|
for i, arg := range results {
|
||||||
if asString, ok := arg.(*string); ok {
|
if asString, ok := arg.(*string); ok {
|
||||||
*asString = q.stmt.ColumnText(i)
|
*asString = r.query.stmt.ColumnText(i)
|
||||||
} else if asInt, ok := arg.(*int); ok {
|
} else if asInt, ok := arg.(*int); ok {
|
||||||
*asInt = q.stmt.ColumnInt(i)
|
*asInt = r.query.stmt.ColumnInt(i)
|
||||||
} else if asBool, ok := arg.(*bool); ok {
|
} else if asBool, ok := arg.(*bool); ok {
|
||||||
*asBool = q.stmt.ColumnBool(i)
|
*asBool = r.query.stmt.ColumnBool(i)
|
||||||
} else {
|
} else {
|
||||||
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
|
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
|
||||||
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
|
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
|
||||||
@ -100,9 +170,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Query) MustScanSingle(results ...any) {
|
func (r *Rows) MustScan(results ...any) {
|
||||||
err := q.ScanSingle(results...)
|
err := r.Scan(results...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(fmt.Sprintf("error getting results: %v", err))
|
panic(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *Query) Range(err *error) iter.Seq[*Rows] {
|
||||||
|
return func(yield func(*Rows) bool) {
|
||||||
|
// Start the scan
|
||||||
|
rows, terr := q.ScanMulti()
|
||||||
|
if terr != nil {
|
||||||
|
*err = terr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure we close and return any errors
|
||||||
|
defer func() {
|
||||||
|
terr := rows.Finish()
|
||||||
|
if terr != nil {
|
||||||
|
*err = terr
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Loop over each record
|
||||||
|
for {
|
||||||
|
// Get the record
|
||||||
|
hasRow, terr := rows.Next()
|
||||||
|
if terr != nil {
|
||||||
|
*err = terr
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !hasRow {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pass it to the range body
|
||||||
|
if !yield(&Rows{query: q}) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) {
|
|||||||
db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
|
db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
|
||||||
require.Equal(t, "bar", value, "bad value returned")
|
require.Equal(t, "bar", value, "bad value returned")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestQueryWithTwoRows(t *testing.T) {
|
||||||
|
db := openTestDb(t)
|
||||||
|
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
|
||||||
|
|
||||||
|
rows, err := db.Query("select value from mytable").ScanMulti()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer rows.MustFinish()
|
||||||
|
|
||||||
|
require.True(t, rows.MustNext(), "expected first row")
|
||||||
|
var value string
|
||||||
|
rows.MustScan(&value)
|
||||||
|
require.Equal(t, "bar", value, "bad value returned")
|
||||||
|
|
||||||
|
require.True(t, rows.MustNext(), "expected second row")
|
||||||
|
rows.MustScan(&value)
|
||||||
|
require.Equal(t, "ipsum", value, "bad value returned")
|
||||||
|
|
||||||
|
require.False(t, rows.MustNext(), "expected no more rows")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQueryWithRange(t *testing.T) {
|
||||||
|
db := openTestDb(t)
|
||||||
|
db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
|
||||||
|
|
||||||
|
var err error
|
||||||
|
index := 0
|
||||||
|
for row := range db.Query("select value from mytable").Range(&err) {
|
||||||
|
var value string
|
||||||
|
row.MustScan(&value)
|
||||||
|
if index == 0 {
|
||||||
|
require.Equal(t, "bar", value)
|
||||||
|
} else if index == 1 {
|
||||||
|
require.Equal(t, "ipsum", value)
|
||||||
|
} else {
|
||||||
|
require.FailNow(t, "more rows than expected")
|
||||||
|
}
|
||||||
|
index++
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user