Skip to content

Commit 3a7afad

Browse files
committed
fix: Prevent arithmetic overflow and underflow
1 parent acb85a0 commit 3a7afad

File tree

5 files changed

+159
-20
lines changed

5 files changed

+159
-20
lines changed

app/src/main/java/to/bitkit/ext/Numbers.kt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,3 @@ fun ULong.toActivityItemDate(): String {
99
fun ULong.toActivityItemTime(): String {
1010
return Instant.ofEpochSecond(this.toLong()).formatted(DatePattern.ACTIVITY_TIME)
1111
}
12-
13-
// TODO replace all usages of faulty `(ULong - ULong).coerceAtLeast(0u)`
14-
/**
15-
* Safely subtracts [other] from this ULong, returning 0 if the result would be negative,
16-
* to prevent ULong wraparound by checking before subtracting, same as `x.saturating_sub(y)` in Rust.
17-
*/
18-
infix fun ULong.minusOrZero(other: ULong): ULong = if (this >= other) this - other else 0uL
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package to.bitkit.models
2+
3+
/**
4+
* A wrapper for [ULong] that provides saturating arithmetic operations.
5+
* All operations prevent overflow/underflow by clamping to valid range [0, [ULong.MAX_VALUE]].
6+
* Similar to Rust's saturating arithmetic (e.g., `x.saturating_sub(y)`).
7+
*/
8+
@JvmInline
9+
value class USat(val value: ULong) : Comparable<USat> {
10+
11+
override fun compareTo(other: USat): Int = value.compareTo(other.value)
12+
13+
/** Saturating subtraction: returns 0 if result would be negative. */
14+
operator fun minus(other: USat): ULong =
15+
if (value >= other.value) value - other.value else 0uL
16+
17+
/** Saturating addition: caps at ULong.MAX_VALUE if result would overflow. */
18+
operator fun plus(other: USat): ULong =
19+
if (value <= ULong.MAX_VALUE - other.value) value + other.value else ULong.MAX_VALUE
20+
}
21+
22+
/**
23+
* Wraps this ULong in a [USat] for saturating arithmetic operations.
24+
* Use this when performing arithmetic that could overflow/underflow.
25+
*
26+
* Example:
27+
* ```
28+
* val result = a.safe() - b.safe() // Returns 0 if a < b instead of wrapping
29+
* ```
30+
*/
31+
fun ULong.safe(): USat = USat(this)

app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ import to.bitkit.data.SettingsStore
77
import to.bitkit.data.entities.TransferEntity
88
import to.bitkit.ext.amountSats
99
import to.bitkit.ext.channelId
10-
import to.bitkit.ext.minusOrZero
1110
import to.bitkit.ext.totalNextOutboundHtlcLimitSats
1211
import to.bitkit.models.BalanceState
12+
import to.bitkit.models.safe
1313
import to.bitkit.repositories.LightningRepo
1414
import to.bitkit.repositories.TransferRepo
1515
import to.bitkit.utils.Logger
@@ -32,12 +32,11 @@ class DeriveBalanceStateUseCase @Inject constructor(
3232
val pendingChannelsSats = getPendingChannelsSats(activeTransfers, channels, balanceDetails)
3333

3434
val toSavingsAmount = getTransferToSavingsSats(activeTransfers, channels, balanceDetails)
35-
val toSpendingAmount = paidOrdersSats + pendingChannelsSats
35+
val toSpendingAmount = paidOrdersSats.safe() + pendingChannelsSats.safe()
3636

3737
val totalOnchainSats = balanceDetails.totalOnchainBalanceSats
38-
val totalLightningSats = balanceDetails.totalLightningBalanceSats
39-
.minusOrZero(pendingChannelsSats)
40-
.minusOrZero(toSavingsAmount)
38+
val afterPendingChannels = balanceDetails.totalLightningBalanceSats.safe() - pendingChannelsSats.safe()
39+
val totalLightningSats = afterPendingChannels.safe() - toSavingsAmount.safe()
4140

4241
val balanceState = BalanceState(
4342
totalOnchainSats = totalOnchainSats,
@@ -113,7 +112,7 @@ class DeriveBalanceStateUseCase @Inject constructor(
113112
Logger.debug("Could not calculate max send amount, using fallback of: $fallback", context = TAG)
114113
}.getOrDefault(fallback)
115114

116-
return spendableOnchainSats.minusOrZero(fee)
115+
return spendableOnchainSats.safe() - fee.safe()
117116
}
118117

119118
companion object {

app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ import to.bitkit.models.EUR_CURRENCY
3636
import to.bitkit.models.Toast
3737
import to.bitkit.models.TransactionSpeed
3838
import to.bitkit.models.TransferType
39+
import to.bitkit.models.safe
3940
import to.bitkit.repositories.BlocktankRepo
4041
import to.bitkit.repositories.CurrencyRepo
4142
import to.bitkit.repositories.LightningRepo
@@ -304,7 +305,7 @@ class TransferViewModel @Inject constructor(
304305
maxLspFee = estimate.feeSat
305306

306307
// Calculate the available balance to send after LSP fee
307-
val balanceAfterLspFee = availableAmount - maxLspFee
308+
val balanceAfterLspFee = availableAmount.safe() - maxLspFee.safe()
308309

309310
_spendingUiState.update {
310311
// Calculate the max available to send considering the current balance and LSP policy
@@ -380,11 +381,11 @@ class TransferViewModel @Inject constructor(
380381
val maxChannelSize1 = (maxChannelSizeSat.toDouble() * 0.98).roundToLong().toULong()
381382

382383
// The maximum channel size the user can open including existing channels
383-
val maxChannelSize2 = (maxChannelSize1 - channelsSize).coerceAtLeast(0u)
384+
val maxChannelSize2 = maxChannelSize1.safe() - channelsSize.safe()
384385
val maxChannelSizeAvailableToIncrease = min(maxChannelSize1, maxChannelSize2)
385386

386387
val minLspBalance = getMinLspBalance(clientBalanceSat, minChannelSizeSat)
387-
val maxLspBalance = (maxChannelSizeAvailableToIncrease - clientBalanceSat).coerceAtLeast(0u)
388+
val maxLspBalance = maxChannelSizeAvailableToIncrease.safe() - clientBalanceSat.safe()
388389
val defaultLspBalance = getDefaultLspBalance(clientBalanceSat, maxLspBalance)
389390
val maxClientBalance = getMaxClientBalance(maxChannelSizeAvailableToIncrease)
390391

@@ -436,11 +437,11 @@ class TransferViewModel @Inject constructor(
436437
}
437438

438439
val lspBalance = if (clientBalanceSat < threshold1) { // 0-225€: LSP balance = 450€ - client balance
439-
defaultLspBalanceSats - clientBalanceSat
440+
defaultLspBalanceSats.safe() - clientBalanceSat.safe()
440441
} else if (clientBalanceSat < threshold2) { // 225-495€: LSP balance = client balance
441442
clientBalanceSat
442443
} else if (clientBalanceSat < maxLspBalance) { // 495-950€: LSP balance = max - client balance
443-
maxLspBalance - clientBalanceSat
444+
maxLspBalance.safe() - clientBalanceSat.safe()
444445
} else {
445446
maxLspBalance
446447
}
@@ -452,7 +453,7 @@ class TransferViewModel @Inject constructor(
452453
// LSP balance must be at least 2.5% of the channel size for LDK to accept (reserve balance)
453454
val ldkMinimum = (clientBalance.toDouble() * 0.025).toULong()
454455
// Channel size must be at least minChannelSize
455-
val lspMinimum = if (minChannelSize > clientBalance) minChannelSize - clientBalance else 0u
456+
val lspMinimum = minChannelSize.safe() - clientBalance.safe()
456457

457458
return max(ldkMinimum, lspMinimum)
458459
}
@@ -461,7 +462,7 @@ class TransferViewModel @Inject constructor(
461462
// Remote balance must be at least 2.5% of the channel size for LDK to accept (reserve balance)
462463
val minRemoteBalance = (maxChannelSize.toDouble() * 0.025).toULong()
463464

464-
return maxChannelSize - minRemoteBalance
465+
return maxChannelSize.safe() - minRemoteBalance.safe()
465466
}
466467

467468
/** Calculates the total value of channels connected to Blocktank nodes */
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
package to.bitkit.models
2+
3+
import org.junit.Test
4+
import kotlin.test.assertEquals
5+
import kotlin.test.assertFalse
6+
import kotlin.test.assertTrue
7+
8+
class USatTest {
9+
10+
// region Subtraction
11+
@Test
12+
fun `minus returns difference when a greater than b`() {
13+
val result = USat(10uL) - USat(5uL)
14+
assertEquals(5uL, result)
15+
}
16+
17+
@Test
18+
fun `minus returns zero when a equals b`() {
19+
val result = USat(5uL) - USat(5uL)
20+
assertEquals(0uL, result)
21+
}
22+
23+
@Test
24+
fun `minus returns zero when would underflow`() {
25+
val result = USat(5uL) - USat(10uL)
26+
assertEquals(0uL, result)
27+
}
28+
29+
@Test
30+
fun `minus handles max ULong values`() {
31+
val result = USat(0uL) - USat(ULong.MAX_VALUE)
32+
assertEquals(0uL, result)
33+
}
34+
35+
@Test
36+
fun `chained minus operations work correctly`() {
37+
val intermediate = 100uL.safe() - 30uL.safe()
38+
val result = intermediate.safe() - 20uL.safe()
39+
assertEquals(50uL, result)
40+
}
41+
42+
@Test
43+
fun `chained minus returns zero when intermediate would underflow`() {
44+
val intermediate = 10uL.safe() - 20uL.safe()
45+
val result = intermediate.safe() - 5uL.safe()
46+
assertEquals(0uL, result)
47+
}
48+
// endregion
49+
50+
// region Addition
51+
@Test
52+
fun `plus returns sum`() {
53+
val result = USat(10uL) + USat(5uL)
54+
assertEquals(15uL, result)
55+
}
56+
57+
@Test
58+
fun `plus saturates at max when would overflow`() {
59+
val result = USat(ULong.MAX_VALUE) + USat(1uL)
60+
assertEquals(ULong.MAX_VALUE, result)
61+
}
62+
63+
@Test
64+
fun `plus saturates when both values are large`() {
65+
val result = USat(ULong.MAX_VALUE - 10uL) + USat(20uL)
66+
assertEquals(ULong.MAX_VALUE, result)
67+
}
68+
69+
@Test
70+
fun `chained plus operations work correctly`() {
71+
val intermediate = 10uL.safe() + 20uL.safe()
72+
val result = intermediate.safe() + 30uL.safe()
73+
assertEquals(60uL, result)
74+
}
75+
// endregion
76+
77+
// region Comparisons
78+
@Test
79+
fun `compareTo returns negative when less than`() {
80+
assertTrue(USat(5uL) < USat(10uL))
81+
}
82+
83+
@Test
84+
fun `compareTo returns positive when greater than`() {
85+
assertTrue(USat(10uL) > USat(5uL))
86+
}
87+
88+
@Test
89+
fun `compareTo returns zero when equal`() {
90+
assertEquals(0, USat(10uL).compareTo(USat(10uL)))
91+
}
92+
93+
@Test
94+
fun `comparison operators work correctly`() {
95+
assertTrue(USat(5uL) <= USat(10uL))
96+
assertTrue(USat(10uL) >= USat(5uL))
97+
assertTrue(USat(10uL) <= USat(10uL))
98+
assertTrue(USat(10uL) >= USat(10uL))
99+
assertFalse(USat(10uL) < USat(10uL))
100+
assertFalse(USat(10uL) > USat(10uL))
101+
}
102+
// endregion
103+
104+
// region Realistic scenarios
105+
@Test
106+
fun `realistic bitcoin calculation`() {
107+
val channelSize = 10_000_000uL // 0.1 BTC in sats
108+
val balance = 1_000_000uL // 0.01 BTC in sats
109+
110+
val maxSend = channelSize.safe() - balance.safe()
111+
112+
assertEquals(9_000_000uL, maxSend)
113+
}
114+
// endregion
115+
}

0 commit comments

Comments
 (0)