From fdff01f3deb370bfff54eaa2aff71a333506ff6e Mon Sep 17 00:00:00 2001 From: Stephen Toub Date: Fri, 29 Sep 2023 16:25:31 -0400 Subject: [PATCH] Vectorize TensorPrimitives.ConvertToSingle (#92779) * Vectorize TensorPrimitives.ConvertToSingle * Address PR feedback --- .../Tensors/TensorPrimitives.netcore.cs | 250 +++++++++++++++++- .../System.Private.CoreLib/src/System/Half.cs | 4 +- 2 files changed, 250 insertions(+), 4 deletions(-) diff --git a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs index f257b3d4853ad..4cc29c70ce0bd 100644 --- a/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs +++ b/src/libraries/System.Numerics.Tensors/src/System/Numerics/Tensors/TensorPrimitives.netcore.cs @@ -353,10 +353,256 @@ public static void ConvertToSingle(ReadOnlySpan source, Span destin ThrowHelper.ThrowArgument_DestinationTooShort(); } - for (int i = 0; i < source.Length; i++) + ref short sourceRef = ref Unsafe.As(ref MemoryMarshal.GetReference(source)); + ref float destinationRef = ref MemoryMarshal.GetReference(destination); + int i = 0, oneVectorFromEnd; + +#if NET8_0_OR_GREATER + if (Vector512.IsHardwareAccelerated) + { + oneVectorFromEnd = source.Length - Vector512.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + + i += Vector512.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector512.Count; + + (Vector512 lower, Vector512 upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512.Count)); + } + + return; + } + } +#endif + + if (Vector256.IsHardwareAccelerated) + { + oneVectorFromEnd = source.Length - Vector256.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + + i += Vector256.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector256.Count; + + (Vector256 lower, Vector256 upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256.Count)); + } + + return; + } + } + + if (Vector128.IsHardwareAccelerated) { - destination[i] = (float)source[i]; + oneVectorFromEnd = source.Length - Vector128.Count; + if (i <= oneVectorFromEnd) + { + // Loop handling one input vector / two output vectors at a time. + do + { + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); + + i += Vector128.Count; + } + while (i <= oneVectorFromEnd); + + // Handle any remaining elements with a final input vector. + if (i != source.Length) + { + i = source.Length - Vector128.Count; + + (Vector128 lower, Vector128 upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i)); + HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i); + HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128.Count)); + } + + return; + } } + + while (i < source.Length) + { + Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As(ref Unsafe.Add(ref sourceRef, i)); + i++; + } + + // This implements a vectorized version of the `explicit operator float(Half value) operator`. + // See detailed description of the algorithm used here: + // https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040 + // The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx and an output VectorXx. + // The VectorXx is created by reading a vector of Halfs as a VectorXx then widened to two VectorXxs and cast to VectorXxs. + // We loop handling one input vector at a time, producing two output float vectors. + + #pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948 + const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single + const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) + const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single + const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half + const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half + #pragma warning restore IDE0059 + + static Vector128 HalfAsWidenedUInt32ToSingle_Vector128(Vector128 value) + { + // Extract sign bit of value + Vector128 sign = value & Vector128.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector128 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector128 offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector128 subnormalMask = Vector128.Equals(offsetExponent, Vector128.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector128 infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector128 maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector128 offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128.Zero), + offsetMaskedExponentLowerBound, + Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector128 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } + + static Vector256 HalfAsWidenedUInt32ToSingle_Vector256(Vector256 value) + { + // Extract sign bit of value + Vector256 sign = value & Vector256.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector256 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector256 offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector256 subnormalMask = Vector256.Equals(offsetExponent, Vector256.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector256 infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector256 maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector256 offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256.Zero), + offsetMaskedExponentLowerBound, + Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector256 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } + +#if NET8_0_OR_GREATER + static Vector512 HalfAsWidenedUInt32ToSingle_Vector512(Vector512 value) + { + // Extract sign bit of value + Vector512 sign = value & Vector512.Create(SingleSignMask); + + // Copy sign bit to upper bits + Vector512 bitValueInProcess = value; + + // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift) + Vector512 offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask); + + // ~0u when value is subnormal, 0 otherwise + Vector512 subnormalMask = Vector512.Equals(offsetExponent, Vector512.Zero); + + // ~0u when value is either Infinity or NaN, 0 otherwise + Vector512 infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask)); + + // 0x3880_0000u if value is subnormal, 0 otherwise + Vector512 maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound); + + // 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise + Vector512 offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound; + + // Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single) + bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13); + + // Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN + offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512.Zero), + offsetMaskedExponentLowerBound, + Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1)); + + // Extract exponent bits and fraction bits of value + bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask); + + // Adjust exponent to match the range of exponent + bitValueInProcess += offsetMaskedExponentLowerBound; + + // If value is subnormal, remove unnecessary 1 on top of fraction bits. + Vector512 absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32(); + + // Merge sign bit with rest + return (absoluteValue | sign).AsSingle(); + } +#endif } private static float CosineSimilarityCore(ReadOnlySpan x, ReadOnlySpan y) diff --git a/src/libraries/System.Private.CoreLib/src/System/Half.cs b/src/libraries/System.Private.CoreLib/src/System/Half.cs index 8daa37bbab576..cd3e6ab3ed73c 100644 --- a/src/libraries/System.Private.CoreLib/src/System/Half.cs +++ b/src/libraries/System.Private.CoreLib/src/System/Half.cs @@ -1044,7 +1044,7 @@ public static explicit operator float(Half value) // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13) const uint ExponentOffset = 0x3800_0000u; // Mask for sign bit in Single - const uint FloatSignMask = float.SignMask; + const uint SingleSignMask = float.SignMask; // Mask for exponent bits in Half const uint HalfExponentMask = BiasedExponentMask; // Mask for bits in Single converted from Half @@ -1052,7 +1052,7 @@ public static explicit operator float(Half value) // Extract the internal representation of value short valueInInt16Bits = BitConverter.HalfToInt16Bits(value); // Extract sign bit of value - uint sign = (uint)(int)valueInInt16Bits & FloatSignMask; + uint sign = (uint)(int)valueInInt16Bits & SingleSignMask; // Copy sign bit to upper bits uint bitValueInProcess = (uint)valueInInt16Bits; // Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)