All checks were successful
Build / build (push) Successful in 1m55s
102 lines
2.6 KiB
Go
102 lines
2.6 KiB
Go
package main
|
|
|
|
import (
|
|
"database/sql"
|
|
"embed"
|
|
"fmt"
|
|
_ "github.com/mattn/go-sqlite3"
|
|
"log"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
)
|
|
|
|
//go:embed migrations/*.sql
|
|
var embeddedMigrations embed.FS
|
|
|
|
func openDatabase() *sql.DB {
|
|
// Get database file
|
|
databaseSource := os.Getenv("VIVAPLUS_DATABASE")
|
|
if databaseSource == "" {
|
|
databaseSource = "videos.db3"
|
|
}
|
|
return openDatabaseSource(databaseSource)
|
|
}
|
|
|
|
func openDatabaseSource(databaseSource string) *sql.DB {
|
|
// Initialize the database connection
|
|
db, err := sql.Open("sqlite3", databaseSource)
|
|
if err != nil {
|
|
log.Fatalf("error opening database: %v", err)
|
|
}
|
|
|
|
// Read all migrations
|
|
migrationFiles, err := embeddedMigrations.ReadDir("migrations")
|
|
if err != nil {
|
|
log.Fatalf("error reading migration files: %v", err)
|
|
}
|
|
var migrations = make(map[int]string)
|
|
latestVersion := 0
|
|
for _, f := range migrationFiles {
|
|
versionStr := f.Name()
|
|
version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0])
|
|
if err != nil {
|
|
log.Fatalf("invalid version number for migration script: %v", err)
|
|
}
|
|
migrations[version] = versionStr
|
|
latestVersion = max(latestVersion, version)
|
|
}
|
|
|
|
// Get current migration version from user_version
|
|
var currentVersion int
|
|
err = db.QueryRow("PRAGMA user_version").Scan(¤tVersion)
|
|
if err != nil {
|
|
log.Fatalf("error getting current version: %v", err)
|
|
}
|
|
log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion)
|
|
|
|
// If we are no up-to-date, bring the db up-to-date
|
|
for currentVersion != latestVersion {
|
|
targetVersion := currentVersion + 1
|
|
migrationFile := migrations[targetVersion]
|
|
log.Printf("migration to version %s", migrationFile)
|
|
migrationScript, err := embeddedMigrations.ReadFile("migrations/" + migrationFile)
|
|
if err != nil {
|
|
log.Fatalf("error opening migration script %s: %v", migrationScript, err)
|
|
}
|
|
|
|
tx, err := db.Begin()
|
|
if err != nil {
|
|
log.Fatalf("error beginning transaction: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec(string(migrationScript))
|
|
if err != nil {
|
|
rollbackIgnoringErrors(tx)
|
|
log.Fatalf("error performing migration: %v", err)
|
|
}
|
|
|
|
_, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", targetVersion))
|
|
if err != nil {
|
|
rollbackIgnoringErrors(tx)
|
|
log.Fatalf("error updating version: %v", err)
|
|
}
|
|
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
log.Fatalf("error commiting transaction: %v", err)
|
|
}
|
|
currentVersion = targetVersion
|
|
}
|
|
|
|
log.Println("All migrations applied")
|
|
return db
|
|
}
|
|
|
|
func rollbackIgnoringErrors(tx *sql.Tx) {
|
|
err := tx.Rollback()
|
|
if err != nil {
|
|
log.Printf("error rolling back: %v", err)
|
|
}
|
|
}
|