Skip to content

Commit

Permalink
Improve Random{NumberGenerator}.GetItems/String for non-power of 2 ch…
Browse files Browse the repository at this point in the history
…oices

In .NET 9, we added an optimization to Random.GetItems and RandomNumberGenerator.GetItems/GetString that special-cases a power-of-2 number of choices that's <= 256. In such a case, we can avoid many trips to the RNG by requesting bytes in bulk, rather than requesting an Int32 per element. Each byte is masked to produce the index into the choices.

This PR extends that optimization to also cover non-power-of-2 choices. It can't just mask off the bits as in the power-of-2 case, but it can mask off some bits and then do rejection sampling, which on average still yields big wins.
  • Loading branch information
stephentoub committed Sep 18, 2024
1 parent 24e7d1b commit 547df92
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 47 deletions.
95 changes: 71 additions & 24 deletions src/libraries/System.Private.CoreLib/src/System/Random.cs
Original file line number Diff line number Diff line change
Expand Up @@ -197,41 +197,88 @@ public void GetItems<T>(ReadOnlySpan<T> choices, Span<T> destination)
throw new ArgumentException(SR.Arg_EmptySpan, nameof(choices));
}

// The most expensive part of this operation is the call to get random data. We can
// do so potentially many fewer times if:
// - the number of choices is <= 256. This let's us get a single byte per choice.
// - the number of choices is a power of two. This let's us use a byte and simply mask off
// unnecessary bits cheaply rather than needing to use rejection sampling.
// In such a case, we can grab a bunch of random bytes in one call.
if (BitOperations.IsPow2(choices.Length) && choices.Length <= 256)
// The most expensive part of this operation is the call to get random data. If the number of
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
if (choices.Length <= 256)
{
Span<byte> randomBytes = stackalloc byte[512]; // arbitrary size, a balance between stack consumed and number of random calls required
while (!destination.IsEmpty)
// Get stack space to store random bytes. This size was chosen to balance between
// stack consumed and number of random calls required.
Span<byte> randomBytes = stackalloc byte[512];

if (BitOperations.IsPow2(choices.Length))
{
if (destination.Length < randomBytes.Length)
// To avoid bias, we can't just % all bytes to get them into range; that would cause
// the lower values to be more likely than the higher values. If the number of choices
// is a power of 2, though, we can just mask off the extraneous bits.

int mask = choices.Length - 1;

while (!destination.IsEmpty)
{
randomBytes = randomBytes.Slice(0, destination.Length);
// If this will be the last iteration, avoid over-requesting randomness.
if (destination.Length < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length);
}

NextBytes(randomBytes);

for (int i = 0; i < randomBytes.Length; i++)
{
destination[i] = choices[randomBytes[i] & mask];
}

destination = destination.Slice(randomBytes.Length);
}
}
else
{
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
// then reduces the chances of needing to reject a value.

NextBytes(randomBytes);
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;

int mask = choices.Length - 1;
for (int i = 0; i < randomBytes.Length; i++)
while (!destination.IsEmpty)
{
destination[i] = choices[randomBytes[i] & mask];
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
// be rejected. On average, half the bytes may be rejected, so we heuristically
// choose to shrink to twice the destination length.
if (destination.Length * 2 < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length * 2);
}

NextBytes(randomBytes);

int i = 0;
foreach (byte b in randomBytes)
{
if ((uint)i >= (uint)destination.Length)
{
break;
}

byte masked = (byte)(b & mask);
if (masked < (uint)choices.Length)
{
destination[i++] = choices[masked];
}
}

destination = destination.Slice(i);
}

destination = destination.Slice(randomBytes.Length);
}

return;
}

// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
else
{
destination[i] = choices[Next(choices.Length)];
// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
{
destination[i] = choices[Next(choices.Length)];
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -346,44 +346,90 @@ private static void GetHexStringCore(Span<char> destination, bool lowercase)

private static void GetItemsCore<T>(ReadOnlySpan<T> choices, Span<T> destination)
{
// The most expensive part of this operation is the call to get random data. We can
// do so potentially many fewer times if:
// - the number of choices is <= 256. This let's us get a single byte per choice.
// - the number of choices is a power of two. This let's us use a byte and simply mask off
// unnecessary bits cheaply rather than needing to use rejection sampling.
// In such a case, we can grab a bunch of random bytes in one call.
if (BitOperations.IsPow2(choices.Length) && choices.Length <= 256)
Debug.Assert(choices.Length > 0);

// The most expensive part of this operation is the call to get random data. If the number of
// choices is <= 256 (which is the majority use case), we can use a single byte per element,
// which means we can ammortize the cost of getting random data by getting random bytes in bulk.
if (choices.Length <= 256)
{
// Get stack space to store random bytes. This size was chosen to balance between
// stack consumed and number of random calls required.
Span<byte> randomBytes = stackalloc byte[512];

while (!destination.IsEmpty)
if (BitOperations.IsPow2(choices.Length))
{
if (destination.Length < randomBytes.Length)
// To avoid bias, we can't just % all bytes to get them into range; that would cause
// the lower values to be more likely than the higher values. If the number of choices
// is a power of 2, though, we can just mask off the extraneous bits.

int mask = choices.Length - 1;

while (!destination.IsEmpty)
{
randomBytes = randomBytes.Slice(0, destination.Length);
// If this will be the last iteration, avoid over-requesting randomness.
if (destination.Length < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length);
}

RandomNumberGeneratorImplementation.FillSpan(randomBytes);

for (int i = 0; i < randomBytes.Length; i++)
{
destination[i] = choices[randomBytes[i] & mask];
}

destination = destination.Slice(randomBytes.Length);
}
}
else
{
// As the length isn't a power of two, we can't just mask off all extraneous bits, and
// instead need to do rejection sampling. However, we can mask off the irrelevant bits, which
// then reduces the chances of needing to reject a value.

RandomNumberGeneratorImplementation.FillSpan(randomBytes);
int mask = (int)BitOperations.RoundUpToPowerOf2((uint)choices.Length) - 1;

int mask = choices.Length - 1;
for (int i = 0; i < randomBytes.Length; i++)
while (!destination.IsEmpty)
{
destination[i] = choices[randomBytes[i] & mask];
// Unlike in the IsPow2 case, where every byte will be used, some bytes here may
// be rejected. On average, half the bytes may be rejected, so we heuristically
// choose to shrink to twice the destination length.
if (destination.Length * 2 < randomBytes.Length)
{
randomBytes = randomBytes.Slice(0, destination.Length * 2);
}

RandomNumberGeneratorImplementation.FillSpan(randomBytes);

int i = 0;
foreach (byte b in randomBytes)
{
if ((uint)i >= (uint)destination.Length)
{
break;
}

byte masked = (byte)(b & mask);
if (masked < (uint)choices.Length)
{
destination[i++] = choices[masked];
}
}

destination = destination.Slice(i);
}

destination = destination.Slice(randomBytes.Length);
}

return;
}

// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
else
{
destination[i] = choices[GetInt32(choices.Length)];
// Simple fallback: get each item individually, generating a new random Int32 for each
// item. This is slower than the above, but it works for all types and sizes of choices.
for (int i = 0; i < destination.Length; i++)
{
destination[i] = choices[GetInt32(choices.Length)];
}
}
}

Expand Down

0 comments on commit 547df92

Please sign in to comment.