diff --git a/backend/api_test.go b/backend/api_test.go index 46d7151..3ba6b31 100644 --- a/backend/api_test.go +++ b/backend/api_test.go @@ -764,7 +764,6 @@ func TestAddAllowanceSimple(t *testing.T) { createTestAllowance(e, "Test Allowance 1", 1000, 1) request := map[string]interface{}{ - "id": 1, "amount": 10, "description": "Added to allowance 1", } @@ -783,6 +782,30 @@ func TestAddAllowanceSimple(t *testing.T) { history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") } +func TestAddAllowanceIdZero(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(10.0, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(1) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") +} + func TestAddAllowanceWithSpillage(t *testing.T) { e := startServer(t) @@ -791,7 +814,6 @@ func TestAddAllowanceWithSpillage(t *testing.T) { e.PUT("/user/1/allowance/0").WithJSON(UpdateAllowanceRequest{Weight: 1}).Expect().Status(200) request := map[string]interface{}{ - "id": 1, "amount": 10, "description": "Added to allowance 1", } diff --git a/backend/db.go b/backend/db.go index c6775c6..37d6fbb 100644 --- a/backend/db.go +++ b/backend/db.go @@ -561,36 +561,44 @@ func (db *Db) AddAllowanceAmount(userId int, allowanceId int, request AddAllowan return err } - // Fetch the target and progress of the specified allowance - var target, progress int - err = tx.Query("select target, balance from allowances where id = ? and user_id = ?"). - Bind(allowanceId, userId).ScanSingle(&target, &progress) - if err != nil { - return err - } - - // Calculate the amount to add to the current allowance - toAdd := remainingAmount - if progress+toAdd > target { - toAdd = target - progress - } - remainingAmount -= toAdd - - // Update the current allowance - if toAdd > 0 { - err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). - Bind(toAdd, allowanceId, userId).Exec() + if allowanceId == 0 { + err = tx.Query("update users set balance = balance + ? where id = ?"). + Bind(remainingAmount, userId).Exec() if err != nil { return err } - } - - // If there's remaining amount, distribute it to the user's allowances - if remainingAmount > 0 { - err = db.addDistributedReward(tx, userId, remainingAmount) + } else { + // Fetch the target and progress of the specified allowance + var target, progress int + err = tx.Query("select target, balance from allowances where id = ? and user_id = ?"). + Bind(allowanceId, userId).ScanSingle(&target, &progress) if err != nil { return err } + + // Calculate the amount to add to the current allowance + toAdd := remainingAmount + if progress+toAdd > target { + toAdd = target - progress + } + remainingAmount -= toAdd + + // Update the current allowance + if toAdd > 0 { + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(toAdd, allowanceId, userId).Exec() + if err != nil { + return err + } + } + + // If there's remaining amount, distribute it to the user's allowances + if remainingAmount > 0 { + err = db.addDistributedReward(tx, userId, remainingAmount) + if err != nil { + return err + } + } } return tx.Commit()