From 258dcc7180b5eee3518ec758c97fca98aa18b445 Mon Sep 17 00:00:00 2001 From: Sebastiaan de Schaetzen Date: Thu, 20 Feb 2025 12:02:16 +0100 Subject: [PATCH] Add reading of multiple rows with optional iterator --- query.go | 123 ++++++++++++++++++++++++++++++++++++++++++++++---- query_test.go | 41 +++++++++++++++++ 2 files changed, 156 insertions(+), 8 deletions(-) diff --git a/query.go b/query.go index ddf426c..5a3f95a 100644 --- a/query.go +++ b/query.go @@ -2,6 +2,7 @@ package mysqlite import ( "fmt" + "iter" "reflect" "zombiezen.com/go/sqlite" ) @@ -68,27 +69,96 @@ func (q *Query) MustExec() { } func (q *Query) ScanSingle(results ...any) (rerr error) { + // Scan rows if q.stmt != nil { defer func() { rerr = q.stmt.Finalize() }() } if q.err != nil { return q.err } - rowReturned, err := q.stmt.Step() + rows, err := q.ScanMulti() if err != nil { return err } - if !rowReturned { + + // Fetch the first row + hasResult, err := rows.Next() + if err != nil { + return err + } + if !hasResult { return fmt.Errorf("did not return any rows") } + // Scan its columns + err = rows.Scan(results...) + if err != nil { + return err + } + + // Ensure there are no more rows + hasResult, err = rows.Next() + if err != nil { + return err + } + if hasResult { + return fmt.Errorf("returned more than one row") + } + return nil +} + +func (q *Query) MustScanSingle(results ...any) { + err := q.ScanSingle(results...) + if err != nil { + panic(fmt.Sprintf("error getting results: %v", err)) + } +} + +type Rows struct { + query *Query +} + +func (q *Query) ScanMulti() (*Rows, error) { + return &Rows{ + query: q, + }, nil +} + +func (r *Rows) Finish() error { + return r.query.stmt.Finalize() +} + +func (r *Rows) MustFinish() { + err := r.Finish() + if err != nil { + panic(err) + } +} + +func (r *Rows) Next() (bool, error) { + gotRow, err := r.query.stmt.Step() + if err != nil { + return false, err + } + return gotRow, nil +} + +func (r *Rows) MustNext() bool { + gotRow, err := r.Next() + if err != nil { + panic(err) + } + return gotRow +} + +func (r *Rows) Scan(results ...any) error { for i, arg := range results { if asString, ok := arg.(*string); ok { - *asString = q.stmt.ColumnText(i) + *asString = r.query.stmt.ColumnText(i) } else if asInt, ok := arg.(*int); ok { - *asInt = q.stmt.ColumnInt(i) + *asInt = r.query.stmt.ColumnInt(i) } else if asBool, ok := arg.(*bool); ok { - *asBool = q.stmt.ColumnBool(i) + *asBool = r.query.stmt.ColumnBool(i) } else { if reflect.TypeOf(arg).Kind() != reflect.Ptr { return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) @@ -100,9 +170,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) { return nil } -func (q *Query) MustScanSingle(results ...any) { - err := q.ScanSingle(results...) +func (r *Rows) MustScan(results ...any) { + err := r.Scan(results...) if err != nil { - panic(fmt.Sprintf("error getting results: %v", err)) + panic(err) + } +} + +func (q *Query) Range(err *error) iter.Seq[*Rows] { + return func(yield func(*Rows) bool) { + // Start the scan + rows, terr := q.ScanMulti() + if terr != nil { + *err = terr + return + } + + // Ensure we close and return any errors + defer func() { + terr := rows.Finish() + if terr != nil { + *err = terr + } + }() + + // Loop over each record + for { + // Get the record + hasRow, terr := rows.Next() + if terr != nil { + *err = terr + return + } + if !hasRow { + return + } + + // Pass it to the range body + if !yield(&Rows{query: q}) { + return + } + } } } diff --git a/query_test.go b/query_test.go index c2071cf..171f11a 100644 --- a/query_test.go +++ b/query_test.go @@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) { db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value) require.Equal(t, "bar", value, "bad value returned") } + +func TestQueryWithTwoRows(t *testing.T) { + db := openTestDb(t) + db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec() + + rows, err := db.Query("select value from mytable").ScanMulti() + require.NoError(t, err) + defer rows.MustFinish() + + require.True(t, rows.MustNext(), "expected first row") + var value string + rows.MustScan(&value) + require.Equal(t, "bar", value, "bad value returned") + + require.True(t, rows.MustNext(), "expected second row") + rows.MustScan(&value) + require.Equal(t, "ipsum", value, "bad value returned") + + require.False(t, rows.MustNext(), "expected no more rows") +} + +func TestQueryWithRange(t *testing.T) { + db := openTestDb(t) + db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec() + + var err error + index := 0 + for row := range db.Query("select value from mytable").Range(&err) { + var value string + row.MustScan(&value) + if index == 0 { + require.Equal(t, "bar", value) + } else if index == 1 { + require.Equal(t, "ipsum", value) + } else { + require.FailNow(t, "more rows than expected") + } + index++ + } + require.NoError(t, err) +}