273 lines
5.7 KiB
Go
273 lines
5.7 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"log"
|
|
"time"
|
|
|
|
"gitea.seeseepuff.be/seeseemelk/mysqlite"
|
|
)
|
|
|
|
type Db struct {
|
|
db *mysqlite.Db
|
|
}
|
|
|
|
func NewDb(datasource string) *Db {
|
|
// Open a file-based database
|
|
db, err := mysqlite.OpenDb(datasource)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
// Apply migrations
|
|
err = db.MigrateDb(migrations, "migrations")
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
|
|
return &Db{db: db}
|
|
}
|
|
|
|
func (db *Db) GetUsers() ([]User, error) {
|
|
var err error
|
|
users := make([]User, 0)
|
|
|
|
for row := range db.db.Query("select id, name from users").Range(&err) {
|
|
user := User{}
|
|
err = row.Scan(&user.ID, &user.Name)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
users = append(users, user)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return users, nil
|
|
}
|
|
|
|
func (db *Db) GetUser(id int) (*UserWithAllowance, error) {
|
|
user := &UserWithAllowance{}
|
|
|
|
err := db.db.Query("select u.id, u.name, (select ifnull(sum(h.amount), 0) from history h where h.user_id = u.id) from users u where u.id = ?").
|
|
Bind(id).ScanSingle(&user.ID, &user.Name, &user.Allowance)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return user, nil
|
|
}
|
|
|
|
func (db *Db) UserExists(userId int) (bool, error) {
|
|
count := 0
|
|
err := db.db.Query("select count(*) from users where id = ?").
|
|
Bind(userId).ScanSingle(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (db *Db) GetUserGoals(userId int) ([]Goal, error) {
|
|
goals := make([]Goal, 0)
|
|
var err error
|
|
|
|
for row := range db.db.Query("select id, name, target, progress, weight from goals where user_id = ?").
|
|
Bind(userId).Range(&err) {
|
|
goal := Goal{}
|
|
err = row.Scan(&goal.ID, &goal.Name, &goal.Target, &goal.Progress, &goal.Weight)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
goals = append(goals, goal)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return goals, nil
|
|
}
|
|
|
|
func (db *Db) CreateGoal(userId int, goal *CreateGoalRequest) (int, error) {
|
|
// Check if user exists before attempting to create a goal
|
|
exists, err := db.UserExists(userId)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
if !exists {
|
|
return 0, errors.New("user does not exist")
|
|
}
|
|
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer tx.MustRollback()
|
|
|
|
// Insert the new goal
|
|
err = tx.Query("insert into goals (user_id, name, target, progress, weight) values (?, ?, ?, 0, ?)").
|
|
Bind(userId, goal.Name, goal.Target, goal.Weight).
|
|
Exec()
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Get the last inserted ID
|
|
var lastId int
|
|
err = tx.Query("select last_insert_rowid()").ScanSingle(&lastId)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Commit the transaction
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return lastId, nil
|
|
}
|
|
|
|
func (db *Db) DeleteGoal(userId int, goalId int) error {
|
|
// Check if the goal exists for the user
|
|
count := 0
|
|
err := db.db.Query("select count(*) from goals where id = ? and user_id = ?").
|
|
Bind(goalId, userId).ScanSingle(&count)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if count == 0 {
|
|
return errors.New("goal not found")
|
|
}
|
|
|
|
// Delete the goal
|
|
err = db.db.Query("delete from goals where id = ? and user_id = ?").
|
|
Bind(goalId, userId).Exec()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (db *Db) CreateTask(task *CreateTaskRequest) (int, error) {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
defer tx.MustRollback()
|
|
|
|
// Insert the new task
|
|
err = tx.Query("insert into tasks (name, reward, assigned) values (?, ?, ?)").
|
|
Bind(task.Name, task.Reward, task.Assigned).
|
|
Exec()
|
|
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Get the last inserted ID
|
|
var lastId int
|
|
err = tx.Query("select last_insert_rowid()").ScanSingle(&lastId)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
// Commit the transaction
|
|
err = tx.Commit()
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
|
|
return lastId, nil
|
|
}
|
|
|
|
func (db *Db) GetTasks() ([]Task, error) {
|
|
tasks := make([]Task, 0)
|
|
var err error
|
|
|
|
for row := range db.db.Query("select id, name, reward, assigned from tasks").Range(&err) {
|
|
task := Task{}
|
|
err = row.Scan(&task.ID, &task.Name, &task.Reward, &task.Assigned)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
tasks = append(tasks, task)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return tasks, nil
|
|
}
|
|
|
|
func (db *Db) GetTask(id int) (Task, error) {
|
|
task := Task{}
|
|
|
|
err := db.db.Query("select id, name, reward, assigned from tasks where id = ?").
|
|
Bind(id).ScanSingle(&task.ID, &task.Name, &task.Reward, &task.Assigned)
|
|
if err != nil {
|
|
return Task{}, err
|
|
}
|
|
return task, nil
|
|
}
|
|
|
|
func (db *Db) HasTask(id int) (bool, error) {
|
|
count := 0
|
|
err := db.db.Query("select count(*) from tasks where id = ?").
|
|
Bind(id).ScanSingle(&count)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (db *Db) UpdateTask(id int, task *CreateTaskRequest) error {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.MustRollback()
|
|
|
|
err = tx.Query("update tasks set name=?, reward=?, assigned=? where id = ?").
|
|
Bind(task.Name, task.Reward, task.Assigned, id).
|
|
Exec()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (db *Db) AddAllowance(userId int, allowance *PostAllowance) error {
|
|
tx, err := db.db.Begin()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.MustRollback()
|
|
|
|
err = tx.Query("insert into history (user_id, timestamp, amount) values (?, ?, ?)").
|
|
Bind(userId, time.Now().Unix(), allowance.Allowance).
|
|
Exec()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return tx.Commit()
|
|
}
|
|
|
|
func (db *Db) GetHistory(userId int) ([]Allowance, error) {
|
|
history := make([]Allowance, 0)
|
|
var err error
|
|
|
|
for row := range db.db.Query("select amount from history where user_id = ? order by `timestamp` desc").
|
|
Bind(userId).Range(&err) {
|
|
allowance := Allowance{}
|
|
err = row.Scan(&allowance.Allowance)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
history = append(history, allowance)
|
|
}
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return history, nil
|
|
}
|