diff --git a/backend/api_test.go b/backend/api_test.go index c862102..46d7151 100644 --- a/backend/api_test.go +++ b/backend/api_test.go @@ -758,6 +758,64 @@ func TestPutBulkAllowance(t *testing.T) { allowances.Value(2).Object().Value("weight").Number().IsEqual(10) } +func TestAddAllowanceSimple(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "id": 1, + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(1).Object().Value("id").Number().IsEqual(1) + allowances.Value(1).Object().Value("progress").Number().InDelta(10.0, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(1) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") +} + +func TestAddAllowanceWithSpillage(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 5, 1) + createTestAllowance(e, "Test Allowance 2", 5, 1) + e.PUT("/user/1/allowance/0").WithJSON(UpdateAllowanceRequest{Weight: 1}).Expect().Status(200) + + request := map[string]interface{}{ + "id": 1, + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(1).Object().Value("id").Number().IsEqual(1) + allowances.Value(1).Object().Value("progress").Number().InDelta(5.0, 0.01) + + allowances.Value(2).Object().Value("id").Number().IsEqual(2) + allowances.Value(2).Object().Value("progress").Number().InDelta(2.5, 0.01) + + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(2.5, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(1) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") +} + func getDelta(base time.Time, delta float64) (time.Time, time.Time) { start := base.Add(-time.Duration(delta) * time.Second) end := base.Add(time.Duration(delta) * time.Second) diff --git a/backend/db.go b/backend/db.go index 6f31ea6..c6775c6 100644 --- a/backend/db.go +++ b/backend/db.go @@ -428,10 +428,9 @@ func (db *Db) CompleteTask(taskId int) error { return err } - for userRow := range tx.Query("select id, weight from users").Range(&err) { + for userRow := range tx.Query("select id from users").Range(&err) { var userId int - var userWeight float64 - err = userRow.Scan(&userId, &userWeight) + err = userRow.Scan(&userId) if err != nil { return err } @@ -444,42 +443,7 @@ func (db *Db) CompleteTask(taskId int) error { return err } - // Calculate the sums of all weights - var sumOfWeights float64 - err = tx.Query("select sum(weight) from allowances where user_id = ? and weight > 0").Bind(userId).ScanSingle(&sumOfWeights) - sumOfWeights += userWeight - - remainingReward := reward - - if sumOfWeights > 0 { - // Distribute the reward to the allowances - for allowanceRow := range tx.Query("select id, weight, target, balance from allowances where user_id = ? and weight > 0 order by (target - balance) asc").Bind(userId).Range(&err) { - var allowanceId, allowanceTarget, allowanceBalance int - var allowanceWeight float64 - err = allowanceRow.Scan(&allowanceId, &allowanceWeight, &allowanceTarget, &allowanceBalance) - if err != nil { - return err - } - - // Calculate the amount to add to the allowance - amount := int((allowanceWeight / sumOfWeights) * float64(remainingReward)) - if allowanceBalance+amount > allowanceTarget { - // If the amount reaches past the target, set it to the target - amount = allowanceTarget - allowanceBalance - } - sumOfWeights -= allowanceWeight - err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). - Bind(amount, allowanceId, userId).Exec() - if err != nil { - return err - } - remainingReward -= amount - } - } - - // Add the remaining reward to the user - err = tx.Query("update users set balance = balance + ? where id = ?"). - Bind(remainingReward, userId).Exec() + err := db.addDistributedReward(tx, userId, reward) if err != nil { return err } @@ -494,6 +458,52 @@ func (db *Db) CompleteTask(taskId int) error { return tx.Commit() } +func (db *Db) addDistributedReward(tx *mysqlite.Tx, userId int, reward int) error { + var userWeight float64 + err := tx.Query("select weight from users where id = ?").Bind(userId).ScanSingle(&userWeight) + if err != nil { + return err + } + + // Calculate the sums of all weights + var sumOfWeights float64 + err = tx.Query("select sum(weight) from allowances where user_id = ? and weight > 0").Bind(userId).ScanSingle(&sumOfWeights) + sumOfWeights += userWeight + + remainingReward := reward + + if sumOfWeights > 0 { + // Distribute the reward to the allowances + for allowanceRow := range tx.Query("select id, weight, target, balance from allowances where user_id = ? and weight > 0 order by (target - balance) asc").Bind(userId).Range(&err) { + var allowanceId, allowanceTarget, allowanceBalance int + var allowanceWeight float64 + err = allowanceRow.Scan(&allowanceId, &allowanceWeight, &allowanceTarget, &allowanceBalance) + if err != nil { + return err + } + + // Calculate the amount to add to the allowance + amount := int((allowanceWeight / sumOfWeights) * float64(remainingReward)) + if allowanceBalance+amount > allowanceTarget { + // If the amount reaches past the target, set it to the target + amount = allowanceTarget - allowanceBalance + } + sumOfWeights -= allowanceWeight + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(amount, allowanceId, userId).Exec() + if err != nil { + return err + } + remainingReward -= amount + } + } + + // Add the remaining reward to the user + err = tx.Query("update users set balance = balance + ? where id = ?"). + Bind(remainingReward, userId).Exec() + return err +} + func (db *Db) AddHistory(userId int, allowance *PostHistory) error { tx, err := db.db.Begin() if err != nil { @@ -532,3 +542,56 @@ func (db *Db) GetHistory(userId int) ([]History, error) { } return history, nil } + +func (db *Db) AddAllowanceAmount(userId int, allowanceId int, request AddAllowanceAmountRequest) error { + tx, err := db.db.Begin() + if err != nil { + return err + } + defer tx.MustRollback() + + // Convert amount to integer (cents) + remainingAmount := int(math.Round(request.Amount * 100)) + + // Insert history entry + err = tx.Query("insert into history (user_id, timestamp, amount, description) values (?, ?, ?, ?)"). + Bind(userId, time.Now().Unix(), remainingAmount, request.Description). + Exec() + if err != nil { + return err + } + + // Fetch the target and progress of the specified allowance + var target, progress int + err = tx.Query("select target, balance from allowances where id = ? and user_id = ?"). + Bind(allowanceId, userId).ScanSingle(&target, &progress) + if err != nil { + return err + } + + // Calculate the amount to add to the current allowance + toAdd := remainingAmount + if progress+toAdd > target { + toAdd = target - progress + } + remainingAmount -= toAdd + + // Update the current allowance + if toAdd > 0 { + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(toAdd, allowanceId, userId).Exec() + if err != nil { + return err + } + } + + // If there's remaining amount, distribute it to the user's allowances + if remainingAmount > 0 { + err = db.addDistributedReward(tx, userId, remainingAmount) + if err != nil { + return err + } + } + + return tx.Commit() +} diff --git a/backend/dto.go b/backend/dto.go index a79035c..d51831c 100644 --- a/backend/dto.go +++ b/backend/dto.go @@ -73,3 +73,8 @@ type CreateTaskRequest struct { type CreateTaskResponse struct { ID int `json:"id"` } + +type AddAllowanceAmountRequest struct { + Amount float64 `json:"amount"` + Description string `json:"description"` +} diff --git a/backend/main.go b/backend/main.go index e9aa104..d18c841 100644 --- a/backend/main.go +++ b/backend/main.go @@ -368,6 +368,56 @@ func completeAllowance(c *gin.Context) { c.IndentedJSON(http.StatusOK, gin.H{"message": "Allowance completed successfully"}) } +func addToAllowance(c *gin.Context) { + userIdStr := c.Param("userId") + allowanceIdStr := c.Param("allowanceId") + + userId, err := strconv.Atoi(userIdStr) + if err != nil { + log.Printf(ErrInvalidUserID+": %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": ErrInvalidUserID}) + return + } + + allowanceId, err := strconv.Atoi(allowanceIdStr) + if err != nil { + log.Printf("Invalid allowance ID: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid allowance ID"}) + return + } + + exists, err := db.UserExists(userId) + if err != nil { + log.Printf(ErrCheckingUserExist, err) + c.JSON(http.StatusInternalServerError, gin.H{"error": ErrInternalServerError}) + return + } + if !exists { + c.JSON(http.StatusNotFound, gin.H{"error": ErrUserNotFound}) + return + } + + var allowanceRequest AddAllowanceAmountRequest + if err := c.ShouldBindJSON(&allowanceRequest); err != nil { + log.Printf("Error parsing request body: %v", err) + c.JSON(http.StatusBadRequest, gin.H{"error": "Invalid request body"}) + return + } + + err = db.AddAllowanceAmount(userId, allowanceId, allowanceRequest) + if errors.Is(err, mysqlite.ErrNoRows) { + c.JSON(http.StatusNotFound, gin.H{"error": "Allowance not found"}) + return + } + if err != nil { + log.Printf("Error completing allowance: %v", err) + c.JSON(http.StatusInternalServerError, gin.H{"error": ErrInternalServerError}) + return + } + + c.IndentedJSON(http.StatusOK, gin.H{"message": "Allowance completed successfully"}) +} + func createTask(c *gin.Context) { var taskRequest CreateTaskRequest if err := c.ShouldBindJSON(&taskRequest); err != nil { @@ -611,6 +661,7 @@ func start(ctx context.Context, config *ServerConfig) { router.DELETE("/api/user/:userId/allowance/:allowanceId", deleteUserAllowance) router.PUT("/api/user/:userId/allowance/:allowanceId", putUserAllowance) router.POST("/api/user/:userId/allowance/:allowanceId/complete", completeAllowance) + router.POST("/api/user/:userId/allowance/:allowanceId/add", addToAllowance) router.POST("/api/tasks", createTask) router.GET("/api/tasks", getTasks) router.GET("/api/task/:taskId", getTask)