package mysqlite import ( "fmt" "iter" "reflect" "zombiezen.com/go/sqlite" ) type Query struct { stmt *sqlite.Stmt // Reference to the database. If set, it is assumed that a lock was taken // by the query that should be freed by the query. db *Db err error } func (d *Db) Query(query string) *Query { d.Lock() q := d.query(query) q.db = d return q } func (d *Db) query(query string) *Query { stmt, remaining, err := d.Db.PrepareTransient(query) if err != nil { return &Query{err: err} } if remaining != 0 { return &Query{err: fmt.Errorf("remaining bytes: %d", remaining)} } return &Query{stmt: stmt} } func (q *Query) Bind(args ...any) *Query { if q.err != nil || q.stmt == nil { return q } for i, arg := range args { if asString, ok := arg.(string); ok { q.stmt.BindText(i+1, asString) } else if asInt, ok := arg.(int); ok { q.stmt.BindInt64(i+1, int64(asInt)) } else if asInt, ok := arg.(int64); ok { q.stmt.BindInt64(i+1, asInt) } else if asBool, ok := arg.(bool); ok { q.stmt.BindBool(i+1, asBool) } else { q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) return q } } return q } func (q *Query) Exec() (rerr error) { defer q.unlock() if q.stmt != nil { defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err } rowReturned, err := q.stmt.Step() if err != nil { return err } if rowReturned { return fmt.Errorf("row returned unexpectedly") } return err } func (q *Query) MustExec() { err := q.Exec() if err != nil { panic(err) } } func (q *Query) ScanSingle(results ...any) (rerr error) { defer q.unlock() // Scan rows if q.stmt != nil { defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err } rows, err := q.ScanMulti() if err != nil { return err } // Fetch the first row hasResult, err := rows.Next() if err != nil { return err } if !hasResult { return fmt.Errorf("did not return any rows") } // Scan its columns err = rows.Scan(results...) if err != nil { return err } // Ensure there are no more rows hasResult, err = rows.Next() if err != nil { return err } if hasResult { return fmt.Errorf("returned more than one row") } return nil } func (q *Query) MustScanSingle(results ...any) { err := q.ScanSingle(results...) if err != nil { panic(fmt.Sprintf("error getting results: %v", err)) } } func (q *Query) unlock() { if q.db != nil { q.db.Unlock() } } type Rows struct { query *Query } func (q *Query) ScanMulti() (*Rows, error) { return &Rows{ query: q, }, nil } func (r *Rows) Finish() error { defer r.query.unlock() return r.query.stmt.Finalize() } func (r *Rows) MustFinish() { err := r.Finish() if err != nil { panic(err) } } func (r *Rows) Next() (bool, error) { gotRow, err := r.query.stmt.Step() if err != nil { return false, err } return gotRow, nil } func (r *Rows) MustNext() bool { gotRow, err := r.Next() if err != nil { panic(err) } return gotRow } func (r *Rows) Scan(results ...any) error { for i, arg := range results { if asString, ok := arg.(*string); ok { *asString = r.query.stmt.ColumnText(i) } else if asInt, ok := arg.(*int); ok { *asInt = r.query.stmt.ColumnInt(i) } else if asBool, ok := arg.(*bool); ok { *asBool = r.query.stmt.ColumnBool(i) } else { if reflect.TypeOf(arg).Kind() != reflect.Ptr { return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) } name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() return fmt.Errorf("unsupported column type *%s at index %d", name, i) } } return nil } func (r *Rows) MustScan(results ...any) { err := r.Scan(results...) if err != nil { panic(err) } } func (q *Query) Range(err *error) iter.Seq[*Rows] { return func(yield func(*Rows) bool) { // Start the scan rows, terr := q.ScanMulti() if terr != nil { *err = terr return } // Ensure we close and return any errors defer func() { terr := rows.Finish() if terr != nil { *err = terr } }() // Loop over each record for { // Get the record hasRow, terr := rows.Next() if terr != nil { *err = terr return } if !hasRow { return } // Pass it to the range body if !yield(&Rows{query: q}) { return } } } }