package mysqlite import ( "fmt" "reflect" "zombiezen.com/go/sqlite" ) type Query struct { stmt *sqlite.Stmt err error } func (d *Db) Query(query string) *Query { stmt, remaining, err := d.Db.PrepareTransient(query) if err != nil { return &Query{err: err} } if remaining != 0 { return &Query{err: fmt.Errorf("remaining bytes: %d", remaining)} } return &Query{stmt: stmt} } func (q *Query) Bind(args ...any) *Query { if q.err != nil || q.stmt == nil { return q } for i, arg := range args { if asString, ok := arg.(string); ok { q.stmt.BindText(i+1, asString) } else if asInt, ok := arg.(int); ok { q.stmt.BindInt64(i+1, int64(asInt)) } else if asInt, ok := arg.(int64); ok { q.stmt.BindInt64(i+1, asInt) } else if asBool, ok := arg.(bool); ok { q.stmt.BindBool(i+1, asBool) } else { q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) return q } } return q } func (q *Query) Exec() (rerr error) { if q.stmt != nil { defer func() { rerr = q.stmt.Finalize() }() } if q.err != nil { return q.err } rowReturned, err := q.stmt.Step() if err != nil { return err } if rowReturned { return fmt.Errorf("row returned unexpectedly") } return err } func (q *Query) MustExec() { err := q.Exec() if err != nil { panic(err) } } func (q *Query) ScanSingle(results ...any) (rerr error) { if q.stmt != nil { defer func() { rerr = q.stmt.Finalize() }() } if q.err != nil { return q.err } rowReturned, err := q.stmt.Step() if err != nil { return err } if !rowReturned { return fmt.Errorf("did not return any rows") } for i, arg := range results { if asString, ok := arg.(*string); ok { *asString = q.stmt.ColumnText(i) } else if asInt, ok := arg.(*int); ok { *asInt = q.stmt.ColumnInt(i) } else if asBool, ok := arg.(*bool); ok { *asBool = q.stmt.ColumnBool(i) } else { if reflect.TypeOf(arg).Kind() != reflect.Ptr { return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) } name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() return fmt.Errorf("unsupported column type *%s at index %d", name, i) } } return nil } func (q *Query) MustScanSingle(results ...any) { err := q.ScanSingle(results...) if err != nil { panic(fmt.Sprintf("error getting results: %v", err)) } }