diff --git a/big/int.go b/big/int.go index 4b9cdb80..67302edb 100644 --- a/big/int.go +++ b/big/int.go @@ -12,6 +12,8 @@ import ( // BigIntMaxSerializedLen is the max length of a byte slice representing a CBOR serialized big. const BigIntMaxSerializedLen = 128 +var zeroInternal = big.NewInt(0) + type Int struct { *big.Int } @@ -71,19 +73,19 @@ func Product(ints ...Int) Int { } func Mul(a, b Int) Int { - return Int{big.NewInt(0).Mul(a.Int, b.Int)} + return Int{big.NewInt(0).Mul(a.i(), b.i())} } func Div(a, b Int) Int { - return Int{big.NewInt(0).Div(a.Int, b.Int)} + return Int{big.NewInt(0).Div(a.i(), b.i())} } func Mod(a, b Int) Int { - return Int{big.NewInt(0).Mod(a.Int, b.Int)} + return Int{big.NewInt(0).Mod(a.i(), b.i())} } func Add(a, b Int) Int { - return Int{big.NewInt(0).Add(a.Int, b.Int)} + return Int{big.NewInt(0).Add(a.i(), b.i())} } func Sum(ints ...Int) Int { @@ -103,32 +105,32 @@ func Subtract(num1 Int, ints ...Int) Int { } func Sub(a, b Int) Int { - return Int{big.NewInt(0).Sub(a.Int, b.Int)} + return Int{big.NewInt(0).Sub(a.i(), b.i())} } // Returns a**e unless e <= 0 (in which case returns 1). func Exp(a Int, e Int) Int { - return Int{big.NewInt(0).Exp(a.Int, e.Int, nil)} + return Int{big.NewInt(0).Exp(a.i(), e.i(), nil)} } // Returns x << n func Lsh(a Int, n uint) Int { - return Int{big.NewInt(0).Lsh(a.Int, n)} + return Int{big.NewInt(0).Lsh(a.i(), n)} } // Returns x >> n func Rsh(a Int, n uint) Int { - return Int{big.NewInt(0).Rsh(a.Int, n)} + return Int{big.NewInt(0).Rsh(a.i(), n)} } func BitLen(a Int) uint { - return uint(a.Int.BitLen()) + return uint(a.i().BitLen()) } func Max(x, y Int) Int { // taken from max.Max() if x.Equals(Zero()) && x.Equals(y) { - if x.Sign() != 0 { + if x.i().Sign() != 0 { return y } return x @@ -142,7 +144,7 @@ func Max(x, y Int) Int { func Min(x, y Int) Int { // taken from max.Min() if x.Equals(Zero()) && x.Equals(y) { - if x.Sign() != 0 { + if x.i().Sign() != 0 { return x } return y @@ -154,7 +156,14 @@ func Min(x, y Int) Int { } func Cmp(a, b Int) int { - return a.Int.Cmp(b.Int) + return a.i().Cmp(b.i()) +} + +func (bi Int) i() *big.Int { + if bi.Int != nil { + return bi.Int + } + return zeroInternal } // LessThan returns true if bi < o @@ -179,7 +188,7 @@ func (bi Int) GreaterThanEqual(o Int) bool { // Neg returns the negative of bi. func (bi Int) Neg() Int { - return Int{big.NewInt(0).Neg(bi.Int)} + return Int{big.NewInt(0).Neg(bi.i())} } // Abs returns the absolute value of bi. @@ -338,7 +347,7 @@ func (bi *Int) UnmarshalCBOR(br io.Reader) error { } func (bi *Int) IsZero() bool { - return bi.Int.Sign() == 0 + return bi.NilOrZero() } func (bi *Int) Nil() bool {