diff --git a/.gitignore b/.gitignore index c38fa4e..53a45f4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ .idea *.iml +*.db3 diff --git a/database.go b/database.go index e00b92a..17b7722 100644 --- a/database.go +++ b/database.go @@ -4,71 +4,88 @@ import ( "database/sql" "embed" "fmt" - "github.com/mattn/go-sqlite3" + _ "github.com/mattn/go-sqlite3" "log" - "os" + "strconv" + "strings" ) //go:embed migrations/*.sql -var migrations embed.FS +var embeddedMigrations embed.FS func openDatabase() *sql.DB { // Initialize the database connection - db, err := sqlite3.Open("videos.db") + db, err := sql.Open("sqlite3", "videos.db3") if err != nil { - log.Fatalf("Error opening database: %v\n", err) + log.Fatalf("error opening database: %v", err) + } + + // Read all migrations + migrationFiles, err := embeddedMigrations.ReadDir("migrations") + if err != nil { + log.Fatalf("error reading migration files: %v", err) + } + var migrations = make(map[int]string) + latestVersion := 0 + for _, f := range migrationFiles { + versionStr := f.Name() + version, err := strconv.Atoi(strings.SplitN(versionStr, "_", 2)[0]) + if err != nil { + log.Fatalf("invalid version number for migration script: %v", err) + } + migrations[version] = versionStr + latestVersion = max(latestVersion, version) } // Get current migration version from user_version var currentVersion int err = db.QueryRow("PRAGMA user_version").Scan(¤tVersion) if err != nil { - fmt.Printf("Error getting current version: %v\n", err) - os.Exit(1) + log.Fatalf("error getting current version: %v", err) } + log.Printf("Current database migration version is %d, latest version is %d", currentVersion, latestVersion) - // Read all migrations - migrationFiles, err := migrations.ReadDir(migrationDir) - if err != nil { - fmt.Printf("Error reading migration files: %v\n", err) - os.Exit(1) - } - - // Sort and process each migration file (assuming filenames are versioned like 1.sql, 2.sql, etc.) - for _, f := range migrationFiles { - if !f.IsDir() { - versionStr := f.Name() - version, err := extractVersion(versionStr) - if err != nil { - fmt.Printf("Error extracting version from %s: %v\n", versionStr, err) - continue - } - - if version > currentVersion { - // Apply the migration - sqlContent, err := readMigrationFile(f.Name()) - if err != nil { - fmt.Printf("Error reading migration file %s: %v\n", f.Name(), err) - os.Exit(1) - } - - _, err = db.Exec(sqlContent) - if err != nil { - fmt.Printf("Error applying migration %d: %v\n", version, err) - os.Exit(1) - } - - // Update the user_version - _, err = db.Exec(fmt.Sprintf("PRAGMA user_version=%d", version)) - if err != nil { - fmt.Printf("Error updating user_version to %d: %v\n", version, err) - os.Exit(1) - } - - fmt.Printf("Applied migration %d successfully\n", version) - } + // If we are no up-to-date, bring the db up-to-date + for currentVersion != latestVersion { + targetVersion := currentVersion + 1 + migrationFile := migrations[targetVersion] + log.Printf("migration to version %s", migrationFile) + migrationScript, err := embeddedMigrations.ReadFile("migrations/" + migrationFile) + if err != nil { + log.Fatalf("error opening migration script %s: %v", migrationScript, err) } + + tx, err := db.Begin() + if err != nil { + log.Fatalf("error beginning transaction: %v", err) + } + + _, err = tx.Exec(string(migrationScript)) + if err != nil { + rollbackIgnoringErrors(tx) + log.Fatalf("error performing migration: %v", err) + } + + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", targetVersion)) + if err != nil { + rollbackIgnoringErrors(tx) + log.Fatalf("error updating version: %v", err) + } + + err = tx.Commit() + if err != nil { + log.Fatalf("error commiting transaction: %v", err) + } + currentVersion = targetVersion } - fmt.Println("All migrations applied") + log.Println("All migrations applied") + return db +} + +func rollbackIgnoringErrors(tx *sql.Tx) { + err := tx.Rollback() + if err != nil { + log.Printf("error rolling back: %v", err) + } } diff --git a/main.go b/main.go index ee6e1bd..360342d 100644 --- a/main.go +++ b/main.go @@ -1,6 +1,35 @@ package main +import ( + "encoding/base64" + "log" + "net/http" + "net/url" + "os" + "strings" +) + func main() { db := openDatabase() defer db.Close() + + username := os.Getenv("VIVAPLUS_USER") + password, err := base64.StdEncoding.DecodeString(os.Getenv("VIVAPLUS_PASS")) + if err != nil { + log.Fatalf("error decoding password: %v", err) + } + + form := url.Values{} + form.Set("email", username) + form.Set("password", string(password)) + + // First fetch csrf token by doing a get. It is found in a meta tag with name="csrf-token" + + resp, err := http.Post("https://vivaplus.tv/supporters/sign_in", "application/x-www-form-urlencoded;charset=UTF-8", strings.NewReader(form.Encode())) + if err != nil { + log.Fatalf("error logging in: %v", err) + } + + log.Printf("Status code: %d", resp.StatusCode) + //println(resp) }