package main import ( "database/sql" "embed" "fmt" _ "github.com/mattn/go-sqlite3" "log" "os" "strconv" "strings" ) //go:embed migrations/*.sql var embeddedMigrations embed.FS func openDatabase() *sql.DB { // Get database file databaseSource := os.Getenv("VIVAPLUS_DATABASE") if databaseSource == "" { databaseSource = "videos.db3" } return openDatabaseSource(databaseSource) } func openDatabaseSource(databaseSource string) *sql.DB { // Initialize the database connection db, err := sql.Open("sqlite3", databaseSource) if err != nil { log.Fatalf("error opening database: %v", err) } // Read all migrations migrationFiles, err := embeddedMigrations.ReadDir("migrations") if err != nil { log.Fatalf("error reading migration files: %v", err) } var migrations = make(map[int]string) latestVersion := 0 for _, f := range migrationFiles { versionStr := f.Name() version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) if err != nil { log.Fatalf("invalid version number for migration script: %v", err) } migrations[version] = versionStr latestVersion = max(latestVersion, version) } // Get current migration version from user_version var currentVersion int err = db.QueryRow("PRAGMA user_version").Scan(¤tVersion) if err != nil { log.Fatalf("error getting current version: %v", err) } log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) // If we are no up-to-date, bring the db up-to-date for currentVersion != latestVersion { targetVersion := currentVersion + 1 migrationFile := migrations[targetVersion] log.Printf("migration to version %s", migrationFile) migrationScript, err := embeddedMigrations.ReadFile("migrations/" + migrationFile) if err != nil { log.Fatalf("error opening migration script %s: %v", migrationScript, err) } tx, err := db.Begin() if err != nil { log.Fatalf("error beginning transaction: %v", err) } _, err = tx.Exec(string(migrationScript)) if err != nil { rollbackIgnoringErrors(tx) log.Fatalf("error performing migration: %v", err) } _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) if err != nil { rollbackIgnoringErrors(tx) log.Fatalf("error updating version: %v", err) } err = tx.Commit() if err != nil { log.Fatalf("error commiting transaction: %v", err) } currentVersion = targetVersion } log.Println("All migrations applied") return db } func rollbackIgnoringErrors(tx *sql.Tx) { err := tx.Rollback() if err != nil { log.Printf("error rolling back: %v", err) } }