diff --git a/backend/api_test.go b/backend/api_test.go index 46d7151..dec630e 100644 --- a/backend/api_test.go +++ b/backend/api_test.go @@ -764,7 +764,6 @@ func TestAddAllowanceSimple(t *testing.T) { createTestAllowance(e, "Test Allowance 1", 1000, 1) request := map[string]interface{}{ - "id": 1, "amount": 10, "description": "Added to allowance 1", } @@ -791,7 +790,6 @@ func TestAddAllowanceWithSpillage(t *testing.T) { e.PUT("/user/1/allowance/0").WithJSON(UpdateAllowanceRequest{Weight: 1}).Expect().Status(200) request := map[string]interface{}{ - "id": 1, "amount": 10, "description": "Added to allowance 1", } @@ -816,6 +814,89 @@ func TestAddAllowanceWithSpillage(t *testing.T) { history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") } +func TestAddAllowanceIdZero(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(10.0, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(1) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") +} + +func TestSubtractAllowanceSimple(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + request["amount"] = -2.5 + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(1).Object().Value("id").Number().IsEqual(1) + allowances.Value(1).Object().Value("progress").Number().InDelta(7.5, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(2) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") + + history.Value(1).Object().Value("allowance").Number().InDelta(-2.5, 0.01) + history.Value(1).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(1).Object().Value("description").String().IsEqual("Added to allowance 1") +} + +func TestSubtractllowanceIdZero(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + request["amount"] = -2.5 + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(7.5, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(2) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") + + history.Value(1).Object().Value("allowance").Number().InDelta(-2.5, 0.01) + history.Value(1).Object().Value("description").String().IsEqual("Added to allowance 1") +} + func getDelta(base time.Time, delta float64) (time.Time, time.Time) { start := base.Add(-time.Duration(delta) * time.Second) end := base.Add(time.Duration(delta) * time.Second) diff --git a/backend/db.go b/backend/db.go index c6775c6..4c0aaf3 100644 --- a/backend/db.go +++ b/backend/db.go @@ -561,36 +561,72 @@ func (db *Db) AddAllowanceAmount(userId int, allowanceId int, request AddAllowan return err } - // Fetch the target and progress of the specified allowance - var target, progress int - err = tx.Query("select target, balance from allowances where id = ? and user_id = ?"). - Bind(allowanceId, userId).ScanSingle(&target, &progress) - if err != nil { - return err - } - - // Calculate the amount to add to the current allowance - toAdd := remainingAmount - if progress+toAdd > target { - toAdd = target - progress - } - remainingAmount -= toAdd - - // Update the current allowance - if toAdd > 0 { - err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). - Bind(toAdd, allowanceId, userId).Exec() + if allowanceId == 0 { + if remainingAmount < 0 { + var userBalance int + err = tx.Query("select balance from users where id = ?"). + Bind(userId).ScanSingle(&userBalance) + if err != nil { + return err + } + if remainingAmount > userBalance { + return fmt.Errorf("cannot remove more than the current balance: %d", userBalance) + } + } + err = tx.Query("update users set balance = balance + ? where id = ?"). + Bind(remainingAmount, userId).Exec() if err != nil { return err } - } - - // If there's remaining amount, distribute it to the user's allowances - if remainingAmount > 0 { - err = db.addDistributedReward(tx, userId, remainingAmount) + } else if remainingAmount < 0 { + var progress int + err = tx.Query("select balance from allowances where id = ? and user_id = ?"). + Bind(allowanceId, userId).ScanSingle(&progress) if err != nil { return err } + + if remainingAmount > progress { + return fmt.Errorf("cannot remove more than the current allowance balance: %d", progress) + } + + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(remainingAmount, allowanceId, userId).Exec() + if err != nil { + return err + } + } else { + // Fetch the target and progress of the specified allowance + var target, progress int + err = tx.Query("select target, balance from allowances where id = ? and user_id = ?"). + Bind(allowanceId, userId).ScanSingle(&target, &progress) + if err != nil { + return err + } + + // Calculate the amount to add to the current allowance + toAdd := remainingAmount + if progress+toAdd > target { + toAdd = target - progress + } + remainingAmount -= toAdd + + // Update the current allowance + if toAdd > 0 { + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(toAdd, allowanceId, userId).Exec() + if err != nil { + return err + } + } + + // If there's remaining amount, distribute it to the user's allowances + if remainingAmount > 0 { + err = db.addDistributedReward(tx, userId, remainingAmount) + if err != nil { + return err + } + } } return tx.Commit()