package main import ( "database/sql" "embed" "errors" "fmt" "github.com/golang-migrate/migrate/v4" "github.com/golang-migrate/migrate/v4/database/sqlite3" "github.com/golang-migrate/migrate/v4/source/iofs" "log" ) //go:embed migrations/*.sql var fs embed.FS var db *sql.DB func OpenDb(dbName string) { if dbName == "" || dbName == "-" { log.Println("Using in-memory database") dbName = ":memory:" } else { dbName = dbName + ".chat" log.Printf("Using database %s\n", dbName) } var err error db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", dbName)) if err != nil { log.Fatal(err) } driver, err := sqlite3.WithInstance(db, &sqlite3.Config{}) if err != nil { log.Fatal(err) } d, err := iofs.New(fs, "migrations") if err != nil { log.Fatal(err) } m, err := migrate.NewWithInstance("iofs", d, "sqlite3", driver) if err != nil { log.Fatal(err) } err = m.Up() if err != nil && !errors.Is(err, migrate.ErrNoChange) { log.Fatal(err) } } func GetMesages() []Message { var messages []Message rows, err := db.Query("select type, message from messages order by id") if err != nil { log.Fatal(err) } defer rows.Close() for rows.Next() { var message Message err = rows.Scan(&message.Type, &message.Content) if err != nil { log.Fatal(err) } messages = append(messages, message) } return messages } func attempRollback(tx *sql.Tx) { err := tx.Rollback() if err != nil { log.Printf("Failed to perform rollback: %v\n", err) } } func SaveMessage(role MessageType, content string) error { tx, err := db.Begin() if err != nil { return err } _, err = tx.Exec("INSERT INTO messages (type, message) VALUES (?, ?)", role, content) if err != nil { attempRollback(tx) return err } return tx.Commit() }