package main

import (
	"errors"
	"fmt"
	"log"
	"math"
	"time"

	"gitea.seeseepuff.be/seeseemelk/mysqlite"
)

type Db struct {
	db *mysqlite.Db
}

func NewDb(datasource string) *Db {
	// Open a file-based database
	db, err := mysqlite.OpenDb(datasource)
	if err != nil {
		log.Fatal(err)
	}

	// Apply migrations
	err = db.MigrateDb(migrations, "migrations")
	if err != nil {
		log.Fatal(err)
	}

	return &Db{db: db}
}

func (db *Db) GetUsers() ([]User, error) {
	var err error
	users := make([]User, 0)

	for row := range db.db.Query("select id, name from users").Range(&err) {
		user := User{}
		err = row.Scan(&user.ID, &user.Name)
		if err != nil {
			return nil, err
		}
		users = append(users, user)
	}
	if err != nil {
		return nil, err
	}
	return users, nil
}

func (db *Db) GetUser(id int) (*UserWithAllowance, error) {
	user := &UserWithAllowance{}

	var allowance int
	err := db.db.Query("select u.id, u.name, (select ifnull(sum(h.amount), 0) from history h where h.user_id = u.id) from users u where u.id = ?").
		Bind(id).ScanSingle(&user.ID, &user.Name, &allowance)
	user.Allowance = float64(allowance) / 100.0
	if err != nil {
		return nil, err
	}
	return user, nil
}

func (db *Db) UserExists(userId int) (bool, error) {
	count := 0
	err := db.db.Query("select count(*) from users where id = ?").
		Bind(userId).ScanSingle(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

func (db *Db) GetUserAllowances(userId int) ([]Allowance, error) {
	allowances := make([]Allowance, 0)
	var err error
	var progress int64

	totalAllowance := Allowance{}
	err = db.db.Query("select balance, weight from users where id = ?").Bind(userId).ScanSingle(&progress, &totalAllowance.Weight)
	if err != nil {
		return nil, err
	}
	totalAllowance.Progress = float64(progress) / 100.0
	allowances = append(allowances, totalAllowance)

	for row := range db.db.Query("select id, name, target, balance, weight from allowances where user_id = ?").
		Bind(userId).Range(&err) {
		allowance := Allowance{}
		var target, progress int
		err = row.Scan(&allowance.ID, &allowance.Name, &target, &progress, &allowance.Weight)
		allowance.Target = float64(target) / 100.0
		allowance.Progress = float64(progress) / 100.0
		if err != nil {
			return nil, err
		}
		allowances = append(allowances, allowance)
	}
	if err != nil {
		return nil, err
	}
	return allowances, nil
}

func (db *Db) GetUserAllowanceById(userId int, allowanceId int) (*Allowance, error) {
	allowance := &Allowance{}
	if allowanceId == 0 {
		var progress int64
		err := db.db.Query("select balance, weight from users where id = ?").
			Bind(userId).ScanSingle(&progress, &allowance.Weight)
		allowance.Progress = float64(progress) / 100.0
		if err != nil {
			return nil, err
		}
	} else {
		var target, progress, colour int64
		err := db.db.Query("select id, name, target, balance, weight, colour from allowances where user_id = ? and id = ?").
			Bind(userId, allowanceId).
			ScanSingle(&allowance.ID, &allowance.Name, &target, &progress, &allowance.Weight, &colour)
		allowance.Target = float64(target) / 100.0
		allowance.Progress = float64(progress) / 100.0
		allowance.Colour = fmt.Sprintf("#%06X", colour)
		if err != nil {
			return nil, err
		}
	}
	return allowance, nil
}

func (db *Db) CreateAllowance(userId int, allowance *CreateAllowanceRequest) (int, error) {
	// Check if user exists before attempting to create an allowance
	exists, err := db.UserExists(userId)
	if err != nil {
		return 0, err
	}
	if !exists {
		return 0, errors.New("user does not exist")
	}

	tx, err := db.db.Begin()
	if err != nil {
		return 0, err
	}
	defer tx.MustRollback()

	// Convert string colour to a valid hex format
	colour, err := ConvertStringToColour(allowance.Colour)
	if err != nil {
		return 0, err
	}

	// Insert the new allowance
	err = tx.Query("insert into allowances (user_id, name, target, weight, colour) values (?, ?, ?, ?, ?)").
		Bind(userId, allowance.Name, int(math.Round(allowance.Target*100.0)), allowance.Weight, colour).
		Exec()

	if err != nil {
		return 0, err
	}

	// Get the last inserted ID
	var lastId int
	err = tx.Query("select last_insert_rowid()").ScanSingle(&lastId)
	if err != nil {
		return 0, err
	}

	// Commit the transaction
	err = tx.Commit()
	if err != nil {
		return 0, err
	}

	return lastId, nil
}

func (db *Db) DeleteAllowance(userId int, allowanceId int) error {
	// Check if the allowance exists for the user
	count := 0
	err := db.db.Query("select count(*) from allowances where id = ? and user_id = ?").
		Bind(allowanceId, userId).ScanSingle(&count)
	if err != nil {
		return err
	}
	if count == 0 {
		return errors.New("allowance not found")
	}

	// Delete the allowance
	err = db.db.Query("delete from allowances where id = ? and user_id = ?").
		Bind(allowanceId, userId).Exec()
	if err != nil {
		return err
	}

	return nil
}

func (db *Db) CompleteAllowance(userId int, allowanceId int) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	// Get the cost of the allowance
	var cost int
	err = tx.Query("select balance from allowances where id = ? and user_id = ?").
		Bind(allowanceId, userId).ScanSingle(&cost)
	if err != nil {
		return err
	}

	// Delete the allowance
	err = tx.Query("delete from allowances where id = ? and user_id = ?").
		Bind(allowanceId, userId).Exec()
	if err != nil {
		return err
	}

	// Add a history entry
	err = tx.Query("insert into history (user_id, timestamp, amount) values (?, ?, ?)").
		Bind(userId, time.Now().Unix(), -cost).
		Exec()
	if err != nil {
		return err
	}

	return tx.Commit()
}

func (db *Db) UpdateUserAllowance(userId int, allowance *UpdateAllowanceRequest) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	err = tx.Query("update users set weight=? where id = ?").
		Bind(allowance.Weight, userId).
		Exec()
	if err != nil {
		return err
	}
	return tx.Commit()
}

