-
Notifications
You must be signed in to change notification settings - Fork 4.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Vectorized MemoryExtensions.CommonPrefixLength #68210
Vectorized MemoryExtensions.CommonPrefixLength #68210
Conversation
Tagging subscribers to this area: @dotnet/area-system-memory Issue DetailsFunctional implementation of the new API got added in #67929
Note: I tried different unroll-levels, 4 gave best results. Benchmarks aren't in https://github.com/dotnet/performance Benchmark codeusing System.Diagnostics;
using System.Numerics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Intrinsics;
using BenchmarkDotNet.Attributes;
namespace ConsoleApp3;
//[ShortRunJob]
//[DisassemblyDiagnoser]
public class CommonPrefixLengthBenchmark
{
public static void Run()
{
CommonPrefixLengthBenchmark bench = new();
bench.Setup();
Console.WriteLine(bench.Default());
Console.WriteLine(bench.PR_Unroll4());
}
public record Args(int Length, int IndexOfMatch)
{
public override string ToString() => $"({this.Length}, {this.IndexOfMatch})";
}
public static IEnumerable<Args> Arguments()
{
yield return new Args(3, 2);
yield return new Args(10, 5);
yield return new Args(10, 9);
yield return new Args(20, 13);
yield return new Args(100, 16);
yield return new Args(100, 99);
}
[ParamsSource(nameof(Arguments))]
public Args LengthAndIndexOfMatch { get; set; } = new Args(3, 2);
private byte[]? _arr0;
private byte[]? _arr1;
[GlobalSetup]
public void Setup()
{
_arr0 = new byte[this.LengthAndIndexOfMatch.Length];
Random.Shared.NextBytes(_arr0);
_arr1 = _arr0.AsSpan().ToArray();
_arr1[this.LengthAndIndexOfMatch.IndexOfMatch] = 0xff;
}
[Benchmark(Baseline = true)]
public int Default() => SpanHelpers.CommonPrefixLength<byte>(_arr0, _arr1);
[Benchmark]
public int PR_Unroll4() => SpanHelpers.CommonPrefixLength_PRUnroll4<byte>(_arr0, _arr1);
}
public static class SpanHelpers
{
public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
{
// Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
// the Length of one of them and at least have bounds checks removed from that one.
if (other.Length > span.Length)
{
other = other.Slice(0, span.Length);
}
else if (span.Length > other.Length)
{
span = span.Slice(0, other.Length);
}
Debug.Assert(span.Length == other.Length);
// Find the first element pairwise that is not equal, and return its index as the length
// of the sequence before it that matches.
for (int i = 0; i < span.Length; i++)
{
if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
{
return i;
}
}
return span.Length;
}
public static int CommonPrefixLength_PRUnroll4<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
{
if (typeof(T) == typeof(byte))
{
nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length);
nuint size = (nuint)Unsafe.SizeOf<T>();
return (int)CommonPrefixLengthUnroll4(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
length * size);
}
// Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
// the Length of one of them and at least have bounds checks removed from that one.
if (other.Length > span.Length)
{
other = other.Slice(0, span.Length);
}
else if (span.Length > other.Length)
{
span = span.Slice(0, other.Length);
}
Debug.Assert(span.Length == other.Length);
// Find the first element pairwise that is not equal, and return its index as the length
// of the sequence before it that matches.
for (int i = 0; i < span.Length; i++)
{
if (!EqualityComparer<T>.Default.Equals(span[i], other[i]))
{
return i;
}
}
return span.Length;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static nuint CommonPrefixLengthUnroll4(ref byte first, ref byte second, nuint length)
{
nuint i;
if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128<byte>.Count)
{
// To have kind of fast path for small inputs, we handle as much elements needed
// so that either we are done or can use the unrolled loop below.
i = length % 4;
if (i > 0)
{
if (first != second)
{
return 0;
}
if (i > 1)
{
if (Unsafe.Add(ref first, 1) != Unsafe.Add(ref second, 1))
{
return 1;
}
if (i > 2 && Unsafe.Add(ref first, 2) != Unsafe.Add(ref second, 2))
{
return 2;
}
}
}
for (; (nint)i <= (nint)length - 4; i += 4)
{
if (Unsafe.Add(ref first, i + 0) != Unsafe.Add(ref second, i + 0)) return i + 0;
if (Unsafe.Add(ref first, i + 1) != Unsafe.Add(ref second, i + 1)) return i + 1;
if (Unsafe.Add(ref first, i + 2) != Unsafe.Add(ref second, i + 2)) return i + 2;
if (Unsafe.Add(ref first, i + 3) != Unsafe.Add(ref second, i + 3)) return i + 3;
}
return length;
}
Debug.Assert(length >= (uint)Vector128<byte>.Count);
int mask;
nuint lengthToExamine = length - (nuint)Vector128<byte>.Count;
Vector128<byte> firstVec;
Vector128<byte> secondVec;
Vector128<byte> maskVec;
i = 0;
if (lengthToExamine != 0)
{
do
{
firstVec = Vector128.LoadUnsafe(ref first, i);
secondVec = Vector128.LoadUnsafe(ref second, i);
maskVec = Vector128.Equals(firstVec, secondVec);
mask = (int)maskVec.ExtractMostSignificantBits();
if (mask != 0xFFFF)
{
goto Found;
}
i += (nuint)Vector128<byte>.Count;
} while (i < lengthToExamine);
}
// Do final compare as Vector128<byte>.Count from end rather than start
i = lengthToExamine;
firstVec = Vector128.LoadUnsafe(ref first, i);
secondVec = Vector128.LoadUnsafe(ref second, i);
maskVec = Vector128.Equals(firstVec, secondVec);
mask = (int)maskVec.ExtractMostSignificantBits();
if (mask != 0xFFFF)
{
goto Found;
}
return length;
Found:
mask = ~mask;
return i + (uint)BitOperations.TrailingZeroCount(mask);
}
} I tried forwarding @stephentoub: in regards to #67929 (comment) I know you favored #67942, but the PR for the CommonPrefixLength is merged, and I had already similar code lying around.
|
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)), | ||
length * size); | ||
|
||
return (int)(index / size); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of byte and then dividing to get rid of partials is clever... but deserves a big comment :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Too verbose now?
Thanks. This LGTM, but @tannergooding should review as well.
It's too new; last I checked, the performance repo hadn't ingested a runtime with these APIs. If it now has them, it'd be worth adding a few. |
@@ -2033,6 +2033,18 @@ public static void Sort<T>(this Span<T> span, Comparison<T> comparison) | |||
/// <returns>The length of the common prefix shared by the two spans. If there's no shared prefix, 0 is returned.</returns> | |||
public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other) | |||
{ | |||
if (RuntimeHelpers.IsBitwiseEquatable<T>()) | |||
{ | |||
nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was going to suggest nuint.Min
but then I remembered that IntPtr
and UIntPtr
still have them explicitly implemented until the work around dotnet/csharplang#6031 goes in.
} | ||
|
||
i += (nuint)Vector128<byte>.Count; | ||
} while (i < lengthToExamine); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: using while (i < lengthToExamine) { }
would cover the if check and reduce nesting
nuint lengthToExamine = length - (nuint)Vector128<byte>.Count; | ||
|
||
Vector128<byte> firstVec; | ||
Vector128<byte> secondVec; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Remove this and do Vector128.Equals(firstVector, Vector128.LoadUnsafe(ref second, i))
as the JIT tends to optimize the load better in this scenario
-- It does still fold the load most of the time, but there are cases where storing to a local can make it "worse"
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
as the JIT tends to optimize the load better in this scenario
Good to know, thanks!
Checked codegen, no difference.
firstVec = Vector128.LoadUnsafe(ref first, i); | ||
secondVec = Vector128.LoadUnsafe(ref second, i); | ||
maskVec = Vector128.Equals(firstVec, secondVec); | ||
mask = (int)maskVec.ExtractMostSignificantBits(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why convert to int
here and below? TrailingZeroCount
has an overload that takes uint
and i
is uint
(and its also available as uint.TrailingZeroCount
now)
So keeping it as uint
just simplifies things and reduces code complexity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏻
uint.TrailingZeroCount
just delegates to BitOperations.TrailingZeroCount
:
public static uint TrailingZeroCount(uint value) => (uint)BitOperations.TrailingZeroCount(value); |
So it's a bit more work for the JIT. What is peferable now?
I like the generic math approach a bit more, but no strong reason for it, only feeling.
(Besides that uint.TrailingZeroCount
returns uint
and not int
).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is peferable now?
We don't have any official guidance right yet. My personal recommendation is that if you need a uint
back out to just use the uint.TrailingZeroCount
version.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM. A couple styling/consistency/simplification nits, but nothing that should block this
…te.cs Co-authored-by: Tanner Gooding <[email protected]>
@tannergooding seems this can be merged? |
Functional implementation of the new API got added in #67929
This PR vectorizes the implementation.
Note: I tried different unroll-levels, 4 gave best results.
Benchmarks aren't in https://github.com/dotnet/performance
Should these be added? Or is it covered via usage in other parts, or too niche to have it covered?
Benchmark code
I tried forwarding
SequenceEqual
toCommonPrefixLength
, but this regressed as inSequenceEqual
some nice tricks are employed. These can't be used forCommonPrefixLength
, as here we need the actualindex
of the difference, and not onlytrue/false
.@stephentoub: in regards to #67929 (comment) I know you favored #67942, but the PR for the CommonPrefixLength is merged, and I had already similar code lying around.
And for my usecase (a kind of parser) I'd prefer to use the (optimized) implementation in .NET instead of having to use my own (which in this case is the same code anyway 😉).