Skip to content

Commit

Permalink
Vector512 Support for Enumerable<int>.Min/Max (#93369)
Browse files Browse the repository at this point in the history
* Vector512 Support for Enumerable<int>.Min/Max

* implement comparer in Min/Max

* remove trailing whitespace

* remove conditions to only target NET 8 as Linq Packages is versioned anyway

* minor codestyle fix

* increate Max_AllTypes_TestData to be at least one elemend wider than Vector512

* avoid generating testdata for sbyte which are outside the representable range

* dont overflow the datatype inside the unittest

* dont generate testdata which leads to overflows in the number type during test
  • Loading branch information
Spacefish authored Oct 16, 2023
1 parent 3eae196 commit cccdc1d
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 7 deletions.
1 change: 1 addition & 0 deletions src/libraries/System.Linq/src/System/Linq/Max.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static partial class Enumerable
public static bool Compare(T left, T right) => left > right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Max(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Max(left, right);
public static Vector512<T> Compare(Vector512<T> left, Vector512<T> right) => Vector512.Max(left, right);
}

public static int? Max(this IEnumerable<int?> source) => MaxInteger(source);
Expand Down
27 changes: 26 additions & 1 deletion src/libraries/System.Linq/src/System/Linq/MaxMin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ private interface IMinMaxCalc<T> where T : struct, IBinaryInteger<T>
public static abstract bool Compare(T left, T right);
public static abstract Vector128<T> Compare(Vector128<T> left, Vector128<T> right);
public static abstract Vector256<T> Compare(Vector256<T> left, Vector256<T> right);
public static abstract Vector512<T> Compare(Vector512<T> left, Vector512<T> right);
}

private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
Expand Down Expand Up @@ -66,7 +67,7 @@ private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
}
}
}
else
else if (!Vector512.IsHardwareAccelerated || span.Length < Vector512<T>.Count)
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector256<T>.Count);
Expand All @@ -90,6 +91,30 @@ private static T MinMaxInteger<T, TMinMax>(this IEnumerable<T> source)
}
}
}
else
{
ref T current = ref MemoryMarshal.GetReference(span);
ref T lastVectorStart = ref Unsafe.Add(ref current, span.Length - Vector512<T>.Count);

Vector512<T> best = Vector512.LoadUnsafe(ref current);
current = ref Unsafe.Add(ref current, Vector512<T>.Count);

while (Unsafe.IsAddressLessThan(ref current, ref lastVectorStart))
{
best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref current));
current = ref Unsafe.Add(ref current, Vector512<T>.Count);
}
best = TMinMax.Compare(best, Vector512.LoadUnsafe(ref lastVectorStart));

value = best[0];
for (int i = 1; i < Vector512<T>.Count; i++)
{
if (TMinMax.Compare(best[i], value))
{
value = best[i];
}
}
}
}
else
{
Expand Down
1 change: 1 addition & 0 deletions src/libraries/System.Linq/src/System/Linq/Min.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ public static partial class Enumerable
public static bool Compare(T left, T right) => left < right;
public static Vector128<T> Compare(Vector128<T> left, Vector128<T> right) => Vector128.Min(left, right);
public static Vector256<T> Compare(Vector256<T> left, Vector256<T> right) => Vector256.Min(left, right);
public static Vector512<T> Compare(Vector512<T> left, Vector512<T> right) => Vector512.Min(left, right);
}

public static int? Min(this IEnumerable<int?> source) => MinInteger(source);
Expand Down
9 changes: 6 additions & 3 deletions src/libraries/System.Linq/tests/MaxTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ public class MaxTests : EnumerableTests
{
public static IEnumerable<object[]> Max_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
for (int length = 2; length < 65; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)(length + length - 1) };

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) };
// Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue
if ((length + length) < sbyte.MaxValue) {
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)(length + length - 1) };
}

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)(length + length - 1) };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)(length + length - 1) };
Expand Down
9 changes: 6 additions & 3 deletions src/libraries/System.Linq/tests/MinTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,16 @@ public class MinTests : EnumerableTests
{
public static IEnumerable<object[]> Min_AllTypes_TestData()
{
for (int length = 2; length < 33; length++)
for (int length = 2; length < 65; length++)
{
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i)), (byte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (byte)i).ToArray()), (byte)length };

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length };
// Unit Tests does +T.One so we should generate data up to one value below sbyte.MaxValue, otherwise the type overflows
if ((length + length) < sbyte.MaxValue) {
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i)), (sbyte)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (sbyte)i).ToArray()), (sbyte)length };
}

yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i)), (ushort)length };
yield return new object[] { Shuffler.Shuffle(Enumerable.Range(length, length).Select(i => (ushort)i).ToArray()), (ushort)length };
Expand Down

0 comments on commit cccdc1d

Please sign in to comment.