func (db *Db) UpdateAllowance(userId int, allowanceId int, allowance *UpdateAllowanceRequest) error {
	// Check if the allowance exists for the user
	count := 0
	err := db.db.Query("select count(*) from allowances where id = ? and user_id = ?").
		Bind(allowanceId, userId).ScanSingle(&count)
	if err != nil {
		return err
	}
	if count == 0 {
		return errors.New("allowance not found")
	}

	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	colour, err := ConvertStringToColour(allowance.Colour)
	if err != nil {
		return err
	}

	target := int(math.Round(allowance.Target * 100.0))
	err = tx.Query("update allowances set name=?, target=?, weight=?, colour=? where id = ? and user_id = ?").
		Bind(allowance.Name, target, allowance.Weight, colour, allowanceId, userId).
		Exec()
	if err != nil {
		return err
	}
	return tx.Commit()
}

func (db *Db) BulkUpdateAllowance(userId int, allowances []BulkUpdateAllowanceRequest) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	for _, allowance := range allowances {
		if allowance.ID == 0 {
			err = tx.Query("update users set weight=? where id = ?").
				Bind(allowance.Weight, userId).
				Exec()
		} else {
			err = tx.Query("update allowances set weight=? where id = ? and user_id = ?").
				Bind(allowance.Weight, allowance.ID, userId).
				Exec()
		}
		if err != nil {
			return err
		}
	}

	return tx.Commit()
}

func (db *Db) CreateTask(task *CreateTaskRequest) (int, error) {
	tx, err := db.db.Begin()
	if err != nil {
		return 0, err
	}
	defer tx.MustRollback()

	// Insert the new task
	reward := int(math.Round(task.Reward * 100.0))
	err = tx.Query("insert into tasks (name, reward, assigned) values (?, ?, ?)").
		Bind(task.Name, reward, task.Assigned).
		Exec()

	if err != nil {
		return 0, err
	}

	// Get the last inserted ID
	var lastId int
	err = tx.Query("select last_insert_rowid()").ScanSingle(&lastId)
	if err != nil {
		return 0, err
	}

	// Commit the transaction
	err = tx.Commit()
	if err != nil {
		return 0, err
	}

	return lastId, nil
}

func (db *Db) GetTasks() ([]Task, error) {
	tasks := make([]Task, 0)
	var err error

	for row := range db.db.Query("select id, name, reward, assigned from tasks").Range(&err) {
		task := Task{}
		var reward int64
		err = row.Scan(&task.ID, &task.Name, &reward, &task.Assigned)
		task.Reward = float64(reward) / 100.0
		if err != nil {
			return nil, err
		}
		tasks = append(tasks, task)
	}
	if err != nil {
		return nil, err
	}
	return tasks, nil
}

