Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits
int i = 0;
ulong carry = 0UL;

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
ulong digits = (ulong)left[i] * right + carry;
bits[i] = unchecked((uint)digits);
Expand All @@ -151,9 +151,9 @@ public static void Multiply(ReadOnlySpan<uint> left, uint right, Span<uint> bits

#if DEBUG
// Mutable for unit testing...
private static
internal static
#else
private const
internal const
#endif
int MultiplyThreshold = 32;

Expand All @@ -162,6 +162,216 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
Debug.Assert(left.Length >= right.Length);
Debug.Assert(bits.Length == left.Length + right.Length);

if (left.Length - right.Length < 3)
{
MultiplyNearLength(left, right, bits);
}
else
{
MultiplyFarLength(left, right, bits);
}
}

private static void MultiplyFarLength(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Span<uint> bits)
{
Debug.Assert(left.Length - right.Length >= 3);
Debug.Assert(bits.Length == left.Length + right.Length);

// Executes different algorithms for computing z = a * b
// based on the actual length of b. If b is "small" enough
// we stick to the classic "grammar-school" method; for the
// rest we switch to implementations with less complexity
// albeit more overhead (which needs to pay off!).

// NOTE: useful thresholds needs some "empirical" testing,
// which are smaller in DEBUG mode for testing purpose.

if (right.Length < MultiplyThreshold)
{
// Switching to managed references helps eliminating
// index bounds check...
ref uint resultPtr = ref MemoryMarshal.GetReference(bits);

// Multiplies the bits using the "grammar-school" method.
// Envisioning the "rhombus" of a pen-and-paper calculation
// should help getting the idea of these two loops...
// The inner multiplication operations are safe, because
// z_i+j + a_j * b_i + c <= 2(2^32 - 1) + (2^32 - 1)^2 =
// = 2^64 - 1 (which perfectly matches with ulong!).

for (int i = 0; i < right.Length; i++)
{
ulong carry = 0UL;
for (int j = 0; j < left.Length; j++)
{
ref uint elementPtr = ref Unsafe.Add(ref resultPtr, i + j);
ulong digits = elementPtr + carry + (ulong)left[j] * right[i];
Comment on lines +204 to +208
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: we could most likely get minor improvement by hoisting the result of right[i] (however I am not 100% sure that JIT does not perform this optimization already)

Suggested change
ulong carry = 0UL;
for (int j = 0; j < left.Length; j++)
{
ref uint elementPtr = ref Unsafe.Add(ref resultPtr, i + j);
ulong digits = elementPtr + carry + (ulong)left[j] * right[i];
ulong carry = 0UL;
uint right_i = right[i];
for (int j = 0; j < left.Length; j++)
{
ref uint elementPtr = ref Unsafe.Add(ref resultPtr, i + j);
ulong digits = elementPtr + carry + (ulong)left[j] * right_i;

elementPtr = unchecked((uint)digits);
carry = digits >> 32;
}
Unsafe.Add(ref resultPtr, i + left.Length) = (uint)carry;
}
}
else
{
// Based on the Toom-Cook multiplication we split left/right
// into two smaller values, doing recursive multiplication.
// The special form of this multiplication, where we
// split both operands into two operands, is also known
// as the Karatsuba algorithm...

// https://en.wikipedia.org/wiki/Toom-Cook_multiplication
// https://en.wikipedia.org/wiki/Karatsuba_algorithm

// Say we want to compute z = a * b ...

// ... we need to determine our new length (just the half)
int n = left.Length >> 1;
if (right.Length <= n + 1)
{
// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow;
uint rightHigh;
if (n < right.Length)
{
Debug.Assert(right.Length == n + 1);
rightLow = right.Slice(0, n);
rightHigh = right[n];
}
else
{
rightLow = right;
rightHigh = 0;
}

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n + rightLow.Length);
Span<uint> bitsHigh = bits.Slice(n);

int carryLength = rightLow.Length;
uint[]? carryFromPool = null;
Span<uint> carry = ((uint)carryLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: carryFromPool = ArrayPool<uint>.Shared.Rent(carryLength)).Slice(0, carryLength);

// ... compute low
Multiply(leftLow, rightLow, bitsLow);
Span<uint> carryOrig = bits.Slice(n, rightLow.Length);
carryOrig.CopyTo(carry);
carryOrig.Clear();

if (rightHigh != 0)
{
// ... compute high
MultiplyNearLength(leftHigh, rightLow, bitsHigh.Slice(0, leftHigh.Length + n));

int upperRightLength = left.Length + 1;
uint[]? upperRightFromPool = null;
Span<uint> upperRight = ((uint)upperRightLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: upperRightFromPool = ArrayPool<uint>.Shared.Rent(upperRightLength)).Slice(0, upperRightLength);
upperRight.Clear();

Multiply(left, rightHigh, upperRight);
Comment on lines +277 to +279
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Multiply does not use upperRight as an input and it's going to overwrite all values starting from 0 to left.Length:

for ( ; i < left.Length; i++)
{
ulong digits = (ulong)left[i] * right + carry;
bits[i] = unchecked((uint)digits);
carry = digits >> 32;
}
bits[i] = (uint)carry;

So we can reduce the clear to only last element (this span has left.Length + 1 elements)

Suggested change
upperRight.Clear();
Multiply(left, rightHigh, upperRight);
// Multiply has set 0..left.Length elements, the size is left.Length+1
// We need to zero the last element to make sure it does not contain any garbage.
Multiply(left, rightHigh, upperRight);
upperRight[^1] = 0;


AddSelf(bitsHigh, upperRight);

if (upperRightFromPool != null)
ArrayPool<uint>.Shared.Return(upperRightFromPool);
}
else
{
// ... compute high
Multiply(leftHigh, rightLow, bitsHigh);
}

AddSelf(bitsHigh, carry);

if (carryFromPool != null)
ArrayPool<uint>.Shared.Return(carryFromPool);
}
else
{
int n2 = n << 1;

Debug.Assert(left.Length > right.Length);

// ... split left like a = (a_1 << n) + a_0
ReadOnlySpan<uint> leftLow = left.Slice(0, n);
ReadOnlySpan<uint> leftHigh = left.Slice(n);

// ... split right like b = (b_1 << n) + b_0
ReadOnlySpan<uint> rightLow = right.Slice(0, n);
ReadOnlySpan<uint> rightHigh = right.Slice(n);

// ... prepare our result array (to reuse its memory)
Span<uint> bitsLow = bits.Slice(0, n2);
Span<uint> bitsHigh = bits.Slice(n2);

// ... compute z_0 = a_0 * b_0 (multiply again)
MultiplyNearLength(rightLow, leftLow, bitsLow);

// ... compute z_2 = a_1 * b_1 (multiply again)
MultiplyFarLength(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = leftHigh.Length + 1;
uint[]? leftFoldFromPool = null;
Span<uint> leftFold = ((uint)leftFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: leftFoldFromPool = ArrayPool<uint>.Shared.Rent(leftFoldLength)).Slice(0, leftFoldLength);
leftFold.Clear();

int rightFoldLength = n + 1;
uint[]? rightFoldFromPool = null;
Span<uint> rightFold = ((uint)rightFoldLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: rightFoldFromPool = ArrayPool<uint>.Shared.Rent(rightFoldLength)).Slice(0, rightFoldLength);
rightFold.Clear();

int coreLength = leftFoldLength + rightFoldLength;
uint[]? coreFromPool = null;
Span<uint> core = ((uint)coreLength <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: coreFromPool = ArrayPool<uint>.Shared.Rent(coreLength)).Slice(0, coreLength);
core.Clear();

Debug.Assert(bits.Length - n >= core.Length);
Debug.Assert(rightLow.Length >= rightHigh.Length);

// ... compute z_a = a_1 + a_0 (call it fold...)
Add(leftHigh, leftLow, leftFold);

// ... compute z_b = b_1 + b_0 (call it fold...)
Add(rightLow, rightHigh, rightFold);

// ... compute z_1 = z_a * z_b - z_0 - z_2
MultiplyNearLength(leftFold, rightFold, core);

if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);

if (rightFoldFromPool != null)
ArrayPool<uint>.Shared.Return(rightFoldFromPool);

SubtractCore(bitsLow, bitsHigh, core);

// ... and finally merge the result! :-)
AddSelf(bits.Slice(n), core);

if (coreFromPool != null)
ArrayPool<uint>.Shared.Return(coreFromPool);
}
}
}
private static void MultiplyNearLength(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Span<uint> bits)
{
Debug.Assert(left.Length - right.Length < 3);
Debug.Assert(bits.Length == left.Length + right.Length);

// Executes different algorithms for computing z = a * b
// based on the actual length of b. If b is "small" enough
// we stick to the classic "grammar-school" method; for the
Expand Down Expand Up @@ -227,10 +437,10 @@ public static void Multiply(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, S
Span<uint> bitsHigh = bits.Slice(n2);

// ... compute z_0 = a_0 * b_0 (multiply again)
Multiply(leftLow, rightLow, bitsLow);
MultiplyNearLength(leftLow, rightLow, bitsLow);

// ... compute z_2 = a_1 * b_1 (multiply again)
Multiply(leftHigh, rightHigh, bitsHigh);
MultiplyNearLength(leftHigh, rightHigh, bitsHigh);

int leftFoldLength = leftHigh.Length + 1;
uint[]? leftFoldFromPool = null;
Expand Down Expand Up @@ -260,7 +470,7 @@ stackalloc uint[StackAllocThreshold]
Add(rightHigh, rightLow, rightFold);

// ... compute z_1 = z_a * z_b - z_0 - z_2
Multiply(leftFold, rightFold, core);
MultiplyNearLength(leftFold, rightFold, core);

if (leftFoldFromPool != null)
ArrayPool<uint>.Shared.Return(leftFoldFromPool);
Expand Down Expand Up @@ -298,21 +508,21 @@ private static void SubtractCore(ReadOnlySpan<uint> left, ReadOnlySpan<uint> rig
ref uint leftPtr = ref MemoryMarshal.GetReference(left);
ref uint corePtr = ref MemoryMarshal.GetReference(core);

for ( ; i < right.Length; i++)
for (; i < right.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - Unsafe.Add(ref leftPtr, i) - right[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; i < left.Length; i++)
for (; i < left.Length; i++)
{
long digit = (Unsafe.Add(ref corePtr, i) + carry) - left[i];
Unsafe.Add(ref corePtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}

for ( ; carry != 0 && i < core.Length; i++)
for (; carry != 0 && i < core.Length; i++)
{
long digit = core[i] + carry;
core[i] = (uint)digit;
Expand Down
22 changes: 22 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/multiply.cs
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,28 @@ public static void RunMultiply_Boundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 bMultiply");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
Random random = new Random(s_seed);
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "bMultiply");
}
}
}
}

[Fact]
public static void RunMultiply_OnePositiveOneNegative()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,27 @@ public static void RunMultiplyBoundary()
VerifyMultiplyString(Math.Pow(2, 33) + " 2 b*");
}

[Fact]
public static void RunMultiplyKaratsubaBoundary()
{
byte[] tempByteArray1 = new byte[0];
byte[] tempByteArray2 = new byte[0];

// Multiply Method - One Large BigInteger
for (int i = 0; i < s_samples; i++)
{
for (int d1 = -2; d1 <= 2; d1++)
{
tempByteArray1 = GetRandomByteArray(s_random, BigIntegerCalculator.MultiplyThreshold + d1);
for (int d2 = -4; d2 <= 4; d2++)
{
tempByteArray2 = GetRandomByteArray(s_random, (BigIntegerCalculator.MultiplyThreshold + 1) * 2 + d2);
VerifyMultiplyString(Print(tempByteArray1) + Print(tempByteArray2) + "b*");
}
}
}
}

[Fact]
public static void RunMultiplyTests()
{
Expand Down