llamachat/db.go

84 lines
1.7 KiB
Go

package main
import (
"database/sql"
"embed"
"errors"
"fmt"
"github.com/golang-migrate/migrate/v4"
"github.com/golang-migrate/migrate/v4/database/sqlite3"
"github.com/golang-migrate/migrate/v4/source/iofs"
"log"
)
//go:embed migrations/*.sql
var fs embed.FS
var db *sql.DB
func OpenDb(dbName string) {
var err error
db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s.chat?_foreign_keys=on", dbName))
if err != nil {
log.Fatal(err)
}
driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
if err != nil {
log.Fatal(err)
}
d, err := iofs.New(fs, "migrations")
if err != nil {
log.Fatal(err)
}
m, err := migrate.NewWithInstance("iofs", d, "sqlite3", driver)
if err != nil {
log.Fatal(err)
}
//err = m.Down()
//if err != nil && !errors.Is(err, migrate.ErrNoChange) {
// log.Println("Could not go down:", err)
//}
err = m.Up()
if err != nil && !errors.Is(err, migrate.ErrNoChange) {
log.Fatal(err)
}
}
func GetMesages() []Message {
var messages []Message
rows, err := db.Query("select type, message from messages order by id")
if err != nil {
log.Fatal(err)
}
defer rows.Close()
for rows.Next() {
var message Message
err = rows.Scan(&message.Type, &message.Content)
if err != nil {
log.Fatal(err)
}
messages = append(messages, message)
}
return messages
}
func attempRollback(tx *sql.Tx) {
err := tx.Rollback()
if err != nil {
log.Printf("Failed to perform rollback: %v\n", err)
}
}
func SaveMessage(role MessageType, content string) error {
tx, err := db.Begin()
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO messages (type, message) VALUES (?, ?)", role, content)
if err != nil {
attempRollback(tx)
return err
}
return tx.Commit()
}