Compare commits
	
		
			1 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 258dcc7180 | 
							
								
								
									
										123
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										123
									
								
								query.go
									
									
									
									
									
								
							@@ -2,6 +2,7 @@ package mysqlite
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
import (
 | 
					import (
 | 
				
			||||||
	"fmt"
 | 
						"fmt"
 | 
				
			||||||
 | 
						"iter"
 | 
				
			||||||
	"reflect"
 | 
						"reflect"
 | 
				
			||||||
	"zombiezen.com/go/sqlite"
 | 
						"zombiezen.com/go/sqlite"
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
@@ -68,27 +69,96 @@ func (q *Query) MustExec() {
 | 
				
			|||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (q *Query) ScanSingle(results ...any) (rerr error) {
 | 
					func (q *Query) ScanSingle(results ...any) (rerr error) {
 | 
				
			||||||
 | 
						// Scan rows
 | 
				
			||||||
	if q.stmt != nil {
 | 
						if q.stmt != nil {
 | 
				
			||||||
		defer func() { rerr = q.stmt.Finalize() }()
 | 
							defer func() { rerr = q.stmt.Finalize() }()
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if q.err != nil {
 | 
						if q.err != nil {
 | 
				
			||||||
		return q.err
 | 
							return q.err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	rowReturned, err := q.stmt.Step()
 | 
						rows, err := q.ScanMulti()
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		return err
 | 
							return err
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
	if !rowReturned {
 | 
					
 | 
				
			||||||
 | 
						// Fetch the first row
 | 
				
			||||||
 | 
						hasResult, err := rows.Next()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if !hasResult {
 | 
				
			||||||
		return fmt.Errorf("did not return any rows")
 | 
							return fmt.Errorf("did not return any rows")
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Scan its columns
 | 
				
			||||||
 | 
						err = rows.Scan(results...)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						// Ensure there are no more rows
 | 
				
			||||||
 | 
						hasResult, err = rows.Next()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						if hasResult {
 | 
				
			||||||
 | 
							return fmt.Errorf("returned more than one row")
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *Query) MustScanSingle(results ...any) {
 | 
				
			||||||
 | 
						err := q.ScanSingle(results...)
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(fmt.Sprintf("error getting results: %v", err))
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type Rows struct {
 | 
				
			||||||
 | 
						query *Query
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *Query) ScanMulti() (*Rows, error) {
 | 
				
			||||||
 | 
						return &Rows{
 | 
				
			||||||
 | 
							query: q,
 | 
				
			||||||
 | 
						}, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) Finish() error {
 | 
				
			||||||
 | 
						return r.query.stmt.Finalize()
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) MustFinish() {
 | 
				
			||||||
 | 
						err := r.Finish()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) Next() (bool, error) {
 | 
				
			||||||
 | 
						gotRow, err := r.query.stmt.Step()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							return false, err
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return gotRow, nil
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) MustNext() bool {
 | 
				
			||||||
 | 
						gotRow, err := r.Next()
 | 
				
			||||||
 | 
						if err != nil {
 | 
				
			||||||
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						return gotRow
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (r *Rows) Scan(results ...any) error {
 | 
				
			||||||
	for i, arg := range results {
 | 
						for i, arg := range results {
 | 
				
			||||||
		if asString, ok := arg.(*string); ok {
 | 
							if asString, ok := arg.(*string); ok {
 | 
				
			||||||
			*asString = q.stmt.ColumnText(i)
 | 
								*asString = r.query.stmt.ColumnText(i)
 | 
				
			||||||
		} else if asInt, ok := arg.(*int); ok {
 | 
							} else if asInt, ok := arg.(*int); ok {
 | 
				
			||||||
			*asInt = q.stmt.ColumnInt(i)
 | 
								*asInt = r.query.stmt.ColumnInt(i)
 | 
				
			||||||
		} else if asBool, ok := arg.(*bool); ok {
 | 
							} else if asBool, ok := arg.(*bool); ok {
 | 
				
			||||||
			*asBool = q.stmt.ColumnBool(i)
 | 
								*asBool = r.query.stmt.ColumnBool(i)
 | 
				
			||||||
		} else {
 | 
							} else {
 | 
				
			||||||
			if reflect.TypeOf(arg).Kind() != reflect.Ptr {
 | 
								if reflect.TypeOf(arg).Kind() != reflect.Ptr {
 | 
				
			||||||
				return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
 | 
									return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
 | 
				
			||||||
@@ -100,9 +170,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
 | 
				
			|||||||
	return nil
 | 
						return nil
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
func (q *Query) MustScanSingle(results ...any) {
 | 
					func (r *Rows) MustScan(results ...any) {
 | 
				
			||||||
	err := q.ScanSingle(results...)
 | 
						err := r.Scan(results...)
 | 
				
			||||||
	if err != nil {
 | 
						if err != nil {
 | 
				
			||||||
		panic(fmt.Sprintf("error getting results: %v", err))
 | 
							panic(err)
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func (q *Query) Range(err *error) iter.Seq[*Rows] {
 | 
				
			||||||
 | 
						return func(yield func(*Rows) bool) {
 | 
				
			||||||
 | 
							// Start the scan
 | 
				
			||||||
 | 
							rows, terr := q.ScanMulti()
 | 
				
			||||||
 | 
							if terr != nil {
 | 
				
			||||||
 | 
								*err = terr
 | 
				
			||||||
 | 
								return
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Ensure we close and return any errors
 | 
				
			||||||
 | 
							defer func() {
 | 
				
			||||||
 | 
								terr := rows.Finish()
 | 
				
			||||||
 | 
								if terr != nil {
 | 
				
			||||||
 | 
									*err = terr
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
							// Loop over each record
 | 
				
			||||||
 | 
							for {
 | 
				
			||||||
 | 
								// Get the record
 | 
				
			||||||
 | 
								hasRow, terr := rows.Next()
 | 
				
			||||||
 | 
								if terr != nil {
 | 
				
			||||||
 | 
									*err = terr
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
								if !hasRow {
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
								// Pass it to the range body
 | 
				
			||||||
 | 
								if !yield(&Rows{query: q}) {
 | 
				
			||||||
 | 
									return
 | 
				
			||||||
 | 
								}
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
	}
 | 
						}
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) {
 | 
				
			|||||||
	db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
 | 
						db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
 | 
				
			||||||
	require.Equal(t, "bar", value, "bad value returned")
 | 
						require.Equal(t, "bar", value, "bad value returned")
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithTwoRows(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						rows, err := db.Query("select value from mytable").ScanMulti()
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
						defer rows.MustFinish()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						require.True(t, rows.MustNext(), "expected first row")
 | 
				
			||||||
 | 
						var value string
 | 
				
			||||||
 | 
						rows.MustScan(&value)
 | 
				
			||||||
 | 
						require.Equal(t, "bar", value, "bad value returned")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						require.True(t, rows.MustNext(), "expected second row")
 | 
				
			||||||
 | 
						rows.MustScan(&value)
 | 
				
			||||||
 | 
						require.Equal(t, "ipsum", value, "bad value returned")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						require.False(t, rows.MustNext(), "expected no more rows")
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func TestQueryWithRange(t *testing.T) {
 | 
				
			||||||
 | 
						db := openTestDb(t)
 | 
				
			||||||
 | 
						db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
						var err error
 | 
				
			||||||
 | 
						index := 0
 | 
				
			||||||
 | 
						for row := range db.Query("select value from mytable").Range(&err) {
 | 
				
			||||||
 | 
							var value string
 | 
				
			||||||
 | 
							row.MustScan(&value)
 | 
				
			||||||
 | 
							if index == 0 {
 | 
				
			||||||
 | 
								require.Equal(t, "bar", value)
 | 
				
			||||||
 | 
							} else if index == 1 {
 | 
				
			||||||
 | 
								require.Equal(t, "ipsum", value)
 | 
				
			||||||
 | 
							} else {
 | 
				
			||||||
 | 
								require.FailNow(t, "more rows than expected")
 | 
				
			||||||
 | 
							}
 | 
				
			||||||
 | 
							index++
 | 
				
			||||||
 | 
						}
 | 
				
			||||||
 | 
						require.NoError(t, err)
 | 
				
			||||||
 | 
					}
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user