Skip to content

Commit 383f9af

Browse files
Ensure FMA optimizations kick in under embedded broadcast (#116891)
1 parent 5c95e6e commit 383f9af

File tree

3 files changed

+115
-79
lines changed

3 files changed

+115
-79
lines changed

src/coreclr/jit/gentree.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28308,6 +28308,22 @@ bool GenTree::OperIsVectorConditionalSelect() const
2830828308
return false;
2830928309
}
2831028310

28311+
//------------------------------------------------------------------------
28312+
// OperIsVectorFusedMultiplyOp: Is this a vector FusedMultiplyOp hwintrinsic
28313+
//
28314+
// Return Value:
28315+
// true if the node is a vector FusedMultiplyOp hwintrinsic
28316+
// otherwise; false
28317+
//
28318+
bool GenTree::OperIsVectorFusedMultiplyOp() const
28319+
{
28320+
if (OperIsHWIntrinsic())
28321+
{
28322+
return AsHWIntrinsic()->OperIsVectorFusedMultiplyOp();
28323+
}
28324+
return false;
28325+
}
28326+
2831128327
//------------------------------------------------------------------------
2831228328
// OperIsMemoryLoad: Does this HWI node have memory load semantics?
2831328329
//

src/coreclr/jit/gentree.h

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1682,6 +1682,7 @@ struct GenTree
16821682
bool OperIsConvertMaskToVector() const;
16831683
bool OperIsConvertVectorToMask() const;
16841684
bool OperIsVectorConditionalSelect() const;
1685+
bool OperIsVectorFusedMultiplyOp() const;
16851686

16861687
// This is here for cleaner GT_LONG #ifdefs.
16871688
static bool OperIsLong(genTreeOps gtOper)
@@ -6468,14 +6469,63 @@ struct GenTreeHWIntrinsic : public GenTreeJitIntrinsic
64686469

64696470
bool OperIsVectorConditionalSelect() const
64706471
{
6472+
switch (GetHWIntrinsicId())
6473+
{
64716474
#if defined(TARGET_XARCH)
6472-
return OperIsHWIntrinsic(NI_Vector128_ConditionalSelect) || OperIsHWIntrinsic(NI_Vector256_ConditionalSelect) ||
6473-
OperIsHWIntrinsic(NI_Vector512_ConditionalSelect);
6474-
#elif defined(TARGET_ARM64)
6475-
return OperIsHWIntrinsic(NI_AdvSimd_BitwiseSelect) || OperIsHWIntrinsic(NI_Sve_ConditionalSelect);
6476-
#else
6477-
return false;
6478-
#endif
6475+
case NI_Vector128_ConditionalSelect:
6476+
case NI_Vector256_ConditionalSelect:
6477+
case NI_Vector512_ConditionalSelect:
6478+
{
6479+
return true;
6480+
}
6481+
#endif // TARGET_XARCH
6482+
6483+
#if defined(TARGET_ARM64)
6484+
case NI_AdvSimd_BitwiseSelect:
6485+
case NI_Sve_ConditionalSelect:
6486+
{
6487+
return true;
6488+
}
6489+
#endif // TARGET_ARM64
6490+
6491+
default:
6492+
{
6493+
return false;
6494+
}
6495+
}
6496+
}
6497+
6498+
bool OperIsVectorFusedMultiplyOp() const
6499+
{
6500+
switch (GetHWIntrinsicId())
6501+
{
6502+
#if defined(TARGET_XARCH)
6503+
case NI_AVX2_MultiplyAdd:
6504+
case NI_AVX2_MultiplyAddNegated:
6505+
case NI_AVX2_MultiplyAddNegatedScalar:
6506+
case NI_AVX2_MultiplyAddScalar:
6507+
case NI_AVX2_MultiplySubtract:
6508+
case NI_AVX2_MultiplySubtractNegated:
6509+
case NI_AVX2_MultiplySubtractNegatedScalar:
6510+
case NI_AVX2_MultiplySubtractScalar:
6511+
case NI_AVX512_FusedMultiplyAdd:
6512+
case NI_AVX512_FusedMultiplyAddNegated:
6513+
case NI_AVX512_FusedMultiplyAddNegatedScalar:
6514+
case NI_AVX512_FusedMultiplyAddScalar:
6515+
case NI_AVX512_FusedMultiplySubtract:
6516+
case NI_AVX512_FusedMultiplySubtractNegated:
6517+
case NI_AVX512_FusedMultiplySubtractNegatedScalar:
6518+
case NI_AVX512_FusedMultiplySubtractScalar:
6519+
{
6520+
return true;
6521+
}
6522+
#endif // TARGET_XARCH
6523+
6524+
default:
6525+
{
6526+
return false;
6527+
}
6528+
}
64796529
}
64806530

