From a451e68bb15302f81a90546d84ef1b87dbc30449 Mon Sep 17 00:00:00 2001 From: Tanner Gooding Date: Thu, 12 Oct 2023 10:37:21 -0700 Subject: [PATCH] Improve a vector implementation to support alignment and non-temporal tores (#93296) * Improve a vector implementation to support alignment and non-temporal stores * Fix a build error and mark a couple methods as AggressiveInlining * Fix the remaining block count computation * Ensure overlapping for small data on the V256/512 is handled * Ensure we only go down the vectorized path when supported for netstandard --- .../Tensors/TensorPrimitives.netcore.cs | 896 +++++++++++++++++- .../Tensors/TensorPrimitives.netstandard.cs | 279 +++++- .../tests/TensorPrimitivesTests.cs | 2 +- 3 files changed, 1102 insertions(+), 75 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index 869efd2714e03..8166797f34896 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -13,6 +13,8 @@ namespace System.Numerics.Tensors { public static partial class TensorPrimitives { + private const nuint NonTemporalByteThreshold = 256 * 1024; + /// /// Copies to , converting each /// value to its nearest representable half-precision floating-point value. @@ -1224,102 +1226,892 @@ private static void InvokeSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); #if NET8_0_OR_GREATER if (Vector512.IsHardwareAccelerated) { - oneVectorFromEnd = x.Length - Vector512.Count; - if (i <= oneVectorFromEnd) + if (remainder >= (uint)(Vector512.Count)) { - // Loop handling one vector at a time. - do + Vectorized512(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized512Small(ref xRef, ref dRef, remainder); + } + + return; + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector256.Count)) + { + Vectorized256(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized256Small(ref xRef, ref dRef, remainder); + } + + return; + } + + if (Vector128.IsHardwareAccelerated) + { + if (remainder >= (uint)(Vector128.Count)) + { + Vectorized128(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + Vectorized128Small(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + } + } + + static void Vectorized128(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + if (remainder > (uint)(Vector128.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) { - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* dPtr = pd; - i += Vector512.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector128)) - ((nuint)(dPtr) % (uint)(sizeof(Vector128)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector128))) == 0); + + remainder -= misalignment; + } + + Vector128 vector1; + Vector128 vector2; + Vector128 vector3; + Vector128 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector128.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 0)); + vector2.Store(dPtr + (uint)(Vector128.Count * 1)); + vector3.Store(dPtr + (uint)(Vector128.Count * 2)); + vector4.Store(dPtr + (uint)(Vector128.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector128.Load(xPtr + (uint)(Vector128.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector128.Count * 4)); + vector2.Store(dPtr + (uint)(Vector128.Count * 5)); + vector3.Store(dPtr + (uint)(Vector128.Count * 6)); + vector4.Store(dPtr + (uint)(Vector128.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector128.Count * 8); + dPtr += (uint)(Vector128.Count * 8); + + remainder -= (uint)(Vector128.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector128.Count - 1)) & (nuint)(-Vector128.Count); + + switch (remainder / (uint)(Vector128.Count)) + { + case 8: { - uint lastVectorIndex = (uint)(x.Length - Vector512.Count); - Vector512.ConditionalSelect( - Vector512.Equals(CreateRemainderMaskSingleVector512(x.Length - i), Vector512.Zero), - Vector512.LoadUnsafe(ref dRef, lastVectorIndex), - TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 8)); + goto case 7; } - return; + case 7: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 7)); + goto case 6; + } + + case 6: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 6)); + goto case 5; + } + + case 5: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 5)); + goto case 4; + } + + case 4: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 4)); + goto case 3; + } + + case 3: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 3)); + goto case 2; + } + + case 2: + { + Vector128 vector = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector128.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } -#endif - if (Vector256.IsHardwareAccelerated) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized128Small(ref float xRef, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector256.Count; - if (i <= oneVectorFromEnd) + switch (remainder) { - // Loop handling one vector at a time. - do + case 3: { - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } - i += Vector256.Count; + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; } - while (i <= oneVectorFromEnd); - // Handle any remaining elements with a final vector. - if (i != x.Length) + case 1: { - uint lastVectorIndex = (uint)(x.Length - Vector256.Count); - Vector256.ConditionalSelect( - Vector256.Equals(CreateRemainderMaskSingleVector256(x.Length - i), Vector256.Zero), - Vector256.LoadUnsafe(ref dRef, lastVectorIndex), - TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + dRef = TUnaryOperator.Invoke(xRef); + break; } - return; + default: + { + Debug.Assert(remainder == 0); + break; + } } } - if (Vector128.IsHardwareAccelerated) + static void Vectorized256(ref float xRef, ref float dRef, nuint remainder) { - oneVectorFromEnd = x.Length - Vector128.Count; - if (i <= oneVectorFromEnd) + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + if (remainder > (uint)(Vector256.Count * 8)) { - // Loop handling one vector at a time. - do + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) { - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, (uint)i)).StoreUnsafe(ref dRef, (uint)i); + float* xPtr = px; + float* dPtr = pd; - i += Vector128.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector256)) - ((nuint)(dPtr) % (uint)(sizeof(Vector256)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector256))) == 0); + + remainder -= misalignment; + } + + Vector256 vector1; + Vector256 vector2; + Vector256 vector3; + Vector256 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector256.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 0)); + vector2.Store(dPtr + (uint)(Vector256.Count * 1)); + vector3.Store(dPtr + (uint)(Vector256.Count * 2)); + vector4.Store(dPtr + (uint)(Vector256.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector256.Load(xPtr + (uint)(Vector256.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector256.Count * 4)); + vector2.Store(dPtr + (uint)(Vector256.Count * 5)); + vector3.Store(dPtr + (uint)(Vector256.Count * 6)); + vector4.Store(dPtr + (uint)(Vector256.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector256.Count * 8); + dPtr += (uint)(Vector256.Count * 8); + + remainder -= (uint)(Vector256.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector256.Count - 1)) & (nuint)(-Vector256.Count); + + switch (remainder / (uint)(Vector256.Count)) + { + case 8: { - uint lastVectorIndex = (uint)(x.Length - Vector128.Count); - Vector128.ConditionalSelect( - Vector128.Equals(CreateRemainderMaskSingleVector128(x.Length - i), Vector128.Zero), - Vector128.LoadUnsafe(ref dRef, lastVectorIndex), - TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, lastVectorIndex))).StoreUnsafe(ref dRef, lastVectorIndex); + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 8)); + goto case 7; } - return; + case 7: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 7)); + goto case 6; + } + + case 6: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 6)); + goto case 5; + } + + case 5: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 5)); + goto case 4; + } + + case 4: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 4)); + goto case 3; + } + + case 3: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 3)); + goto case 2; + } + + case 2: + { + Vector256 vector = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector256.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } } } - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized256Small(ref float xRef, ref float dRef, nuint remainder) { - Unsafe.Add(ref dRef, i) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, i)); + switch (remainder) + { + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); - i++; + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } + } + +#if NET8_0_OR_GREATER + static void Vectorized512(ref float xRef, ref float dRef, nuint remainder) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector512 beg = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef)); + Vector512 end = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count))); + + if (remainder > (uint)(Vector512.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) + { + float* xPtr = px; + float* dPtr = pd; + + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector512)) - ((nuint)(dPtr) % (uint)(sizeof(Vector512)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector512))) == 0); + + remainder -= misalignment; + } + + Vector512 vector1; + Vector512 vector2; + Vector512 vector3; + Vector512 vector4; + + if (canAlign && (remainder > (NonTemporalByteThreshold / sizeof(float)))) + { + // This loop stores the data non-temporally, which benefits us when there + // is a large amount of data involved as it avoids polluting the cache. + + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 0)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 1)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 2)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 4)); + vector2.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 5)); + vector3.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 6)); + vector4.StoreAlignedNonTemporal(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + else + { + while (remainder >= (uint)(Vector512.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 0))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 1))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 2))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 3))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 0)); + vector2.Store(dPtr + (uint)(Vector512.Count * 1)); + vector3.Store(dPtr + (uint)(Vector512.Count * 2)); + vector4.Store(dPtr + (uint)(Vector512.Count * 3)); + + // We load, process, and store the next four vectors + + vector1 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 4))); + vector2 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 5))); + vector3 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 6))); + vector4 = TUnaryOperator.Invoke(Vector512.Load(xPtr + (uint)(Vector512.Count * 7))); + + vector1.Store(dPtr + (uint)(Vector512.Count * 4)); + vector2.Store(dPtr + (uint)(Vector512.Count * 5)); + vector3.Store(dPtr + (uint)(Vector512.Count * 6)); + vector4.Store(dPtr + (uint)(Vector512.Count * 7)); + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector512.Count * 8); + dPtr += (uint)(Vector512.Count * 8); + + remainder -= (uint)(Vector512.Count * 8); + } + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; + } + } + + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector512.Count - 1)) & (nuint)(-Vector512.Count); + + switch (remainder / (uint)(Vector512.Count)) + { + case 8: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 8))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 8)); + goto case 7; + } + + case 7: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 7))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 7)); + goto case 6; + } + + case 6: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 6))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 6)); + goto case 5; + } + + case 5: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 5))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 5)); + goto case 4; + } + + case 4: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 4))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 4)); + goto case 3; + } + + case 3: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 3))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 3)); + goto case 2; + } + + case 2: + { + Vector512 vector = TUnaryOperator.Invoke(Vector512.LoadUnsafe(ref xRef, remainder - (uint)(Vector512.Count * 2))); + vector.StoreUnsafe(ref dRef, remainder - (uint)(Vector512.Count * 2)); + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + end.StoreUnsafe(ref dRef, endIndex - (uint)Vector512.Count); + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + beg.StoreUnsafe(ref dRefBeg); + break; + } + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void Vectorized512Small(ref float xRef, ref float dRef, nuint remainder) + { + switch (remainder) + { + case 15: + case 14: + case 13: + case 12: + case 11: + case 10: + case 9: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + Vector256 end = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef, remainder - (uint)(Vector256.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector256.Count)); + + break; + } + + case 8: + { + Debug.Assert(Vector256.IsHardwareAccelerated); + + Vector256 beg = TUnaryOperator.Invoke(Vector256.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 7: + case 6: + case 5: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + Vector128 end = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef, remainder - (uint)(Vector128.Count))); + + beg.StoreUnsafe(ref dRef); + end.StoreUnsafe(ref dRef, remainder - (uint)(Vector128.Count)); + + break; + } + + case 4: + { + Debug.Assert(Vector128.IsHardwareAccelerated); + + Vector128 beg = TUnaryOperator.Invoke(Vector128.LoadUnsafe(ref xRef)); + beg.StoreUnsafe(ref dRef); + + break; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = TUnaryOperator.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = TUnaryOperator.Invoke(xRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } +#endif } /// diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs index 43cb3d5770ead..cde38a70fbbef 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netstandard.cs @@ -313,7 +313,7 @@ private static float MinMaxCore(ReadOnlySpan x, TMinMaxO /// Performs an element-wise operation on and writes the results to . /// Specifies the operation to perform on each element loaded from . - private static void InvokeSpanIntoSpan( + private static unsafe void InvokeSpanIntoSpan( ReadOnlySpan x, Span destination, TUnaryOperator op = default) where TUnaryOperator : struct, IUnaryOperator { @@ -324,45 +324,269 @@ private static void InvokeSpanIntoSpan( ValidateInputOutputSpanNonOverlapping(x, destination); + // Since every branch has a cost and since that cost is + // essentially lost for larger inputs, we do branches + // in a way that allows us to have the minimum possible + // for small sizes + ref float xRef = ref MemoryMarshal.GetReference(x); ref float dRef = ref MemoryMarshal.GetReference(destination); - int i = 0, oneVectorFromEnd; + + nuint remainder = (uint)(x.Length); if (Vector.IsHardwareAccelerated && op.CanVectorize) { - oneVectorFromEnd = x.Length - Vector.Count; - if (oneVectorFromEnd >= 0) + if (remainder >= (uint)(Vector.Count)) { - // Loop handling one vector at a time. - do + Vectorized(ref xRef, ref dRef, remainder); + } + else + { + // We have less than a vector and so we can only handle this as scalar. To do this + // efficiently, we simply have a small jump table and fallthrough. So we get a simple + // length check, single jump, and then linear execution. + + VectorizedSmall(ref xRef, ref dRef, remainder); + } + + return; + } + + // This is the software fallback when no acceleration is available + // It requires no branches to hit + + SoftwareFallback(ref xRef, ref dRef, remainder); + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void SoftwareFallback(ref float xRef, ref float dRef, nuint length, TUnaryOperator op = default) + { + for (nuint i = 0; i < length; i++) + { + Unsafe.Add(ref dRef, (nint)(i)) = op.Invoke(Unsafe.Add(ref xRef, (nint)(i))); + } + } + + static void Vectorized(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) + { + ref float dRefBeg = ref dRef; + + // Preload the beginning and end so that overlapping accesses don't negatively impact the data + + Vector beg = op.Invoke(AsVector(ref xRef)); + Vector end = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count))); + + if (remainder > (uint)(Vector.Count * 8)) + { + // Pinning is cheap and will be short lived for small inputs and unlikely to be impactful + // for large inputs (> 85KB) which are on the LOH and unlikely to be compacted. + + fixed (float* px = &xRef) + fixed (float* pd = &dRef) { - AsVector(ref dRef, i) = op.Invoke(AsVector(ref xRef, i)); + float* xPtr = px; + float* dPtr = pd; - i += Vector.Count; + // We need to the ensure the underlying data can be aligned and only align + // it if it can. It is possible we have an unaligned ref, in which case we + // can never achieve the required SIMD alignment. + + bool canAlign = ((nuint)(dPtr) % sizeof(float)) == 0; + + if (canAlign) + { + // Compute by how many elements we're misaligned and adjust the pointers accordingly + // + // Noting that we are only actually aligning dPtr. THis is because unaligned stores + // are more expensive than unaligned loads and aligning both is significantly more + // complex. + + nuint misalignment = ((uint)(sizeof(Vector)) - ((nuint)(dPtr) % (uint)(sizeof(Vector)))) / sizeof(float); + + xPtr += misalignment; + dPtr += misalignment; + + Debug.Assert(((nuint)(dPtr) % (uint)(sizeof(Vector))) == 0); + + remainder -= misalignment; + } + + Vector vector1; + Vector vector2; + Vector vector3; + Vector vector4; + + while (remainder >= (uint)(Vector.Count * 8)) + { + // We load, process, and store the first four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 0))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 1))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 2))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 3))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 0)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 1)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 2)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 3)) = vector4; + + // We load, process, and store the next four vectors + + vector1 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 4))); + vector2 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 5))); + vector3 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 6))); + vector4 = op.Invoke(*(Vector*)(xPtr + (uint)(Vector.Count * 7))); + + *(Vector*)(dPtr + (uint)(Vector.Count * 4)) = vector1; + *(Vector*)(dPtr + (uint)(Vector.Count * 5)) = vector2; + *(Vector*)(dPtr + (uint)(Vector.Count * 6)) = vector3; + *(Vector*)(dPtr + (uint)(Vector.Count * 7)) = vector4; + + // We adjust the source and destination references, then update + // the count of remaining elements to process. + + xPtr += (uint)(Vector.Count * 8); + dPtr += (uint)(Vector.Count * 8); + + remainder -= (uint)(Vector.Count * 8); + } + + // Adjusting the refs here allows us to avoid pinning for very small inputs + + xRef = ref *xPtr; + dRef = ref *dPtr; } - while (i <= oneVectorFromEnd); + } - // Handle any remaining elements with a final vector. - if (i != x.Length) + // Process the remaining [Count, Count * 8] elements via a jump table + // + // Unless the original length was an exact multiple of Count, then we'll + // end up reprocessing a couple elements in case 1 for end. We'll also + // potentially reprocess a few elements in case 0 for beg, to handle any + // data before the first aligned address. + + nuint endIndex = remainder; + remainder = (remainder + (uint)(Vector.Count - 1)) & (nuint)(-Vector.Count); + + switch (remainder / (uint)(Vector.Count)) + { + case 8: { - int lastVectorIndex = x.Length - Vector.Count; - ref Vector dest = ref AsVector(ref dRef, lastVectorIndex); - dest = Vector.ConditionalSelect( - Vector.Equals(CreateRemainderMaskSingleVector(x.Length - i), Vector.Zero), - dest, - op.Invoke(AsVector(ref xRef, lastVectorIndex))); + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 8))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 8)) = vector; + goto case 7; } - return; + case 7: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 7))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 7)) = vector; + goto case 6; + } + + case 6: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 6))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 6)) = vector; + goto case 5; + } + + case 5: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 5))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 5)) = vector; + goto case 4; + } + + case 4: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 4))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 4)) = vector; + goto case 3; + } + + case 3: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 3))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 3)) = vector; + goto case 2; + } + + case 2: + { + Vector vector = op.Invoke(AsVector(ref xRef, remainder - (uint)(Vector.Count * 2))); + AsVector(ref dRef, remainder - (uint)(Vector.Count * 2)) = vector; + goto case 1; + } + + case 1: + { + // Store the last block, which includes any elements that wouldn't fill a full vector + AsVector(ref dRef, endIndex - (uint)Vector.Count) = end; + goto default; + } + + default: + { + // Store the first block, which includes any elements preceding the first aligned block + AsVector(ref dRefBeg) = beg; + break; + } } } - // Loop handling one element at a time. - while (i < x.Length) + [MethodImpl(MethodImplOptions.AggressiveInlining)] + static void VectorizedSmall(ref float xRef, ref float dRef, nuint remainder, TUnaryOperator op = default) { - Unsafe.Add(ref dRef, i) = op.Invoke(Unsafe.Add(ref xRef, i)); + switch (remainder) + { + case 7: + { + Unsafe.Add(ref dRef, 6) = op.Invoke(Unsafe.Add(ref xRef, 6)); + goto case 6; + } - i++; + case 6: + { + Unsafe.Add(ref dRef, 5) = op.Invoke(Unsafe.Add(ref xRef, 5)); + goto case 5; + } + + case 5: + { + Unsafe.Add(ref dRef, 4) = op.Invoke(Unsafe.Add(ref xRef, 4)); + goto case 4; + } + + case 4: + { + Unsafe.Add(ref dRef, 3) = op.Invoke(Unsafe.Add(ref xRef, 3)); + goto case 3; + } + + case 3: + { + Unsafe.Add(ref dRef, 2) = op.Invoke(Unsafe.Add(ref xRef, 2)); + goto case 2; + } + + case 2: + { + Unsafe.Add(ref dRef, 1) = op.Invoke(Unsafe.Add(ref xRef, 1)); + goto case 1; + } + + case 1: + { + dRef = op.Invoke(xRef); + break; + } + + default: + { + Debug.Assert(remainder == 0); + break; + } + } } } @@ -743,12 +967,23 @@ private static void InvokeSpanScalarSpanIntoSpan( } } + /// Loads a from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start) => + ref Unsafe.As>(ref start); + /// Loads a that begins at the specified from . [MethodImpl(MethodImplOptions.AggressiveInlining)] private static ref Vector AsVector(ref float start, int offset) => ref Unsafe.As>( ref Unsafe.Add(ref start, offset)); + /// Loads a that begins at the specified from . + [MethodImpl(MethodImplOptions.AggressiveInlining)] + private static ref Vector AsVector(ref float start, nuint offset) => + ref Unsafe.As>( + ref Unsafe.Add(ref start, (nint)(offset))); + /// Gets whether the specified is negative. private static unsafe bool IsNegative(float f) => *(int*)&f < 0; diff --git a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs index 608d41dea5c5b..ed5317420ae87 100644 --- a/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs +++ b/src/libraries/System.Numerics.Tensors/tests/TensorPrimitivesTests.cs @@ -19,7 +19,7 @@ public static partial class TensorPrimitivesTests TensorLengths.Concat(new object[][] { [0] }); public static IEnumerable TensorLengths => - from length in Enumerable.Range(1, 128) + from length in Enumerable.Range(1, 256) select new object[] { length }; public static IEnumerable VectorLengthAndIteratedRange(float min, float max, float increment)