diff --git a/backend/api_test.go b/backend/api_test.go index 3ba6b31..dec630e 100644 --- a/backend/api_test.go +++ b/backend/api_test.go @@ -782,30 +782,6 @@ func TestAddAllowanceSimple(t *testing.T) { history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") } -func TestAddAllowanceIdZero(t *testing.T) { - e := startServer(t) - - createTestAllowance(e, "Test Allowance 1", 1000, 1) - - request := map[string]interface{}{ - "amount": 10, - "description": "Added to allowance 1", - } - e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) - - // Verify the allowance is updated - allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() - allowances.Value(0).Object().Value("id").Number().IsEqual(0) - allowances.Value(0).Object().Value("progress").Number().InDelta(10.0, 0.01) - - // Verify the history is updated - history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() - history.Length().IsEqual(1) - history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) - history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) - history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") -} - func TestAddAllowanceWithSpillage(t *testing.T) { e := startServer(t) @@ -838,6 +814,89 @@ func TestAddAllowanceWithSpillage(t *testing.T) { history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") } +func TestAddAllowanceIdZero(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(10.0, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(1) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") +} + +func TestSubtractAllowanceSimple(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + request["amount"] = -2.5 + e.POST("/user/1/allowance/1/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(1).Object().Value("id").Number().IsEqual(1) + allowances.Value(1).Object().Value("progress").Number().InDelta(7.5, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(2) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") + + history.Value(1).Object().Value("allowance").Number().InDelta(-2.5, 0.01) + history.Value(1).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(1).Object().Value("description").String().IsEqual("Added to allowance 1") +} + +func TestSubtractllowanceIdZero(t *testing.T) { + e := startServer(t) + + createTestAllowance(e, "Test Allowance 1", 1000, 1) + + request := map[string]interface{}{ + "amount": 10, + "description": "Added to allowance 1", + } + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + request["amount"] = -2.5 + e.POST("/user/1/allowance/0/add").WithJSON(request).Expect().Status(200) + + // Verify the allowance is updated + allowances := e.GET("/user/1/allowance").Expect().Status(200).JSON().Array() + allowances.Value(0).Object().Value("id").Number().IsEqual(0) + allowances.Value(0).Object().Value("progress").Number().InDelta(7.5, 0.01) + + // Verify the history is updated + history := e.GET("/user/1/history").Expect().Status(200).JSON().Array() + history.Length().IsEqual(2) + history.Value(0).Object().Value("allowance").Number().InDelta(10.0, 0.01) + history.Value(0).Object().Value("timestamp").String().AsDateTime().InRange(getDelta(time.Now(), 2.0)) + history.Value(0).Object().Value("description").String().IsEqual("Added to allowance 1") + + history.Value(1).Object().Value("allowance").Number().InDelta(-2.5, 0.01) + history.Value(1).Object().Value("description").String().IsEqual("Added to allowance 1") +} + func getDelta(base time.Time, delta float64) (time.Time, time.Time) { start := base.Add(-time.Duration(delta) * time.Second) end := base.Add(time.Duration(delta) * time.Second) diff --git a/backend/db.go b/backend/db.go index 37d6fbb..4c0aaf3 100644 --- a/backend/db.go +++ b/backend/db.go @@ -562,11 +562,39 @@ func (db *Db) AddAllowanceAmount(userId int, allowanceId int, request AddAllowan } if allowanceId == 0 { + if remainingAmount < 0 { + var userBalance int + err = tx.Query("select balance from users where id = ?"). + Bind(userId).ScanSingle(&userBalance) + if err != nil { + return err + } + if remainingAmount > userBalance { + return fmt.Errorf("cannot remove more than the current balance: %d", userBalance) + } + } err = tx.Query("update users set balance = balance + ? where id = ?"). Bind(remainingAmount, userId).Exec() if err != nil { return err } + } else if remainingAmount < 0 { + var progress int + err = tx.Query("select balance from allowances where id = ? and user_id = ?"). + Bind(allowanceId, userId).ScanSingle(&progress) + if err != nil { + return err + } + + if remainingAmount > progress { + return fmt.Errorf("cannot remove more than the current allowance balance: %d", progress) + } + + err = tx.Query("update allowances set balance = balance + ? where id = ? and user_id = ?"). + Bind(remainingAmount, allowanceId, userId).Exec() + if err != nil { + return err + } } else { // Fetch the target and progress of the specified allowance var target, progress int