64816531
bool OperRequiresAsgFlag() const;

src/coreclr/jit/lowerxarch.cpp

Lines changed: 42 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1387,7 +1387,6 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
13871387
{
13881388
assert(node->GetOperandCount() == 3);
13891389

1390-
bool isAvx512 = false;
13911390
bool negated = false;
13921391
bool subtract = false;
13931392
bool isScalar = false;
@@ -1397,111 +1396,59 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
13971396
switch (intrinsic)
13981397
{
13991398
case NI_AVX2_MultiplyAdd:
1399+
case NI_AVX512_FusedMultiplyAdd:
14001400
{
14011401
break;
14021402
}
14031403

14041404
case NI_AVX2_MultiplyAddScalar:
1405+
case NI_AVX512_FusedMultiplyAddScalar:
14051406
{
14061407
isScalar = true;
14071408
break;
14081409
}
14091410

14101411
case NI_AVX2_MultiplyAddNegated:
1412+
case NI_AVX512_FusedMultiplyAddNegated:
14111413
{
14121414
negated = true;
14131415
break;
14141416
}
14151417

14161418
case NI_AVX2_MultiplyAddNegatedScalar:
1417-
{
1418-
negated = true;
1419-
isScalar = true;
1420-
break;
1421-
}
1422-
1423-
case NI_AVX2_MultiplySubtract:
1424-
{
1425-
subtract = true;
1426-
break;
1427-
}
1428-
1429-
case NI_AVX2_MultiplySubtractScalar:
1430-
{
1431-
subtract = true;
1432-
isScalar = true;
1433-
break;
1434-
}
1435-
1436-
case NI_AVX2_MultiplySubtractNegated:
1437-
{
1438-
subtract = true;
1439-
negated = true;
1440-
break;
1441-
}
1442-
1443-
case NI_AVX2_MultiplySubtractNegatedScalar:
1444-
{
1445-
subtract = true;
1446-
negated = true;
1447-
isScalar = true;
1448-
break;
1449-
}
1450-
1451-
case NI_AVX512_FusedMultiplyAdd:
1452-
{
1453-
isAvx512 = true;
1454-
break;
1455-
}
1456-
1457-
case NI_AVX512_FusedMultiplyAddScalar:
1458-
{
1459-
isAvx512 = true;
1460-
isScalar = true;
1461-
break;
1462-
}
1463-
1464-
case NI_AVX512_FusedMultiplyAddNegated:
1465-
{
1466-
isAvx512 = true;
1467-
negated = true;
1468-
break;
1469-
}
1470-
14711419
case NI_AVX512_FusedMultiplyAddNegatedScalar:
14721420
{
1473-
isAvx512 = true;
14741421
negated = true;
14751422
isScalar = true;
14761423
break;
14771424
}
14781425

1426+
case NI_AVX2_MultiplySubtract:
14791427
case NI_AVX512_FusedMultiplySubtract:
14801428
{
1481-
isAvx512 = true;
14821429
subtract = true;
14831430
break;
14841431
}
14851432

1433+
case NI_AVX2_MultiplySubtractScalar:
14861434
case NI_AVX512_FusedMultiplySubtractScalar:
14871435
{
1488-
isAvx512 = true;
14891436
subtract = true;
14901437
isScalar = true;
14911438
break;
14921439
}
14931440

1441+
case NI_AVX2_MultiplySubtractNegated:
14941442
case NI_AVX512_FusedMultiplySubtractNegated:
14951443
{
1496-
isAvx512 = true;
14971444
subtract = true;
14981445
negated = true;
14991446
break;
15001447
}
15011448

1449+
case NI_AVX2_MultiplySubtractNegatedScalar:
15021450
case NI_AVX512_FusedMultiplySubtractNegatedScalar:
15031451
{
1504-
isAvx512 = true;
15051452
subtract = true;
15061453
negated = true;
15071454
isScalar = true;
@@ -1543,10 +1490,7 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
15431490
argOp->ClearContained();
15441491
ContainCheckHWIntrinsic(arg->AsHWIntrinsic());
15451492

1546-
// We want to toggle the tracking and then check it again,
1547-
// which is the simplest way to handle cases like -CreateScalarUnsafe(-x)
15481493
negatedArgs[i - 1] ^= true;
1549-
i -= 1;
15501494
}
15511495

15521496
break;
@@ -1562,11 +1506,18 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
15621506
break;
15631507
}
15641508

1565-
GenTree* argOp = hwArg->Op(1);
1509+
GenTree* argOp = hwArg->Op(2);
1510+
1511+
if (!argOp->isContained())
1512+
{
1513+
// A constant should have already been contained
1514+
break;
1515+
}
1516+
1517+
// xor is bitwise and the actual xor node might be a different base type
1518+
// from the FMA node, so we check if its negative zero using the FMA base
1519+
// type since that's what the end negation would end up using
15661520

1567-
// xor is bitwise and the actual xor node might not be floating-point
1568-
// so we check if its negative zero using the FMA base type since that's
1569-
// what the end negation will end up using
15701521
if (argOp->IsVectorNegativeZero(node->GetSimdBaseType()))
15711522
{
15721523
BlockRange().Remove(hwArg);
@@ -1576,11 +1527,9 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
15761527
argOp->ClearContained();
15771528
node->Op(i) = argOp;
15781529

1579-
// We want to toggle the tracking and then check it again,
1580-
// which is the simplest way to handle cases like -CreateScalarUnsafe(-x)
15811530
negatedArgs[i - 1] ^= true;
1582-
i -= 1;
15831531
}
1532+
15841533
break;
15851534
}
15861535
}
@@ -1590,7 +1539,7 @@ void Lowering::LowerFusedMultiplyOp(GenTreeHWIntrinsic* node)
15901539
negated ^= negatedArgs[1];
15911540
subtract ^= negatedArgs[2];
15921541