func (db *Db) GetTask(id int) (Task, error) {
	task := Task{}

	var reward int64
	err := db.db.Query("select id, name, reward, assigned from tasks where id = ?").
		Bind(id).ScanSingle(&task.ID, &task.Name, &reward, &task.Assigned)
	task.Reward = float64(reward) / 100.0
	if err != nil {
		return Task{}, err
	}
	return task, nil
}

func (db *Db) DeleteTask(id int) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	err = tx.Query("delete from tasks where id = ?").Bind(id).Exec()
	if err != nil {
		return err
	}

	return tx.Commit()
}

func (db *Db) HasTask(id int) (bool, error) {
	count := 0
	err := db.db.Query("select count(*) from tasks where id = ?").
		Bind(id).ScanSingle(&count)
	if err != nil {
		return false, err
	}
	return count > 0, nil
}

func (db *Db) UpdateTask(id int, task *CreateTaskRequest) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	reward := int(math.Round(task.Reward * 100.0))
	err = tx.Query("update tasks set name=?, reward=?, assigned=? where id = ?").
		Bind(task.Name, reward, task.Assigned, id).
		Exec()
	if err != nil {
		return err
	}
	return tx.Commit()
}

func (db *Db) CompleteTask(taskId int) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	var reward int
	err = tx.Query("select reward from tasks where id = ?").Bind(taskId).ScanSingle(&reward)
	if err != nil {
		return err
	}

	for userRow := range tx.Query("select id, weight from users").Range(&err) {
		var userId int
		var userWeight float64
		err = userRow.Scan(&userId, &userWeight)
		if err != nil {
			return err
		}

		// Add the history entry
		err = tx.Query("insert into history (user_id, timestamp, amount) values (?, ?, ?)").
			Bind(userId, time.Now().Unix(), reward).
			Exec()
		if err != nil {
			return err
		}

		// Calculate the sums of all weights
		var sumOfWeights float64
		err = tx.Query("select sum(weight) from allowances where user_id = ? and weight > 0").Bind(userId).ScanSingle(&sumOfWeights)
		sumOfWeights += userWeight

		remainingReward := reward

		if sumOfWeights > 0 {
			// Distribute the reward to the allowances
			for allowanceRow := range tx.Query("select id, weight, target, balance from allowances where user_id = ? and weight > 0 order by (target - balance) asc").Bind(userId).Range(&err) {
				var allowanceId, allowanceTarget, allowanceBalance int
				var allowanceWeight float64
				err = allowanceRow.Scan(&allowanceId, &allowanceWeight, &allowanceTarget, &allowanceBalance)
				if err != nil {
					return err
				}

				// Calculate the amount to add to the allowance
				amount := int((allowanceWeight / sumOfWeights) * float64(remainingReward))
				if allowanceBalance+amount > allowanceTarget {
					// If the amount reaches past the target, set it to the target
					amount = allowanceTarget - allowanceBalance
				}
				sumOfWeights -= allowanceWeight
				err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?").
					Bind(amount, allowanceId, userId).Exec()
				if err != nil {
					return err
				}
				remainingReward -= amount
			}
		}

		// Add the remaining reward to the user
		err = tx.Query("update users set balance = balance + ? where id = ?").
			Bind(remainingReward, userId).Exec()
		if err != nil {
			return err
		}
	}
	if err != nil {
		return err
	}

	// Remove the task
	err = tx.Query("delete from tasks where id = ?").Bind(taskId).Exec()

	return tx.Commit()
}

func (db *Db) AddHistory(userId int, allowance *PostHistory) error {
	tx, err := db.db.Begin()
	if err != nil {
		return err
	}
	defer tx.MustRollback()

	amount := int(math.Round(allowance.Allowance * 100.0))
	err = tx.Query("insert into history (user_id, timestamp, amount) values (?, ?, ?)").
		Bind(userId, time.Now().Unix(), amount).
		Exec()
	if err != nil {
		return err
	}
	return tx.Commit()
}

func (db *Db) GetHistory(userId int) ([]History, error) {
	history := make([]History, 0)
	var err error

	for row := range db.db.Query("select amount, `timestamp` from history where user_id = ? order by `timestamp` desc").
		Bind(userId).Range(&err) {
		allowance := History{}
		var timestamp, amount int64
		err = row.Scan(&amount, &timestamp)
		if err != nil {
			return nil, err
		}
		allowance.Allowance = float64(amount) / 100.0
		allowance.Timestamp = time.Unix(timestamp, 0)
		history = append(history, allowance)
	}
	if err != nil {
		return nil, err
	}
	return history, nil
}