Skip to content

Commit

Permalink
Vectorize TensorPrimitives.ConvertToSingle (dotnet#92779)
Browse files Browse the repository at this point in the history
* Vectorize TensorPrimitives.ConvertToSingle

* Address PR feedback
  • Loading branch information
stephentoub authored and michaelgsharp committed Oct 20, 2023
1 parent 02416c2 commit fdff01f
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,256 @@ public static void ConvertToSingle(ReadOnlySpan<Half> source, Span<float> destin
ThrowHelper.ThrowArgument_DestinationTooShort();
}

for (int i = 0; i < source.Length; i++)
ref short sourceRef = ref Unsafe.As<Half, short>(ref MemoryMarshal.GetReference(source));
ref float destinationRef = ref MemoryMarshal.GetReference(destination);
int i = 0, oneVectorFromEnd;

#if NET8_0_OR_GREATER
if (Vector512.IsHardwareAccelerated)
{
oneVectorFromEnd = source.Length - Vector512<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));

i += Vector512<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector512<short>.Count;

(Vector512<int> lower, Vector512<int> upper) = Vector512.Widen(Vector512.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector512(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector512(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector512<float>.Count));
}

return;
}
}
#endif

if (Vector256.IsHardwareAccelerated)
{
oneVectorFromEnd = source.Length - Vector256<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));

i += Vector256<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector256<short>.Count;

(Vector256<int> lower, Vector256<int> upper) = Vector256.Widen(Vector256.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector256(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector256(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector256<float>.Count));
}

return;
}
}

if (Vector128.IsHardwareAccelerated)
{
destination[i] = (float)source[i];
oneVectorFromEnd = source.Length - Vector128<short>.Count;
if (i <= oneVectorFromEnd)
{
// Loop handling one input vector / two output vectors at a time.
do
{
(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));

i += Vector128<short>.Count;
}
while (i <= oneVectorFromEnd);

// Handle any remaining elements with a final input vector.
if (i != source.Length)
{
i = source.Length - Vector128<short>.Count;

(Vector128<int> lower, Vector128<int> upper) = Vector128.Widen(Vector128.LoadUnsafe(ref sourceRef, (uint)i));
HalfAsWidenedUInt32ToSingle_Vector128(lower.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)i);
HalfAsWidenedUInt32ToSingle_Vector128(upper.AsUInt32()).StoreUnsafe(ref destinationRef, (uint)(i + Vector128<float>.Count));
}

return;
}
}

while (i < source.Length)
{
Unsafe.Add(ref destinationRef, i) = (float)Unsafe.As<short, Half>(ref Unsafe.Add(ref sourceRef, i));
i++;
}

// This implements a vectorized version of the `explicit operator float(Half value) operator`.
// See detailed description of the algorithm used here:
// https://github.com/dotnet/runtime/blob/3bf40a378f00cb5bf18ff62796bc7097719b974c/src/libraries/System.Private.CoreLib/src/System/Half.cs#L1010-L1040
// The cast operator converts a Half represented as uint to a float. This does the same, with an input VectorXx<uint> and an output VectorXx<float>.
// The VectorXx<uint> is created by reading a vector of Halfs as a VectorXx<short> then widened to two VectorXx<int>s and cast to VectorXx<uint>s.
// We loop handling one input vector at a time, producing two output float vectors.

#pragma warning disable IDE0059 // https://github.com/dotnet/roslyn/issues/44948
const uint ExponentLowerBound = 0x3880_0000u; // The smallest positive normal number in Half, converted to Single
const uint ExponentOffset = 0x3800_0000u; // BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint SingleSignMask = 0x8000_0000; // float.SignMask; // Mask for sign bit in Single
const uint HalfExponentMask = 0x7C00; // Mask for exponent bits in Half
const uint HalfToSingleBitsMask = 0x0FFF_E000; // Mask for bits in Single converted from Half
#pragma warning restore IDE0059

