Skip to content

Commit

Permalink
Expose the FusedMultiplyAdd and MultiplyAddEstimate APIs on relevant …
Browse files Browse the repository at this point in the history
…vector and scalar types (dotnet#102181)

* Expose FusedMultiplyAdd and MultiplyAddEstimate on the scalar and vector types

* Adding tests covering FusedMultiplyAdd and MultiplyAddEstimate for the vector types

* Ensure TensorPrimitives uses the xplat APIs on .NET 9+

* Apply formatting patch

* Fix an accidental change to GenericVectorTests

* Ensure Arm64 passes fma operands in the correct order

* Apply formatting patch

* Ensure all the Arm64 code paths are spilling and swapping operands

* Apply formatting patch

* Don't pop the stack value twice
  • Loading branch information
tannergooding authored and Ruihan-Yin committed May 30, 2024
1 parent fa3a3cf commit 19ab1d6
Show file tree
Hide file tree
Showing 58 changed files with 1,424 additions and 32 deletions.
7 changes: 7 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3226,6 +3226,13 @@ class Compiler
GenTree* gtNewSimdFloorNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdFmaNode(var_types type,
GenTree* op1,
GenTree* op2,
GenTree* op3,
CorInfoType simdBaseJitType,
unsigned simdSize);

GenTree* gtNewSimdGetElementNode(var_types type,
GenTree* op1,
GenTree* op2,
Expand Down
55 changes: 55 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23390,6 +23390,61 @@ GenTree* Compiler::gtNewSimdFloorNode(var_types type, GenTree* op1, CorInfoType
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize);
}

GenTree* Compiler::gtNewSimdFmaNode(
var_types type, GenTree* op1, GenTree* op2, GenTree* op3, CorInfoType simdBaseJitType, unsigned simdSize)
{
assert(varTypeIsSIMD(type));
assert(getSIMDTypeForSize(simdSize) == type);

assert(op1 != nullptr);
assert(op1->TypeIs(type));

assert(op2 != nullptr);
assert(op2->TypeIs(type));

assert(op3 != nullptr);
assert(op3->TypeIs(type));

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
assert(varTypeIsFloating(simdBaseType));

NamedIntrinsic intrinsic = NI_Illegal;

#if defined(TARGET_XARCH)
if (simdSize == 64)
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX512F));
intrinsic = NI_AVX512F_FusedMultiplyAdd;
}
else
{
assert(compIsaSupportedDebugOnly(InstructionSet_FMA));
intrinsic = NI_FMA_MultiplyAdd;
}
#elif defined(TARGET_ARM64)
assert(IsBaselineSimdIsaSupportedDebugOnly());

if (simdBaseType == TYP_DOUBLE)
{
intrinsic = (simdSize == 8) ? NI_AdvSimd_FusedMultiplyAddScalar : NI_AdvSimd_Arm64_FusedMultiplyAdd;
}
else
{
intrinsic = NI_AdvSimd_FusedMultiplyAdd;
}

// AdvSimd.FusedMultiplyAdd expects (addend, left, right), while the APIs take (left, right, addend)
// We expect op1 and op2 to have already been spilled

std::swap(op1, op3);
#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64

assert(intrinsic != NI_Illegal);
return gtNewSimdHWIntrinsicNode(type, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
}

GenTree* Compiler::gtNewSimdGetElementNode(
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
{
Expand Down
45 changes: 45 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1358,6 +1358,26 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector64_FusedMultiplyAdd:
case NI_Vector128_FusedMultiplyAdd:
{
assert(sig->numArgs == 3);
assert(varTypeIsFloating(simdBaseType));

impSpillSideEffect(true, verCurrentState.esStackDepth -
3 DEBUGARG("Spilling op1 side effects for FusedMultiplyAdd"));

impSpillSideEffect(true, verCurrentState.esStackDepth -
2 DEBUGARG("Spilling op2 side effects for FusedMultiplyAdd"));

op3 = impSIMDPopStack();
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_get_AllBitsSet:
case NI_Vector128_get_AllBitsSet:
{
Expand Down Expand Up @@ -1702,6 +1722,31 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector64_MultiplyAddEstimate:
case NI_Vector128_MultiplyAddEstimate:
{
assert(sig->numArgs == 3);
assert(varTypeIsFloating(simdBaseType));

if (BlockNonDeterministicIntrinsics(mustExpand))
{
break;
}

impSpillSideEffect(true, verCurrentState.esStackDepth -
3 DEBUGARG("Spilling op1 side effects for MultiplyAddEstimate"));

impSpillSideEffect(true, verCurrentState.esStackDepth -
2 DEBUGARG("Spilling op2 side effects for MultiplyAddEstimate"));

op3 = impSIMDPopStack();
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_Narrow:
case NI_Vector128_Narrow:
{
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ HARDWARE_INTRINSIC(Vector64, EqualsAll,
HARDWARE_INTRINSIC(Vector64, EqualsAny, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector64, ExtractMostSignificantBits, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector64, Floor, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, FusedMultiplyAdd, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_AllBitsSet, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_Indices, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_One, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand All @@ -80,6 +81,7 @@ HARDWARE_INTRINSIC(Vector64, LoadUnsafe,
HARDWARE_INTRINSIC(Vector64, Max, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Min, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Multiply, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, MultiplyAddEstimate, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Narrow, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Negate, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, OnesComplement, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand Down Expand Up @@ -169,6 +171,7 @@ HARDWARE_INTRINSIC(Vector128, EqualsAll,
HARDWARE_INTRINSIC(Vector128, EqualsAny, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector128, ExtractMostSignificantBits, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector128, Floor, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, FusedMultiplyAdd, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_AllBitsSet, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_Indices, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_One, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand All @@ -195,6 +198,7 @@ HARDWARE_INTRINSIC(Vector128, LoadUnsafe,
HARDWARE_INTRINSIC(Vector128, Max, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Min, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Multiply, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, MultiplyAddEstimate, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Narrow, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Negate, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, OnesComplement, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand Down
Loading

0 comments on commit 19ab1d6

Please sign in to comment.