mysqlite/query.go
Sebastiaan de Schaetzen 3f031f2fe2
All checks were successful
Build / build (push) Successful in 1m32s
Fix error in return values
2025-02-18 12:54:57 +01:00

109 lines
2.3 KiB
Go

package mysqlite
import (
"fmt"
"reflect"
"zombiezen.com/go/sqlite"
)
type Query struct {
stmt *sqlite.Stmt
err error
}
func (d *Db) Query(query string) *Query {
stmt, remaining, err := d.Db.PrepareTransient(query)
if err != nil {
return &Query{err: err}
}
if remaining != 0 {
return &Query{err: fmt.Errorf("remaining bytes: %d", remaining)}
}
return &Query{stmt: stmt}
}
func (q *Query) Bind(args ...any) *Query {
if q.err != nil || q.stmt == nil {
return q
}
for i, arg := range args {
if asString, ok := arg.(string); ok {
q.stmt.BindText(i+1, asString)
} else if asInt, ok := arg.(int); ok {
q.stmt.BindInt64(i+1, int64(asInt))
} else if asInt, ok := arg.(int64); ok {
q.stmt.BindInt64(i+1, asInt)
} else if asBool, ok := arg.(bool); ok {
q.stmt.BindBool(i+1, asBool)
} else {
q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i)
return q
}
}
return q
}
func (q *Query) Exec() (rerr error) {
if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }()
}
if q.err != nil {
return q.err
}
rowReturned, err := q.stmt.Step()
if err != nil {
return err
}
if rowReturned {
return fmt.Errorf("row returned unexpectedly")
}
return err
}
func (q *Query) MustExec() {
err := q.Exec()
if err != nil {
panic(err)
}
}
func (q *Query) ScanSingle(results ...any) (rerr error) {
if q.stmt != nil {
defer func() { rerr = q.stmt.Finalize() }()
}
if q.err != nil {
return q.err
}
rowReturned, err := q.stmt.Step()
if err != nil {
return err
}
if !rowReturned {
return fmt.Errorf("did not return any rows")
}
for i, arg := range results {
if asString, ok := arg.(*string); ok {
*asString = q.stmt.ColumnText(i)
} else if asInt, ok := arg.(*int); ok {
*asInt = q.stmt.ColumnInt(i)
} else if asBool, ok := arg.(*bool); ok {
*asBool = q.stmt.ColumnBool(i)
} else {
if reflect.TypeOf(arg).Kind() != reflect.Ptr {
return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i)
}
name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name()
return fmt.Errorf("unsupported column type *%s at index %d", name, i)
}
}
return nil
}
func (q *Query) MustScanSingle(results ...any) {
err := q.ScanSingle(results...)
if err != nil {
panic(fmt.Sprintf("error getting results: %v", err))
}
}