1593-
if (isAvx512)
1542+
if (intrinsic >= FIRST_NI_AVX512)
15941543
{
15951544
if (negated)
15961545
{
@@ -9928,7 +9877,28 @@ void Lowering::ContainCheckHWIntrinsic(GenTreeHWIntrinsic* node)
99289877

99299878
if (containedOperand != nullptr)
99309879
{
9931-
if (containedOperand->IsCnsVec() && node->isEmbeddedBroadcastCompatibleHWIntrinsic(comp))
9880+
bool isEmbeddedBroadcastCompatible =
9881+
containedOperand->IsCnsVec() && node->isEmbeddedBroadcastCompatibleHWIntrinsic(comp);
9882+
9883+
bool isScalarArg = false;
9884+
genTreeOps oper = node->GetOperForHWIntrinsicId(&isScalarArg);
9885+
9886+
// We want to skip trying to make this an embedded broadcast in certain scenarios
9887+
// because it will prevent other transforms that will be better for codegen.
9888+
9889+
LIR::Use use;
9890+
9891+
if ((oper == GT_XOR) && BlockRange().TryGetUse(node, &use) &&
9892+
use.User()->OperIsVectorFusedMultiplyOp())
9893+
{
9894+
// xor is bitwise and the actual xor node might be a different base type
9895+
// from the FMA node, so we check if its negative zero using the FMA base
9896+
// type since that's what the end negation would end up using
9897+
var_types fmaSimdBaseType = use.User()->AsHWIntrinsic()->GetSimdBaseType();
9898+
isEmbeddedBroadcastCompatible = !containedOperand->IsVectorNegativeZero(fmaSimdBaseType);
9899+
}
9900+
9901+
if (isEmbeddedBroadcastCompatible)
99329902
{
99339903
TryFoldCnsVecForEmbeddedBroadcast(node, containedOperand->AsVecCon());
99349904
}

0 commit comments

Comments
 (0)