package mysqlite

import (
	"fmt"
	"io/fs"
	"log"
	"path"
	"strconv"
	"strings"
)

type ReadDirFileFS interface {
	fs.ReadDirFS
	fs.ReadFileFS
}

func (d *Db) MigrateDb(filesystem ReadDirFileFS, directory string) error {
	// Read all migrations
	migrationFiles, err := filesystem.ReadDir(directory)
	if err != nil {
		return fmt.Errorf("error reading migration files: %v", err)
	}
	var migrationsByVersion = make(map[int]string)
	latestVersion := 0
	for _, f := range migrationFiles {
		versionStr := f.Name()
		version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
		if err != nil {
			return fmt.Errorf("invalid version number for migration script: %v", err)
		}
		migrationsByVersion[version] = versionStr
		latestVersion = max(latestVersion, version)
	}

	// Get current migration version from user_version
	var currentVersion int
	err = d.Query("PRAGMA user_version").ScanSingle(&currentVersion)
	if err != nil {
		return fmt.Errorf("error getting current version: %v", err)
	}
	log.Printf("Current version is %d, max migration version is %d", currentVersion, latestVersion)

	// If we are no up-to-date, bring the db up-to-date
	for currentVersion != latestVersion {
		targetVersion := currentVersion + 1
		migrationFile := migrationsByVersion[targetVersion]
		log.Printf("migrating to version %s", migrationFile)
		migrationScript, err := filesystem.ReadFile(path.Join(directory, migrationFile))
		if err != nil {
			return fmt.Errorf("error opening migration script %s: %v", migrationScript, err)
		}

		err = performSingleMigration(d, migrationScript, targetVersion)
		if err != nil {
			return err
		}
		currentVersion = targetVersion
	}

	log.Println("Database is up-to-date")
	return nil
}

func performSingleMigration(d *Db, migrationScript []byte, targetVersion int) error {
	script := string(migrationScript)
	// Split script based on semicolon
	statements := strings.Split(script, ";")

	tx, err := d.Begin()
	if err != nil {
		return fmt.Errorf("error beginning transaction: %v", err)
	}
	defer tx.MustRollback()

	for _, statement := range statements {
		statement = strings.TrimSpace(statement)
		err = tx.Query(statement).Exec()
		if err != nil {
			return fmt.Errorf("error performing migration: %v", err)
		}

		err = tx.Query(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)).Exec()
		if err != nil {
			return fmt.Errorf("error updating version: %v", err)
		}

	}

	err = tx.Commit()
	if err != nil {
		return fmt.Errorf("error commiting transaction: %v", err)
	}
	return nil
}