From cccdc1d7b61da57c677b20a0b92721be2ba0a7a5 Mon Sep 17 00:00:00 2001 From: Spacefish Date: Mon, 16 Oct 2023 17:55:14 +0200 Subject: [PATCH] Vector512 Support for Enumerable.Min/Max (#93369) * Vector512 Support for Enumerable.Min/Max * implement comparer in Min/Max * remove trailing whitespace * remove conditions to only target NET 8 as Linq Packages is versioned anyway * minor codestyle fix * increate Max_AllTypes_TestData to be at least one elemend wider than Vector512 * avoid generating testdata for sbyte which are outside the representable range * dont overflow the datatype inside the unittest * dont generate testdata which leads to overflows in the number type during test --- .../System.Linq/src/System/Linq/Max.cs | 1 + .../System.Linq/src/System/Linq/MaxMin.cs | 27 ++++++++++++++++++- .../System.Linq/src/System/Linq/Min.cs | 1 + src/libraries/System.Linq/tests/MaxTests.cs | 9 ++++--- src/libraries/System.Linq/tests/MinTests.cs | 9 ++++--- 5 files changed, 40 insertions(+), 7 deletions(-) diff --git a/src/libraries/System.Linq/src/System/Linq/Max.cs b/src/libraries/System.Linq/src/System/Linq/Max.cs index 543d300389904..bd08684270f6c 100644 --- a/src/libraries/System.Linq/src/System/Linq/Max.cs +++ b/src/libraries/System.Linq/src/System/Linq/Max.cs @@ -18,6 +18,7 @@ public static partial class Enumerable public static bool Compare(T left, T right) => left > right; public static Vector128 Compare(Vector128 left, Vector128 right) => Vector128.Max(left, right); public static Vector256 Compare(Vector256 left, Vector256 right) => Vector256.Max(left, right); + public static Vector512 Compare(Vector512 left, Vector512 right) => Vector512.Max(left, right); } public static int? Max(this IEnumerable source) => MaxInteger(source); diff --git a/src/libraries/System.Linq/src/System/Linq/MaxMin.cs b/src/libraries/System.Linq/src/System/Linq/MaxMin.cs index 5def04f6cedba..47974c48a1c6f 100644 --- a/src/libraries/System.Linq/src/System/Linq/MaxMin.cs +++ b/src/libraries/System.Linq/src/System/Linq/MaxMin.cs @@ -16,6 +16,7 @@ private interface IMinMaxCalc where T : struct, IBinaryInteger public static abstract bool Compare(T left, T right); public static abstract Vector128 Compare(Vector128 left, Vector128 right); public static abstract Vector256 Compare(Vector256 left, Vector256 right); + public static abstract Vector512 Compare(Vector512 left, Vector512 right); } private static T MinMaxInteger(this IEnumerable source) @@ -66,7 +67,7 @@ private static T MinMaxInteger(this IEnumerable source) } } } - else + else if (!Vector512.IsHardwareAccelerated || span.Length < Vector512.Count) { ref T current = ref MemoryMarshal.GetReference(span); ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector256.Count); @@ -90,6 +91,30 @@ private static T MinMaxInteger(this IEnumerable source) } } } + else + { + ref T current = ref MemoryMarshal.GetReference(span); + ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector512.Count); + + Vector512 best = Vector512.LoadUnsafe(ref current); + current = ref Unsafe.Add(ref current, Vector512.Count); + + while (Unsafe.IsAddressLessThan(ref current, ref lastVectorStart)) + { + best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref current)); + current = ref Unsafe.Add(ref current, Vector512.Count); + } + best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref lastVectorStart)); + + value = best[0]; + for (int i = 1; i < Vector512.Count; i++) + { + if (TMinMax.Compare(best[i], value)) + { + value = best[i]; + } + } + } } else { diff --git a/src/libraries/System.Linq/src/System/Linq/Min.cs b/src/libraries/System.Linq/src/System/Linq/Min.cs index a553f5ddcb00e..6c81274f3013c 100644 --- a/src/libraries/System.Linq/src/System/Linq/Min.cs +++ b/src/libraries/System.Linq/src/System/Linq/Min.cs @@ -18,6 +18,7 @@ public static partial class Enumerable public static bool Compare(T left, T right) => left < right; public static Vector128 Compare(Vector128 left, Vector128 right) => Vector128.Min(left, right); public static Vector256 Compare(Vector256 left, Vector256 right) => Vector256.Min(left, right); + public static Vector512 Compare(Vector512 left, Vector512 right) => Vector512.Min(left, right); } public static int? Min(this IEnumerable source) => MinInteger(source); diff --git a/src/libraries/System.Linq/tests/MaxTests.cs b/src/libraries/System.Linq/tests/MaxTests.cs index 42727d8362085..d9b020e7d4796 100644 --- a/src/libraries/System.Linq/tests/MaxTests.cs +++ b/src/libraries/System.Linq/tests/MaxTests.cs @@ -11,13 +11,16 @@ public class MaxTests : EnumerableTests { public static IEnumerable Max_AllTypes_TestData() { - for (int length = 2; length < 33; length++) + for (int length = 2; length < 65; length++) { yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)(length + length - 1) }; yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)(length + length - 1) }; - yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) }; - yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) }; + // Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue + if ((length + length) < sbyte.MaxValue) { + yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) }; + yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) }; + } yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)(length + length - 1) }; yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)(length + length - 1) }; diff --git a/src/libraries/System.Linq/tests/MinTests.cs b/src/libraries/System.Linq/tests/MinTests.cs index e877dd5911a40..feca6994d066d 100644 --- a/src/libraries/System.Linq/tests/MinTests.cs +++ b/src/libraries/System.Linq/tests/MinTests.cs @@ -11,13 +11,16 @@ public class MinTests : EnumerableTests { public static IEnumerable Min_AllTypes_TestData() { - for (int length = 2; length < 33; length++) + for (int length = 2; length < 65; length++) { yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)length }; yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)length }; - yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length }; - yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length }; + // Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue, otherwise the type overflows + if ((length + length) < sbyte.MaxValue) { + yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length }; + yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length }; + } yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)length }; yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)length };