package main

import (
	"database/sql"
	"embed"
	"errors"
	"fmt"
	"github.com/golang-migrate/migrate/v4"
	"github.com/golang-migrate/migrate/v4/database/sqlite3"
	"github.com/golang-migrate/migrate/v4/source/iofs"
	"log"
)

//go:embed migrations/*.sql
var fs embed.FS

var db *sql.DB

func OpenDb(dbName string) {
	if dbName == "" || dbName == "-" {
		log.Println("Using in-memory database")
		dbName = ":memory:"
	} else {
		dbName = dbName + ".chat"
		log.Printf("Using database %s\n", dbName)
	}

	var err error
	db, err = sql.Open("sqlite3", fmt.Sprintf("file:%s?_foreign_keys=on", dbName))
	if err != nil {
		log.Fatal(err)
	}
	driver, err := sqlite3.WithInstance(db, &sqlite3.Config{})
	if err != nil {
		log.Fatal(err)
	}
	d, err := iofs.New(fs, "migrations")
	if err != nil {
		log.Fatal(err)
	}
	m, err := migrate.NewWithInstance("iofs", d, "sqlite3", driver)
	if err != nil {
		log.Fatal(err)
	}

	err = m.Up()
	if err != nil && !errors.Is(err, migrate.ErrNoChange) {
		log.Fatal(err)
	}
}

func GetMesages() []Message {
	var messages []Message
	rows, err := db.Query("select type, message from messages order by id")
	if err != nil {
		log.Fatal(err)
	}
	defer rows.Close()
	for rows.Next() {
		var message Message
		err = rows.Scan(&message.Type, &message.Content)
		if err != nil {
			log.Fatal(err)
		}
		messages = append(messages, message)
	}
	return messages
}

func attempRollback(tx *sql.Tx) {
	err := tx.Rollback()
	if err != nil {
		log.Printf("Failed to perform rollback: %v\n", err)
	}
}

func SaveMessage(role MessageType, content string) error {
	tx, err := db.Begin()
	if err != nil {
		return err
	}
	_, err = tx.Exec("INSERT INTO messages (type, message) VALUES (?, ?)", role, content)
	if err != nil {
		attempRollback(tx)
		return err
	}
	return tx.Commit()
}