From df6794d01401a2b11f711ae0e47f786adadf6f60 Mon Sep 17 00:00:00 2001 From: Brendan Chou <3680392+BrendanChou@users.noreply.github.com> Date: Wed, 22 May 2024 11:44:58 -0400 Subject: [PATCH] Add additional `big_math` helper functions (#1563) --- protocol/lib/big_math.go | 34 ++++++ protocol/lib/big_math_test.go | 215 ++++++++++++++++++++++++++++++++++ 2 files changed, 249 insertions(+) diff --git a/protocol/lib/big_math.go b/protocol/lib/big_math.go index 11f84394b8..35ee84f6cf 100644 --- a/protocol/lib/big_math.go +++ b/protocol/lib/big_math.go @@ -6,6 +6,27 @@ import ( "math/big" ) +// BigU returns a new big.Int from the input unsigned integer. +func BigU[T uint | uint32 | uint64](u T) *big.Int { + return new(big.Int).SetUint64(uint64(u)) +} + +// BigI returns a new big.Int from the input signed integer. +func BigI[T int | int32 | int64](i T) *big.Int { + return big.NewInt(int64(i)) +} + +// BigMulPpm returns the result of `val * ppm / 1_000_000`, rounding in the direction indicated. +func BigMulPpm(val *big.Int, ppm *big.Int, roundUp bool) *big.Int { + result := new(big.Int).Mul(val, ppm) + oneMillion := BigIntOneMillion() + if roundUp { + return BigDivCeil(result, oneMillion) + } else { + return result.Div(result, oneMillion) + } +} + // BigMulPow10 returns the result of `val * 10^exponent`, in *big.Rat. func BigMulPow10( val *big.Int, @@ -137,6 +158,19 @@ func BigIntClamp(n *big.Int, lowerBound *big.Int, upperBound *big.Int) *big.Int return bigGenericClamp(n, lowerBound, upperBound) } +// BigDivCeil returns the ceiling of `a / b`. +func BigDivCeil(a *big.Int, b *big.Int) *big.Int { + result, remainder := new(big.Int).QuoRem(a, b, new(big.Int)) + + // If the value was rounded (i.e. there is a remainder), and the exact result would be positive, + // then add 1 to the result. + if remainder.Sign() != 0 && (a.Sign() == b.Sign()) { + result.Add(result, big.NewInt(1)) + } + + return result +} + // BigRatRound takes an input and a direction to round (true for up, false for down). // It returns the result rounded to a `*big.Int` in the specified direction. func BigRatRound(n *big.Rat, roundUp bool) *big.Int { diff --git a/protocol/lib/big_math_test.go b/protocol/lib/big_math_test.go index 101a89c832..553421d30c 100644 --- a/protocol/lib/big_math_test.go +++ b/protocol/lib/big_math_test.go @@ -12,6 +12,138 @@ import ( "github.com/stretchr/testify/require" ) +func BenchmarkBigI(b *testing.B) { + var result *big.Int + for i := 0; i < b.N; i++ { + result = lib.BigI(int64(i)) + } + require.Equal(b, result, result) +} + +func BenchmarkBigU(b *testing.B) { + var result *big.Int + for i := 0; i < b.N; i++ { + result = lib.BigU(uint32(i)) + } + require.Equal(b, result, result) +} + +func TestBigI(t *testing.T) { + require.Equal(t, big.NewInt(-123), lib.BigI(int(-123))) + require.Equal(t, big.NewInt(-123), lib.BigI(int32(-123))) + require.Equal(t, big.NewInt(-123), lib.BigI(int64(-123))) + require.Equal(t, big.NewInt(math.MaxInt64), lib.BigI(math.MaxInt64)) +} + +func TestBigU(t *testing.T) { + require.Equal(t, big.NewInt(123), lib.BigU(uint(123))) + require.Equal(t, big.NewInt(123), lib.BigU(uint32(123))) + require.Equal(t, big.NewInt(123), lib.BigU(uint64(123))) + require.Equal(t, new(big.Int).SetUint64(math.MaxUint64), lib.BigU(uint64(math.MaxUint64))) +} + +func BenchmarkBigMulPpm_RoundDown(b *testing.B) { + val := big.NewInt(543_211) + ppm := big.NewInt(876_543) + var result *big.Int + for i := 0; i < b.N; i++ { + result = lib.BigMulPpm(val, ppm, false) + } + require.Equal(b, big.NewInt(476147), result) +} + +func BenchmarkBigMulPpm_RoundUp(b *testing.B) { + val := big.NewInt(543_211) + ppm := big.NewInt(876_543) + var result *big.Int + for i := 0; i < b.N; i++ { + result = lib.BigMulPpm(val, ppm, true) + } + require.Equal(b, big.NewInt(476148), result) +} + +func TestBigMulPpm(t *testing.T) { + tests := map[string]struct { + val *big.Int + ppm *big.Int + roundUp bool + expectedResult *big.Int + }{ + "Positive round down": { + val: big.NewInt(543_211), + ppm: big.NewInt(876_543), + roundUp: false, + expectedResult: big.NewInt(476147), + }, + "Negative round down": { + val: big.NewInt(-543_211), + ppm: big.NewInt(876_543), + roundUp: false, + expectedResult: big.NewInt(-476148), + }, + "Positive round up": { + val: big.NewInt(543_211), + ppm: big.NewInt(876_543), + roundUp: true, + expectedResult: big.NewInt(476148), + }, + "Negative round up": { + val: big.NewInt(-543_211), + ppm: big.NewInt(876_543), + roundUp: true, + expectedResult: big.NewInt(-476147), + }, + "Zero val": { + val: big.NewInt(0), + ppm: big.NewInt(876_543), + roundUp: true, + expectedResult: big.NewInt(0), + }, + "Zero ppm": { + val: big.NewInt(543_211), + ppm: big.NewInt(0), + roundUp: true, + expectedResult: big.NewInt(0), + }, + "Zero val and ppm": { + val: big.NewInt(0), + ppm: big.NewInt(0), + roundUp: true, + expectedResult: big.NewInt(0), + }, + "Negative val": { + val: big.NewInt(-543_211), + ppm: big.NewInt(876_543), + roundUp: true, + expectedResult: big.NewInt(-476147), + }, + "Negative ppm": { + val: big.NewInt(543_211), + ppm: big.NewInt(-876_543), + roundUp: true, + expectedResult: big.NewInt(-476147), + }, + "Negative val and ppm": { + val: big.NewInt(-543_211), + ppm: big.NewInt(-876_543), + roundUp: true, + expectedResult: big.NewInt(476148), + }, + "Greater than max int64": { + val: big_testutil.MustFirst(new(big.Int).SetString("1000000000000000000000000", 10)), + ppm: big.NewInt(10_000), + roundUp: true, + expectedResult: big_testutil.MustFirst(new(big.Int).SetString("10000000000000000000000", 10)), + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + result := lib.BigMulPpm(tc.val, tc.ppm, tc.roundUp) + require.Equal(t, tc.expectedResult, result) + }) + } +} + func TestBigPow10(t *testing.T) { tests := map[string]struct { exponent uint64 @@ -533,6 +665,89 @@ func TestBigIntClamp(t *testing.T) { } } +func BenchmarkBigDivCeil(b *testing.B) { + numerator := big.NewInt(10) + denominator := big.NewInt(3) + var result *big.Int + for i := 0; i < b.N; i++ { + result = lib.BigDivCeil(numerator, denominator) + } + require.Equal(b, big.NewInt(4), result) +} + +func TestBigDivCeil(t *testing.T) { + tests := map[string]struct { + numerator *big.Int + denominator *big.Int + expectedResult *big.Int + }{ + "Divides evenly": { + numerator: big.NewInt(10), + denominator: big.NewInt(5), + expectedResult: big.NewInt(2), + }, + "Doesn't divide evenly": { + numerator: big.NewInt(10), + denominator: big.NewInt(3), + expectedResult: big.NewInt(4), + }, + "Negative numerator": { + numerator: big.NewInt(-10), + denominator: big.NewInt(3), + expectedResult: big.NewInt(-3), + }, + "Negative numerator 2": { + numerator: big.NewInt(-1), + denominator: big.NewInt(2), + expectedResult: big.NewInt(0), + }, + "Negative denominator": { + numerator: big.NewInt(10), + denominator: big.NewInt(-3), + expectedResult: big.NewInt(-3), + }, + "Negative denominator 2": { + numerator: big.NewInt(1), + denominator: big.NewInt(-2), + expectedResult: big.NewInt(0), + }, + "Negative numerator and denominator": { + numerator: big.NewInt(-10), + denominator: big.NewInt(-3), + expectedResult: big.NewInt(4), + }, + "Negative numerator and denominator 2": { + numerator: big.NewInt(-1), + denominator: big.NewInt(-2), + expectedResult: big.NewInt(1), + }, + "Zero numerator": { + numerator: big.NewInt(0), + denominator: big.NewInt(3), + expectedResult: big.NewInt(0), + }, + "Zero denominator": { + numerator: big.NewInt(10), + denominator: big.NewInt(0), + expectedResult: nil, + }, + } + for name, tc := range tests { + t.Run(name, func(t *testing.T) { + // Panics if the expected result is nil + if tc.expectedResult == nil { + require.Panics(t, func() { + lib.BigDivCeil(tc.numerator, tc.denominator) + }) + return + } + // Otherwise test the result + result := lib.BigDivCeil(tc.numerator, tc.denominator) + require.Equal(t, tc.expectedResult, result) + }) + } +} + func TestBigRatRound(t *testing.T) { tests := map[string]struct { input *big.Rat