Skip to content

Commit 8422ef9

Browse files
committed
fix: Prevent arithmetic overflow in substractions
1 parent 3124ac0 commit 8422ef9

File tree

5 files changed

+66
-14
lines changed

5 files changed

+66
-14
lines changed

app/src/main/java/to/bitkit/ext/Numbers.kt

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,3 @@ fun ULong.toActivityItemDate(): String {
99
fun ULong.toActivityItemTime(): String {
1010
return Instant.ofEpochSecond(this.toLong()).formatted(DatePattern.ACTIVITY_TIME)
1111
}
12-
13-
// TODO replace all usages of faulty `(ULong - ULong).coerceAtLeast(0u)`
14-
/**
15-
* Safely subtracts [other] from this ULong, returning 0 if the result would be negative,
16-
* to prevent ULong wraparound by checking before subtracting, same as `x.saturating_sub(y)` in Rust.
17-
*/
18-
infix fun ULong.minusOrZero(other: ULong): ULong = if (this >= other) this - other else 0uL
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package to.bitkit.models
2+
3+
/**
4+
* A wrapper for ULong that provides safe arithmetic operations.
5+
* The minus operator returns 0 instead of overflowing when the result would be negative.
6+
*/
7+
@JvmInline
8+
value class USat(val value: ULong) {
9+
/**
10+
* Safely subtracts [other] from this ULong, returning 0 if the result would be negative,
11+
* to prevent ULong wraparound by checking before subtracting, same as `x.saturating_sub(y)` in Rust.
12+
*/
13+
operator fun minus(other: USat): USat = USat(value minusOrZero other.value)
14+
15+
private infix fun ULong.minusOrZero(other: ULong): ULong = if (this >= other) this - other else 0uL
16+
}

app/src/main/java/to/bitkit/usecases/DeriveBalanceStateUseCase.kt

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@ import org.lightningdevkit.ldknode.BalanceDetails
55
import org.lightningdevkit.ldknode.ChannelDetails
66
import to.bitkit.data.SettingsStore
77
import to.bitkit.data.entities.TransferEntity
8+
import to.bitkit.models.USat
89
import to.bitkit.ext.amountSats
910
import to.bitkit.ext.channelId
10-
import to.bitkit.ext.minusOrZero
1111
import to.bitkit.ext.totalNextOutboundHtlcLimitSats
1212
import to.bitkit.models.BalanceState
1313
import to.bitkit.repositories.LightningRepo
@@ -35,9 +35,8 @@ class DeriveBalanceStateUseCase @Inject constructor(
3535
val toSpendingAmount = paidOrdersSats + pendingChannelsSats
3636

3737
val totalOnchainSats = balanceDetails.totalOnchainBalanceSats
38-
val totalLightningSats = balanceDetails.totalLightningBalanceSats
39-
.minusOrZero(pendingChannelsSats)
40-
.minusOrZero(toSavingsAmount)
38+
val totalLightningSats =
39+
(USat(balanceDetails.totalLightningBalanceSats) - USat(pendingChannelsSats) - USat(toSavingsAmount)).value
4140

4241
val balanceState = BalanceState(
4342
totalOnchainSats = totalOnchainSats,
@@ -113,7 +112,7 @@ class DeriveBalanceStateUseCase @Inject constructor(
113112
Logger.debug("Could not calculate max send amount, using fallback of: $fallback", context = TAG)
114113
}.getOrDefault(fallback)
115114

116-
return spendableOnchainSats.minusOrZero(fee)
115+
return (USat(spendableOnchainSats) - USat(fee)).value
117116
}
118117

119118
companion object {

app/src/main/java/to/bitkit/viewmodels/TransferViewModel.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ import to.bitkit.R
3131
import to.bitkit.data.CacheStore
3232
import to.bitkit.data.SettingsStore
3333
import to.bitkit.env.Env
34+
import to.bitkit.models.USat
3435
import to.bitkit.ext.amountOnClose
3536
import to.bitkit.models.EUR_CURRENCY
3637
import to.bitkit.models.Toast
@@ -380,11 +381,11 @@ class TransferViewModel @Inject constructor(
380381
val maxChannelSize1 = (maxChannelSizeSat.toDouble() * 0.98).roundToLong().toULong()
381382

382383
// The maximum channel size the user can open including existing channels
383-
val maxChannelSize2 = (maxChannelSize1 - channelsSize).coerceAtLeast(0u)
384+
val maxChannelSize2 = (USat(maxChannelSize1) - USat(channelsSize)).value
384385
val maxChannelSizeAvailableToIncrease = min(maxChannelSize1, maxChannelSize2)
385386

386387
val minLspBalance = getMinLspBalance(clientBalanceSat, minChannelSizeSat)
387-
val maxLspBalance = (maxChannelSizeAvailableToIncrease - clientBalanceSat).coerceAtLeast(0u)
388+
val maxLspBalance = (USat(maxChannelSizeAvailableToIncrease) - USat(clientBalanceSat)).value
388389
val defaultLspBalance = getDefaultLspBalance(clientBalanceSat, maxLspBalance)
389390
val maxClientBalance = getMaxClientBalance(maxChannelSizeAvailableToIncrease)
390391

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package to.bitkit.models
2+
3+
import org.junit.Test
4+
import kotlin.test.assertEquals
5+
6+
class USatTest {
7+
8+
@Test
9+
fun `minus returns difference when a greater than b`() {
10+
val result = USat(10uL) - USat(5uL)
11+
assertEquals(5uL, result.value)
12+
}
13+
14+
@Test
15+
fun `minus returns zero when a equals b`() {
16+
val result = USat(5uL) - USat(5uL)
17+
assertEquals(0uL, result.value)
18+
}
19+
20+
@Test
21+
fun `minus returns zero when would overflow`() {
22+
val result = USat(5uL) - USat(10uL)
23+
assertEquals(0uL, result.value)
24+
}
25+
26+
@Test
27+
fun `minus handles max ULong values`() {
28+
val result = USat(0uL) - USat(ULong.MAX_VALUE)
29+
assertEquals(0uL, result.value)
30+
}
31+
32+
@Test
33+
fun `chained minus operations work correctly`() {
34+
val result = USat(100uL) - USat(30uL) - USat(20uL)
35+
assertEquals(50uL, result.value)
36+
}
37+
38+
@Test
39+
fun `chained minus returns zero when intermediate would overflow`() {
40+
val result = USat(10uL) - USat(20uL) - USat(5uL)
41+
assertEquals(0uL, result.value)
42+
}
43+
}

0 commit comments

Comments
 (0)