Compare commits
	
		
			2 Commits
		
	
	
		
	
	| Author | SHA1 | Date | |
|---|---|---|---|
| 94ae36305a | |||
| 2eacf6fbc4 | 
| @@ -74,6 +74,9 @@ func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) er | ||||
|  | ||||
| 	for _, statement := range statements { | ||||
| 		statement = strings.TrimSpace(statement) | ||||
| 		if statement == "" { | ||||
| 			continue | ||||
| 		} | ||||
| 		err = tx.Query(statement).Exec() | ||||
| 		if err != nil { | ||||
| 			return fmt.Errorf("error performing migration: %v", err) | ||||
|   | ||||
							
								
								
									
										52
									
								
								query.go
									
									
									
									
									
								
							
							
						
						
									
										52
									
								
								query.go
									
									
									
									
									
								
							| @@ -58,7 +58,7 @@ func (q *Query) Exec() (rerr error) { | ||||
| 	defer q.unlock() | ||||
|  | ||||
| 	if q.stmt != nil { | ||||
| 		defer func() { rerr = q.stmt.Finalize() }() | ||||
| 		defer func() { forwardError(q.stmt.Finalize(), &rerr) }() | ||||
| 	} | ||||
| 	if q.err != nil { | ||||
| 		return q.err | ||||
| @@ -84,7 +84,7 @@ func (q *Query) ScanSingle(results ...any) (rerr error) { | ||||
| 	defer q.unlock() | ||||
| 	// Scan rows | ||||
| 	if q.stmt != nil { | ||||
| 		defer func() { rerr = q.stmt.Finalize() }() | ||||
| 		defer func() { forwardError(q.stmt.Finalize(), &rerr) }() | ||||
| 	} | ||||
| 	if q.err != nil { | ||||
| 		return q.err | ||||
| @@ -173,23 +173,47 @@ func (r *Rows) MustNext() bool { | ||||
|  | ||||
| func (r *Rows) Scan(results ...any) error { | ||||
| 	for i, arg := range results { | ||||
| 		if asString, ok := arg.(*string); ok { | ||||
| 			*asString = r.query.stmt.ColumnText(i) | ||||
| 		} else if asInt, ok := arg.(*int); ok { | ||||
| 			*asInt = r.query.stmt.ColumnInt(i) | ||||
| 		} else if asBool, ok := arg.(*bool); ok { | ||||
| 			*asBool = r.query.stmt.ColumnBool(i) | ||||
| 		} else { | ||||
| 			if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 				return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 			} | ||||
| 			name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 			return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 		err := r.scanArgument(i, arg) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) scanArgument(i int, arg any) error { | ||||
| 	if asString, ok := arg.(*string); ok { | ||||
| 		*asString = r.query.stmt.ColumnText(i) | ||||
| 	} else if asInt, ok := arg.(*int); ok { | ||||
| 		*asInt = r.query.stmt.ColumnInt(i) | ||||
| 	} else if asBool, ok := arg.(*bool); ok { | ||||
| 		*asBool = r.query.stmt.ColumnBool(i) | ||||
| 	} else if reflect.TypeOf(arg).Kind() == reflect.Ptr && reflect.TypeOf(arg).Elem().Kind() == reflect.Ptr { | ||||
| 		return r.handleNullableType(i, arg) | ||||
| 	} else { | ||||
| 		if reflect.TypeOf(arg).Kind() != reflect.Ptr { | ||||
| 			return fmt.Errorf("unsupported column type %s at index %d (it should be a pointer)", reflect.TypeOf(arg).Name(), i) | ||||
| 		} | ||||
| 		name := reflect.Indirect(reflect.ValueOf(arg)).Type().Name() | ||||
| 		return fmt.Errorf("unsupported column type *%s at index %d", name, i) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) handleNullableType(i int, asPtr any) error { | ||||
| 	if r.query.stmt.ColumnIsNull(i) { | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.Zero(reflect.TypeOf(asPtr).Elem())) | ||||
| 	} else { | ||||
| 		value := reflect.New(reflect.TypeOf(asPtr).Elem().Elem()).Interface() | ||||
| 		err := r.scanArgument(i, value) | ||||
| 		if err != nil { | ||||
| 			return err | ||||
| 		} | ||||
| 		reflect.ValueOf(asPtr).Elem().Set(reflect.ValueOf(value)) | ||||
| 	} | ||||
| 	return nil | ||||
| } | ||||
|  | ||||
| func (r *Rows) MustScan(results ...any) { | ||||
| 	err := r.Scan(results...) | ||||
| 	if err != nil { | ||||
|   | ||||
| @@ -73,11 +73,40 @@ func TestUpdateQuery(t *testing.T) { | ||||
| 		tx := db.MustBegin() | ||||
| 		defer tx.MustRollback() | ||||
| 		tx.Query("insert into mytable(key, value) values ('lorem', 'bar')").MustExec() | ||||
| 		tx.Query("update mytable set value = 'ipsum' where key = 'lorem'").MustExec() | ||||
| 		value := "ipsum" | ||||
| 		key := "lorem" | ||||
| 		tx.Query("update mytable set value = ? where key = ?").Bind(value, key).MustExec() | ||||
| 		tx.MustCommit() | ||||
| 	}() | ||||
|  | ||||
| 	var value string | ||||
| 	db.Query("select value from mytable where value = 'ipsum'").MustScanSingle(&value) | ||||
| 	db.Query("select value from mytable where key = 'lorem'").MustScanSingle(&value) | ||||
| 	require.Equal(t, "ipsum", value) | ||||
| } | ||||
|  | ||||
| func TestUpdateQueryWithWrongArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	value := "ipsum" | ||||
| 	err := db.Query("insert into mytable(key, value) values ('lorem', ?)").Bind(&value).Exec() | ||||
| 	require.Error(t, err) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArguments(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	var result *string | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.NotNil(t, result) | ||||
| 	require.Equal(t, "bar", *result) | ||||
| } | ||||
|  | ||||
| func TestQueryWithPointerStringArgumentsCanSetToNull(t *testing.T) { | ||||
| 	db := openTestDb(t) | ||||
| 	db.Query("update mytable set value=NULL where key = 'foo'").MustExec() | ||||
| 	myString := "some string" | ||||
| 	var result *string | ||||
| 	result = &myString | ||||
| 	err := db.Query("select value from mytable where key = 'foo'").ScanSingle(&result) | ||||
| 	require.NoError(t, err) | ||||
| 	require.Nil(t, result) | ||||
| } | ||||
|   | ||||
		Reference in New Issue
	
	Block a user