diff --git a/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java b/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java index 290ef16cffb..b93388b7a3b 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/UInt256.java @@ -15,78 +15,56 @@ package org.hyperledger.besu.evm; import java.math.BigInteger; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; - -import com.google.common.annotations.VisibleForTesting; +import java.util.Arrays; /** * 256-bits wide unsigned integer class. * *

This class is an optimised version of BigInteger for fixed width 256-bits integers. + * + * @param u3 4th digit + * @param u2 3rd digit + * @param u1 2nd digit + * @param u0 1st digit */ -public final class UInt256 { - // region Internals +public record UInt256(long u3, long u2, long u1, long u0) { + + // region Values + // -------------------------------------------------------------------------- + // UInt256 represents a big-endian 256-bits integer. + // As opposed to Java int, operations are by default unsigned, + // and signed version are interpreted in two-complements as usual. // -------------------------------------------------------------------------- - // UInt256 is a big-endian up to 256-bits integer. - // Internally, it is represented with fixed-size int/long limbs in little-endian order. - // Length is used to optimise algorithms, skipping leading zeroes. - // Nonetheless, 256bits are always allocated and initialised to zeroes. /** Fixed size in bytes. */ public static final int BYTESIZE = 32; - /** Fixed size in bits. */ - public static final int BITSIZE = 256; - - // Fixed number of limbs or digits - private static final int N_LIMBS = 8; - // Fixed number of bits per limb. - private static final int N_BITS_PER_LIMB = 32; - // Mask for long values - private static final long MASK_L = 0xFFFFFFFFL; + /** The constant 0. */ + public static final UInt256 ZERO = new UInt256(0, 0, 0, 0); - private final int[] limbs; - private final int length; + private static final byte[] ZERO_BYTES = new byte[BYTESIZE]; - @VisibleForTesting - int[] limbs() { - return limbs; - } + /** The constant All ones */ + public static final UInt256 MAX = + new UInt256( + 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL, 0xFFFFFFFFFFFFFFFFL); // -------------------------------------------------------------------------- // endregion - /** The constant 0. */ - public static final UInt256 ZERO = new UInt256(new int[] {0, 0, 0, 0, 0, 0, 0, 0}, 0); + // region (private) Internal Values + // -------------------------------------------------------------------------- - /** The constant All ones */ - public static final UInt256 ALL_ONES = - new UInt256( - new int[] { - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF, - 0xFFFFFFFF - }, - N_LIMBS); + // Fixed number of limbs or digits + // private static final int N_LIMBS = 4; + // Fixed number of bits per limb. + private static final int N_BITS_PER_LIMB = 64; - // region Constructors // -------------------------------------------------------------------------- + // endregion - UInt256(final int[] limbs, final int length) { - // Unchecked length: assumes limbs have length == N_LIMBS - this.limbs = limbs; - this.length = length; - } - - UInt256(final int[] limbs) { - this(limbs, N_LIMBS); - } + // region Alternative Constructors + // -------------------------------------------------------------------------- /** * Instantiates a new UInt256 from byte array. @@ -95,48 +73,25 @@ int[] limbs() { * @return Big-endian UInt256 represented by the bytes. */ public static UInt256 fromBytesBE(final byte[] bytes) { - int byteLen = bytes.length; - if (byteLen == 0) return ZERO; - - int[] limbs = new int[N_LIMBS]; - - // Fast path for exactly 32 bytes - if (byteLen == 32) { - limbs[7] = getIntBE(bytes, 0); - limbs[6] = getIntBE(bytes, 4); - limbs[5] = getIntBE(bytes, 8); - limbs[4] = getIntBE(bytes, 12); - limbs[3] = getIntBE(bytes, 16); - limbs[2] = getIntBE(bytes, 20); - limbs[1] = getIntBE(bytes, 24); - limbs[0] = getIntBE(bytes, 28); - return new UInt256(limbs, N_LIMBS); - } - - // General path for variable length - int limbIndex = 0; - int byteIndex = byteLen - 1; - - while (byteIndex >= 0 && limbIndex < N_LIMBS) { - int limb = 0; - int shift = 0; - - for (int j = 0; j < 4 && byteIndex >= 0; j++, byteIndex--, shift += 8) { - limb |= (bytes[byteIndex] & 0xFF) << shift; - } - - limbs[limbIndex++] = limb; + if (bytes.length == 0) return ZERO; + long u3 = 0; + long u2 = 0; + long u1 = 0; + long u0 = 0; + int b = bytes.length - 1; // Index in bytes array + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u0 |= ((bytes[b] & 0xFFL) << shift); } - - return new UInt256(limbs, limbIndex); - } - - // Helper method to read 4 bytes as big-endian int - private static int getIntBE(final byte[] bytes, final int offset) { - return ((bytes[offset] & 0xFF) << 24) - | ((bytes[offset + 1] & 0xFF) << 16) - | ((bytes[offset + 2] & 0xFF) << 8) - | (bytes[offset + 3] & 0xFF); + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u1 |= ((bytes[b] & 0xFFL) << shift); + } + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u2 |= ((bytes[b] & 0xFFL) << shift); + } + for (int shift = 0; shift < 64 && b >= 0; b--, shift += 8) { + u3 |= ((bytes[b] & 0xFFL) << shift); + } + return new UInt256(u3, u2, u1, u0); } /** @@ -146,10 +101,7 @@ private static int getIntBE(final byte[] bytes, final int offset) { * @return The UInt256 equivalent of value. */ public static UInt256 fromInt(final int value) { - if (value == 0) return ZERO; - int[] limbs = new int[N_LIMBS]; - limbs[0] = value; - return new UInt256(limbs, 1); + return new UInt256(0, 0, 0, value & 0xFFFFFFFFL); } /** @@ -159,27 +111,25 @@ public static UInt256 fromInt(final int value) { * @return The UInt256 equivalent of value. */ public static UInt256 fromLong(final long value) { - if (value == 0) return ZERO; - int[] limbs = new int[N_LIMBS]; - limbs[0] = (int) value; - limbs[1] = (int) (value >>> 32); - return new UInt256(limbs, 2); + return new UInt256(0, 0, 0, value); } /** - * Instantiates a new UInt256 from an int array. + * Instantiates a new UInt256 from an array. * - *

The array is interpreted in little-endian order. It is either padded with 0s or truncated if - * necessary. + *

Read digits from an array starting from the end. The array must have at least N_LIMBS + * elements. * - * @param arr int array of limbs. - * @return The UInt256 equivalent of value. + * @param limbs The array holding the digits. + * @return The UInt256 from the array */ - public static UInt256 fromArray(final int[] arr) { - int[] limbs = new int[N_LIMBS]; - int len = Math.min(N_LIMBS, arr.length); - System.arraycopy(arr, 0, limbs, 0, len); - return new UInt256(limbs, len); + public static UInt256 fromArray(final long[] limbs) { + int i = limbs.length; + long z0 = limbs[--i]; + long z1 = limbs[--i]; + long z2 = limbs[--i]; + long z3 = limbs[--i]; + return new UInt256(z3, z2, z1, z0); } // -------------------------------------------------------------------------- @@ -193,7 +143,7 @@ public static UInt256 fromArray(final int[] arr) { * @return Value truncated to an int, possibly lossy. */ public int intValue() { - return limbs[0]; + return (int) u0; } /** @@ -202,7 +152,7 @@ public int intValue() { * @return Value truncated to a long, possibly lossy. */ public long longValue() { - return (limbs[0] & MASK_L) | ((limbs[1] & MASK_L) << 32); + return u0; } /** @@ -211,11 +161,24 @@ public long longValue() { * @return Big-endian ordered bytes for this UInt256 value. */ public byte[] toBytesBE() { - ByteBuffer buf = ByteBuffer.allocate(BYTESIZE).order(ByteOrder.BIG_ENDIAN); - for (int i = N_LIMBS - 1; i >= 0; i--) { - buf.putInt(limbs[i]); - } - return buf.array(); + byte[] result = new byte[BYTESIZE]; + longIntoBytes(result, 0, u3); + longIntoBytes(result, 8, u2); + longIntoBytes(result, 16, u1); + longIntoBytes(result, 24, u0); + return result; + } + + // Helper method to write 8 bytes from big-endian int + private static void longIntoBytes(final byte[] bytes, final int offset, final long value) { + bytes[offset] = (byte) (value >>> 56); + bytes[offset + 1] = (byte) (value >>> 48); + bytes[offset + 2] = (byte) (value >>> 40); + bytes[offset + 3] = (byte) (value >>> 32); + bytes[offset + 4] = (byte) (value >>> 24); + bytes[offset + 5] = (byte) (value >>> 16); + bytes[offset + 6] = (byte) (value >>> 8); + bytes[offset + 7] = (byte) value; } /** @@ -227,8 +190,14 @@ public BigInteger toBigInteger() { return new BigInteger(1, toBytesBE()); } - @Override - public String toString() { + /** + * Convert to hexstring. + * + *

Convert this integer into big-endian hexstring representation. + * + * @return The hexstring representing the integer. + */ + public String toHexString() { StringBuilder sb = new StringBuilder("0x"); for (byte b : toBytesBE()) { sb.append(String.format("%02x", b)); @@ -236,6 +205,26 @@ public String toString() { return sb.toString(); } + private UInt320 UInt320Value() { + return new UInt320(0, u3, u2, u1, u0); + } + + private Modulus64 asModulus64() { + return new Modulus64(u0); + } + + private Modulus128 asModulus128() { + return new Modulus128(u1, u0); + } + + private Modulus192 asModulus192() { + return new Modulus192(u2, u1, u0); + } + + private Modulus256 asModulus256() { + return new Modulus256(u3, u2, u1, u0); + } + // -------------------------------------------------------------------------- // endregion @@ -248,8 +237,61 @@ public String toString() { * @return true if this UInt256 value is 0. */ public boolean isZero() { - return (limbs[0] | limbs[1] | limbs[2] | limbs[3] | limbs[4] | limbs[5] | limbs[6] | limbs[7]) - == 0; + return (u0 | u1 | u2 | u3) == 0; + } + + /** + * Is the value 1 ? + * + * @return true if this UInt256 value is 1. + */ + public boolean isOne() { + return ((u0 ^ 1L) | u1 | u2 | u3) == 0; + } + + /** + * Is the value 0 or 1 ? + * + * @return true if this UInt256 value is 1. + */ + public boolean isZeroOrOne() { + return ((u0 & -2L) | u1 | u2 | u3) == 0; + } + + /** + * Is the two complements signed representation of this integer negative. + * + * @return True if the two complements representation of this integer is negative. + */ + public boolean isNegative() { + return u3 < 0; + } + + /** + * Does the value fit a long. + * + * @return true if it has at most 1 effective digit. + */ + public boolean isUInt64() { + return (u1 | u2 | u3) == 0; + } + + /** + * Does the value fit 2 longs. + * + * @return true if it has at most 2 effective digits. + */ + public boolean isUInt128() { + return (u2 | u3) == 0; + } + + /** + * Does the value fit 3 longs. + * + * @return true if it has at most 3 effective digits. + */ + public boolean isUInt192() { + return u3 == 0; } /** @@ -260,39 +302,89 @@ public boolean isZero() { * @return 0 if a == b, negative if a < b and positive if a > b. */ public static int compare(final UInt256 a, final UInt256 b) { - int comp; - for (int i = N_LIMBS - 1; i >= 0; i--) { - comp = Integer.compareUnsigned(a.limbs[i], b.limbs[i]); - if (comp != 0) return comp; - } - return 0; + if (a.u3 != b.u3) return Long.compareUnsigned(a.u3, b.u3); + if (a.u2 != b.u2) return Long.compareUnsigned(a.u2, b.u2); + if (a.u1 != b.u1) return Long.compareUnsigned(a.u1, b.u1); + return Long.compareUnsigned(a.u0, b.u0); } - @Override - public boolean equals(final Object obj) { - if (this == obj) return true; - if (!(obj instanceof UInt256)) return false; - UInt256 other = (UInt256) obj; + // -------------------------------------------------------------------------- + // endregion + + // region Bitwise Operations + // -------------------------------------------------------------------------- + + /** + * Bitwise AND operation + * + * @param other The UInt256 to AND with this. + * @return The UInt256 result from the bitwise AND operation + */ + public UInt256 and(final UInt256 other) { + return new UInt256(u3 & other.u3, u2 & other.u2, u1 & other.u1, u0 & other.u0); + } - int xor = - (this.limbs[0] ^ other.limbs[0]) - | (this.limbs[1] ^ other.limbs[1]) - | (this.limbs[2] ^ other.limbs[2]) - | (this.limbs[3] ^ other.limbs[3]) - | (this.limbs[4] ^ other.limbs[4]) - | (this.limbs[5] ^ other.limbs[5]) - | (this.limbs[6] ^ other.limbs[6]) - | (this.limbs[7] ^ other.limbs[7]); - return xor == 0; + /** + * Bitwise XOR operation + * + * @param other The UInt256 to XOR with this. + * @return The UInt256 result from the bitwise XOR operation + */ + public UInt256 xor(final UInt256 other) { + return new UInt256(u3 ^ other.u3, u2 ^ other.u2, u1 ^ other.u1, u0 ^ other.u0); } - @Override - public int hashCode() { - int h = 1; - for (int i = 0; i < N_LIMBS; i++) { - h = 31 * h + limbs[i]; - } - return h; + /** + * Bitwise OR operation + * + * @param other The UInt256 to OR with this. + * @return The UInt256 result from the bitwise OR operation + */ + public UInt256 or(final UInt256 other) { + return new UInt256(u3 | other.u3, u2 | other.u2, u1 | other.u1, u0 | other.u0); + } + + /** + * Bitwise NOT operation + * + * @return The UInt256 result from the bitwise NOT operation + */ + public UInt256 not() { + return new UInt256(~u3, ~u2, ~u1, ~u0); + } + + /** + * Bitwise shift left. + * + * @param shift The number of bits to shift left (at most 64). + * @return The shifted UInt256. + */ + public UInt256 shiftLeft(final int shift) { + // Unchecked: 0 <= shift < 64 + if (shift == 0) return this; + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + return new UInt256(z3, z2, z1, z0); + } + + /** + * Bitwise shift right. + * + * @param shift The number of bits to shift right (at most 64). + * @return The shifted UInt256. + */ + public UInt256 shiftRight(final int shift) { + // Unchecked: 0 <= shift < 64 + if (shift == 0) return this; + int invShift = (N_BITS_PER_LIMB - shift); + long z3 = (u3 >>> shift); + long z2 = (u2 >>> shift) | u3 << invShift; + long z1 = (u1 >>> shift) | u2 << invShift; + long z0 = (u0 >>> shift) | u1 << invShift; + return new UInt256(z3, z2, z1, z0); } // -------------------------------------------------------------------------- @@ -301,15 +393,78 @@ public int hashCode() { // region Arithmetic Operations // -------------------------------------------------------------------------- + /** + * Compute the two-complement negative representation of this integer. + * + * @return The negative of this integer. + */ + public UInt256 neg() { + long z0 = ~u0 + 1; + long carry = (z0 == 0) ? 1 : 0; + long z1 = ~u1 + carry; + carry = (z1 == 0 && carry == 1) ? 1 : 0; + long z2 = ~u2 + carry; + carry = (z2 == 0 && carry == 1) ? 1 : 0; + long z3 = ~u3 + carry; + return new UInt256(z3, z2, z1, z0); + } + + /** + * Compute the absolute value for a two-complement negative representation of this integer. + * + * @return The absolute value of this integer. + */ + public UInt256 abs() { + return isNegative() ? neg() : this; + } + + /** + * Addition + * + *

Compute the wrapping sum of 2 256-bits integers. + * + * @param other Integer to add to this integer. + * @return The sum. + */ + public UInt256 add(final UInt256 other) { + if (isZero()) return other; + if (other.isZero()) return this; + return adc(other).UInt256Value(); + } + + /** + * Multiplication + * + *

Compute the wrapping product of 2 256-bits integers. + * + * @param other Integer to multiply with this integer. + * @return The product. + */ + public UInt256 mul(final UInt256 other) { + if (isZero() || other.isZero()) return ZERO; + if (other.isOne()) return this; + if (this.isOne()) return other; + if (u3 != 0) return mul256(other).UInt256Value(); + if (u2 != 0) return mul192(other).UInt256Value(); + if (u1 != 0) return mul128(other); + return mul64(other); + } + /** * Unsigned modulo reduction. * - * @param modulus The modulus of the reduction - * @return The remainder modulo {@code modulus}. + *

Compute dividend (mod modulus) as unsigned big-endian integer. + * + * @param modulus The modulus of the reduction. + * @return The remainder. */ public UInt256 mod(final UInt256 modulus) { - if (this.isZero() || modulus.isZero()) return ZERO; - return new UInt256(knuthRemainder(this.limbs, modulus.limbs), modulus.length); + if (isZero()) return ZERO; + if (modulus.u3 != 0) return modulus.asModulus256().reduce(this); + if (modulus.u2 != 0) return modulus.asModulus192().reduce(this); + if (modulus.u1 != 0) return modulus.asModulus128().reduce(this); + if ((modulus.u0 == 0) || (modulus.u0 == 1)) return ZERO; + return modulus.asModulus64().reduce(this); } /** @@ -322,17 +477,12 @@ public UInt256 mod(final UInt256 modulus) { * @return The remainder modulo {@code modulus}. */ public UInt256 signedMod(final UInt256 modulus) { - if (this.isZero() || modulus.isZero()) return ZERO; - int[] x = new int[N_LIMBS]; - int[] y = new int[N_LIMBS]; - absInto(x, this.limbs, N_LIMBS); - absInto(y, modulus.limbs, N_LIMBS); - int[] r = knuthRemainder(x, y); - if (isNeg(this.limbs, N_LIMBS)) { - negate(r, N_LIMBS); - return new UInt256(r); - } - return new UInt256(r, modulus.length); + if (isZero() || modulus.isZeroOrOne() || modulus.equals(MAX)) return ZERO; + UInt256 a = abs(); + UInt256 m = modulus.abs(); + UInt256 r = a.mod(m); + if (isNegative()) r = r.neg(); + return r; } /** @@ -343,10 +493,13 @@ public UInt256 signedMod(final UInt256 modulus) { * @return This integer this + other (mod modulus). */ public UInt256 addMod(final UInt256 other, final UInt256 modulus) { - if (modulus.isZero()) return ZERO; - int[] sum = addWithCarry(this.limbs, this.length, other.limbs, other.length); - int[] rem = knuthRemainder(sum, modulus.limbs); - return new UInt256(rem, modulus.length); + if (isZero()) return other.mod(modulus); + if (other.isZero()) return this.mod(modulus); + if (modulus.isZeroOrOne()) return ZERO; + if (modulus.u3 != 0) return modulus.asModulus256().sum(this, other); + if (modulus.u2 != 0) return modulus.asModulus192().sum(this, other); + if (modulus.u1 != 0) return modulus.asModulus128().sum(this, other); + return modulus.asModulus64().sum(this, other); } /** @@ -357,365 +510,1216 @@ public UInt256 addMod(final UInt256 other, final UInt256 modulus) { * @return This integer this + other (mod modulus). */ public UInt256 mulMod(final UInt256 other, final UInt256 modulus) { - if (this.isZero() || other.isZero() || modulus.isZero()) return ZERO; - int[] result = addMul(this.limbs, this.length, other.limbs, other.length); - result = knuthRemainder(result, modulus.limbs); - return new UInt256(result, modulus.length); + if (this.isZero() || other.isZero() || modulus.isZeroOrOne()) return ZERO; + if (this.isOne()) return other.mod(modulus); + if (other.isOne()) return this.mod(modulus); + if (modulus.u3 != 0) return modulus.asModulus256().mul(this, other); + if (modulus.u2 != 0) return modulus.asModulus192().mul(this, other); + if (modulus.u1 != 0) return modulus.asModulus128().mul(this, other); + return modulus.asModulus64().mul(this, other); } + // -------------------------------------------------------------------------- + // endregion + + // region Bytes Arithmetic Operations + // + // Addition is faster when done straight in byte[] + // -------------------------------------------------------------------------- + /** - * Bitwise AND operation + * Addition in bytes: x + y. * - * @param other The UInt256 to AND with this. - * @return The UInt256 result from the bitwise AND operation + *

Compute the wrapping sum + * + * @param x The left value to add. + * @param y The right value to add. + * @return The sum x + y. */ - public UInt256 and(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] & other.limbs[0]; - result[1] = this.limbs[1] & other.limbs[1]; - result[2] = this.limbs[2] & other.limbs[2]; - result[3] = this.limbs[3] & other.limbs[3]; - result[4] = this.limbs[4] & other.limbs[4]; - result[5] = this.limbs[5] & other.limbs[5]; - result[6] = this.limbs[6] & other.limbs[6]; - result[7] = this.limbs[7] & other.limbs[7]; - return new UInt256(result, N_LIMBS); + public static byte[] add(final byte[] x, final byte[] y) { + if (isZero(x)) return y; + if (isZero(y)) return x; + return adc(x, y); } /** - * Bitwise XOR operation + * Substraction in bytes: x - y. * - * @param other The UInt256 to XOR with this. - * @return The UInt256 result from the bitwise XOR operation + *

Compute the wrapping difference + * + * @param x The left value. + * @param y The right value to substract. + * @return The wrapping difference x - y. */ - public UInt256 xor(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] ^ other.limbs[0]; - result[1] = this.limbs[1] ^ other.limbs[1]; - result[2] = this.limbs[2] ^ other.limbs[2]; - result[3] = this.limbs[3] ^ other.limbs[3]; - result[4] = this.limbs[4] ^ other.limbs[4]; - result[5] = this.limbs[5] ^ other.limbs[5]; - result[6] = this.limbs[6] ^ other.limbs[6]; - result[7] = this.limbs[7] ^ other.limbs[7]; - return new UInt256(result, N_LIMBS); + public static byte[] sub(final byte[] x, final byte[] y) { + if (isZero(y)) return x; + if (isZero(x)) return neg(y); + return sbb(x, y); } - /** - * Bitwise OR operation - * - * @param other The UInt256 to OR with this. - * @return The UInt256 result from the bitwise OR operation - */ - public UInt256 or(final UInt256 other) { - int[] result = new int[N_LIMBS]; - result[0] = this.limbs[0] | other.limbs[0]; - result[1] = this.limbs[1] | other.limbs[1]; - result[2] = this.limbs[2] | other.limbs[2]; - result[3] = this.limbs[3] | other.limbs[3]; - result[4] = this.limbs[4] | other.limbs[4]; - result[5] = this.limbs[5] | other.limbs[5]; - result[6] = this.limbs[6] | other.limbs[6]; - result[7] = this.limbs[7] | other.limbs[7]; - return new UInt256(result, N_LIMBS); + private static boolean isZero(final byte[] arr) { + int index = Arrays.mismatch(arr, ZERO_BYTES); + return (index == -1 || index >= arr.length); } - /** - * Bitwise NOT operation - * - * @return The UInt256 result from the bitwise NOT operation - */ - public UInt256 not() { - int[] result = new int[N_LIMBS]; - result[0] = ~this.limbs[0]; - result[1] = ~this.limbs[1]; - result[2] = ~this.limbs[2]; - result[3] = ~this.limbs[3]; - result[4] = ~this.limbs[4]; - result[5] = ~this.limbs[5]; - result[6] = ~this.limbs[6]; - result[7] = ~this.limbs[7]; - return new UInt256(result, N_LIMBS); + private static byte[] padLeft(final byte[] a) { + if (a.length == BYTESIZE) return a; + byte[] res = new byte[BYTESIZE]; + System.arraycopy(a, 0, res, BYTESIZE - a.length, a.length); + return res; + } + + private static byte[] adc(final byte[] a, final byte[] b) { + int res; + int carry = 0; + byte[] x = padLeft(a); + byte[] y = padLeft(b); + byte[] sum = new byte[BYTESIZE]; + for (int i = 31; i >= 0; i--) { + res = (x[i] & 0xFF) + (y[i] & 0xFF) + carry; + sum[i] = (byte) res; + carry = (res >> 8); + } + return sum; + } + + private static byte[] neg(final byte[] a) { + int res; + int carry = 1; + byte[] x = padLeft(a); + byte[] out = new byte[BYTESIZE]; + for (int i = 31; i >= 0; i--) { + res = (~x[i] & 0xFF) + carry; + out[i] = (byte) res; + carry = (res >> 8); + } + return out; + } + + private static byte[] sbb(final byte[] a, final byte[] b) { + int res; + int borrow = 0; + byte[] x = padLeft(a); + byte[] y = padLeft(b); + byte[] diff = new byte[BYTESIZE]; + for (int i = 31; i >= 0; i--) { + res = (x[i] & 0xFF) - (y[i] & 0xFF) - borrow; + diff[i] = (byte) res; + borrow = (res < 0) ? 1 : 0; + } + return diff; } // -------------------------------------------------------------------------- // endregion - /** - * Simple addition - * - * @param other The UInt256 to add to this. - * @return The UInt256 result from the addition - */ - public UInt256 add(final UInt256 other) { - return new UInt256( - addWithCarry(this.limbs, this.limbs.length, other.limbs, other.limbs.length)); + // region private basic operations + // + // adc (add and carry): carry, a <- a + b + // mac (multiply accumulate): a <- a + b * c + carryIn + // -------------------------------------------------------------------------- + + private UInt320 shiftLeftWide(final int shift) { + if (shift == 0) return UInt320Value(); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + long z4 = u3 >>> invShift; + return new UInt320(z4, z3, z2, z1, z0); + } + + private UInt256 shiftDigitsRight() { + return new UInt256(0, u3, u2, u1); + } + + private UInt257 adc(final UInt256 other) { + if (isZero()) return new UInt257(false, other); + if (other.isZero()) return new UInt257(false, this); + long z0 = u0 + other.u0; + long carry = Long.compareUnsigned(z0, u0) < 0 ? 1 : 0; + + long z1 = u1 + other.u1 + carry; + long overflow1 = Long.compareUnsigned(z1, u1) < 0 ? 1 : 0; + long overflow2 = Long.compareUnsigned(z1, u1) == 0 ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + long z2 = u2 + other.u2 + carry; + overflow1 = Long.compareUnsigned(z2, u2) < 0 ? 1 : 0; + overflow2 = Long.compareUnsigned(z2, u2) == 0 ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + long z3 = u3 + other.u3 + carry; + overflow1 = Long.compareUnsigned(z3, u3) < 0 ? 1 : 0; + overflow2 = Long.compareUnsigned(z3, u3) == 0 ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + return new UInt257(carry != 0, new UInt256(z3, z2, z1, z0)); } - // region Support (private) Algorithms + private UInt256 mac128(final long multiplier, final UInt256 carryIn) { + // Multiply accumulate for 128bits integer (this): + // = * multiplier + carryIn + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + long p0 = u0 * multiplier; + long p1 = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = p0 + carryIn.u1; + long carry = p1 + ((Long.compareUnsigned(z0, p0) < 0) ? 1 : 0); + + p0 = u1 * multiplier; + p1 = Math.unsignedMultiplyHigh(u1, multiplier); + long res = carry + carryIn.u2; + long z1 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z1, res) < 0) ? 1 : 0); + + long z2 = carry; + + return new UInt256(0, z2, z1, z0); + } + + private UInt256 mac192(final long multiplier, final UInt256 carryIn) { + // Multiply accumulate for 192bits integer (this): + // Returns this * multiplier + (carryIn >>> 64) + // = * multiplier + carryIn + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + long p0 = u0 * multiplier; + long p1 = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = p0 + carryIn.u1; + long carry = p1 + ((Long.compareUnsigned(z0, p0) < 0) ? 1 : 0); + + p0 = u1 * multiplier; + p1 = Math.unsignedMultiplyHigh(u1, multiplier); + long res = carry + carryIn.u2; + long z1 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z1, res) < 0) ? 1 : 0); + + p0 = u2 * multiplier; + p1 = Math.unsignedMultiplyHigh(u2, multiplier); + res = carry + carryIn.u3; + long z2 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z2, res) < 0) ? 1 : 0); + + long z3 = carry; + return new UInt256(z3, z2, z1, z0); + } + + private UInt320 mac256(final long multiplier, final UInt320 carryIn) { + // Multiply accumulate for 192bits integer (this): + // Returns this * multiplier + carryIn + // = * multiplier + carryIn + if (multiplier == 0) return carryIn.shiftDigitsRight(); + + long p0 = u0 * multiplier; + long p1 = Math.unsignedMultiplyHigh(u0, multiplier); + long z0 = p0 + carryIn.u1; + long carry = p1 + ((Long.compareUnsigned(z0, p0) < 0) ? 1 : 0); + + p0 = u1 * multiplier; + p1 = Math.unsignedMultiplyHigh(u1, multiplier); + long res = carry + carryIn.u2; + long z1 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z1, res) < 0) ? 1 : 0); + + p0 = u2 * multiplier; + p1 = Math.unsignedMultiplyHigh(u2, multiplier); + res = carry + carryIn.u3; + long z2 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z2, res) < 0) ? 1 : 0); + + p0 = u3 * multiplier; + p1 = Math.unsignedMultiplyHigh(u3, multiplier); + res = carry + carryIn.u4; + long z3 = res + p0; + carry = (Long.compareUnsigned(res, carry) < 0) ? 1 : 0; + carry += p1 + ((Long.compareUnsigned(z3, res) < 0) ? 1 : 0); + + long z4 = carry; + + return new UInt320(z4, z3, z2, z1, z0); + } + + // -------------------------------------------------------------------------- + // endregion + + // region private multiplication // -------------------------------------------------------------------------- - private static int nSetLimbs(final int[] x) { - int offset = x.length - 1; - while ((offset >= 0) && (x[offset] == 0)) offset--; - return offset + 1; + + private UInt256 mul64(final UInt256 v) { + // = * multiplier + long p0 = u0 * v.u0; + long p1 = Math.unsignedMultiplyHigh(u0, v.u0); + long z0 = p0; + long carry = p1; + + p0 = u0 * v.u1; + p1 = Math.unsignedMultiplyHigh(u0, v.u1); + long z1 = p0 + carry; + carry = p1 + ((Long.compareUnsigned(z1, p0) < 0) ? 1 : 0); + + p0 = u0 * v.u2; + p1 = Math.unsignedMultiplyHigh(u0, v.u2); + long z2 = p0 + carry; + carry = p1 + ((Long.compareUnsigned(z2, p0) < 0) ? 1 : 0); + + long z3 = u0 * v.u3 + carry; + + return new UInt256(z3, z2, z1, z0); } - private static int compareLimbs(final int[] a, final int aLen, final int[] b, final int bLen) { - int cmp; - if (aLen > bLen) { - for (int i = aLen - 1; i >= bLen; i--) { - cmp = Integer.compareUnsigned(a[i], 0); - if (cmp != 0) return cmp; - } - } else if (aLen < bLen) { - for (int i = bLen - 1; i >= aLen; i--) { - cmp = Integer.compareUnsigned(0, b[i]); - if (cmp != 0) return cmp; - } - } - for (int i = Math.min(aLen, bLen) - 1; i >= 0; i--) { - cmp = Integer.compareUnsigned(a[i], b[i]); - if (cmp != 0) return cmp; - } - return 0; + private UInt256 mul128(final UInt256 v) { + UInt256 res; + + res = mac128(v.u0, ZERO); + long z0 = res.u0; + res = mac128(v.u1, res); + long z1 = res.u0; + res = mac128(v.u2, res); + long z2 = res.u0; + res = mac128(v.u3, res); + + return new UInt256(res.u0, z2, z1, z0); } - private static boolean isNeg(final int[] x, final int xLen) { - return x[xLen - 1] < 0; + private UInt448 mul192(final UInt256 v) { + UInt256 res; + res = mac192(v.u0, ZERO); + long z0 = res.u0; + res = mac192(v.u1, res); + long z1 = res.u0; + res = mac192(v.u2, res); + long z2 = res.u0; + res = mac192(v.u3, res); + + return new UInt448(res.u3, res.u2, res.u1, res.u0, z2, z1, z0); } - private static void negate(final int[] x, final int xLen) { - int carry = 1; - for (int i = 0; i < xLen; i++) { - x[i] = ~x[i] + carry; - carry = (x[i] == 0 && carry == 1) ? 1 : 0; + private UInt512 mul256(final UInt256 v) { + UInt320 res; + res = mac256(v.u0, UInt320.ZERO); + long z0 = res.u0; + res = mac256(v.u1, res); + long z1 = res.u0; + res = mac256(v.u2, res); + long z2 = res.u0; + res = mac256(v.u3, res); + return new UInt512(res.u4, res.u3, res.u2, res.u1, res.u0, z2, z1, z0); + } + + // -------------------------------------------------------------------------- + // endregion + + // region private quotient estimation + // -------------------------------------------------------------------------- + + // Lookup table for $\floor{\frac{2^{19} -3 ⋅ 2^8}{d_9 - 256}}$ + private static final short[] LUT = + new short[] { + 2045, 2037, 2029, 2021, 2013, 2005, 1998, 1990, 1983, 1975, 1968, 1960, 1953, 1946, 1938, + 1931, 1924, 1917, 1910, 1903, 1896, 1889, 1883, 1876, 1869, 1863, 1856, 1849, 1843, 1836, + 1830, 1824, 1817, 1811, 1805, 1799, 1792, 1786, 1780, 1774, 1768, 1762, 1756, 1750, 1745, + 1739, 1733, 1727, 1722, 1716, 1710, 1705, 1699, 1694, 1688, 1683, 1677, 1672, 1667, 1661, + 1656, 1651, 1646, 1641, 1636, 1630, 1625, 1620, 1615, 1610, 1605, 1600, 1596, 1591, 1586, + 1581, 1576, 1572, 1567, 1562, 1558, 1553, 1548, 1544, 1539, 1535, 1530, 1526, 1521, 1517, + 1513, 1508, 1504, 1500, 1495, 1491, 1487, 1483, 1478, 1474, 1470, 1466, 1462, 1458, 1454, + 1450, 1446, 1442, 1438, 1434, 1430, 1426, 1422, 1418, 1414, 1411, 1407, 1403, 1399, 1396, + 1392, 1388, 1384, 1381, 1377, 1374, 1370, 1366, 1363, 1359, 1356, 1352, 1349, 1345, 1342, + 1338, 1335, 1332, 1328, 1325, 1322, 1318, 1315, 1312, 1308, 1305, 1302, 1299, 1295, 1292, + 1289, 1286, 1283, 1280, 1276, 1273, 1270, 1267, 1264, 1261, 1258, 1255, 1252, 1249, 1246, + 1243, 1240, 1237, 1234, 1231, 1228, 1226, 1223, 1220, 1217, 1214, 1211, 1209, 1206, 1203, + 1200, 1197, 1195, 1192, 1189, 1187, 1184, 1181, 1179, 1176, 1173, 1171, 1168, 1165, 1163, + 1160, 1158, 1155, 1153, 1150, 1148, 1145, 1143, 1140, 1138, 1135, 1133, 1130, 1128, 1125, + 1123, 1121, 1118, 1116, 1113, 1111, 1109, 1106, 1104, 1102, 1099, 1097, 1095, 1092, 1090, + 1088, 1086, 1083, 1081, 1079, 1077, 1074, 1072, 1070, 1068, 1066, 1064, 1061, 1059, 1057, + 1055, 1053, 1051, 1049, 1047, 1044, 1042, 1040, 1038, 1036, 1034, 1032, 1030, 1028, 1026, + 1024, + }; + + private static final long TWO_POW_SIXTY = 1L << 60; + + // Taken from https://gmplib.org/~tege/division-paper.pdf taken from III. algorithm 2 + private static long reciprocal(final long x) { + // Unchecked: x >= (1 << 63) + long x0 = x & 1L; + int x9 = (int) (x >>> 55); + long x40 = 1 + (x >>> 24); + long x63 = (x + 1) >>> 1; + long v0 = LUT[x9 - 256] & 0xFFFFL; + long v1 = (v0 << 11) - ((v0 * v0 * x40) >>> 40) - 1; + long v2 = (v1 << 13) + ((v1 * (TWO_POW_SIXTY - v1 * x40)) >>> 47); + long e = ((v2 >>> 1) & (-x0)) - v2 * x63; + long s = Math.unsignedMultiplyHigh(v2, e); + long v3 = (s >>> 1) + (v2 << 31); + long t0 = v3 * x; + long t1 = Math.unsignedMultiplyHigh(v3, x); + t0 += x; + t1 += Long.compareUnsigned(t0, x) < 0 ? 1 : 0; + t1 += x; + long v4 = v3 - t1; + return v4; + } + + private static DivEstimate div2by1(final long x1, final long x0, final long y, final long yInv) { + // wrapping umul z1 * yInv + long q0 = x1 * yInv; + long q1 = Math.unsignedMultiplyHigh(x1, yInv); + + // wrapping uadd + + <1, 0> + long sum = q0 + x0; + long carry = ((q0 & x0) | ((q0 | x0) & ~sum)) >>> 63; + q0 = sum; + q1 += x1 + carry + 1; + + long r = x0 - q1 * y; + + long adjust = Long.compareUnsigned(q0, r) < 0 ? 1 : 0; + q1 -= adjust; + r += adjust * y; + + adjust = Long.compareUnsigned(y, r) <= 0 ? 1 : 0; + q1 += adjust; + r -= y * adjust; + + return new DivEstimate(q1, r); + } + + private static long mod2by1(final long x1, final long x0, final long y, final long yInv) { + // wrapping umul z1 * yInv + long q0 = x1 * yInv; + long q1 = Math.unsignedMultiplyHigh(x1, yInv); + + // wrapping uadd + + <1, 0> + long sum = q0 + x0; + long carry = ((q0 & x0) | ((q0 | x0) & ~sum)) >>> 63; + q0 = sum; + q1 += x1 + carry + 1; + + long r = x0 - q1 * y; + + long adjust = Long.compareUnsigned(q0, r) < 0 ? 1 : 0; + r += y * adjust; + + adjust = Long.compareUnsigned(y, r) <= 0 ? 1 : 0; + r -= y * adjust; + + return r; + } + + // -------------------------------------------------------------------------- + // endregion + + // region Records + // -------------------------------------------------------------------------- + record UInt257(boolean carry, UInt256 u) { + boolean isUInt64() { + return !carry && u.isUInt64(); + } + + boolean isUInt256() { + return !carry; + } + + UInt256 UInt256Value() { + return u; + } + + UInt320 shiftLeftWide(final int shift) { + long u4 = (carry ? 1L : 0L); + if (shift == 0) return new UInt320(u4, u.u3, u.u2, u.u1, u.u0); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u.u0 << shift); + long z1 = (u.u1 << shift) | u.u0 >>> invShift; + long z2 = (u.u2 << shift) | u.u1 >>> invShift; + long z3 = (u.u3 << shift) | u.u2 >>> invShift; + long z4 = (u4 << shift) | u.u3 >>> invShift; + return new UInt320(z4, z3, z2, z1, z0); } } - private static void absInplace(final int[] x, final int xLen) { - if (isNeg(x, xLen)) negate(x, xLen); + record UInt128(long u1, long u0) {} + + record UInt192(long u2, long u1, long u0) {} + + record UInt320(long u4, long u3, long u2, long u1, long u0) { + static final UInt320 ZERO = new UInt320(0, 0, 0, 0, 0); + + UInt320 shiftDigitsRight() { + return new UInt320(0, u4, u3, u2, u1); + } } - private static void absInto(final int[] dst, final int[] src, final int srcLen) { - System.arraycopy(src, 0, dst, 0, srcLen); - absInplace(dst, dst.length); + record UInt448(long u6, long u5, long u4, long u3, long u2, long u1, long u0) { + UInt256 UInt256Value() { + return new UInt256(u3, u2, u1, u0); + } + + UInt512 shiftLeftWide(final int shift) { + if (shift == 0) return new UInt512(0, u6, u5, u4, u3, u2, u1, u0); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + long z4 = (u4 << shift) | u3 >>> invShift; + long z5 = (u5 << shift) | u4 >>> invShift; + long z6 = (u6 << shift) | u5 >>> invShift; + long z7 = u6 >>> invShift; + return new UInt512(z7, z6, z5, z4, z3, z2, z1, z0); + } } - private static int numberOfLeadingZeros(final int[] x, final int xLen) { - int leadingIndex = xLen - 1; - while ((leadingIndex >= 0) && (x[leadingIndex] == 0)) leadingIndex--; - return 32 * (xLen - leadingIndex - 1) + Integer.numberOfLeadingZeros(x[leadingIndex]); + record UInt512(long u7, long u6, long u5, long u4, long u3, long u2, long u1, long u0) { + UInt256 UInt256Value() { + return new UInt256(u3, u2, u1, u0); + } + + UInt576 shiftLeftWide(final int shift) { + if (shift == 0) return new UInt576(0, u7, u6, u5, u4, u3, u2, u1, u0); + int invShift = (N_BITS_PER_LIMB - shift); + long z0 = (u0 << shift); + long z1 = (u1 << shift) | u0 >>> invShift; + long z2 = (u2 << shift) | u1 >>> invShift; + long z3 = (u3 << shift) | u2 >>> invShift; + long z4 = (u4 << shift) | u3 >>> invShift; + long z5 = (u5 << shift) | u4 >>> invShift; + long z6 = (u6 << shift) | u5 >>> invShift; + long z7 = (u7 << shift) | u6 >>> invShift; + long z8 = u7 >>> invShift; + return new UInt576(z8, z7, z6, z5, z4, z3, z2, z1, z0); + } } - private static void shiftLeftInto( - final int[] result, final int[] x, final int xLen, final int shift) { - // Unchecked: result should be initialised with zeroes - // Unchecked: result length should be at least x.length + limbShift - int limbShift = shift / N_BITS_PER_LIMB; - int bitShift = shift % N_BITS_PER_LIMB; - if (bitShift == 0) { - System.arraycopy(x, 0, result, limbShift, xLen); - return; + record UInt576(long u8, long u7, long u6, long u5, long u4, long u3, long u2, long u1, long u0) {} + + private record DivEstimate(long q, long r) {} + + // -------------------------------------------------------------------------- + // endregion + + // region 64bits Modulus + // -------------------------------------------------------------------------- + record Modulus64(long u0) { + Modulus64 shiftLeft(final int shift) { + return (shift == 0) ? this : new Modulus64(u0 << shift); } - int j = limbShift; - int carry = 0; - for (int i = 0; i < xLen; ++i, ++j) { - result[j] = (x[i] << bitShift) | carry; - carry = x[i] >>> (32 - bitShift); + UInt256 reduce(final UInt256 that) { + if (that.isUInt64()) { + return UInt256.fromLong(Long.remainderUnsigned(that.u0, u0)); + } + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (sum.isUInt64()) return UInt256.fromLong(Long.remainderUnsigned(sum.u().u0, u0)); + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt64() && b.isUInt64()) { + UInt256 prod = a.mul64(b); + if (prod.isUInt64()) return UInt256.fromLong(Long.remainderUnsigned(prod.u0, u0)); + return reduce(prod); + } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u0); + Modulus64 m = shiftLeft(shift); + long inv = reciprocal(m.u0); + UInt256 x = (a.isUInt64()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt64()) ? b : m.reduceNormalised(b, shift, inv); + UInt256 prod = x.mul64(y); + return prod.isUInt64() + ? UInt256.fromLong(Long.remainderUnsigned(prod.u0, u0)) + : m.reduceNormalised(prod, shift, inv); + } + + private long reduceStep(final long v1, final long v0, final long inv) { + return mod2by1(v1, v0, u0, inv); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + if ((v.u4 | v.u3) == 0 + && Long.compareUnsigned(v.u3, u0) <= 0 + && Long.compareUnsigned(v.u2, u0) <= 0) { + long r; + if (v.u2 != 0 || Long.compareUnsigned(v.u1, u0) > 0) { + r = (Long.compareUnsigned(v.u2, u0) >= 0) ? v.u2 - u0 : v.u2; + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = (Long.compareUnsigned(v.u1, u0) >= 0) ? v.u1 - u0 : v.u1; + r = reduceStep(r, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } + return reduceNormalisedSlowPathUInt256(v, shift, inv); + } + + private UInt256 reduceNormalisedSlowPathUInt256( + final UInt320 v, final int shift, final long inv) { + long r; + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u0) > 0) { + r = (Long.compareUnsigned(v.u4, u0) >= 0) ? v.u4 - u0 : v.u4; + r = reduceStep(r, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = (Long.compareUnsigned(v.u3, u0) >= 0) ? v.u3 - u0 : v.u3; + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + if ((v.u4 | v.u3) == 0 + && Long.compareUnsigned(v.u3, u0) <= 0 + && Long.compareUnsigned(v.u2, u0) <= 0) { + long r; + if (v.u2 != 0 || Long.compareUnsigned(v.u1, u0) > 0) { + r = reduceStep(v.u2, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = reduceStep(v.u1, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); + } + return reduceNormalisedSlowPathUInt257(v, shift, inv); + } + + private UInt256 reduceNormalisedSlowPathUInt257( + final UInt320 v, final int shift, final long inv) { + long r; + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u0) > 0) { + r = reduceStep(v.u4, v.u3, inv); + r = reduceStep(r, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, inv); + r = reduceStep(r, v.u1, inv); + r = reduceStep(r, v.u0, inv); + } + return UInt256.fromLong(r >>> shift); } - if (carry != 0) result[j] = carry; // last carry } - private static void shiftRightInto( - final int[] result, final int[] x, final int xLen, final int shift) { - // Unchecked: result length should be at least x.length - limbShift - int limbShift = shift / 32; - int bitShift = shift % 32; - int nLimbs = xLen - limbShift; - if (nLimbs <= 0) return; + // -------------------------------------------------------------------------- + // endregion 64bits Modulus - if (bitShift == 0) { - System.arraycopy(x, limbShift, result, 0, nLimbs); - return; + // region 128bits Modulus + // -------------------------------------------------------------------------- + record Modulus128(long u1, long u0) { + Modulus128 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + return new Modulus128((u1 << shift) | (u0 >>> invShift), u0 << shift); } - int carry = 0; - for (int i = nLimbs - 1 + limbShift, j = nLimbs - 1; j >= 0; i--, j--) { - int r = (x[i] >>> bitShift) | carry; - result[j] = r; - carry = x[i] << (32 - bitShift); - } - } - - private static int[] addWithCarry(final int[] x, final int xLen, final int[] y, final int yLen) { - // Step 1: Add with carry - int[] a; - int[] b; - int aLen; - int bLen; - if (xLen < yLen) { - a = y; - aLen = yLen; - b = x; - bLen = xLen; - } else { - a = x; - aLen = xLen; - b = y; - bLen = yLen; - } - int[] sum = new int[aLen + 1]; - long carry = 0; - for (int i = 0; i < bLen; i++) { - long ai = a[i] & MASK_L; - long bi = b[i] & MASK_L; - long s = ai + bi + carry; - sum[i] = (int) s; - carry = s >>> 32; - } - int icarry = (int) carry; - for (int i = bLen; i < aLen; i++) { - sum[i] = a[i] + icarry; - icarry = (a[i] != 0 && sum[i] == 0) ? 1 : 0; - } - sum[aLen] = icarry; - return sum; + int compareTo(final UInt256 v) { + if ((v.u3 | v.u2) != 0) return -1; + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + int cmp = sum.isUInt256() ? compareTo(sum.UInt256Value()) : -1; + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt128() && b.isUInt128()) { + UInt256 prod = a.mul128(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod; + return reduce(prod); + } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u1); + Modulus128 m = shiftLeft(shift); + long inv = reciprocal(m.u1); + UInt256 x = (a.isUInt128()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt128()) ? b : m.reduceNormalised(b, shift, inv); + UInt256 prod = x.mul128(y); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod; + return m.reduceNormalised(prod, shift, inv); + } + + private UInt128 addBack(final long v1, final long v0) { + // Quotient estimate could be 0, +1, +2 of real quotient. + // Add back step in case estimate is off. + long z0 = v0 + u0; + long carry = (Long.compareUnsigned(z0, v0) < 0) ? 1 : 0; + + long z1 = v1 + u1 + carry; + long overflow1 = (Long.compareUnsigned(z1, v1) < 0) ? 1 : 0; + long overflow2 = (Long.compareUnsigned(z1, v1) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + if (carry == 0) { // unlikely: add back again + // Proper quotient estimation guarantees recursion max-depth <= 2 + // Unbounded recursion only if there's a bug - fail fast is better than give wrong result + return addBack(z1, z0); + } + return new UInt128(z1, z0); + } + + private UInt128 mulSub(final long x1, final long x0, final long q) { + // Multiply-subtract: highest limb is already substracted + // = * q + long p0 = u0 * q; + long p1 = Math.unsignedMultiplyHigh(u0, q); + long z0 = x0 - p0; + long carry = p1 + ((Long.compareUnsigned(x0, z0) < 0) ? 1 : 0); + + // Propagate overflows (borrows) + long z1 = x1 - carry; + long borrow = (Long.compareUnsigned(x1, z1) < 0) ? 1 : 0; + + if (borrow != 0) return addBack(z1, z0); // less likely + return new UInt128(z1, z0); + } + + private UInt128 mulSubOverflow(final long v1, final long v0) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, MAX> + // = -1 * u0 = + long z0 = v0 + u0; + long carry = u0 - 1 + ((Long.compareUnsigned(v0, z0) <= 0) ? 1 : 0); + + long z1 = v1 + u1 - carry; + return new UInt128(z1, z0); + } + + private UInt128 reduceStep(final long v2, final long v1, final long v0, final long inv) { + if (v2 == u1) return mulSubOverflow(v1, v0); + DivEstimate qr = div2by1(v2, v1, u1, inv); + if (qr.q != 0) return mulSub(qr.r, v0, qr.q); + return new UInt128(qr.r, v0); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 == 0 && Long.compareUnsigned(v.u4, u1) < 0 && Long.compareUnsigned(v.u3, u1) < 0) { + UInt128 r; + if (v.u3 != 0 || Long.compareUnsigned(v.u2, u1) >= 0) { + r = reduceStep(v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } + return reduceNormalisedSlowPath(v, shift, inv); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 == 0 && Long.compareUnsigned(v.u4, u1) < 0 && Long.compareUnsigned(v.u3, u1) < 0) { + UInt128 r; + if (v.u3 != 0 || Long.compareUnsigned(v.u2, u1) >= 0) { + r = reduceStep(v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } + return reduceNormalisedSlowPath(v, shift, inv); + } + + private UInt256 reduceNormalisedSlowPath(final UInt320 v, final int shift, final long inv) { + UInt128 r; + if (Long.compareUnsigned(v.u4, u1) >= 0) { + r = reduceStep(0, v.u4, v.u3, inv); + r = reduceStep(r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u1, r.u0, v.u0, inv); + } + return new UInt256(0, 0, r.u1, r.u0).shiftRight(shift); + } } - private static int[] addMul(final int[] a, final int aLen, final int[] b, final int bLen) { - // Shortest in outer loop, swap if needed - int[] x; - int xLen; - int[] y; - int yLen; - if (a.length < b.length) { - x = b; - xLen = bLen; - y = a; - yLen = aLen; - } else { - x = a; - xLen = aLen; - y = b; - yLen = bLen; - } - int[] lhs = new int[xLen + yLen + 1]; - - // Main algo - for (int i = 0; i < yLen; i++) { - long carry = 0; - long yi = y[i] & MASK_L; - - int k = i; - for (int j = 0; j < xLen; j++, k++) { - long prod = yi * (x[j] & MASK_L); - long sum = (lhs[k] & MASK_L) + prod + carry; - lhs[k] = (int) sum; - carry = sum >>> 32; + // -------------------------------------------------------------------------- + // endregion 128bits Modulus + + // region 192bits Modulus + // -------------------------------------------------------------------------- + record Modulus192(long u2, long u1, long u0) { + Modulus192 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + long z0 = u0 << shift; + long z1 = (u1 << shift) | (u0 >>> invShift); + long z2 = (u2 << shift) | (u1 >>> invShift); + return new Modulus192(z2, z1, z0); + } + + int compareTo(final UInt256 v) { + if (v.u3 != 0) return -1; + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + int compareTo(final UInt448 v) { + if ((v.u6 | v.u5 | v.u4 | v.u3) != 0) return -1; + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt448 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (!sum.carry()) { + int cmp = compareTo(sum.UInt256Value()); + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); } + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + return m.reduceNormalised(sum, shift, inv); + } - // propagate leftover carry - while (carry != 0 && k < lhs.length) { - long sum = (lhs[k] & MASK_L) + carry; - lhs[k] = (int) sum; - carry = sum >>> 32; - k++; + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + if (a.isUInt192() && b.isUInt192()) { + UInt448 prod = a.mul192(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return reduce(prod); } + // reduce-multiply-reduce + int shift = Long.numberOfLeadingZeros(u2); + Modulus192 m = shiftLeft(shift); + long inv = reciprocal(m.u2); + UInt256 x = (a.isUInt192()) ? a : m.reduceNormalised(a, shift, inv); + UInt256 y = (b.isUInt192()) ? b : m.reduceNormalised(b, shift, inv); + UInt448 prod = x.mul192(y); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return m.reduceNormalised(prod, shift, inv); } - return lhs; - } - - private static int[] knuthRemainder(final int[] dividend, final int[] modulus) { - int[] result = new int[N_LIMBS]; - int divLen = nSetLimbs(dividend); - int modLen = nSetLimbs(modulus); - int cmp = compareLimbs(dividend, divLen, modulus, modLen); - if (cmp < 0) { - System.arraycopy(dividend, 0, result, 0, divLen); - return result; - } else if (cmp == 0) { - return result; - } - - int shift = numberOfLeadingZeros(modulus, modLen); - int limbShift = shift / 32; - int n = modLen - limbShift; - if (n == 0) return result; - if (n == 1) { - if (divLen == 1) { - result[0] = Integer.remainderUnsigned(dividend[0], modulus[0]); - return result; + + private UInt192 addBack(final long v2, final long v1, final long v0) { + // Add back + long z0 = v0 + u0; + long carry = ((v0 & u0) | ((v0 | u0) & ~z0)) >>> 63; + + long z1 = v1 + u1 + carry; + long overflow1 = (Long.compareUnsigned(z1, v1) < 0) ? 1 : 0; + long overflow2 = (Long.compareUnsigned(z1, v1) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + long z2 = v2 + u2 + carry; + overflow1 = (Long.compareUnsigned(z2, v2) < 0) ? 1 : 0; + overflow2 = (Long.compareUnsigned(z2, v2) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + if (carry == 0) { // unlikely: add back again + // Proper quotient estimation guarantees recursion max-depth <= 2 + // Unbounded recursion only if there's a bug - fail fast is better than give wrong result + return addBack(z2, z1, z0); } - long d = modulus[0] & MASK_L; - long rem = 0; - // Process from most significant limb downwards - for (int i = divLen - 1; i >= 0; i--) { - long cur = (rem << 32) | (dividend[i] & MASK_L); - rem = Long.remainderUnsigned(cur, d); + return new UInt192(z2, z1, z0); + } + + private UInt192 mulSub(final long v2, final long v1, final long v0, final long q) { + // Multiply-subtract: already have highest 2 limbs + // = * q + long p0 = u0 * q; + long p1 = Math.unsignedMultiplyHigh(u0, q); + long z0 = v0 - p0; + long carry = p1 + (((Long.compareUnsigned(v0, z0) < 0) ? 1 : 0)); + + p0 = u1 * q; + p1 = Math.unsignedMultiplyHigh(u1, q); + long res = v1 - p0; + long z1 = res - carry; + long borrow = (Long.compareUnsigned(res, z1) < 0) ? 1 : 0; + carry = p1 + ((Long.compareUnsigned(v1, res) < 0) ? 1 : 0); + + // Propagate overflows (borrows) + long t2 = v2 - carry; + long z2 = t2 - borrow; + borrow = + ((Long.compareUnsigned(v2, t2) < 0) ? 1 : 0) + | ((Long.compareUnsigned(t2, z2) < 0) ? 1 : 0); + + if (borrow != 0) return addBack(z2, z1, z0); // unlikely + return new UInt192(z2, z1, z0); + } + + private UInt192 mulSubOverflow(final long v2, final long v1, final long v0) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, -1> + // = -1 * u0 = + long z0 = v0 + u0; + long carry = u0 - 1 + ((Long.compareUnsigned(v0, z0) <= 0) ? 1 : 0); + + long res = v1 - carry; + long z1 = res + u1; + long borrow = (Long.compareUnsigned(res, z1) <= 0) ? 1 : 0; + carry = u1 - 1 + ((Long.compareUnsigned(v1, res) < 0) ? 1 : 0); + + long z2 = v2 - carry + u2 - borrow; + return new UInt192(z2, z1, z0); + } + + private UInt192 reduceStep( + final long v3, final long v2, final long v1, final long v0, final long inv) { + if (v3 == u2) return mulSubOverflow(v2, v1, v0); + DivEstimate qr = div2by1(v3, v2, u2, inv); + if (qr.q != 0) return mulSub(qr.r, v1, v0, qr.q); + return new UInt192(v2, v1, v0); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt192 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt192 r; + UInt320 v = that.shiftLeftWide(shift); + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); } - result[0] = (int) rem; - result[1] = (int) (rem >>> 32); - return result; - } - // Normalize - int m = divLen - n; - int bitShift = shift % 32; - int[] vLimbs = new int[n]; - shiftLeftInto(vLimbs, modulus, modLen, shift); - int[] uLimbs = new int[divLen + 1]; - shiftLeftInto(uLimbs, dividend, divLen, bitShift); - - long[] vLimbsAsLong = new long[n]; - for (int i = 0; i < n; i++) { - vLimbsAsLong[i] = vLimbs[i] & MASK_L; - } - - // Main division loop - long vn1 = vLimbsAsLong[n - 1]; - long vn2 = vLimbsAsLong[n - 2]; - for (int j = m; j >= 0; j--) { - long ujn = (uLimbs[j + n] & MASK_L); - long ujn1 = (uLimbs[j + n - 1] & MASK_L); - long ujn2 = (uLimbs[j + n - 2] & MASK_L); - - long dividendPart = (ujn << 32) | ujn1; - // Check that no need for Unsigned version of divrem. - long qhat = Long.divideUnsigned(dividendPart, vn1); - long rhat = Long.remainderUnsigned(dividendPart, vn1); - - while (qhat == 0x1_0000_0000L || Long.compareUnsigned(qhat * vn2, (rhat << 32) | ujn2) > 0) { - qhat--; - rhat += vn1; - if (rhat >= 0x1_0000_0000L) break; + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt448 that, final int shift, final long inv) { + UInt512 v = that.shiftLeftWide(shift); + if ((v.u7 | v.u6 | v.u5) == 0 + && Long.compareUnsigned(v.u6, u2) < 0 + && Long.compareUnsigned(v.u5, u2) < 0 + && Long.compareUnsigned(v.u4, u2) < 0) { + UInt192 r; + if (v.u4 != 0 || Long.compareUnsigned(v.u3, u2) >= 0) { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u3, v.u2, v.u1, v.u0, inv); + } + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); } + return reduceNormalisedSlowPath(v, shift, inv); + } + + private UInt256 reduceNormalisedSlowPath(final UInt512 v, final int shift, final long inv) { + UInt192 r; + if (v.u7 != 0 || Long.compareUnsigned(v.u6, u2) >= 0) { + r = reduceStep(v.u7, v.u6, v.u5, v.u4, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u6 != 0 || Long.compareUnsigned(v.u5, u2) >= 0) { + r = reduceStep(v.u6, v.u5, v.u4, v.u3, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u5, v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u2, r.u1, r.u0, v.u0, inv); + } + return new UInt256(0, r.u2, r.u1, r.u0).shiftRight(shift); + } + } + + // -------------------------------------------------------------------------- + // endregion 192bits Modulus + + // region 256bits Modulus + // -------------------------------------------------------------------------- + record Modulus256(long u3, long u2, long u1, long u0) { + Modulus256 shiftLeft(final int shift) { + if (shift == 0) return this; + int invShift = N_BITS_PER_LIMB - shift; + long z0 = u0 << shift; + long z1 = (u1 << shift) | (u0 >>> invShift); + long z2 = (u2 << shift) | (u1 >>> invShift); + long z3 = (u3 << shift) | (u2 >>> invShift); + return new Modulus256(z3, z2, z1, z0); + } - // Multiply-subtract qhat*v from u slice - long borrow = 0; - for (int i = 0; i < n; i++) { - long prod = vLimbsAsLong[i] * qhat; - long sub = (uLimbs[i + j] & MASK_L) - (prod & MASK_L) - borrow; - uLimbs[i + j] = (int) sub; - borrow = (prod >>> 32) - (sub >> 32); + int compareTo(final UInt256 v) { + if (v.u3 != u3) return Long.compareUnsigned(u3, v.u3); + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + int compareTo(final UInt512 v) { + if ((v.u7 | v.u6 | v.u5 | v.u4) != 0) return -1; + if (v.u3 != u3) return Long.compareUnsigned(u3, v.u3); + if (v.u2 != u2) return Long.compareUnsigned(u2, v.u2); + if (v.u1 != u1) return Long.compareUnsigned(u1, v.u1); + return Long.compareUnsigned(u0, v.u0); + } + + UInt256 reduce(final UInt256 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that; + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 reduce(final UInt512 that) { + int cmp = compareTo(that); + if (cmp == 0) return ZERO; + if (cmp > 0) return that.UInt256Value(); + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(that, shift, inv); + } + + UInt256 sum(final UInt256 a, final UInt256 b) { + UInt257 sum = a.adc(b); + if (!sum.carry()) { + int cmp = compareTo(sum.UInt256Value()); + if (cmp == 0) return ZERO; + if (cmp > 0) return sum.UInt256Value(); } - long sub = (uLimbs[j + n] & MASK_L) - borrow; - uLimbs[j + n] = (int) sub; - - if (sub < 0) { - // Add back - long carry = 0; - for (int i = 0; i < n; i++) { - long sum = (uLimbs[i + j] & MASK_L) + vLimbsAsLong[i] + carry; - uLimbs[i + j] = (int) sum; - carry = sum >>> 32; + int shift = Long.numberOfLeadingZeros(u3); + Modulus256 m = shiftLeft(shift); + long inv = reciprocal(m.u3); + return m.reduceNormalised(sum, shift, inv); + } + + UInt256 mul(final UInt256 a, final UInt256 b) { + // multiply-reduce + UInt512 prod = a.mul256(b); + int cmp = compareTo(prod); + if (cmp == 0) return ZERO; + if (cmp > 0) return prod.UInt256Value(); + return reduce(prod); + } + + private UInt256 addBack(final long v3, final long v2, final long v1, final long v0) { + // Add back + long z0 = v0 + u0; + long carry = ((v0 & u0) | ((v0 | u0) & ~z0)) >>> 63; + + long z1 = v1 + u1 + carry; + long overflow1 = (Long.compareUnsigned(z1, v1) < 0) ? 1 : 0; + long overflow2 = (Long.compareUnsigned(z1, v1) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + long z2 = v2 + u2 + carry; + overflow1 = (Long.compareUnsigned(z2, v2) < 0) ? 1 : 0; + overflow2 = (Long.compareUnsigned(z2, v2) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + long z3 = v3 + u3 + carry; + overflow1 = (Long.compareUnsigned(z3, v3) < 0) ? 1 : 0; + overflow2 = (Long.compareUnsigned(z3, v3) == 0) ? 1 : 0; + carry = overflow1 | (overflow2 & carry); + + if (carry == 0) { // unlikely: add back again + // Proper quotient estimation guarantees recursion max-depth <= 2 + // Unbounded recursion only if there's a bug - fail fast is better than give wrong result + return addBack(z3, z2, z1, z0); + } + return new UInt256(z3, z2, z1, z0); + } + + private UInt256 mulSub( + final long v3, final long v2, final long v1, final long v0, final long q) { + // Multiply-subtract: already have highest 1 limbs + // = * q + long p0 = u0 * q; + long p1 = Math.unsignedMultiplyHigh(u0, q); + long z0 = v0 - p0; + long carry = p1 + (((Long.compareUnsigned(v0, z0) < 0) ? 1 : 0)); + + p0 = u1 * q; + p1 = Math.unsignedMultiplyHigh(u1, q); + long res = v1 - p0; + long z1 = res - carry; + long borrow = (Long.compareUnsigned(res, z1) < 0) ? 1 : 0; + carry = p1 + ((Long.compareUnsigned(v1, res) < 0) ? 1 : 0); + + p0 = u2 * q; + p1 = Math.unsignedMultiplyHigh(u2, q); + long t2 = v2 - p0; + res = t2 - borrow; + long z2 = res - carry; + borrow = (Long.compareUnsigned(res, z2) < 0) ? 1 : 0; + carry = + p1 + + ((Long.compareUnsigned(v2, t2) < 0) ? 1 : 0) + + ((Long.compareUnsigned(t2, res) < 0) ? 1 : 0); + + // Propagate overflows (borrows) + long t3 = v3 - carry; + long z3 = t3 - borrow; + borrow = + ((Long.compareUnsigned(v3, t3) < 0) ? 1 : 0) + | ((Long.compareUnsigned(t3, z3) < 0) ? 1 : 0); + + if (borrow != 0) return addBack(z3, z2, z1, z0); + return new UInt256(z3, z2, z1, z0); + } + + private UInt256 mulSubOverflow(final long v3, final long v2, final long v1, final long v0) { + // Overflow case: div2by1 quotient would be <1, 0>, but adjusts to <0, -1> + // = -1 * u0 = + long res, borrow; + + long z0 = v0 + u0; + long carry = u0 - 1 + ((Long.compareUnsigned(v0, z0) <= 0) ? 1 : 0); + + res = v1 - carry; + long z1 = res + u1; + borrow = (Long.compareUnsigned(res, z1) <= 0) ? 1 : 0; + carry = u1 - 1 + ((Long.compareUnsigned(v1, res) < 0) ? 1 : 0); + + res = v2 - carry - borrow; + long z2 = res + u2; + borrow = (Long.compareUnsigned(res, z2) <= 0) ? 1 : 0; + carry = u2 - 1 + ((Long.compareUnsigned(v2, res) < 0) ? 1 : 0); + + long z3 = v3 + u3 - carry - borrow; + return new UInt256(z3, z2, z1, z0); + } + + private UInt256 reduceStep( + final long v4, final long v3, final long v2, final long v1, final long v0, final long inv) { + if (v4 == u3) return mulSubOverflow(v3, v2, v1, v0); + DivEstimate qr = div2by1(v4, v3, u3, inv); + if (qr.q != 0) return mulSub(qr.r, v2, v1, v0, qr.q); + return new UInt256(v3, v2, v1, v0); + } + + private UInt256 reduceNormalised(final UInt256 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + return reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt257 that, final int shift, final long inv) { + UInt320 v = that.shiftLeftWide(shift); + return reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv).shiftRight(shift); + } + + private UInt256 reduceNormalised(final UInt512 that, final int shift, final long inv) { + UInt576 v = that.shiftLeftWide(shift); + if ((v.u8 | v.u7 | v.u6) == 0 + && Long.compareUnsigned(v.u7, u3) < 0 + && Long.compareUnsigned(v.u6, u3) < 0 + && Long.compareUnsigned(v.u5, u3) < 0) { + UInt256 r; + if (v.u5 != 0 || Long.compareUnsigned(v.u4, u3) >= 0) { + r = reduceStep(v.u5, v.u4, v.u3, v.u2, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u4, v.u3, v.u2, v.u1, v.u0, inv); } - uLimbs[j + n] = (int) (uLimbs[j + n] + carry); + return r.shiftRight(shift); } + return reduceNormalisedSlowPath(v, shift, inv); + } + + private UInt256 reduceNormalisedSlowPath(final UInt576 v, final int shift, final long inv) { + UInt256 r; + if (v.u8 != 0 || Long.compareUnsigned(v.u7, u3) >= 0) { + r = reduceStep(v.u8, v.u7, v.u6, v.u5, v.u4, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u3, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else if (v.u7 != 0 || Long.compareUnsigned(v.u6, u3) >= 0) { + r = reduceStep(v.u7, v.u6, v.u5, v.u4, v.u3, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } else { + r = reduceStep(v.u6, v.u5, v.u4, v.u3, v.u2, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u1, inv); + r = reduceStep(r.u3, r.u2, r.u1, r.u0, v.u0, inv); + } + return r.shiftRight(shift); } - // Unnormalize remainder - shiftRightInto(result, uLimbs, n, bitShift); - return result; } // -------------------------------------------------------------------------- - // endregion + // endregion 256bits Modulus } diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/AddModOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/AddModOperationOptimized.java index 9cbe403359b..515bd4960ad 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/operation/AddModOperationOptimized.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/AddModOperationOptimized.java @@ -54,14 +54,10 @@ public static OperationResult staticOperation(final MessageFrame frame) { final Bytes value1 = frame.popStackItem(); final Bytes value2 = frame.popStackItem(); - if (value2.isZero()) { - resultBytes = Bytes.EMPTY; - } else { - UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); - UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); - UInt256 b2 = UInt256.fromBytesBE(value2.toArrayUnsafe()); - resultBytes = Bytes.wrap(b0.addMod(b1, b2).toBytesBE()); - } + UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); + UInt256 b2 = UInt256.fromBytesBE(value2.toArrayUnsafe()); + resultBytes = Bytes.wrap(b0.addMod(b1, b2).toBytesBE()); frame.pushStackItem(resultBytes); return addModSuccess; diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/AddOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/AddOperationOptimized.java index dcccad8ca7a..03a68f5e4a8 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/operation/AddOperationOptimized.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/AddOperationOptimized.java @@ -53,11 +53,9 @@ public static OperationResult staticOperation(final MessageFrame frame) { final Bytes value0 = frame.popStackItem(); final Bytes value1 = frame.popStackItem(); - UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); - UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); - - UInt256 result = b0.add(b1); - byte[] resultArray = result.toBytesBE(); + byte[] b0 = value0.toArrayUnsafe(); + byte[] b1 = value1.toArrayUnsafe(); + byte[] resultArray = UInt256.add(b0, b1); frame.pushStackItem(Bytes.wrap(resultArray)); return addSuccess; diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/ModOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/ModOperationOptimized.java index dfa4e9a0daa..59d0d49f2c5 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/operation/ModOperationOptimized.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/ModOperationOptimized.java @@ -20,7 +20,6 @@ import org.hyperledger.besu.evm.gascalculator.GasCalculator; import org.apache.tuweni.bytes.Bytes; -import org.apache.tuweni.bytes.Bytes32; /** The Mod operation. */ public class ModOperationOptimized extends AbstractFixedCostOperation { @@ -51,14 +50,11 @@ public Operation.OperationResult executeFixedCostOperation( public static OperationResult staticOperation(final MessageFrame frame) { final Bytes value0 = frame.popStackItem(); final Bytes value1 = frame.popStackItem(); - Bytes resultBytes; - if (value1.isZero()) { - resultBytes = (Bytes) Bytes32.ZERO; - } else { - UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); - UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); - resultBytes = Bytes.wrap(b0.mod(b1).toBytesBE()); - } + + UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); + Bytes resultBytes = Bytes.wrap(b0.mod(b1).toBytesBE()); + frame.pushStackItem(resultBytes); return modSuccess; } diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/MulModOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/MulModOperationOptimized.java index 06bd1291505..23b6224f716 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/operation/MulModOperationOptimized.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/MulModOperationOptimized.java @@ -52,15 +52,10 @@ public static OperationResult staticOperation(final MessageFrame frame) { final Bytes value1 = frame.popStackItem(); final Bytes value2 = frame.popStackItem(); - Bytes resultBytes; - if (value2.isZero()) { - resultBytes = Bytes.EMPTY; - } else { - UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); - UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); - UInt256 b2 = UInt256.fromBytesBE(value2.toArrayUnsafe()); - resultBytes = Bytes.wrap(b0.mulMod(b1, b2).toBytesBE()); - } + UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); + UInt256 b2 = UInt256.fromBytesBE(value2.toArrayUnsafe()); + Bytes resultBytes = Bytes.wrap(b0.mulMod(b1, b2).toBytesBE()); frame.pushStackItem(resultBytes); return mulModSuccess; diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/MulOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/MulOperationOptimized.java new file mode 100644 index 00000000000..32bf01d68c9 --- /dev/null +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/MulOperationOptimized.java @@ -0,0 +1,62 @@ +/** + * Copyright ConsenSys AG. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package org.hyperledger.besu.evm.operation; + +import org.hyperledger.besu.evm.EVM; +import org.hyperledger.besu.evm.UInt256; +import org.hyperledger.besu.evm.frame.MessageFrame; +import org.hyperledger.besu.evm.gascalculator.GasCalculator; + +import org.apache.tuweni.bytes.Bytes; + +/** The Mul operation. */ +public class MulOperationOptimized extends AbstractFixedCostOperation { + + /** The Mul operation success result. */ + static final OperationResult mulSuccess = new OperationResult(5, null); + + /** + * Instantiates a new Mul operation. + * + * @param gasCalculator the gas calculator + */ + public MulOperation(final GasCalculator gasCalculator) { + super(0x02, "MUL", 2, 1, gasCalculator, gasCalculator.getLowTierGasCost()); + } + + @Override + public Operation.OperationResult executeFixedCostOperation( + final MessageFrame frame, final EVM evm) { + return staticOperation(frame); + } + + /** + * Performs mul operation + * + * @param frame the frame + * @return the operation result + */ + public static OperationResult staticOperation(final MessageFrame frame) { + final Bytes value0 = frame.popStackItem(); + final Bytes value1 = frame.popStackItem(); + + UInt256 u0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 u1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); + Bytes resultBytes = Bytes.wrap(u0.mul(u1).toBytesBE()); + + frame.pushStackItem(resultBytes); + return mulModSuccess; + } +} diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/SModOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/SModOperationOptimized.java index 8d98851e492..ec843a56e67 100644 --- a/evm/src/main/java/org/hyperledger/besu/evm/operation/SModOperationOptimized.java +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/SModOperationOptimized.java @@ -20,7 +20,6 @@ import org.hyperledger.besu.evm.gascalculator.GasCalculator; import org.apache.tuweni.bytes.Bytes; -import org.apache.tuweni.bytes.Bytes32; /** The SMod operation. */ public class SModOperationOptimized extends AbstractFixedCostOperation { @@ -52,16 +51,11 @@ public static OperationResult staticOperation(final MessageFrame frame) { final Bytes value0 = frame.popStackItem(); final Bytes value1 = frame.popStackItem(); - Bytes resultBytes; - if (value1.isZero()) { - resultBytes = (Bytes) Bytes32.ZERO; - } else { - UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); - UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); - resultBytes = Bytes.wrap(b0.signedMod(b1).toBytesBE()); - } - frame.pushStackItem(resultBytes); + UInt256 b0 = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b1 = UInt256.fromBytesBE(value1.toArrayUnsafe()); + Bytes resultBytes = Bytes.wrap(b0.signedMod(b1).toBytesBE()); + frame.pushStackItem(resultBytes); return smodSuccess; } } diff --git a/evm/src/main/java/org/hyperledger/besu/evm/operation/SubOperationOptimized.java b/evm/src/main/java/org/hyperledger/besu/evm/operation/SubOperationOptimized.java new file mode 100644 index 00000000000..ad29e8e5a87 --- /dev/null +++ b/evm/src/main/java/org/hyperledger/besu/evm/operation/SubOperationOptimized.java @@ -0,0 +1,62 @@ +/* + * Copyright contributors to Besu. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ +package org.hyperledger.besu.evm.operation; + +import org.hyperledger.besu.evm.EVM; +import org.hyperledger.besu.evm.UInt256; +import org.hyperledger.besu.evm.frame.MessageFrame; +import org.hyperledger.besu.evm.gascalculator.GasCalculator; + +import org.apache.tuweni.bytes.Bytes; + +/** The Sub (Subtract) operation. */ +public class SubOperationOptimized extends AbstractFixedCostOperation { + + /** The Sub operation success result. */ + static final OperationResult subSuccess = new OperationResult(3, null); + + /** + * Instantiates a new Sub operation. + * + * @param gasCalculator the gas calculator + */ + public SubOperationOptimized(final GasCalculator gasCalculator) { + super(0x03, "SUB", 2, 1, gasCalculator, gasCalculator.getVeryLowTierGasCost()); + } + + @Override + public Operation.OperationResult executeFixedCostOperation( + final MessageFrame frame, final EVM evm) { + return staticOperation(frame); + } + + /** + * Performs Sub operation. + * + * @param frame the frame + * @return the operation result + */ + public static OperationResult staticOperation(final MessageFrame frame) { + final Bytes value0 = frame.popStackItem(); + final Bytes value1 = frame.popStackItem(); + + byte[] b0 = value0.toArrayUnsafe(); + byte[] b1 = value1.toArrayUnsafe(); + Bytes resultBytes = Bytes.wrap(UInt256.sub(b0, b1)); + + frame.pushStackItem(resultBytes); + return subSuccess; + } +} diff --git a/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java b/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java index 97744492bbe..99cad626218 100644 --- a/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java +++ b/evm/src/test/java/org/hyperledger/besu/evm/UInt256PropertyBasedTest.java @@ -21,15 +21,19 @@ import net.jqwik.api.Arbitraries; import net.jqwik.api.Arbitrary; +import net.jqwik.api.Assume; import net.jqwik.api.ForAll; import net.jqwik.api.Property; import net.jqwik.api.Provide; +import net.jqwik.api.Tuple; import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; public class UInt256PropertyBasedTest { + private static final BigInteger TWO_256 = BigInteger.ONE.shiftLeft(256); // region Test Data Providers + // -------------------------------------------------------------------------- @Provide Arbitrary unsigned1to32() { @@ -63,9 +67,105 @@ Arbitrary shifts() { return Arbitraries.integers().between(-512, 512); } + // -------------------------------------------------------------------------- + // endregion + + // region Shaped Generators (from record-refactor regression suite) + // -------------------------------------------------------------------------- + + @Provide + Arbitrary bytes0to64_shaped() { + final Arbitrary requiredLengths = + Arbitraries.of(0, 1, 7, 8, 9, 15, 16, 17, 31, 32, 33, 64); + + final Arbitrary lengths = + Arbitraries.frequencyOf( + Tuple.of(8, requiredLengths), Tuple.of(2, Arbitraries.integers().between(0, 64))); + + final Arbitrary patterns = + Arbitraries.frequencyOf( + Tuple.of(4, Arbitraries.of(Pattern.ALL_ZERO)), + Tuple.of(4, Arbitraries.of(Pattern.ALL_FF)), + Tuple.of(4, Arbitraries.of(Pattern.LIMB_SIGN_BITS)), + Tuple.of(10, Arbitraries.of(Pattern.RANDOM))); + + return lengths.flatMap( + len -> + patterns.flatMap( + pat -> { + if (len == 0) { + return Arbitraries.of(new byte[0]); + } + return Arbitraries.bytes() + .array(byte[].class) + .ofSize(len) + .map(b -> applyPattern(b, pat)); + })); + } + + @Provide + Arbitrary bytes32_shaped() { + return bytes0to64_shaped().map(UInt256PropertyBasedTest::toBytes32Unsigned); + } + + @Provide + Arbitrary bytes33to64_shaped() { + return bytes0to64_shaped() + .filter(b -> b.length >= 33 && b.length <= 64) + .map(UInt256PropertyBasedTest::forceNonZeroHighBytes); + } + + @Provide + Arbitrary shifts_extreme() { + final Arbitrary edges = + Arbitraries.of( + -512, -257, -256, -129, -128, -65, -64, -1, 0, 1, 63, 64, 65, 127, 128, 129, 255, 256, + 257, 512); + return Arbitraries.frequencyOf( + Tuple.of(7, edges), Tuple.of(3, Arbitraries.integers().between(-512, 512))); + } + + private enum Pattern { + ALL_ZERO, + ALL_FF, + LIMB_SIGN_BITS, + RANDOM + } + + private static byte[] applyPattern(final byte[] bytes, final Pattern pat) { + switch (pat) { + case ALL_ZERO: + Arrays.fill(bytes, (byte) 0x00); + return bytes; + case ALL_FF: + Arrays.fill(bytes, (byte) 0xFF); + return bytes; + case LIMB_SIGN_BITS: + Arrays.fill(bytes, (byte) 0x00); + forceMsbAtIndexIfPresent(bytes, 0); + forceMsbAtIndexIfPresent(bytes, 8); + forceMsbAtIndexIfPresent(bytes, 16); + forceMsbAtIndexIfPresent(bytes, 24); + forceMsbAtIndexIfPresent(bytes, bytes.length - 1); + return bytes; + case RANDOM: + default: + return bytes; + } + } + + private static void forceMsbAtIndexIfPresent(final byte[] bytes, final int index) { + if (index < 0 || index >= bytes.length) { + return; + } + bytes[index] = (byte) (bytes[index] | 0x80); + } + + // -------------------------------------------------------------------------- // endregion // region Serialization Tests + // -------------------------------------------------------------------------- @Property void property_roundTripUnsigned_toFromBytesBE(@ForAll("unsigned0to64") final byte[] any) { @@ -82,9 +182,11 @@ void property_roundTripUnsigned_toFromBytesBE(@ForAll("unsigned0to64") final byt assertThat(back).containsExactly(expected); } + // -------------------------------------------------------------------------- // endregion // region Comparison Tests + // -------------------------------------------------------------------------- @Property void property_equals_compare_consistent( @@ -106,9 +208,11 @@ void property_equals_compare_consistent( assertThat(Integer.signum(cmp)).isEqualTo(Integer.signum(bc)); } + // -------------------------------------------------------------------------- // endregion // region Modulo Tests (MOD/SMOD/ADDMOD/MULMOD) + // -------------------------------------------------------------------------- @Property void property_mod_matchesBigInteger( @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] m) { @@ -226,9 +330,11 @@ void property_divByZero_invariants() { assertThat(x.mulMod(x, zero).toBytesBE()).containsExactly(Bytes32.ZERO.toArrayUnsafe()); } + // -------------------------------------------------------------------------- // endregion // region AND Operation Tests + // -------------------------------------------------------------------------- @Property void property_and_matchesBytesAnd( @@ -411,9 +517,11 @@ void property_and_specific_patterns() { .containsExactly(bytes1.and(bytes2).toArrayUnsafe()); } + // -------------------------------------------------------------------------- // endregion // region XOR Operation Tests + // -------------------------------------------------------------------------- @Property void property_xor_matchesBytesXor( @@ -511,7 +619,7 @@ void property_xor_involutive( void property_xor_with_allOnes_is_complement(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act final UInt256 result = ua.xor(allOnes); @@ -574,7 +682,7 @@ void property_xor_specific_patterns() { (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55 }); // 01010101... - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - 0xAA XOR 0x55 = 0xFF assertThat(pattern1.xor(pattern2)).isEqualTo(allOnes); @@ -600,9 +708,11 @@ void property_xor_reversible( assertThat(decrypted).isEqualTo(ua); } + // -------------------------------------------------------------------------- // endregion // region OR Operation Tests + // -------------------------------------------------------------------------- @Property void property_or_matchesBytesOr( @@ -687,7 +797,7 @@ void property_or_idempotent(@ForAll("unsigned1to32") final byte[] a) { void property_or_with_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - A | 0xFF...FF = 0xFF...FF (domination) assertThat(ua.or(allOnes)).isEqualTo(allOnes); assertThat(allOnes.or(ua)).isEqualTo(allOnes); @@ -742,7 +852,7 @@ void property_or_with_complement_is_allOnes(@ForAll("unsigned1to32") final byte[ complementBytes[i] = (byte) ~aBytes32[i]; } final UInt256 complement = UInt256.fromBytesBE(complementBytes); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - A | ~A = 0xFF...FF assertThat(ua.or(complement)).isEqualTo(allOnes); } @@ -774,7 +884,7 @@ void property_or_specific_patterns() { (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55, (byte) 0x55 }); // 01010101... - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - 0xAA OR 0x55 = 0xFF assertThat(pattern1.or(pattern2)).isEqualTo(allOnes); // Verify with Bytes implementation @@ -834,9 +944,11 @@ void property_or_relationship_with_and_xor( assertThat(left).isEqualTo(right); } + // -------------------------------------------------------------------------- // endregion // region NOT Operation Tests + // -------------------------------------------------------------------------- @Property void property_not_matchesBytesNot(@ForAll("unsigned1to32") final byte[] a) { @@ -896,7 +1008,7 @@ void property_not_involutive(@ForAll("unsigned1to32") final byte[] a) { void property_not_different_from_original(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act final UInt256 notA = ua.not(); @@ -922,7 +1034,7 @@ void property_not_with_or_is_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); // Act & Assert - A | ~A = 0xFF...FF - assertThat(ua.or(ua.not())).isEqualTo(UInt256.ALL_ONES); + assertThat(ua.or(ua.not())).isEqualTo(UInt256.MAX); } @Property @@ -930,7 +1042,7 @@ void property_not_with_xor_is_allOnes(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); // Act & Assert - A ^ ~A = 0xFF...FF - assertThat(ua.xor(ua.not())).isEqualTo(UInt256.ALL_ONES); + assertThat(ua.xor(ua.not())).isEqualTo(UInt256.MAX); } @Property @@ -963,7 +1075,7 @@ void property_not_de_morgans_or( void property_not_zero_is_allOnes() { // Arrange final UInt256 zero = UInt256.ZERO; - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - ~0 = 0xFF...FF assertThat(zero.not()).isEqualTo(allOnes); @@ -972,7 +1084,7 @@ void property_not_zero_is_allOnes() { @Property void property_not_allOnes_is_zero() { // Arrange - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; final UInt256 zero = UInt256.ZERO; // Act & Assert - ~0xFF...FF = 0 @@ -1036,7 +1148,7 @@ void property_not_each_bit_flipped(@ForAll("unsigned1to32") final byte[] a) { void property_not_xor_equivalence(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act & Assert - ~A = A ^ 0xFF...FF assertThat(ua.not()).isEqualTo(ua.xor(allOnes)); @@ -1046,7 +1158,7 @@ void property_not_xor_equivalence(@ForAll("unsigned1to32") final byte[] a) { void property_not_sum_with_original(@ForAll("unsigned1to32") final byte[] a) { // Arrange final UInt256 ua = UInt256.fromBytesBE(a); - final UInt256 allOnes = UInt256.ALL_ONES; + final UInt256 allOnes = UInt256.MAX; // Act // When we add A + ~A (bitwise), we should get all 1s in each bit position @@ -1057,11 +1169,10 @@ void property_not_sum_with_original(@ForAll("unsigned1to32") final byte[] a) { assertThat(ua.xor(ua.not())).isEqualTo(allOnes); } + // -------------------------------------------------------------------------- // endregion - // endregion - - // Simple ADD tests + // region Simple ADD tests // -------------------------------------------------------------------------- @Property void property_add_matchesBigInteger( @@ -1291,7 +1402,490 @@ void property_add_near_max_boundary() { // -------------------------------------------------------------------------- // endregion + // region Byte-Array ADD Tests (static UInt256.add(byte[], byte[])) + // -------------------------------------------------------------------------- + + @Property + void property_addBytes_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act + final byte[] got = UInt256.add(a32, b32); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + byte[] expected = bigUnsignedToBytes32(A.add(B)); + assertThat(got).containsExactly(expected); + } + + @Property + void property_addBytes_commutative( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act & Assert + assertThat(UInt256.add(a32, b32)).containsExactly(UInt256.add(b32, a32)); + } + + @Property + void property_addBytes_identity(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] zero = new byte[32]; + + // Act & Assert — x + 0 = x + assertThat(UInt256.add(a32, zero)).containsExactly(a32); + assertThat(UInt256.add(zero, a32)).containsExactly(a32); + } + + @Property + void property_addBytes_consistent_with_UInt256_add( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act + final byte[] fromBytes = UInt256.add(a32, b32); + final byte[] fromUInt256 = ua.add(ub).toBytesBE(); + + // Assert — byte[] and UInt256 paths must agree + assertThat(fromBytes).containsExactly(fromUInt256); + } + + @Property + void property_addBytes_singleLimb_matchesBigInteger( + @ForAll("singleLimbUnsigned1to4") final byte[] a, + @ForAll("singleLimbUnsigned1to4") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act + final byte[] got = UInt256.add(a32, b32); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + byte[] expected = bigUnsignedToBytes32(A.add(B)); + assertThat(got).containsExactly(expected); + } + + // -------------------------------------------------------------------------- + // endregion + + // region Byte-Array SUB Tests (static UInt256.sub(byte[], byte[])) + // -------------------------------------------------------------------------- + + @Property + void property_subBytes_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act + final byte[] got = UInt256.sub(a32, b32); + + // Assert — wrapping subtraction mod 2^256 + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + BigInteger result = A.subtract(B).mod(TWO_256); + byte[] expected = bigUnsignedToBytes32(result); + assertThat(got).containsExactly(expected); + } + + @Property + void property_subBytes_identity(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] zero = new byte[32]; + + // Act & Assert — x - 0 = x + assertThat(UInt256.sub(a32, zero)).containsExactly(a32); + } + + @Property + void property_subBytes_self_is_zero(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + + // Act & Assert — x - x = 0 + assertThat(UInt256.sub(a32, a32)).containsExactly(new byte[32]); + } + + @Property + void property_subBytes_add_inverse( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act & Assert — (x + y) - y = x + byte[] sum = UInt256.add(a32, b32); + assertThat(UInt256.sub(sum, b32)).containsExactly(a32); + } + + @Property + void property_subBytes_anti_commutative( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange — fresh copies for each sub call (sub/neg may mutate input arrays) + final byte[] aMinusB = UInt256.sub(toBytes32Unsigned(a), toBytes32Unsigned(b)); + final byte[] bMinusA = UInt256.sub(toBytes32Unsigned(b), toBytes32Unsigned(a)); + + // Assert — (a - b) + (b - a) = 0 + assertThat(UInt256.add(aMinusB, bMinusA)).containsExactly(new byte[32]); + } + + @Property + void property_subBytes_consistent_with_add_neg( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act — sub(a,b) should equal a + neg(b) via UInt256 + final byte[] fromSub = UInt256.sub(a32, b32); + final byte[] fromAddNeg = ua.add(ub.neg()).toBytesBE(); + + // Assert + assertThat(fromSub).containsExactly(fromAddNeg); + } + + @Property + void property_subBytes_singleLimb_matchesBigInteger( + @ForAll("singleLimbUnsigned1to4") final byte[] a, + @ForAll("singleLimbUnsigned1to4") final byte[] b) { + // Arrange + final byte[] a32 = toBytes32Unsigned(a); + final byte[] b32 = toBytes32Unsigned(b); + + // Act + final byte[] got = UInt256.sub(a32, b32); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + BigInteger result = A.subtract(B).mod(TWO_256); + byte[] expected = bigUnsignedToBytes32(result); + assertThat(got).containsExactly(expected); + } + + // -------------------------------------------------------------------------- + // endregion + + // region MUL Tests (UInt256.mul) + // -------------------------------------------------------------------------- + + @Property + void property_mul_matchesBigInteger( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act + final byte[] got = ua.mul(ub).toBytesBE(); + + // Assert — wrapping multiplication mod 2^256 + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + byte[] expected = bigUnsignedToBytes32(A.multiply(B)); + assertThat(got).containsExactly(expected); + } + + @Property + void property_mul_commutative( + @ForAll("unsigned1to32") final byte[] a, @ForAll("unsigned1to32") final byte[] b) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act & Assert — a * b = b * a + assertThat(ua.mul(ub)).isEqualTo(ub.mul(ua)); + } + + @Property + void property_mul_associative( + @ForAll("unsigned1to32") final byte[] a, + @ForAll("unsigned1to32") final byte[] b, + @ForAll("unsigned1to32") final byte[] c) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 uc = UInt256.fromBytesBE(c); + + // Act & Assert — (a * b) * c = a * (b * c) + assertThat(ua.mul(ub).mul(uc)).isEqualTo(ua.mul(ub.mul(uc))); + } + + @Property + void property_mul_identity(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 one = UInt256.fromBytesBE(new byte[] {1}); + + // Act & Assert — a * 1 = a + assertThat(ua.mul(one)).isEqualTo(ua); + assertThat(one.mul(ua)).isEqualTo(ua); + } + + @Property + void property_mul_zero(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 zero = UInt256.ZERO; + + // Act & Assert — a * 0 = 0 + assertThat(ua.mul(zero)).isEqualTo(zero); + assertThat(zero.mul(ua)).isEqualTo(zero); + } + + @Property + void property_mul_singleLimb_matchesBigInteger( + @ForAll("singleLimbUnsigned1to4") final byte[] a, + @ForAll("singleLimbUnsigned1to4") final byte[] b) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act + final byte[] got = ua.mul(ub).toBytesBE(); + + // Assert + BigInteger A = toBigUnsigned(a); + BigInteger B = toBigUnsigned(b); + byte[] expected = bigUnsignedToBytes32(A.multiply(B)); + assertThat(got).containsExactly(expected); + } + + @Property + void property_mul_distributive_over_add( + @ForAll("unsigned1to32") final byte[] a, + @ForAll("unsigned1to32") final byte[] b, + @ForAll("unsigned1to32") final byte[] c) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 uc = UInt256.fromBytesBE(c); + + // Act & Assert — a * (b + c) = a*b + a*c + UInt256 left = ua.mul(ub.add(uc)); + UInt256 right = ua.mul(ub).add(ua.mul(uc)); + assertThat(left).isEqualTo(right); + } + + @Property + void property_mul_by_two_equals_add_self(@ForAll("unsigned1to32") final byte[] a) { + // Arrange + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 two = UInt256.fromBytesBE(new byte[] {2}); + + // Act & Assert — a * 2 = a + a + assertThat(ua.mul(two)).isEqualTo(ua.add(ua)); + } + + // -------------------------------------------------------------------------- + // endregion + + // region Record Refactor Regression Tests (shaped generators, fixed seeds) + // -------------------------------------------------------------------------- + + @Property(seed = "3735928559") + void property_fromBytesBE_toBytesBE_roundTrip_canonical32( + @ForAll("bytes0to64_shaped") final byte[] input) { + // Arrange. + final byte[] expected = canonicalUnsigned256ToBytes32(input); + + // Act. + final UInt256 u = UInt256.fromBytesBE(input); + final byte[] got = u.toBytesBE(); + + // Assert. + assertThat(got).hasSize(32); + assertThat(got).containsExactly(expected); + } + + @Property(seed = "2718281828") + void property_fromBytesBE_ignores_high_bytes_above_32( + @ForAll("bytes33to64_shaped") final byte[] input) { + // Arrange. + final byte[] low32 = Arrays.copyOfRange(input, input.length - 32, input.length); + + // Act. + final UInt256 a = UInt256.fromBytesBE(input); + final UInt256 b = UInt256.fromBytesBE(low32); + + // Assert. + assertThat(a).isEqualTo(b); + } + + @Property(seed = "3141592653") + void property_fromBytesBE_toBytesBE_is_identity_on_32_bytes( + @ForAll("bytes32_shaped") final byte[] be32) { + // Arrange. + final byte[] input = Arrays.copyOf(be32, be32.length); + + // Act. + final byte[] got = UInt256.fromBytesBE(input).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(input); + } + + @Property(seed = "1618033988") + void property_compare_zero_iff_equals( + @ForAll("bytes32_shaped") final byte[] a, @ForAll("bytes32_shaped") final byte[] b) { + // Arrange. + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + + // Act. + final int cmp = UInt256.compare(ua, ub); + + // Assert. + assertThat(cmp == 0).isEqualTo(ua.equals(ub)); + } + + @Property(seed = "1414213562") + void property_compare_sign_matches_unsigned_big_integer( + @ForAll("bytes32_shaped") final byte[] a, @ForAll("bytes32_shaped") final byte[] b) { + // Arrange. + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final BigInteger A = new BigInteger(1, a); + final BigInteger B = new BigInteger(1, b); + + // Act. + final int cmp = UInt256.compare(ua, ub); + + // Assert. + assertThat(Integer.signum(cmp)).isEqualTo(Integer.signum(A.compareTo(B))); + } + + @Property(seed = "1123581321") + void property_shiftLeft_matches_big_integer_mod_2_256( + @ForAll("bytes32_shaped") final byte[] a, @ForAll("shifts_extreme") final int shift) { + Assume.that(shift >= 0 && shift < 64); + // Arrange. + final BigInteger A = new BigInteger(1, a); + final byte[] expected = expectedShl(A, shift); + + // Act. + final byte[] got = UInt256.fromBytesBE(a).shiftLeft(shift).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + @Property(seed = "867530900") + void property_shiftRight_matches_big_integer_mod_2_256( + @ForAll("bytes32_shaped") final byte[] a, @ForAll("shifts_extreme") final int shift) { + Assume.that(shift >= 0 && shift < 64); + // Arrange. + final BigInteger A = new BigInteger(1, a); + final byte[] expected = expectedShr(A, shift); + + // Act. + final byte[] got = UInt256.fromBytesBE(a).shiftRight(shift).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + @Property(seed = "123456789") + void property_mod_matches_big_integer_unsigned( + @ForAll("bytes0to64_shaped") final byte[] a, @ForAll("bytes0to64_shaped") final byte[] m) { + // Arrange. + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 um = UInt256.fromBytesBE(m); + final BigInteger A = toBigUnsignedMod256(a); + final BigInteger M = toBigUnsignedMod256(m); + final byte[] expected = expectedMod(A, M); + + // Act. + final byte[] got = ua.mod(um).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + @Property(seed = "987654321") + void property_signedMod_matches_evm_semantics( + @ForAll("bytes0to64_shaped") final byte[] a, @ForAll("bytes0to64_shaped") final byte[] m) { + // Arrange. + final byte[] a32 = toBytes32Unsigned(a); + final byte[] m32 = toBytes32Unsigned(m); + final UInt256 ua = UInt256.fromBytesBE(a32); + final UInt256 um = UInt256.fromBytesBE(m32); + final BigInteger A = new BigInteger(a32); + final BigInteger M = new BigInteger(m32); + final byte[] expected = expectedSignedMod(A, M); + + // Act. + final byte[] got = ua.signedMod(um).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + @Property(seed = "42424242") + void property_addMod_matches_big_integer_unsigned( + @ForAll("bytes0to64_shaped") final byte[] a, + @ForAll("bytes0to64_shaped") final byte[] b, + @ForAll("bytes0to64_shaped") final byte[] m) { + // Arrange. + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 um = UInt256.fromBytesBE(m); + final BigInteger A = toBigUnsignedMod256(a); + final BigInteger B = toBigUnsignedMod256(b); + final BigInteger M = toBigUnsignedMod256(m); + final byte[] expected = expectedAddMod(A, B, M); + + // Act. + final byte[] got = ua.addMod(ub, um).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + @Property(seed = "13371337") + void property_mulMod_matches_big_integer_unsigned( + @ForAll("bytes0to64_shaped") final byte[] a, + @ForAll("bytes0to64_shaped") final byte[] b, + @ForAll("bytes0to64_shaped") final byte[] m) { + // Arrange. + final UInt256 ua = UInt256.fromBytesBE(a); + final UInt256 ub = UInt256.fromBytesBE(b); + final UInt256 um = UInt256.fromBytesBE(m); + final BigInteger A = toBigUnsignedMod256(a); + final BigInteger B = toBigUnsignedMod256(b); + final BigInteger M = toBigUnsignedMod256(m); + final byte[] expected = expectedMulMod(A, B, M); + + // Act. + final byte[] got = ua.mulMod(ub, um).toBytesBE(); + + // Assert. + assertThat(got).containsExactly(expected); + } + + // -------------------------------------------------------------------------- + // endregion + // region Utility Methods + // -------------------------------------------------------------------------- private static byte[] clampUnsigned32(final byte[] any) { if (any.length == 0) { @@ -1308,7 +1902,7 @@ private static byte[] bigUnsignedToBytes32(final BigInteger x) { byte[] ba = y.toByteArray(); if (ba.length == 0) { - return new byte[] {0}; + return new byte[32]; } if (ba.length == 32) { @@ -1351,5 +1945,97 @@ private static byte[] padNegative(final BigInteger r) { System.arraycopy(rb, 0, padded, 32 - rb.length, rb.length); return padded; } + + private static byte[] canonicalUnsigned256ToBytes32(final byte[] be) { + if (be.length == 0) { + return new byte[32]; + } + final BigInteger x = new BigInteger(1, be).mod(TWO_256); + return bigUnsignedToBytes32(x); + } + + private static byte[] toBytes32Unsigned(final byte[] be) { + if (be.length == 32) { + return Arrays.copyOf(be, 32); + } + if (be.length == 0) { + return new byte[32]; + } + if (be.length < 32) { + final byte[] out = new byte[32]; + System.arraycopy(be, 0, out, 32 - be.length, be.length); + return out; + } + return Arrays.copyOfRange(be, be.length - 32, be.length); + } + + private static byte[] forceNonZeroHighBytes(final byte[] be) { + if (be.length <= 32) { + return be; + } + final int highLen = be.length - 32; + boolean anyNonZero = false; + for (int i = 0; i < highLen; i++) { + anyNonZero |= (be[i] != 0); + } + if (!anyNonZero) { + be[0] = 1; + } + return be; + } + + private static BigInteger toBigUnsignedMod256(final byte[] be) { + if (be.length == 0) { + return BigInteger.ZERO; + } + return new BigInteger(1, be).mod(TWO_256); + } + + private static byte[] expectedMod(final BigInteger A, final BigInteger M) { + if (M.signum() == 0) { + return new byte[32]; + } + return bigUnsignedToBytes32(A.mod(M)); + } + + private static byte[] expectedAddMod(final BigInteger A, final BigInteger B, final BigInteger M) { + if (M.signum() == 0) { + return new byte[32]; + } + return bigUnsignedToBytes32(A.add(B).mod(M)); + } + + private static byte[] expectedMulMod(final BigInteger A, final BigInteger B, final BigInteger M) { + if (M.signum() == 0) { + return new byte[32]; + } + return bigUnsignedToBytes32(A.multiply(B).mod(M)); + } + + private static byte[] expectedSignedMod(final BigInteger A, final BigInteger M) { + if (M.signum() == 0) { + return new byte[32]; + } + BigInteger r = A.abs().mod(M.abs()); + if (A.signum() < 0 && r.signum() != 0) { + return padNegative(r); + } + return bigUnsignedToBytes32(r); + } + + private static byte[] expectedShl(final BigInteger A, final int shift) { + if (shift < 0 || shift >= 256) { + return new byte[32]; + } + return bigUnsignedToBytes32(A.shiftLeft(shift)); + } + + private static byte[] expectedShr(final BigInteger A, final int shift) { + if (shift < 0 || shift >= 256) { + return new byte[32]; + } + return bigUnsignedToBytes32(A.shiftRight(shift)); + } + // endregion } diff --git a/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java index 080b684ac34..fa9db2b6b23 100644 --- a/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java +++ b/evm/src/test/java/org/hyperledger/besu/evm/UInt256Test.java @@ -15,7 +15,6 @@ package org.hyperledger.besu.evm; import static org.assertj.core.api.Assertions.assertThat; -import static org.junit.jupiter.api.Assertions.assertArrayEquals; import java.math.BigInteger; import java.util.Arrays; @@ -24,24 +23,23 @@ import org.apache.tuweni.bytes.Bytes; import org.apache.tuweni.bytes.Bytes32; import org.junit.jupiter.api.Test; -import org.junit.jupiter.params.ParameterizedTest; -import org.junit.jupiter.params.provider.ValueSource; public class UInt256Test { - static final int SAMPLE_SIZE = 300; + static final int SAMPLE_SIZE = 3; - private Bytes32 bigIntTo32B(final BigInteger x) { - byte[] a = x.toByteArray(); + private Bytes32 bigIntTo32B(final BigInteger y) { + byte[] a = y.toByteArray(); if (a.length > 32) return Bytes32.wrap(a, a.length - 32); return Bytes32.leftPad(Bytes.wrap(a)); } - private Bytes32 bigIntToSigned32B(final BigInteger x) { - if (x.signum() >= 0) return bigIntTo32B(x); + private Bytes32 bigIntTo32B(final BigInteger x, final int sign) { + if (sign >= 0) return bigIntTo32B(x); byte[] a = new byte[32]; Arrays.fill(a, (byte) 0xFF); byte[] b = x.toByteArray(); System.arraycopy(b, 0, a, 32 - b.length, b.length); + if (a.length > 32) return Bytes32.wrap(a, a.length - 32); return Bytes32.leftPad(Bytes.wrap(a)); } @@ -63,22 +61,22 @@ public void fromInts() { public void fromBytesBE() { byte[] input; UInt256 result; - int[] expectedLimbs; + UInt256 expected; input = new byte[] {-128, 0, 0, 0}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {-2147483648, 0, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("4b-neg-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 0, 2147483648L); + assertThat(result).as("4b-neg-limbs").isEqualTo(expected); input = new byte[] {0, 0, 1, 1, 1}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {1 + 256 + 65536, 0, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("3b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 0, 1 + 256 + 65536); + assertThat(result).as("3b-limbs").isEqualTo(expected); - input = new byte[] {1, 0, 0, 0, 0, 1, 1, 1}; + input = new byte[] {1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1}; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {1 + 256 + 65536, 16777216, 0, 0, 0, 0, 0, 0}; - assertThat(result.limbs()).as("8b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(0, 0, 16777216, 1 + 256 + 65536); + assertThat(result).as("8b-limbs").isEqualTo(expected); input = new byte[] { @@ -86,8 +84,8 @@ public void fromBytesBE() { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {0, 0, 0, 0, 0, 0, 0, 16777216}; - assertThat(result.limbs()).as("32b-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(72057594037927936L, 0, 0, 0); + assertThat(result).as("32b-limbs").isEqualTo(expected); input = new byte[] { @@ -95,8 +93,15 @@ public void fromBytesBE() { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 }; result = UInt256.fromBytesBE(input); - expectedLimbs = new int[] {0, 0, 0, 0, 0, 0, 257, 0}; - assertThat(result.limbs()).as("32b-padded-limbs").isEqualTo(expectedLimbs); + expected = new UInt256(257, 0, 0, 0); + assertThat(result).as("32b-padded-limbs").isEqualTo(expected); + + Bytes inputBytes = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + input = inputBytes.toArrayUnsafe(); + result = UInt256.fromBytesBE(input); + expected = new UInt256(0, 4294967295L, -1L, -1L); + assertThat(result).as("32b-case2-limbs").isEqualTo(expected); } @Test @@ -208,6 +213,124 @@ public void modB() { assertThat(remainder).isEqualTo(expected); } + @Test + public void modC() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("ff00000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modD() { + BigInteger big_number = new BigInteger("ff00000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modE() { + BigInteger big_number = new BigInteger("ff00000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modF() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("ff000000000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modG() { + BigInteger big_number = new BigInteger("1000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("100000000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modH() { + BigInteger big_number = + new BigInteger("000000000000000000ff00000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = + new BigInteger("0000000000000000000000000000000000fe0000000000000000000000000001", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modI() { + // modulus 128 with overflow case + BigInteger big_number = new BigInteger("020000000000000000000000000000000000", 16); + BigInteger big_modulus = new BigInteger("02000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modJ() { + // modulus 128 with overflow case -> 2 add back in quotient estimate div2by1. + BigInteger big_number = new BigInteger("10000000000000000010000000000000000", 16); + BigInteger big_modulus = new BigInteger("200000000000000ff", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modK() { + // modulus 128 with overflow case -> 2 add back in quotient estimate div2by1. + BigInteger big_number = + new BigInteger("ff000000000000000000000000000000000000000000000000000000", 16); + BigInteger big_modulus = + new BigInteger("1000000000000000000000002000000000000000000000000", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void modL() { + // modulus 128 with overflow case -> 2 add back in quotient estimate div2by1. + BigInteger big_number = new BigInteger("800000000000000080", 16); + BigInteger big_modulus = new BigInteger("80", 16); + UInt256 number = UInt256.fromBytesBE(big_number.toByteArray()); + UInt256 modulus = UInt256.fromBytesBE(big_modulus.toByteArray()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(number.mod(modulus).toBytesBE())); + Bytes32 expected = Bytes32.leftPad(Bytes.wrap(big_number.mod(big_modulus).toByteArray())); + assertThat(remainder).isEqualTo(expected); + } + @Test public void modGeneralState() { BigInteger big_number = new BigInteger("cea0c5cc171fa61277e5604a3bc8aef4de3d3882", 16); @@ -268,6 +391,30 @@ public void referenceTest459() { assertThat(remainder).isEqualTo(expected); } + @Test + public void ExecutionSpecStateTest_453() { + byte[] xArr = + new byte[] { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -2 + }; + byte[] mArr = + new byte[] { + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1 + }; + BigInteger xbig = new BigInteger(1, xArr); + BigInteger ybig = new BigInteger(1, xArr); + BigInteger mbig = new BigInteger(1, mArr); + UInt256 x = UInt256.fromBytesBE(xArr); + UInt256 y = UInt256.fromBytesBE(xArr); + UInt256 m = UInt256.fromBytesBE(mArr); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(x.addMod(y, m).toBytesBE())); + Bytes32 expected = + BigInteger.ZERO.compareTo(mbig) == 0 ? Bytes32.ZERO : bigIntTo32B(xbig.add(ybig).mod(mbig)); + assertThat(remainder).isEqualTo(expected); + } + @Test public void addMod() { final Random random = new Random(42); @@ -296,6 +443,79 @@ public void addMod() { } } + @Test + public void mulMod_Modulus256_mulSubOverflow() { + Bytes modBytes = + Bytes.fromHexString("0x0000000000000001000000000000000000000000000000000000000000000001"); + Bytes aBytes = + Bytes.fromHexString("0x0000000000000001000000000000000000000000000000000000000000000000"); + Bytes bBytes = + Bytes.fromHexString("0x0000000000000001000000000000000000000000000000000000000000000000"); + BigInteger aInt = new BigInteger(1, aBytes.toArrayUnsafe()); + BigInteger bInt = new BigInteger(1, bBytes.toArrayUnsafe()); + BigInteger mInt = new BigInteger(1, modBytes.toArrayUnsafe()); + UInt256 a = UInt256.fromBytesBE(aBytes.toArrayUnsafe()); + UInt256 b = UInt256.fromBytesBE(bBytes.toArrayUnsafe()); + UInt256 m = UInt256.fromBytesBE(modBytes.toArrayUnsafe()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, m).toBytesBE())); + Bytes32 expected = bigIntTo32B(aInt.multiply(bInt).mod(mInt)); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void mulMod_ExecutionSpecStateTest_104() { + Bytes value0 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + Bytes value1 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + Bytes value2 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + BigInteger aInt = new BigInteger(1, value0.toArrayUnsafe()); + BigInteger bInt = new BigInteger(1, value1.toArrayUnsafe()); + BigInteger cInt = new BigInteger(1, value2.toArrayUnsafe()); + UInt256 a = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b = UInt256.fromBytesBE(value1.toArrayUnsafe()); + UInt256 c = UInt256.fromBytesBE(value2.toArrayUnsafe()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, c).toBytesBE())); + Bytes32 expected = bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + assertThat(remainder).isEqualTo(expected); + } + + @Test + public void mulMod_ExecutionSpecStateTest_457() { + Bytes value0 = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + Bytes value1 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + Bytes value2 = + Bytes.fromHexString("0x000000000000000000000000ffffffffffffffffffffffffffffffffffffffff"); + BigInteger aInt = new BigInteger(1, value0.toArrayUnsafe()); + BigInteger bInt = new BigInteger(1, value1.toArrayUnsafe()); + BigInteger cInt = new BigInteger(1, value2.toArrayUnsafe()); + UInt256 a = UInt256.fromBytesBE(value0.toArrayUnsafe()); + UInt256 b = UInt256.fromBytesBE(value1.toArrayUnsafe()); + UInt256 c = UInt256.fromBytesBE(value2.toArrayUnsafe()); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, c).toBytesBE())); + Bytes32 expected = bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + assertThat(remainder).isEqualTo(expected); + + value0 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + value1 = + Bytes.fromHexString("0xffffffffffffffffffffffffb195148ca348dc57a7331852b390ccefa7b0c18b"); + value2 = + Bytes.fromHexString("0xfffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffe"); + aInt = new BigInteger(1, value0.toArrayUnsafe()); + bInt = new BigInteger(1, value1.toArrayUnsafe()); + cInt = new BigInteger(1, value2.toArrayUnsafe()); + a = UInt256.fromBytesBE(value0.toArrayUnsafe()); + b = UInt256.fromBytesBE(value1.toArrayUnsafe()); + c = UInt256.fromBytesBE(value2.toArrayUnsafe()); + remainder = Bytes32.leftPad(Bytes.wrap(a.mulMod(b, c).toBytesBE())); + expected = bigIntTo32B(aInt.multiply(bInt).mod(cInt)); + assertThat(remainder).isEqualTo(expected); + } + @Test public void mulMod() { final Random random = new Random(123); @@ -324,37 +544,29 @@ public void mulMod() { } } - @Test - public void signedMod_no_padding() { - Bytes aBytes = - Bytes.fromHexString("0xe8e8e8e2000100000009ea02000000000000ff3ffffff80000001000220000"); - Bytes bBytes = - Bytes.fromHexString("0x8000000000000000000000000000000000000000000000000000000000000000"); - Bytes32 expected = - Bytes32.leftPad( - Bytes.fromHexString( - "0x00e8e8e8e2000100000009ea02000000000000ff3ffffff80000001000220000")); - UInt256 a = UInt256.fromBytesBE(aBytes.toArrayUnsafe()); - UInt256 b = UInt256.fromBytesBE(bBytes.toArrayUnsafe()); - Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.signedMod(b).toBytesBE())); - assertThat(remainder).isEqualTo(expected); - } - @Test public void signedMod() { final Random random = new Random(432); for (int i = 0; i < SAMPLE_SIZE; i++) { int aSize = random.nextInt(1, 33); int bSize = random.nextInt(1, 33); + boolean neg = random.nextBoolean(); byte[] aArray = new byte[aSize]; byte[] bArray = new byte[bSize]; random.nextBytes(aArray); random.nextBytes(bArray); + if ((aSize < 32) && (neg)) { + byte[] tmp = new byte[32]; + Arrays.fill(tmp, (byte) 0xFF); + System.arraycopy(aArray, 0, tmp, 32 - aArray.length, aArray.length); + aArray = tmp; + } UInt256 a = UInt256.fromBytesBE(aArray); UInt256 b = UInt256.fromBytesBE(bArray); - BigInteger aInt = aArray.length < 32 ? new BigInteger(1, aArray) : new BigInteger(aArray); - BigInteger bInt = bArray.length < 32 ? new BigInteger(1, bArray) : new BigInteger(bArray); - Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(a.signedMod(b).toBytesBE())); + UInt256 r = a.signedMod(b); + BigInteger aInt = a.isNegative() ? new BigInteger(aArray) : new BigInteger(1, aArray); + BigInteger bInt = b.isNegative() ? new BigInteger(bArray) : new BigInteger(1, bArray); + Bytes32 remainder = Bytes32.leftPad(Bytes.wrap(r.toBytesBE())); Bytes32 expected; BigInteger rem = BigInteger.ZERO; if (BigInteger.ZERO.compareTo(bInt) == 0) expected = Bytes32.ZERO; @@ -362,285 +574,12 @@ public void signedMod() { rem = aInt.abs().mod(bInt.abs()); if ((aInt.compareTo(BigInteger.ZERO) < 0) && (rem.compareTo(BigInteger.ZERO) != 0)) { rem = rem.negate(); - expected = bigIntToSigned32B(rem); + expected = bigIntTo32B(rem, -1); } else { - expected = bigIntTo32B(rem); + expected = bigIntTo32B(rem, 1); } } assertThat(remainder).isEqualTo(expected); } } - - @Test - void testFromBytesBE_emptyArray() { - UInt256 result = UInt256.fromBytesBE(new byte[0]); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.isZero()).isTrue(); - } - - @Test - void testFromBytesBE_singleZeroByte() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0}); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.intValue()).isEqualTo(0); - } - - @Test - void testFromBytesBE_singleByte() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0x42}); - assertThat(result.intValue()).isEqualTo(0x42); - assertThat(result.longValue()).isEqualTo(0x42L); - } - - @Test - void testFromBytesBE_twoBytesFF() { - UInt256 result = UInt256.fromBytesBE(new byte[] {(byte) 0xFF, (byte) 0xFF}); - assertThat(result.intValue()).isEqualTo(0xFFFF); - assertThat(result.longValue()).isEqualTo(0xFFFFL); - } - - @Test - void testFromBytesBE_fourBytes() { - UInt256 result = UInt256.fromBytesBE(new byte[] {0x01, 0x02, 0x03, 0x04}); - assertThat(result.intValue()).isEqualTo(0x01020304); - } - - @Test - void testFromBytesBE_eightBytes() { - UInt256 result = - UInt256.fromBytesBE(new byte[] {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}); - assertThat(result.longValue()).isEqualTo(0x0102030405060708L); - } - - @Test - void testFromBytesBE_exactly32Bytes_allZeros() { - byte[] bytes = new byte[32]; // all zeros - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result).isEqualTo(UInt256.ZERO); - assertThat(result.isZero()).isTrue(); - } - - @Test - void testFromBytesBE_exactly32Bytes_allOnes() { - byte[] bytes = new byte[32]; - for (int i = 0; i < 32; i++) { - bytes[i] = (byte) 0xFF; - } - UInt256 result = UInt256.fromBytesBE(bytes); - - // Should be MAX_UINT256 (2^256 - 1) - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_exactly32Bytes_one() { - byte[] bytes = new byte[32]; - bytes[31] = 0x01; // least significant byte - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(1); - assertThat(result.longValue()).isEqualTo(1L); - } - - @Test - void testFromBytesBE_exactly32Bytes_pattern() { - byte[] bytes = new byte[32]; - // Create pattern: 0x0102030405060708...1F20 - for (int i = 0; i < 32; i++) { - bytes[i] = (byte) (i + 1); - } - UInt256 result = UInt256.fromBytesBE(bytes); - - // Verify round-trip - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_exactly32Bytes_highBitSet() { - byte[] bytes = new byte[32]; - bytes[0] = (byte) 0x80; // high bit set (but still unsigned) - UInt256 result = UInt256.fromBytesBE(bytes); - - // Verify it's treated as unsigned (not negative) - byte[] resultBytes = result.toBytesBE(); - assertArrayEquals(bytes, resultBytes); - } - - @Test - void testFromBytesBE_roundTrip_variousLengths() { - for (int len = 1; len <= 32; len++) { - byte[] original = new byte[len]; - for (int i = 0; i < len; i++) { - original[i] = (byte) (i + 1); - } - - UInt256 value = UInt256.fromBytesBE(original); - byte[] result = value.toBytesBE(); - - // Result is always 32 bytes, so compare with left-padded original - byte[] expected = new byte[32]; - System.arraycopy(original, 0, expected, 32 - len, len); - - assertArrayEquals(expected, result, "Failed for length " + len); - } - } - - @Test - void testFromBytesBE_leadingZeros() { - // Leading zeros should be handled correctly - byte[] bytes = new byte[] {0x00, 0x00, 0x00, 0x01, 0x02, 0x03}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(0x010203); - } - - @Test - void testFromBytesBE_maxInt() { - byte[] bytes = new byte[] {0x7F, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.intValue()).isEqualTo(Integer.MAX_VALUE); - } - - @Test - void testFromBytesBE_maxLong() { - byte[] bytes = - new byte[] { - 0x7F, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF, - (byte) 0xFF - }; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.longValue()).isEqualTo(Long.MAX_VALUE); - } - - @Test - void testFromBytesBE_unsignedIntMax() { - // 0xFFFFFFFF as unsigned = 4294967295 - byte[] bytes = new byte[] {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; - UInt256 result = UInt256.fromBytesBE(bytes); - - assertThat(result.longValue()).isEqualTo(0xFFFFFFFFL); - } - - @Test - void testFromBytesBE_unsignedLongMax() { - // 0xFFFFFFFFFFFFFFFF as unsigned - byte[] bytes = - new byte[] { - (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, - (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF - }; - UInt256 result = UInt256.fromBytesBE(bytes); - - // When converted back to long, should get the bit pattern - assertThat(result.longValue()).isEqualTo(-1L); // all bits set - } - - @Test - void testFromBytesBE_boundaryValues() { - // Test 1, 2, 3, 4, 8, 16, 32 bytes - int[] lengths = {1, 2, 3, 4, 8, 16, 32}; - - for (int len : lengths) { - byte[] bytes = new byte[len]; - bytes[len - 1] = (byte) 0xFF; // set last byte - - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result.intValue() & 0xFF).isEqualTo(0xFF); - } - } - - @Test - void testFromBytesBE_comparisonWithBigInteger() { - byte[] bytes = - new byte[] {0x12, 0x34, 0x56, 0x78, (byte) 0x9A, (byte) 0xBC, (byte) 0xDE, (byte) 0xF0}; - - UInt256 result = UInt256.fromBytesBE(bytes); - java.math.BigInteger expected = new java.math.BigInteger(1, bytes); - - assertThat(result.toBigInteger()).isEqualTo(expected); - } - - @ParameterizedTest - @ValueSource(ints = {0, 1, 127, 128, 255, 256, 65535, 65536, Integer.MAX_VALUE}) - void testFromBytesBE_knownIntegers(final int value) { - // Convert int to bytes (big-endian) - byte[] bytes = new byte[4]; - bytes[0] = (byte) (value >>> 24); - bytes[1] = (byte) (value >>> 16); - bytes[2] = (byte) (value >>> 8); - bytes[3] = (byte) value; - - UInt256 result = UInt256.fromBytesBE(bytes); - assertThat(result.intValue()).isEqualTo(value); - } - - @Test - void testFromBytesBE_powerOfTwo() { - // Test 2^8, 2^16, 2^32, 2^64, 2^128, 2^255 - - // 2^8 = 256 - byte[] bytes8 = new byte[] {0x01, 0x00}; - assertThat(UInt256.fromBytesBE(bytes8).intValue()).isEqualTo(256); - - // 2^16 = 65536 - byte[] bytes16 = new byte[] {0x01, 0x00, 0x00}; - assertThat(UInt256.fromBytesBE(bytes16).intValue()).isEqualTo(65536); - - // 2^32 - byte[] bytes32 = new byte[] {0x01, 0x00, 0x00, 0x00, 0x00}; - assertThat(UInt256.fromBytesBE(bytes32).longValue()).isEqualTo(0x100000000L); - } - - @Test - void testFromBytesBE_alternatingPattern() { - // 0xAA pattern - byte[] bytesAA = new byte[32]; - for (int i = 0; i < 32; i++) { - bytesAA[i] = (byte) 0xAA; - } - UInt256 resultAA = UInt256.fromBytesBE(bytesAA); - assertArrayEquals(bytesAA, resultAA.toBytesBE()); - - // 0x55 pattern - byte[] bytes55 = new byte[32]; - for (int i = 0; i < 32; i++) { - bytes55[i] = (byte) 0x55; - } - UInt256 result55 = UInt256.fromBytesBE(bytes55); - assertArrayEquals(bytes55, result55.toBytesBE()); - } - - @Test - void testFromBytesBE_consistency() { - // Verify same bytes always produce same result - byte[] bytes = new byte[] {0x01, 0x02, 0x03, 0x04, 0x05}; - - UInt256 result1 = UInt256.fromBytesBE(bytes); - UInt256 result2 = UInt256.fromBytesBE(bytes); - - assertThat(result1).isEqualTo(result2); - assertThat(result1.hashCode()).isEqualTo(result2.hashCode()); - } - - @Test - void testFromBytesBE_differentLengthsSameValue() { - // Leading zeros should not affect value - byte[] bytes1 = new byte[] {0x01, 0x02, 0x03}; - byte[] bytes2 = new byte[] {0x00, 0x00, 0x01, 0x02, 0x03}; - - UInt256 result1 = UInt256.fromBytesBE(bytes1); - UInt256 result2 = UInt256.fromBytesBE(bytes2); - - assertThat(result1).isEqualTo(result2); - } }