Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 23 additions & 14 deletions big/int.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ import (
// BigIntMaxSerializedLen is the max length of a byte slice representing a CBOR serialized big.
const BigIntMaxSerializedLen = 128

var zeroInternal = big.NewInt(0)

type Int struct {
*big.Int
}
Expand Down Expand Up @@ -71,19 +73,19 @@ func Product(ints ...Int) Int {
}

func Mul(a, b Int) Int {
return Int{big.NewInt(0).Mul(a.Int, b.Int)}
return Int{big.NewInt(0).Mul(a.i(), b.i())}
}

func Div(a, b Int) Int {
return Int{big.NewInt(0).Div(a.Int, b.Int)}
return Int{big.NewInt(0).Div(a.i(), b.i())}
}

func Mod(a, b Int) Int {
return Int{big.NewInt(0).Mod(a.Int, b.Int)}
return Int{big.NewInt(0).Mod(a.i(), b.i())}
}

func Add(a, b Int) Int {
return Int{big.NewInt(0).Add(a.Int, b.Int)}
return Int{big.NewInt(0).Add(a.i(), b.i())}
}

func Sum(ints ...Int) Int {
Expand All @@ -103,32 +105,32 @@ func Subtract(num1 Int, ints ...Int) Int {
}

func Sub(a, b Int) Int {
return Int{big.NewInt(0).Sub(a.Int, b.Int)}
return Int{big.NewInt(0).Sub(a.i(), b.i())}
}

// Returns a**e unless e <= 0 (in which case returns 1).
func Exp(a Int, e Int) Int {
return Int{big.NewInt(0).Exp(a.Int, e.Int, nil)}
return Int{big.NewInt(0).Exp(a.i(), e.i(), nil)}
}

// Returns x << n
func Lsh(a Int, n uint) Int {
return Int{big.NewInt(0).Lsh(a.Int, n)}
return Int{big.NewInt(0).Lsh(a.i(), n)}
}

// Returns x >> n
func Rsh(a Int, n uint) Int {
return Int{big.NewInt(0).Rsh(a.Int, n)}
return Int{big.NewInt(0).Rsh(a.i(), n)}
}

func BitLen(a Int) uint {
return uint(a.Int.BitLen())
return uint(a.i().BitLen())
}

func Max(x, y Int) Int {
// taken from max.Max()
if x.Equals(Zero()) && x.Equals(y) {
if x.Sign() != 0 {
if x.i().Sign() != 0 {
return y
}
return x
Expand All @@ -142,7 +144,7 @@ func Max(x, y Int) Int {
func Min(x, y Int) Int {
// taken from max.Min()
if x.Equals(Zero()) && x.Equals(y) {
if x.Sign() != 0 {
if x.i().Sign() != 0 {
return x
}
return y
Expand All @@ -154,7 +156,14 @@ func Min(x, y Int) Int {
}

func Cmp(a, b Int) int {
return a.Int.Cmp(b.Int)
return a.i().Cmp(b.i())
}

func (bi Int) i() *big.Int {
if bi.Int != nil {
return bi.Int
}
return zeroInternal
}

// LessThan returns true if bi < o
Expand All @@ -179,7 +188,7 @@ func (bi Int) GreaterThanEqual(o Int) bool {

// Neg returns the negative of bi.
func (bi Int) Neg() Int {
return Int{big.NewInt(0).Neg(bi.Int)}
return Int{big.NewInt(0).Neg(bi.i())}
}

// Abs returns the absolute value of bi.
Expand Down Expand Up @@ -338,7 +347,7 @@ func (bi *Int) UnmarshalCBOR(br io.Reader) error {
}

func (bi *Int) IsZero() bool {
return bi.Int.Sign() == 0
return bi.NilOrZero()
}

func (bi *Int) Nil() bool {
Expand Down