package main

import (
	"database/sql"
	"embed"
	"fmt"
	_ "github.com/mattn/go-sqlite3"
	"log"
	"os"
	"strconv"
	"strings"
)

//go:embed migrations/*.sql
var embeddedMigrations embed.FS

func openDatabase() *sql.DB {
	// Get database directory
	databaseSource := os.Getenv("VIVAPLUS_DATABASE")
	if databaseSource == "" {
		databaseSource = "videos.db3"
	}

	// Initialize the database connection
	db, err := sql.Open("sqlite3", databaseSource)
	if err != nil {
		log.Fatalf("error opening database: %v", err)
	}

	// Read all migrations
	migrationFiles, err := embeddedMigrations.ReadDir("migrations")
	if err != nil {
		log.Fatalf("error reading migration files: %v", err)
	}
	var migrations = make(map[int]string)
	latestVersion := 0
	for _, f := range migrationFiles {
		versionStr := f.Name()
		version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
		if err != nil {
			log.Fatalf("invalid version number for migration script: %v", err)
		}
		migrations[version] = versionStr
		latestVersion = max(latestVersion, version)
	}

	// Get current migration version from user_version
	var currentVersion int
	err = db.QueryRow("PRAGMA user_version").Scan(&currentVersion)
	if err != nil {
		log.Fatalf("error getting current version: %v", err)
	}
	log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion)

	// If we are no up-to-date, bring the db up-to-date
	for currentVersion != latestVersion {
		targetVersion := currentVersion + 1
		migrationFile := migrations[targetVersion]
		log.Printf("migration to version %s", migrationFile)
		migrationScript, err := embeddedMigrations.ReadFile("migrations/" + migrationFile)
		if err != nil {
			log.Fatalf("error opening migration script %s: %v", migrationScript, err)
		}

		tx, err := db.Begin()
		if err != nil {
			log.Fatalf("error beginning transaction: %v", err)
		}

		_, err = tx.Exec(string(migrationScript))
		if err != nil {
			rollbackIgnoringErrors(tx)
			log.Fatalf("error performing migration: %v", err)
		}

		_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", targetVersion))
		if err != nil {
			rollbackIgnoringErrors(tx)
			log.Fatalf("error updating version: %v", err)
		}

		err = tx.Commit()
		if err != nil {
			log.Fatalf("error commiting transaction: %v", err)
		}
		currentVersion = targetVersion
	}

	log.Println("All migrations applied")
	return db
}

func rollbackIgnoringErrors(tx *sql.Tx) {
	err := tx.Rollback()
	if err != nil {
		log.Printf("error rolling back: %v", err)
	}
}