static Vector128<float> HalfAsWidenedUInt32ToSingle_Vector128(Vector128<uint> value)
{
// Extract sign bit of value
Vector128<uint> sign = value & Vector128.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector128<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector128<uint> offsetExponent = bitValueInProcess & Vector128.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector128<uint> subnormalMask = Vector128.Equals(offsetExponent, Vector128<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector128<uint> infinityOrNaNMask = Vector128.Equals(offsetExponent, Vector128.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector128<uint> maskedExponentLowerBound = subnormalMask & Vector128.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector128<uint> offsetMaskedExponentLowerBound = Vector128.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector128.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector128.ConditionalSelect(Vector128.Equals(infinityOrNaNMask, Vector128<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector128.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector128.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector128<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}

static Vector256<float> HalfAsWidenedUInt32ToSingle_Vector256(Vector256<uint> value)
{
// Extract sign bit of value
Vector256<uint> sign = value & Vector256.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector256<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector256<uint> offsetExponent = bitValueInProcess & Vector256.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector256<uint> subnormalMask = Vector256.Equals(offsetExponent, Vector256<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector256<uint> infinityOrNaNMask = Vector256.Equals(offsetExponent, Vector256.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector256<uint> maskedExponentLowerBound = subnormalMask & Vector256.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector256<uint> offsetMaskedExponentLowerBound = Vector256.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector256.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector256.ConditionalSelect(Vector256.Equals(infinityOrNaNMask, Vector256<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector256.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector256.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector256<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}

#if NET8_0_OR_GREATER
static Vector512<float> HalfAsWidenedUInt32ToSingle_Vector512(Vector512<uint> value)
{
// Extract sign bit of value
Vector512<uint> sign = value & Vector512.Create(SingleSignMask);

// Copy sign bit to upper bits
Vector512<uint> bitValueInProcess = value;

// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Vector512<uint> offsetExponent = bitValueInProcess & Vector512.Create(HalfExponentMask);

// ~0u when value is subnormal, 0 otherwise
Vector512<uint> subnormalMask = Vector512.Equals(offsetExponent, Vector512<uint>.Zero);

// ~0u when value is either Infinity or NaN, 0 otherwise
Vector512<uint> infinityOrNaNMask = Vector512.Equals(offsetExponent, Vector512.Create(HalfExponentMask));

// 0x3880_0000u if value is subnormal, 0 otherwise
Vector512<uint> maskedExponentLowerBound = subnormalMask & Vector512.Create(ExponentLowerBound);

// 0x3880_0000u if value is subnormal, 0x3800_0000u otherwise
Vector512<uint> offsetMaskedExponentLowerBound = Vector512.Create(ExponentOffset) | maskedExponentLowerBound;

// Match the position of the boundary of exponent bits and fraction bits with IEEE 754 Binary32(Single)
bitValueInProcess = Vector512.ShiftLeft(bitValueInProcess, 13);

// Double the offsetMaskedExponentLowerBound if value is either Infinity or NaN
offsetMaskedExponentLowerBound = Vector512.ConditionalSelect(Vector512.Equals(infinityOrNaNMask, Vector512<uint>.Zero),
offsetMaskedExponentLowerBound,
Vector512.ShiftLeft(offsetMaskedExponentLowerBound, 1));

// Extract exponent bits and fraction bits of value
bitValueInProcess &= Vector512.Create(HalfToSingleBitsMask);

// Adjust exponent to match the range of exponent
bitValueInProcess += offsetMaskedExponentLowerBound;

// If value is subnormal, remove unnecessary 1 on top of fraction bits.
Vector512<uint> absoluteValue = (bitValueInProcess.AsSingle() - maskedExponentLowerBound.AsSingle()).AsUInt32();

// Merge sign bit with rest
return (absoluteValue | sign).AsSingle();
}
#endif
}

private static float CosineSimilarityCore(ReadOnlySpan<float> x, ReadOnlySpan<float> y)
Expand Down
4 changes: 2 additions & 2 deletions src/libraries/System.Private.CoreLib/src/System/Half.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1044,15 +1044,15 @@ public static explicit operator float(Half value)
// BitConverter.SingleToUInt32Bits(1.0f) - ((uint)BitConverter.HalfToUInt16Bits((Half)1.0f) << 13)
const uint ExponentOffset = 0x3800_0000u;
// Mask for sign bit in Single
const uint FloatSignMask = float.SignMask;
const uint SingleSignMask = float.SignMask;
// Mask for exponent bits in Half
const uint HalfExponentMask = BiasedExponentMask;
// Mask for bits in Single converted from Half
const int HalfToSingleBitsMask = 0x0FFF_E000;
// Extract the internal representation of value
short valueInInt16Bits = BitConverter.HalfToInt16Bits(value);
// Extract sign bit of value
uint sign = (uint)(int)valueInInt16Bits & FloatSignMask;
uint sign = (uint)(int)valueInInt16Bits & SingleSignMask;
// Copy sign bit to upper bits
uint bitValueInProcess = (uint)valueInInt16Bits;
// Extract exponent bits of value (BiasedExponent is not for here as it performs unnecessary shift)
Expand Down

0 comments on commit fdff01f

Please sign in to comment.