package mysqlite import ( "fmt" "iter" "reflect" "zombiezen.com/go/sqlite" ) type Query struct { stmt *sqlite.Stmt // Reference to the database. If set, it is assumed that a lock was taken // by the query that should be freed by the query. db *Db err error } func (d *Db) Query(query string) *Query { d.Lock() q := d.query(query) q.db = d return q } func (d *Db) query(query string) *Query { stmt, remaining, err := d.Db.PrepareTransient(query) if err != nil { return &Query{err: err} } if remaining != 0 { return &Query{err: fmt.Errorf("remaining bytes: %d", remaining)} } return &Query{stmt: stmt} } func (q *Query) Bind(args ...any) *Query { into := 0 return q.bindInto(&into, args...) } func (q *Query) bindInto(into *int, args ...any) *Query { if q.err != nil || q.stmt == nil { return q } for i, arg := range args { *into++ if arg == nil { q.stmt.BindNull(*into) continue } v := reflect.ValueOf(arg) if v.Kind() == reflect.Ptr { if v.IsNil() { q.stmt.BindNull(*into) continue } arg = v.Elem().Interface() } if asString, ok := arg.(string); ok { q.stmt.BindText(*into, asString) } else if asInt, ok := arg.(int); ok { q.stmt.BindInt64(*into, int64(asInt)) } else if asInt, ok := arg.(int64); ok { q.stmt.BindInt64(*into, asInt) } else if asBool, ok := arg.(bool); ok { q.stmt.BindBool(*into, asBool) } else { // Check if the argument is a slice or array of any type v = reflect.ValueOf(arg) if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { *into-- for i := 0; i < v.Len(); i++ { q.bindInto(into, v.Index(i).Interface()) } } else { *into-- q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) return q } } } return q } func (q *Query) Exec() (rerr error) { defer q.unlock() if q.stmt != nil { defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err } rowReturned, err := q.stmt.Step() if err != nil { return err } if rowReturned { return fmt.Errorf("row returned unexpectedly") } return err } func (q *Query) MustExec() { err := q.Exec() if err != nil { panic(err) } } func (q *Query) ScanSingle(results ...any) (rerr error) { defer q.unlock() // Scan rows if q.stmt != nil { defer func() { forwardError(q.stmt.Finalize(), &rerr) }() } if q.err != nil { return q.err } rows, err := q.ScanMulti() if err != nil { return err } // Fetch the first row hasResult, err := rows.Next() if err != nil { return err } if !hasResult { return ErrNoRows } // Scan its columns err = rows.Scan(results...) if err != nil { return err } // Ensure there are no more rows hasResult, err = rows.Next() if err != nil { return err } if hasResult { return fmt.Errorf("returned more than one row") } return nil } func (q *Query) MustScanSingle(results ...any) { err := q.ScanSingle(results...) if err != nil { panic(fmt.Sprintf("error getting results: %v", err)) } } func (q *Query) unlock() { if q.db != nil { q.db.Unlock() } } type Rows struct { query *Query } func (q *Query) ScanMulti() (*Rows, error) { if q.err != nil { return nil, q.err } return &Rows{ query: q, }, nil } func (r *Rows) Finish() error { defer r.query.unlock() return r.query.stmt.Finalize() } func (r *Rows) MustFinish() { err := r.Finish() if err != nil { panic(err) } } func (r *Rows) Next() (bool, error) { gotRow, err := r.query.stmt.Step() if err != nil { return false, err } return gotRow, nil } func (r *Rows) MustNext() bool { gotRow, err := r.Next() if err != nil { panic(err) } return gotRow } func (r *Rows) Scan(results ...any) error { for i, arg := range results { err := r.scanArgument(i, arg) if err != nil { return err } } return nil } func (r *Rows) scanArgument(i int, arg any) error { if asString, ok := arg.(*string); ok { *asString = r.query.stmt.ColumnText(i) } else if asInt, ok := arg.(*int); ok { *asInt = r.query.stmt.ColumnInt(i) } else if asBool, ok := arg.(*bool); ok { *asBool = r.query.stmt.ColumnBool(i) } else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { return r.handleNullableType(i, arg) } else { if reflect.TypeOf(arg).Kind() != reflect.Ptr { return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) } name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() return fmt.Errorf("unsupported column type *%s at index %d", name, i) } return nil } func (r *Rows) handleNullableType(i int, asPtr any) error { if r.query.stmt.ColumnIsNull(i) { reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem())) } else { value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface() err := r.scanArgument(i, value) if err != nil { return err } reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value)) } return nil } func (r *Rows) MustScan(results ...any) { err := r.Scan(results...) if err != nil { panic(err) } } func (q *Query) Range(err *error) iter.Seq[*Rows] { return func(yield func(*Rows) bool) { // Start the scan rows, terr := q.ScanMulti() if terr != nil { *err = terr return } // Ensure we close and return any errors defer func() { terr := rows.Finish() if terr != nil { *err = terr } }() // Loop over each record for { // Get the record hasRow, terr := rows.Next() if terr != nil { *err = terr return } if !hasRow { return } // Pass it to the range body if !yield(&Rows{query: q}) { return } } } }