Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vectorized MemoryExtensions.CommonPrefixLength #68210

Merged
merged 4 commits into from
Apr 27, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion src/libraries/System.Memory/tests/Span/CommonPrefixLength.T.cs
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ private static void ValidateWithDefaultValues<T>(int length1, int length2, IEqua
}

[Fact]
public static void PartialEquals_ReturnsPrefixLength_ValueType()
public static void PartialEquals_ReturnsPrefixLength_Byte()
{
byte[] arr1 = new byte[] { 1, 2, 3, 4, 5 };
byte[] arr2 = new byte[] { 1, 2, 3, 6, 7 };
Expand All @@ -76,6 +76,51 @@ public static void PartialEquals_ReturnsPrefixLength_ValueType()
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, null));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, EqualityComparer<byte>.Default));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, NonDefaultEqualityComparer<byte>.Instance));

// Vectorized code path
arr1 = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 };
arr2 = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 42, 15, 16, 17 };

Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<byte>)arr1, arr2));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<byte>)arr1, arr2, null));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<byte>)arr1, arr2, EqualityComparer<byte>.Default));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<byte>)arr1, arr2, NonDefaultEqualityComparer<byte>.Instance));

Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, null));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, EqualityComparer<byte>.Default));
Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span<byte>)arr1, arr2, NonDefaultEqualityComparer<byte>.Instance));
}

[Fact]
public static void PartialEquals_ReturnsPrefixLength_ValueType()
{
int[] arr1 = new int[] { 1, 2, 3 };
int[] arr2 = new int[] { 1, 2, 6 };

Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, null));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, EqualityComparer<int>.Default));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, NonDefaultEqualityComparer<int>.Instance));

Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, null));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, EqualityComparer<int>.Default));
Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, NonDefaultEqualityComparer<int>.Instance));

// Vectorized code path
arr1 = new int[] { 1, 2, 3, 4, 5 };
arr2 = new int[] { 1, 2, 3, 6, 7 };

Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, null));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, EqualityComparer<int>.Default));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan<int>)arr1, arr2, NonDefaultEqualityComparer<int>.Instance));

Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, null));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, EqualityComparer<int>.Default));
Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span<int>)arr1, arr2, NonDefaultEqualityComparer<int>.Instance));
}

[Fact]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2033,6 +2033,18 @@ public static int CommonPrefixLength<T>(this Span<T> span, ReadOnlySpan<T> other
/// <returns>The length of the common prefix shared by the two spans. If there's no shared prefix, 0 is returned.</returns>
public static int CommonPrefixLength<T>(this ReadOnlySpan<T> span, ReadOnlySpan<T> other)
{
if (RuntimeHelpers.IsBitwiseEquatable<T>())
{
nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going to suggest nuint.Min but then I remembered that IntPtr and UIntPtr still have them explicitly implemented until the work around dotnet/csharplang#6031 goes in.

nuint size = (uint)Unsafe.SizeOf<T>();
nuint index = SpanHelpers.CommonPrefixLength(
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(span)),
ref Unsafe.As<T, byte>(ref MemoryMarshal.GetReference(other)),
length * size);

return (int)(index / size);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The use of byte and then dividing to get rid of partials is clever... but deserves a big comment :)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Too verbose now?

}

// Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until
// the Length of one of them and at least have bounds checks removed from that one.
SliceLongerSpanToMatchShorterLength(ref span, ref other);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2111,6 +2111,95 @@ public static unsafe int SequenceCompareTo(ref byte first, int firstLength, ref
return firstLength - secondLength;
}

public static nuint CommonPrefixLength(ref byte first, ref byte second, nuint length)
{
nuint i;

if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128<byte>.Count)
gfoidl marked this conversation as resolved.
Show resolved Hide resolved
{
// To have kind of fast path for small inputs, we handle as much elements needed
// so that either we are done or can use the unrolled loop below.
i = length % 4;

if (i > 0)
{
if (first != second)
{
return 0;
}

if (i > 1)
{
if (Unsafe.Add(ref first, 1) != Unsafe.Add(ref second, 1))
{
return 1;
}

if (i > 2 && Unsafe.Add(ref first, 2) != Unsafe.Add(ref second, 2))
{
return 2;
}
}
}

for (; (nint)i <= (nint)length - 4; i += 4)
{
if (Unsafe.Add(ref first, i + 0) != Unsafe.Add(ref second, i + 0)) return i + 0;
if (Unsafe.Add(ref first, i + 1) != Unsafe.Add(ref second, i + 1)) return i + 1;
if (Unsafe.Add(ref first, i + 2) != Unsafe.Add(ref second, i + 2)) return i + 2;
if (Unsafe.Add(ref first, i + 3) != Unsafe.Add(ref second, i + 3)) return i + 3;
}

return length;
}

Debug.Assert(length >= (uint)Vector128<byte>.Count);

int mask;
nuint lengthToExamine = length - (nuint)Vector128<byte>.Count;

Vector128<byte> firstVec;
Vector128<byte> secondVec;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Remove this and do Vector128.Equals(firstVector, Vector128.LoadUnsafe(ref second, i)) as the JIT tends to optimize the load better in this scenario

-- It does still fold the load most of the time, but there are cases where storing to a local can make it "worse"

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as the JIT tends to optimize the load better in this scenario

Good to know, thanks!

Checked codegen, no difference.

Vector128<byte> maskVec;
i = 0;

if (lengthToExamine != 0)
{
do
{
firstVec = Vector128.LoadUnsafe(ref first, i);
secondVec = Vector128.LoadUnsafe(ref second, i);
maskVec = Vector128.Equals(firstVec, secondVec);
mask = (int)maskVec.ExtractMostSignificantBits();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why convert to int here and below? TrailingZeroCount has an overload that takes uint and i is uint (and its also available as uint.TrailingZeroCount now)

So keeping it as uint just simplifies things and reduces code complexity.

Copy link
Member Author

@gfoidl gfoidl Apr 21, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏻

uint.TrailingZeroCount just delegates to BitOperations.TrailingZeroCount:

public static uint TrailingZeroCount(uint value) => (uint)BitOperations.TrailingZeroCount(value);

So it's a bit more work for the JIT. What is peferable now?
I like the generic math approach a bit more, but no strong reason for it, only feeling.
(Besides that uint.TrailingZeroCount returns uint and not int).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is peferable now?

We don't have any official guidance right yet. My personal recommendation is that if you need a uint back out to just use the uint.TrailingZeroCount version.


if (mask != 0xFFFF)
{
goto Found;
}

i += (nuint)Vector128<byte>.Count;
} while (i < lengthToExamine);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: using while (i < lengthToExamine) { } would cover the if check and reduce nesting

}

// Do final compare as Vector128<byte>.Count from end rather than start
i = lengthToExamine;
firstVec = Vector128.LoadUnsafe(ref first, i);
secondVec = Vector128.LoadUnsafe(ref second, i);
maskVec = Vector128.Equals(firstVec, secondVec);
mask = (int)maskVec.ExtractMostSignificantBits();

if (mask != 0xFFFF)
{
goto Found;
}

return length;

Found:
mask = ~mask;
return i + (uint)BitOperations.TrailingZeroCount(mask);
}

// Vector sub-search adapted from https://github.com/aspnet/KestrelHttpServer/pull/1138
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static int LocateLastFoundByte(Vector<byte> match)
Expand Down