Compare commits
	
		
			1 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 258dcc7180 | 
							
								
								
									
										123
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										123
									
								
								query.go
									
									
									
									
									
								
							@@ -2,6 +2,7 @@ package mysqlite
 | 
			
		||||
 | 
			
		||||
import (
 | 
			
		||||
	"fmt"
 | 
			
		||||
	"iter"
 | 
			
		||||
	"reflect"
 | 
			
		||||
	"zombiezen.com/go/sqlite"
 | 
			
		||||
)
 | 
			
		||||
@@ -68,27 +69,96 @@ func (q *Query) MustExec() {
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *Query) ScanSingle(results ...any) (rerr error) {
 | 
			
		||||
	// Scan rows
 | 
			
		||||
	if q.stmt != nil {
 | 
			
		||||
		defer func() { rerr = q.stmt.Finalize() }()
 | 
			
		||||
	}
 | 
			
		||||
	if q.err != nil {
 | 
			
		||||
		return q.err
 | 
			
		||||
	}
 | 
			
		||||
	rowReturned, err := q.stmt.Step()
 | 
			
		||||
	rows, err := q.ScanMulti()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if !rowReturned {
 | 
			
		||||
 | 
			
		||||
	// Fetch the first row
 | 
			
		||||
	hasResult, err := rows.Next()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if !hasResult {
 | 
			
		||||
		return fmt.Errorf("did not return any rows")
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Scan its columns
 | 
			
		||||
	err = rows.Scan(results...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
 | 
			
		||||
	// Ensure there are no more rows
 | 
			
		||||
	hasResult, err = rows.Next()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return err
 | 
			
		||||
	}
 | 
			
		||||
	if hasResult {
 | 
			
		||||
		return fmt.Errorf("returned more than one row")
 | 
			
		||||
	}
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *Query) MustScanSingle(results ...any) {
 | 
			
		||||
	err := q.ScanSingle(results...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(fmt.Sprintf("error getting results: %v", err))
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
type Rows struct {
 | 
			
		||||
	query *Query
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *Query) ScanMulti() (*Rows, error) {
 | 
			
		||||
	return &Rows{
 | 
			
		||||
		query: q,
 | 
			
		||||
	}, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Rows) Finish() error {
 | 
			
		||||
	return r.query.stmt.Finalize()
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Rows) MustFinish() {
 | 
			
		||||
	err := r.Finish()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Rows) Next() (bool, error) {
 | 
			
		||||
	gotRow, err := r.query.stmt.Step()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		return false, err
 | 
			
		||||
	}
 | 
			
		||||
	return gotRow, nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Rows) MustNext() bool {
 | 
			
		||||
	gotRow, err := r.Next()
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
	return gotRow
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (r *Rows) Scan(results ...any) error {
 | 
			
		||||
	for i, arg := range results {
 | 
			
		||||
		if asString, ok := arg.(*string); ok {
 | 
			
		||||
			*asString = q.stmt.ColumnText(i)
 | 
			
		||||
			*asString = r.query.stmt.ColumnText(i)
 | 
			
		||||
		} else if asInt, ok := arg.(*int); ok {
 | 
			
		||||
			*asInt = q.stmt.ColumnInt(i)
 | 
			
		||||
			*asInt = r.query.stmt.ColumnInt(i)
 | 
			
		||||
		} else if asBool, ok := arg.(*bool); ok {
 | 
			
		||||
			*asBool = q.stmt.ColumnBool(i)
 | 
			
		||||
			*asBool = r.query.stmt.ColumnBool(i)
 | 
			
		||||
		} else {
 | 
			
		||||
			if reflect.TypeOf(arg).Kind() != reflect.Ptr {
 | 
			
		||||
				return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
 | 
			
		||||
@@ -100,9 +170,46 @@ func (q *Query) ScanSingle(results ...any) (rerr error) {
 | 
			
		||||
	return nil
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *Query) MustScanSingle(results ...any) {
 | 
			
		||||
	err := q.ScanSingle(results...)
 | 
			
		||||
func (r *Rows) MustScan(results ...any) {
 | 
			
		||||
	err := r.Scan(results...)
 | 
			
		||||
	if err != nil {
 | 
			
		||||
		panic(fmt.Sprintf("error getting results: %v", err))
 | 
			
		||||
		panic(err)
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func (q *Query) Range(err *error) iter.Seq[*Rows] {
 | 
			
		||||
	return func(yield func(*Rows) bool) {
 | 
			
		||||
		// Start the scan
 | 
			
		||||
		rows, terr := q.ScanMulti()
 | 
			
		||||
		if terr != nil {
 | 
			
		||||
			*err = terr
 | 
			
		||||
			return
 | 
			
		||||
		}
 | 
			
		||||
 | 
			
		||||
		// Ensure we close and return any errors
 | 
			
		||||
		defer func() {
 | 
			
		||||
			terr := rows.Finish()
 | 
			
		||||
			if terr != nil {
 | 
			
		||||
				*err = terr
 | 
			
		||||
			}
 | 
			
		||||
		}()
 | 
			
		||||
 | 
			
		||||
		// Loop over each record
 | 
			
		||||
		for {
 | 
			
		||||
			// Get the record
 | 
			
		||||
			hasRow, terr := rows.Next()
 | 
			
		||||
			if terr != nil {
 | 
			
		||||
				*err = terr
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
			if !hasRow {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
 | 
			
		||||
			// Pass it to the range body
 | 
			
		||||
			if !yield(&Rows{query: q}) {
 | 
			
		||||
				return
 | 
			
		||||
			}
 | 
			
		||||
		}
 | 
			
		||||
	}
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -25,3 +25,44 @@ func TestSimpleQueryWithArgs(t *testing.T) {
 | 
			
		||||
	db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value)
 | 
			
		||||
	require.Equal(t, "bar", value, "bad value returned")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryWithTwoRows(t *testing.T) {
 | 
			
		||||
	db := openTestDb(t)
 | 
			
		||||
	db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
 | 
			
		||||
 | 
			
		||||
	rows, err := db.Query("select value from mytable").ScanMulti()
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
	defer rows.MustFinish()
 | 
			
		||||
 | 
			
		||||
	require.True(t, rows.MustNext(), "expected first row")
 | 
			
		||||
	var value string
 | 
			
		||||
	rows.MustScan(&value)
 | 
			
		||||
	require.Equal(t, "bar", value, "bad value returned")
 | 
			
		||||
 | 
			
		||||
	require.True(t, rows.MustNext(), "expected second row")
 | 
			
		||||
	rows.MustScan(&value)
 | 
			
		||||
	require.Equal(t, "ipsum", value, "bad value returned")
 | 
			
		||||
 | 
			
		||||
	require.False(t, rows.MustNext(), "expected no more rows")
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
func TestQueryWithRange(t *testing.T) {
 | 
			
		||||
	db := openTestDb(t)
 | 
			
		||||
	db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec()
 | 
			
		||||
 | 
			
		||||
	var err error
 | 
			
		||||
	index := 0
 | 
			
		||||
	for row := range db.Query("select value from mytable").Range(&err) {
 | 
			
		||||
		var value string
 | 
			
		||||
		row.MustScan(&value)
 | 
			
		||||
		if index == 0 {
 | 
			
		||||
			require.Equal(t, "bar", value)
 | 
			
		||||
		} else if index == 1 {
 | 
			
		||||
			require.Equal(t, "ipsum", value)
 | 
			
		||||
		} else {
 | 
			
		||||
			require.FailNow(t, "more rows than expected")
 | 
			
		||||
		}
 | 
			
		||||
		index++
 | 
			
		||||
	}
 | 
			
		||||
	require.NoError(t, err)
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user