Compare commits
	
		
			14 Commits
		
	
	
		
			v0.1.0
			...
			87f10c73d6
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 87f10c73d6 | |||
| 58d63b6cf3 | |||
| 850e4a27d8 | |||
| 3e7455ef31 | |||
| 94ae36305a | |||
| 2eacf6fbc4 | |||
| 2ff3477812 | |||
| 68f8dc50e0 | |||
| 187ed5987d | |||
| a377448de3 | |||
| 0a177e0b46 | |||
| 82c7f57078 | |||
| 9d5c0bcbb1 | |||
| 258dcc7180 | 
| @@ -7,5 +7,10 @@ jobs: | ||||
|       - name: Checkout | ||||
|         uses: actions/checkout@v4 | ||||
|  | ||||
|       - name: Setup Go | ||||
|         uses: actions/setup-go@v5 | ||||
|         with: | ||||
|           go-version: '>=1.24' | ||||
|  | ||||
|       - name: Test | ||||
|         run: go test . -v | ||||
|   | ||||
							
								
								
									
										92
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										92
									
								
								README.md
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,92 @@ | ||||
| # MySQLite | ||||
|  | ||||
| A Go library that provides a convenient wrapper around SQLite with additional functionality for database management, migrations, and transactions. | ||||
|  | ||||
| ## Features | ||||
|  | ||||
| - Simple and intuitive SQLite database connection management | ||||
| - Thread-safe database operations with built-in locking mechanism | ||||
| - Support for database migrations | ||||
| - Transaction management | ||||
| - Built on top of [zombiezen.com/go/sqlite](https://pkg.go.dev/zombiezen.com/go/sqlite) | ||||
|  | ||||
| ## Installation | ||||
|  | ||||
| ```bash | ||||
| go get gitea.seeseepuff.be/seeseemelk/mysqlite | ||||
| ``` | ||||
|  | ||||
| ## Usage | ||||
|  | ||||
| ### Opening a Database Connection | ||||
|  | ||||
| ```go | ||||
| import "gitea.seeseepuff.be/seeseemelk/mysqlite" | ||||
|  | ||||
| // Open an in-memory database | ||||
| db, err := mysqlite.OpenDb(":memory:") | ||||
| if err != nil { | ||||
|     // Handle error | ||||
| } | ||||
| defer db.Close() | ||||
|  | ||||
| // Open a file-based database | ||||
| db, err := mysqlite.OpenDb("path/to/database.db") | ||||
| if err != nil { | ||||
|     // Handle error | ||||
| } | ||||
| defer db.Close() | ||||
| ``` | ||||
|  | ||||
| ### Executing Queries | ||||
|  | ||||
| The library provides methods for executing SQL queries and managing transactions: | ||||
|  | ||||
| ```go | ||||
| // Execute a simple query | ||||
| err := db.Query("CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)").Exec() | ||||
|  | ||||
| // Use transactions | ||||
| tx, err := db.BeginTransaction() | ||||
| if err != nil { | ||||
|     // Handle error | ||||
| } | ||||
|  | ||||
| // Perform operations within transaction | ||||
| // ... | ||||
|  | ||||
| // Commit or rollback | ||||
| err = tx.Commit() // or tx.Rollback() | ||||
| ``` | ||||
|  | ||||
| ### Database Migrations | ||||
|  | ||||
| The library includes support for SQL-based migrations. Migrations are SQL files stored in a directory and are executed in order based on their filename prefix: | ||||
|  | ||||
| 1. Create a directory for your migrations (e.g., `migrations/`) | ||||
| 2. Add numbered SQL migration files: | ||||
|    ``` | ||||
|    migrations/ | ||||
|    ├── 1_initial.sql | ||||
|    ├── 2_add_users.sql | ||||
|    ├── 3_add_posts.sql | ||||
|    ``` | ||||
| 3. Embed the migrations in your Go code: | ||||
|    ```go | ||||
|    import "embed" | ||||
|     | ||||
|    //go:embed migrations/*.sql | ||||
|    var migrations embed.FS | ||||
|     | ||||
|    // Apply migrations | ||||
|    err := db.MigrateDb(migrations, "migrations") | ||||
|    if err != nil { | ||||
|        // Handle error | ||||
|    } | ||||
|    ``` | ||||
|  | ||||
| Each migration file should contain valid SQL statements. The migrations are executed in order and are tracked internally to ensure they only run once. | ||||
|  | ||||
| ## Requirements | ||||
|  | ||||
| - Go 1.24 or higher | ||||
							
								
								
									
										15
									
								
								database.go
									
									
									
									
									
								
							
							
						
						
									
										15
									
								
								database.go
									
									
									
									
									
								
							| @@ -2,12 +2,15 @@ package mysqlite | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"sync" | ||||
| 	"zombiezen.com/go/sqlite" | ||||
| ) | ||||
|  | ||||
| // Db holds a connection to a SQLite database. | ||||
| type Db struct { | ||||
| 	Db *sqlite.Conn | ||||
| 	Db     *sqlite.Conn | ||||
| 	source string | ||||
| 	lock   sync.Mutex | ||||
| } | ||||
|  | ||||
| // OpenDb opens a new connection to a SQLite database. | ||||
| @@ -19,7 +22,7 @@ func OpenDb(databaseSource string) (*Db, error) { | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Db{Db: conn}, nil | ||||
| 	return &Db{Db: conn, source: databaseSource}, nil | ||||
| } | ||||
|  | ||||
| // Close closes the database. | ||||
| @@ -35,3 +38,11 @@ func (d *Db) MustClose() { | ||||
| 		panic(fmt.Sprintf("error closing db: %v", err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (d *Db) Lock() { | ||||
| 	d.lock.Lock() | ||||
| } | ||||
|  | ||||
| func (d *Db) Unlock() { | ||||
| 	d.lock.Unlock() | ||||
| } | ||||
|   | ||||
							
								
								
									
										44
									
								
								migrator.go
									
									
									
									
									
								
							
							
						
						
									
										44
									
								
								migrator.go
									
									
									
									
									
								
							| @@ -4,6 +4,7 @@ import ( | ||||
| 	"fmt" | ||||
| 	"io/fs" | ||||
| 	"log" | ||||
| 	"os" | ||||
| 	"path" | ||||
| 	"strconv" | ||||
| 	"strings" | ||||
| @@ -40,6 +41,20 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { | ||||
| 	} | ||||
| 	log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion) | ||||
|  | ||||
| 	// Create a backup if we're not on the latest version | ||||
| 	if currentVersion != latestVersion && d.source != ":memory:" { | ||||
| 		target := d.source + ".backup." + strconv.Itoa(currentVersion) | ||||
| 		log.Printf("Creating backup of database to %s", target) | ||||
| 		data, err := d.Db.Serialize("main") | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error serializing database: %v", err) | ||||
| 		} | ||||
| 		err = os.WriteFile(target, data, 0644) | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error writing backup: %v", err) | ||||
| 		} | ||||
| 	} | ||||
|  | ||||
| 	// If we are no up-to-date, bring the db up-to-date | ||||
| 	for currentVersion != latestVersion { | ||||
| 		targetVersion := currentVersion + 1 | ||||
| @@ -50,7 +65,7 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { | ||||
| 			return fmt.Errorf("error opening migration script %s: %v", migrationScript, err) | ||||
| 		} | ||||
|  | ||||
| 		err = performSingleMigration(err, d, migrationScript, targetVersion) | ||||
| 		err = performSingleMigration(d, migrationScript, targetVersion) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| @@ -61,21 +76,32 @@ func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error { | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func performSingleMigration(err error, d *Db, migrationScript []byte, targetVersion int) error { | ||||
| func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error { | ||||
| 	script := string(migrationScript) | ||||
| 	// Split script based on semicolon | ||||
| 	statements := strings.Split(script, ";") | ||||
|  | ||||
| 	tx, err := d.Begin() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error beginning transaction: %v", err) | ||||
| 	} | ||||
| 	defer tx.MustRollback() | ||||
|  | ||||
| 	err = tx.Query(string(migrationScript)).Exec() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error performing migration: %v", err) | ||||
| 	} | ||||
| 	for _, statement := range statements { | ||||
| 		statement = strings.TrimSpace(statement) | ||||
| 		if statement == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		err = tx.Query(statement).Exec() | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error performing migration: %v", err) | ||||
| 		} | ||||
|  | ||||
| 		err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error updating version: %v", err) | ||||
| 		} | ||||
|  | ||||
| 	err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec() | ||||
| 	if err != nil { | ||||
| 		return fmt.Errorf("error updating version: %v", err) | ||||
| 	} | ||||
|  | ||||
| 	err = tx.Commit() | ||||
|   | ||||
| @@ -17,4 +17,8 @@ func TestDb_MigrateDb(t *testing.T) { | ||||
| 	var count int | ||||
| 	db.Query("select count(*) from mydata").MustScanSingle(&count) | ||||
| 	require.Equal(t, 1, count, "incorrect number of rows in database") | ||||
|  | ||||
| 	count = 0 | ||||
| 	db.Query("select count(*) from multiTable").MustScanSingle(&count) | ||||
| 	require.Equal(t, 1, count, "incorrect number of rows in database") | ||||
| } | ||||
|   | ||||
							
								
								
									
										220
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										220
									
								
								query.go
									
									
									
									
									
								
							| @@ -2,16 +2,27 @@ package mysqlite | ||||
|  | ||||
| import ( | ||||
| 	"fmt" | ||||
| 	"iter" | ||||
| 	"reflect" | ||||
| 	"zombiezen.com/go/sqlite" | ||||
| ) | ||||
|  | ||||
| type Query struct { | ||||
| 	stmt *sqlite.Stmt | ||||
| 	err  error | ||||
| 	// Reference to the database. If set, it is assumed that a lock was taken | ||||
| 	// by the query that should be freed by the query. | ||||
| 	db  *Db | ||||
| 	err error | ||||
| } | ||||
|  | ||||
| func (d *Db) Query(query string) *Query { | ||||
| 	d.Lock() | ||||
| 	q := d.query(query) | ||||
| 	q.db = d | ||||
| 	return q | ||||
| } | ||||
|  | ||||
| func (d *Db) query(query string) *Query { | ||||
| 	stmt, remaining, err := d.Db.PrepareTransient(query) | ||||
| 	if err != nil { | ||||
| 		return &Query{err: err} | ||||
| @@ -23,29 +34,47 @@ func (d *Db) Query(query string) *Query { | ||||
| } | ||||
|  | ||||
| func (q *Query) Bind(args ...any) *Query { | ||||
| 	into := 0 | ||||
| 	return q.bindInto(&into, args...) | ||||
| } | ||||
|  | ||||
| func (q *Query) bindInto(into *int, args ...any) *Query { | ||||
| 	if q.err != nil || q.stmt == nil { | ||||
| 		return q | ||||
| 	} | ||||
| 	for i, arg := range args { | ||||
| 		*into++ | ||||
| 		if asString, ok := arg.(string); ok { | ||||
| 			q.stmt.BindText(i+1, asString) | ||||
| 			q.stmt.BindText(*into, asString) | ||||
| 		} else if asInt, ok := arg.(int); ok { | ||||
| 			q.stmt.BindInt64(i+1, int64(asInt)) | ||||
| 			q.stmt.BindInt64(*into, int64(asInt)) | ||||
| 		} else if asInt, ok := arg.(int64); ok { | ||||
| 			q.stmt.BindInt64(i+1, asInt) | ||||
| 			q.stmt.BindInt64(*into, asInt) | ||||
| 		} else if asBool, ok := arg.(bool); ok { | ||||
| 			q.stmt.BindBool(i+1, asBool) | ||||
| 			q.stmt.BindBool(*into, asBool) | ||||
| 		} else { | ||||
| 			q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 			return q | ||||
| 			// Check if the argument is a slice or array of any type | ||||
| 			v := reflect.ValueOf(arg) | ||||
| 			if v.Kind() == reflect.Slice || v.Kind() == reflect.Array { | ||||
| 				*into-- | ||||
| 				for i := 0; i < v.Len(); i++ { | ||||
| 					q.bindInto(into, v.Index(i).Interface()) | ||||
| 				} | ||||
| 			} else { | ||||
| 				*into-- | ||||
| 				q.err = fmt.Errorf("unsupported column type %s at index %d", reflect.TypeOf(arg).Name(), i) | ||||
| 				return q | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| 	return q | ||||
| } | ||||
|  | ||||
| func (q *Query) Exec() (rerr error) { | ||||
| 	defer q.unlock() | ||||
|  | ||||
| 	if q.stmt != nil { | ||||
| 		defer func() { rerr = q.stmt.Finalize() }() | ||||
| 		defer func() { forwardError(q.stmt.Finalize(), &rerr) }() | ||||
| 	} | ||||
| 	if q.err != nil { | ||||
| 		return q.err | ||||
| @@ -68,34 +97,41 @@ func (q *Query) MustExec() { | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanSingle(results ...any) (rerr error) { | ||||
| 	defer q.unlock() | ||||
| 	// Scan rows | ||||
| 	if q.stmt != nil { | ||||
| 		defer func() { rerr = q.stmt.Finalize() }() | ||||
| 		defer func() { forwardError(q.stmt.Finalize(), &rerr) }() | ||||
| 	} | ||||
| 	if q.err != nil { | ||||
| 		return q.err | ||||
| 	} | ||||
| 	rowReturned, err := q.stmt.Step() | ||||
| 	rows, err := q.ScanMulti() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !rowReturned { | ||||
|  | ||||
| 	// Fetch the first row | ||||
| 	hasResult, err := rows.Next() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if !hasResult { | ||||
| 		return fmt.Errorf("did not return any rows") | ||||
| 	} | ||||
|  | ||||
| 	for i, arg := range results { | ||||
| 		if asString, ok := arg.(*string); ok { | ||||
| 			*asString = q.stmt.ColumnText(i) | ||||
| 		} else if asInt, ok := arg.(*int); ok { | ||||
| 			*asInt = q.stmt.ColumnInt(i) | ||||
| 		} else if asBool, ok := arg.(*bool); ok { | ||||
| 			*asBool = q.stmt.ColumnBool(i) | ||||
| 		} else { | ||||
| 			if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 				return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 			} | ||||
| 			name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 			return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 		} | ||||
| 	// Scan its columns | ||||
| 	err = rows.Scan(results...) | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
|  | ||||
| 	// Ensure there are no more rows | ||||
| 	hasResult, err = rows.Next() | ||||
| 	if err != nil { | ||||
| 		return err | ||||
| 	} | ||||
| 	if hasResult { | ||||
| 		return fmt.Errorf("returned more than one row") | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
| @@ -106,3 +142,137 @@ func (q *Query) MustScanSingle(results ...any) { | ||||
| 		panic(fmt.Sprintf("error getting results: %v", err)) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (q *Query) unlock() { | ||||
| 	if q.db != nil { | ||||
| 		q.db.Unlock() | ||||
| 	} | ||||
| } | ||||
|  | ||||
| type Rows struct { | ||||
| 	query *Query | ||||
| } | ||||
|  | ||||
| func (q *Query) ScanMulti() (*Rows, error) { | ||||
| 	if q.err != nil { | ||||
| 		return nil, q.err | ||||
| 	} | ||||
| 	return &Rows{ | ||||
| 		query: q, | ||||
| 	}, nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) Finish() error { | ||||
| 	defer r.query.unlock() | ||||
| 	return r.query.stmt.Finalize() | ||||
| } | ||||
|  | ||||
| func (r *Rows) MustFinish() { | ||||
| 	err := r.Finish() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (r *Rows) Next() (bool, error) { | ||||
| 	gotRow, err := r.query.stmt.Step() | ||||
| 	if err != nil { | ||||
| 		return false, err | ||||
| 	} | ||||
| 	return gotRow, nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) MustNext() bool { | ||||
| 	gotRow, err := r.Next() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return gotRow | ||||
| } | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	for i, arg := range results { | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 	if asString, ok := arg.(*string); ok { | ||||
| 		*asString = r.query.stmt.ColumnText(i) | ||||
| 	} else if asInt, ok := arg.(*int); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
| 		return r.handleNullableType(i, arg) | ||||
| 	} else { | ||||
| 		if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 			return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 		} | ||||
| 		name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 		return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) handleNullableType(i int, asPtr any) error { | ||||
| 	if r.query.stmt.ColumnIsNull(i) { | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem())) | ||||
| 	} else { | ||||
| 		value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface() | ||||
| 		err := r.scanArgument(i, value) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) MustScan(results ...any) { | ||||
| 	err := r.Scan(results...) | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (q *Query) Range(err *error) iter.Seq[*Rows] { | ||||
| 	return func(yield func(*Rows) bool) { | ||||
| 		// Start the scan | ||||
| 		rows, terr := q.ScanMulti() | ||||
| 		if terr != nil { | ||||
| 			*err = terr | ||||
| 			return | ||||
| 		} | ||||
|  | ||||
| 		// Ensure we close and return any errors | ||||
| 		defer func() { | ||||
| 			terr := rows.Finish() | ||||
| 			if terr != nil { | ||||
| 				*err = terr | ||||
| 			} | ||||
| 		}() | ||||
|  | ||||
| 		// Loop over each record | ||||
| 		for { | ||||
| 			// Get the record | ||||
| 			hasRow, terr := rows.Next() | ||||
| 			if terr != nil { | ||||
| 				*err = terr | ||||
| 				return | ||||
| 			} | ||||
| 			if !hasRow { | ||||
| 				return | ||||
| 			} | ||||
|  | ||||
| 			// Pass it to the range body | ||||
| 			if !yield(&Rows{query: q}) { | ||||
| 				return | ||||
| 			} | ||||
| 		} | ||||
| 	} | ||||
| } | ||||
|   | ||||
							
								
								
									
										134
									
								
								query_test.go
									
									
									
									
									
								
							
							
						
						
									
										134
									
								
								query_test.go
									
									
									
									
									
								
							| @@ -25,3 +25,137 @@ func TestSimpleQueryWithArgs(t *testing.T) { | ||||
| 	db.Query("select value from mytable where key = ?").Bind("foo").MustScanSingle(&value) | ||||
| 	require.Equal(t, "bar", value, "bad value returned") | ||||
| } | ||||
|  | ||||
| func TestQueryWithTwoRows(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec() | ||||
|  | ||||
| 	rows, err := db.Query("select value from mytable").ScanMulti() | ||||
| 	require.NoError(t, err) | ||||
| 	defer rows.MustFinish() | ||||
|  | ||||
| 	require.True(t, rows.MustNext(), "expected first row") | ||||
| 	var value string | ||||
| 	rows.MustScan(&value) | ||||
| 	require.Equal(t, "bar", value, "bad value returned") | ||||
|  | ||||
| 	require.True(t, rows.MustNext(), "expected second row") | ||||
| 	rows.MustScan(&value) | ||||
| 	require.Equal(t, "ipsum", value, "bad value returned") | ||||
|  | ||||
| 	require.False(t, rows.MustNext(), "expected no more rows") | ||||
| } | ||||
|  | ||||
| func TestQueryWithRange(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("insert into mytable(key, value) values ('lorem', 'ipsum')").MustExec() | ||||
|  | ||||
| 	var err error | ||||
| 	index := 0 | ||||
| 	for row := range db.Query("select value from mytable").Range(&err) { | ||||
| 		var value string | ||||
| 		row.MustScan(&value) | ||||
| 		if index == 0 { | ||||
| 			require.Equal(t, "bar", value) | ||||
| 		} else if index == 1 { | ||||
| 			require.Equal(t, "ipsum", value) | ||||
| 		} else { | ||||
| 			require.FailNow(t, "more rows than expected") | ||||
| 		} | ||||
| 		index++ | ||||
| 	} | ||||
| 	require.NoError(t, err) | ||||
| } | ||||
|  | ||||
| func TestUpdateQuery(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		value := "ipsum" | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "ipsum", value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithWrongArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	value := "ipsum" | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result *string | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.Equal(t, "bar", *result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=NULL where key = 'foo'").MustExec() | ||||
| 	myString := "some string" | ||||
| 	var result *string | ||||
| 	result = &myString | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Nil(t, result) | ||||
| } | ||||
|  | ||||
| func TestDeleteQuery(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("delete from mytable where key = 'foo'").MustExec() | ||||
|  | ||||
| 	var count int | ||||
| 	db.Query("select count(*) from mytable where key = 'foo'").MustScanSingle(&count) | ||||
| 	require.Equal(t, 0, count, "expected row to be deleted") | ||||
| } | ||||
|  | ||||
| func TestTransactionRollback(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	func() { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("update mytable set value = 'ipsum' where key = 'foo'").MustExec() | ||||
| 		// Intentionally not committing the transaction | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where key = 'foo'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "bar", value, "expected original value after rollback") | ||||
| } | ||||
|  | ||||
| func TestQueryWithInClause(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	// Insert additional test rows | ||||
| 	db.Query("insert into mytable(key, value) values ('key1', 'value1')").MustExec() | ||||
| 	db.Query("insert into mytable(key, value) values ('key2', 'value2')").MustExec() | ||||
|  | ||||
| 	// Execute query with IN clause | ||||
| 	args := []string{"foo", "key2"} | ||||
| 	rows, err := db.Query("select key, value from mytable where key in (?, ?)").Bind(args).ScanMulti() | ||||
| 	require.NoError(t, err) | ||||
| 	defer rows.MustFinish() | ||||
|  | ||||
| 	// Check results | ||||
| 	results := make(map[string]string) | ||||
| 	for rows.MustNext() { | ||||
| 		var key, value string | ||||
| 		rows.MustScan(&key, &value) | ||||
| 		results[key] = value | ||||
| 	} | ||||
|  | ||||
| 	// Verify we got exactly the expected results | ||||
| 	require.Equal(t, 2, len(results), "expected 2 matching rows") | ||||
| 	require.Equal(t, "bar", results["foo"], "unexpected value for key 'foo'") | ||||
| 	require.Equal(t, "value2", results["key2"], "unexpected value for key 'key2'") | ||||
| } | ||||
|   | ||||
							
								
								
									
										3
									
								
								testMigrations/3_multicomment.sql
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										3
									
								
								testMigrations/3_multicomment.sql
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,3 @@ | ||||
| create table multiTable(value text); | ||||
|  | ||||
| insert into multiTable(value) values ('testValue'); | ||||
| @@ -7,18 +7,40 @@ type Tx struct { | ||||
| } | ||||
|  | ||||
| func (d *Db) Begin() (*Tx, error) { | ||||
| 	err := d.Query("BEGIN").Exec() | ||||
| 	d.Lock() | ||||
| 	err := d.query("BEGIN").Exec() | ||||
| 	if err != nil { | ||||
| 		return nil, err | ||||
| 	} | ||||
| 	return &Tx{db: d}, nil | ||||
| } | ||||
|  | ||||
| func (d *Db) MustBegin() *Tx { | ||||
| 	tx, err := d.Begin() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| 	return tx | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Commit() error { | ||||
| 	defer tx.unlock() | ||||
| 	return tx.Query("COMMIT").Exec() | ||||
| } | ||||
|  | ||||
| func (tx *Tx) MustCommit() { | ||||
| 	err := tx.Commit() | ||||
| 	if err != nil { | ||||
| 		panic(err) | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Rollback() error { | ||||
| 	if tx.db == nil { | ||||
| 		// The transaction was already commited | ||||
| 		return nil | ||||
| 	} | ||||
| 	defer tx.unlock() | ||||
| 	return tx.Query("ROLLBACK").Exec() | ||||
| } | ||||
|  | ||||
| @@ -29,6 +51,16 @@ func (tx *Tx) MustRollback() { | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Query(query string) *Query { | ||||
| 	return tx.db.Query(query) | ||||
| func (tx *Tx) unlock() { | ||||
| 	if tx.db != nil { | ||||
| 		tx.db.Unlock() | ||||
| 		tx.db = nil | ||||
| 	} | ||||
| } | ||||
|  | ||||
| func (tx *Tx) Query(query string) *Query { | ||||
| 	if tx.db == nil { | ||||
| 		panic("query was performed on a transaction after Commit or Rollback") | ||||
| 	} | ||||
| 	return tx.db.query(query) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user