Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose the FusedMultiplyAdd and MultiplyAddEstimate APIs on relevant vector and scalar types #102181

Merged
merged 10 commits into from
May 16, 2024
7 changes: 7 additions & 0 deletions src/coreclr/jit/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -3225,6 +3225,13 @@ class Compiler
GenTree* gtNewSimdFloorNode(
var_types type, GenTree* op1, CorInfoType simdBaseJitType, unsigned simdSize);

GenTree* gtNewSimdFmaNode(var_types type,
GenTree* op1,
GenTree* op2,
GenTree* op3,
CorInfoType simdBaseJitType,
unsigned simdSize);

GenTree* gtNewSimdGetElementNode(var_types type,
GenTree* op1,
GenTree* op2,
Expand Down
55 changes: 55 additions & 0 deletions src/coreclr/jit/gentree.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23391,6 +23391,61 @@ GenTree* Compiler::gtNewSimdFloorNode(var_types type, GenTree* op1, CorInfoType
return gtNewSimdHWIntrinsicNode(type, op1, intrinsic, simdBaseJitType, simdSize);
}

GenTree* Compiler::gtNewSimdFmaNode(
var_types type, GenTree* op1, GenTree* op2, GenTree* op3, CorInfoType simdBaseJitType, unsigned simdSize)
{
assert(varTypeIsSIMD(type));
assert(getSIMDTypeForSize(simdSize) == type);

assert(op1 != nullptr);
assert(op1->TypeIs(type));

assert(op2 != nullptr);
assert(op2->TypeIs(type));

assert(op3 != nullptr);
assert(op3->TypeIs(type));

var_types simdBaseType = JitType2PreciseVarType(simdBaseJitType);
assert(varTypeIsFloating(simdBaseType));

NamedIntrinsic intrinsic = NI_Illegal;

#if defined(TARGET_XARCH)
if (simdSize == 64)
{
assert(compIsaSupportedDebugOnly(InstructionSet_AVX512F));
intrinsic = NI_AVX512F_FusedMultiplyAdd;
}
else
{
assert(compIsaSupportedDebugOnly(InstructionSet_FMA));
intrinsic = NI_FMA_MultiplyAdd;
}
#elif defined(TARGET_ARM64)
assert(IsBaselineSimdIsaSupportedDebugOnly());

if (simdBaseType == TYP_DOUBLE)
{
intrinsic = (simdSize == 8) ? NI_AdvSimd_FusedMultiplyAddScalar : NI_AdvSimd_Arm64_FusedMultiplyAdd;
}
else
{
intrinsic = NI_AdvSimd_FusedMultiplyAdd;
}

// AdvSimd.FusedMultiplyAdd expects (addend, left, right), while the APIs take (left, right, addend)
// We expect op1 and op2 to have already been spilled

std::swap(op1, op3);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am just curious - who's responsible to spill them?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, nvm, I see

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Typically it's the importer code, prior to calling this method. Latter phases that might call such an API (like morph) are responsible for ensuring the swap is safe in the face of potential side effects.

There's, unfortunately, not really a way for us to do such a swap safely in the API itself (at least that I know of) nor to know if the caller has actually done it. So the typical approach has been to do the swap and comment that callers should be doing the validation.

#else
#error Unsupported platform
#endif // !TARGET_XARCH && !TARGET_ARM64

assert(intrinsic != NI_Illegal);
return gtNewSimdHWIntrinsicNode(type, op1, op2, op3, intrinsic, simdBaseJitType, simdSize);
}

GenTree* Compiler::gtNewSimdGetElementNode(
var_types type, GenTree* op1, GenTree* op2, CorInfoType simdBaseJitType, unsigned simdSize)
{
Expand Down
45 changes: 45 additions & 0 deletions src/coreclr/jit/hwintrinsicarm64.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1347,6 +1347,26 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector64_FusedMultiplyAdd:
case NI_Vector128_FusedMultiplyAdd:
{
assert(sig->numArgs == 3);
assert(varTypeIsFloating(simdBaseType));

impSpillSideEffect(true, verCurrentState.esStackDepth -
3 DEBUGARG("Spilling op1 side effects for FusedMultiplyAdd"));

impSpillSideEffect(true, verCurrentState.esStackDepth -
2 DEBUGARG("Spilling op2 side effects for FusedMultiplyAdd"));

op3 = impSIMDPopStack();
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_get_AllBitsSet:
case NI_Vector128_get_AllBitsSet:
{
Expand Down Expand Up @@ -1691,6 +1711,31 @@ GenTree* Compiler::impSpecialIntrinsic(NamedIntrinsic intrinsic,
break;
}

case NI_Vector64_MultiplyAddEstimate:
case NI_Vector128_MultiplyAddEstimate:
{
assert(sig->numArgs == 3);
assert(varTypeIsFloating(simdBaseType));

if (BlockNonDeterministicIntrinsics(mustExpand))
{
break;
}

impSpillSideEffect(true, verCurrentState.esStackDepth -
3 DEBUGARG("Spilling op1 side effects for MultiplyAddEstimate"));

impSpillSideEffect(true, verCurrentState.esStackDepth -
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't better/simpler to use impSpillSideEffects(true, CHECK_SPILL_ALL DEBUGARG("spilling side-effects")); ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not something we've been doing in other scenarios.

AFAIR it comes down to not spilling values on the stack that aren't impacted. For example, we take 3 args, but the stack could have 4+ on it (spilling these is unnecessary since they are still processed in order with respect to our own op1/op2/op3; otherwise everyone who pops the stack would need to consider the need to spill such additional entries) and we don't need to ever spill the stack top (op3).

So we're doing this to ensure only the minimum number of items that need to be spilled are spilled.

2 DEBUGARG("Spilling op2 side effects for MultiplyAddEstimate"));

op3 = impSIMDPopStack();
op2 = impSIMDPopStack();
op1 = impSIMDPopStack();

retNode = gtNewSimdFmaNode(retType, op1, op2, op3, simdBaseJitType, simdSize);
break;
}

case NI_Vector64_Narrow:
case NI_Vector128_Narrow:
{
Expand Down
4 changes: 4 additions & 0 deletions src/coreclr/jit/hwintrinsiclistarm64.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ HARDWARE_INTRINSIC(Vector64, EqualsAll,
HARDWARE_INTRINSIC(Vector64, EqualsAny, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector64, ExtractMostSignificantBits, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector64, Floor, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, FusedMultiplyAdd, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_AllBitsSet, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_Indices, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, get_One, 8, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand All @@ -80,6 +81,7 @@ HARDWARE_INTRINSIC(Vector64, LoadUnsafe,
HARDWARE_INTRINSIC(Vector64, Max, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Min, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Multiply, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, MultiplyAddEstimate, 8, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Narrow, 8, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, Negate, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector64, OnesComplement, 8, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand Down Expand Up @@ -169,6 +171,7 @@ HARDWARE_INTRINSIC(Vector128, EqualsAll,
HARDWARE_INTRINSIC(Vector128, EqualsAny, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector128, ExtractMostSignificantBits, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId|HW_Flag_BaseTypeFromFirstArg)
HARDWARE_INTRINSIC(Vector128, Floor, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, FusedMultiplyAdd, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_AllBitsSet, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_Indices, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, get_One, 16, 0, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand All @@ -195,6 +198,7 @@ HARDWARE_INTRINSIC(Vector128, LoadUnsafe,
HARDWARE_INTRINSIC(Vector128, Max, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Min, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Multiply, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, MultiplyAddEstimate, 16, 3, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Narrow, 16, 2, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, Negate, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
HARDWARE_INTRINSIC(Vector128, OnesComplement, 16, 1, false, {INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid, INS_invalid}, HW_Category_Helper, HW_Flag_InvalidNodeId)
Expand Down
Loading
Loading