From 6aa20fe982b514156855ff7e9c37413783832be0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnther=20Foidl?= Date: Tue, 19 Apr 2022 11:27:39 +0200 Subject: [PATCH 1/4] Vectorized MemoryExtensions.CommonPrefixLength --- .../src/System/MemoryExtensions.cs | 12 +++ .../src/System/SpanHelpers.Byte.cs | 89 +++++++++++++++++++ 2 files changed, 101 insertions(+) diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index 560925d532fdf..6c1593556971d 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -2033,6 +2033,18 @@ public static int CommonPrefixLength(this Span span, ReadOnlySpan other /// The length of the common prefix shared by the two spans. If there's no shared prefix, 0 is returned. public static int CommonPrefixLength(this ReadOnlySpan span, ReadOnlySpan other) { + if (RuntimeHelpers.IsBitwiseEquatable()) + { + nuint length = Math.Min((nuint)(uint)span.Length, (nuint)(uint)other.Length); + nuint size = (uint)Unsafe.SizeOf(); + nuint index = SpanHelpers.CommonPrefixLength( + ref Unsafe.As(ref MemoryMarshal.GetReference(span)), + ref Unsafe.As(ref MemoryMarshal.GetReference(other)), + length * size); + + return (int)(index / size); + } + // Shrink one of the spans if necessary to ensure they're both the same length. We can then iterate until // the Length of one of them and at least have bounds checks removed from that one. SliceLongerSpanToMatchShorterLength(ref span, ref other); diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index 5bd50caa89f5e..de15c9ad8bd48 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -2111,6 +2111,95 @@ public static unsafe int SequenceCompareTo(ref byte first, int firstLength, ref return firstLength - secondLength; } + public static nuint CommonPrefixLength(ref byte first, ref byte second, nuint length) + { + nuint i; + + if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128.Count) + { + // To have kind of fast path for small inputs, we handle as much elements needed + // so that either we are done or can use the unrolled loop below. + i = length % 4; + + if (i > 0) + { + if (first != second) + { + return 0; + } + + if (i > 1) + { + if (Unsafe.Add(ref first, 1) != Unsafe.Add(ref second, 1)) + { + return 1; + } + + if (i > 2 && Unsafe.Add(ref first, 2) != Unsafe.Add(ref second, 2)) + { + return 2; + } + } + } + + for (; (nint)i <= (nint)length - 4; i += 4) + { + if (Unsafe.Add(ref first, i + 0) != Unsafe.Add(ref second, i + 0)) return i + 0; + if (Unsafe.Add(ref first, i + 1) != Unsafe.Add(ref second, i + 1)) return i + 1; + if (Unsafe.Add(ref first, i + 2) != Unsafe.Add(ref second, i + 2)) return i + 2; + if (Unsafe.Add(ref first, i + 3) != Unsafe.Add(ref second, i + 3)) return i + 3; + } + + return length; + } + + Debug.Assert(length >= (uint)Vector128.Count); + + int mask; + nuint lengthToExamine = length - (nuint)Vector128.Count; + + Vector128 firstVec; + Vector128 secondVec; + Vector128 maskVec; + i = 0; + + if (lengthToExamine != 0) + { + do + { + firstVec = Vector128.LoadUnsafe(ref first, i); + secondVec = Vector128.LoadUnsafe(ref second, i); + maskVec = Vector128.Equals(firstVec, secondVec); + mask = (int)maskVec.ExtractMostSignificantBits(); + + if (mask != 0xFFFF) + { + goto Found; + } + + i += (nuint)Vector128.Count; + } while (i < lengthToExamine); + } + + // Do final compare as Vector128.Count from end rather than start + i = lengthToExamine; + firstVec = Vector128.LoadUnsafe(ref first, i); + secondVec = Vector128.LoadUnsafe(ref second, i); + maskVec = Vector128.Equals(firstVec, secondVec); + mask = (int)maskVec.ExtractMostSignificantBits(); + + if (mask != 0xFFFF) + { + goto Found; + } + + return length; + + Found: + mask = ~mask; + return i + (uint)BitOperations.TrailingZeroCount(mask); + } + // Vector sub-search adapted from https://github.com/aspnet/KestrelHttpServer/pull/1138 [MethodImpl(MethodImplOptions.AggressiveInlining)] private static int LocateLastFoundByte(Vector match) From 4d4b1062fe72dd5df0b432d923773318882c9ac1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnther=20Foidl?= Date: Tue, 19 Apr 2022 11:52:14 +0200 Subject: [PATCH 2/4] Tests --- .../tests/Span/CommonPrefixLength.T.cs | 47 ++++++++++++++++++- 1 file changed, 46 insertions(+), 1 deletion(-) diff --git a/src/libraries/System.Memory/tests/Span/CommonPrefixLength.T.cs b/src/libraries/System.Memory/tests/Span/CommonPrefixLength.T.cs index c8ddc58d924d9..5092510d795e0 100644 --- a/src/libraries/System.Memory/tests/Span/CommonPrefixLength.T.cs +++ b/src/libraries/System.Memory/tests/Span/CommonPrefixLength.T.cs @@ -62,7 +62,7 @@ private static void ValidateWithDefaultValues(int length1, int length2, IEqua } [Fact] - public static void PartialEquals_ReturnsPrefixLength_ValueType() + public static void PartialEquals_ReturnsPrefixLength_Byte() { byte[] arr1 = new byte[] { 1, 2, 3, 4, 5 }; byte[] arr2 = new byte[] { 1, 2, 3, 6, 7 }; @@ -76,6 +76,51 @@ public static void PartialEquals_ReturnsPrefixLength_ValueType() Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, null)); Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, EqualityComparer.Default)); Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, NonDefaultEqualityComparer.Instance)); + + // Vectorized code path + arr1 = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17 }; + arr2 = new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 42, 15, 16, 17 }; + + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, null)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, NonDefaultEqualityComparer.Instance)); + + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, null)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(13, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, NonDefaultEqualityComparer.Instance)); + } + + [Fact] + public static void PartialEquals_ReturnsPrefixLength_ValueType() + { + int[] arr1 = new int[] { 1, 2, 3 }; + int[] arr2 = new int[] { 1, 2, 6 }; + + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, null)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, NonDefaultEqualityComparer.Instance)); + + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, null)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(2, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, NonDefaultEqualityComparer.Instance)); + + // Vectorized code path + arr1 = new int[] { 1, 2, 3, 4, 5 }; + arr2 = new int[] { 1, 2, 3, 6, 7 }; + + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, null)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((ReadOnlySpan)arr1, arr2, NonDefaultEqualityComparer.Instance)); + + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, null)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, EqualityComparer.Default)); + Assert.Equal(3, MemoryExtensions.CommonPrefixLength((Span)arr1, arr2, NonDefaultEqualityComparer.Instance)); } [Fact] From a1a723d9fabe7930350ed3bfe271211cc5f2ee14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnther=20Foidl?= Date: Thu, 21 Apr 2022 10:54:50 +0200 Subject: [PATCH 3/4] Addressed feedback --- .../src/System/MemoryExtensions.cs | 11 ++++++ .../src/System/SpanHelpers.Byte.cs | 37 ++++++++----------- 2 files changed, 27 insertions(+), 21 deletions(-) diff --git a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs index 6c1593556971d..ada738f12b653 100644 --- a/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs +++ b/src/libraries/System.Private.CoreLib/src/System/MemoryExtensions.cs @@ -2042,6 +2042,17 @@ ref Unsafe.As(ref MemoryMarshal.GetReference(span)), ref Unsafe.As(ref MemoryMarshal.GetReference(other)), length * size); + // A byte-wise comparison in CommonPrefixLength can be used for multi-byte types, + // that are bitwise-equatable, too. In order to get the correct index in terms of type T + // of the first mismatch, integer division by the size of T is used. + // + // Example for short: + // index (byte-based): b-1, b, b+1, b+2, b+3 + // index (short-based): s-1, s, s+1 + // byte sequence 1: { ..., [0x42, 0x43], [0x37, 0x38], ... } + // byte sequence 2: { ..., [0x42, 0x43], [0x37, 0xAB], ... } + // So the mismatch is a byte-index b+3, which gives integer divided by the size of short: + // 3 / 2 = 1, thus the expected index short-based. return (int)(index / size); } diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index de15c9ad8bd48..fc2bf0794f12e 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -2155,39 +2155,34 @@ public static nuint CommonPrefixLength(ref byte first, ref byte second, nuint le Debug.Assert(length >= (uint)Vector128.Count); - int mask; + uint mask; nuint lengthToExamine = length - (nuint)Vector128.Count; - Vector128 firstVec; - Vector128 secondVec; Vector128 maskVec; i = 0; - if (lengthToExamine != 0) + while (i < lengthToExamine) { - do - { - firstVec = Vector128.LoadUnsafe(ref first, i); - secondVec = Vector128.LoadUnsafe(ref second, i); - maskVec = Vector128.Equals(firstVec, secondVec); - mask = (int)maskVec.ExtractMostSignificantBits(); + maskVec = Vector128.Equals( + Vector128.LoadUnsafe(ref first, i), + Vector128.LoadUnsafe(ref second, i)); - if (mask != 0xFFFF) - { - goto Found; - } + mask = maskVec.ExtractMostSignificantBits(); + if (mask != 0xFFFF) + { + goto Found; + } - i += (nuint)Vector128.Count; - } while (i < lengthToExamine); + i += (nuint)Vector128.Count; } // Do final compare as Vector128.Count from end rather than start i = lengthToExamine; - firstVec = Vector128.LoadUnsafe(ref first, i); - secondVec = Vector128.LoadUnsafe(ref second, i); - maskVec = Vector128.Equals(firstVec, secondVec); - mask = (int)maskVec.ExtractMostSignificantBits(); + maskVec = Vector128.Equals( + Vector128.LoadUnsafe(ref first, i), + Vector128.LoadUnsafe(ref second, i)); + mask = maskVec.ExtractMostSignificantBits(); if (mask != 0xFFFF) { goto Found; @@ -2197,7 +2192,7 @@ public static nuint CommonPrefixLength(ref byte first, ref byte second, nuint le Found: mask = ~mask; - return i + (uint)BitOperations.TrailingZeroCount(mask); + return i + uint.TrailingZeroCount(mask); } // Vector sub-search adapted from https://github.com/aspnet/KestrelHttpServer/pull/1138 From a6c603aa605142f1f61859031c7d8e1cf077e93f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=BCnther=20Foidl?= Date: Thu, 21 Apr 2022 17:54:13 +0200 Subject: [PATCH 4/4] Update src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs Co-authored-by: Tanner Gooding --- .../System.Private.CoreLib/src/System/SpanHelpers.Byte.cs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs index fc2bf0794f12e..7282f73c7d8fa 100644 --- a/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs +++ b/src/libraries/System.Private.CoreLib/src/System/SpanHelpers.Byte.cs @@ -2115,6 +2115,8 @@ public static nuint CommonPrefixLength(ref byte first, ref byte second, nuint le { nuint i; + // It is ordered this way to match the default branch predictor rules, to don't have too much + // overhead for short input-lengths. if (!Vector128.IsHardwareAccelerated || length < (nuint)Vector128.Count) { // To have kind of fast path for small inputs, we handle as much elements needed