Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 76 additions & 16 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57397,23 +57397,27 @@ static SDValue combineX86GatherScatter(SDNode *N, SelectionDAG &DAG,
return SDValue();
}

static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
static SDValue rebuildGatherScatter(SelectionDAG &DAG,
MaskedGatherScatterSDNode *GorS,
SDValue Index, SDValue Base, SDValue Scale,
SelectionDAG &DAG) {
SDValue Mask = SDValue()) {
SDLoc DL(GorS);

if (!Mask.getNode())
Mask = GorS->getMask();

if (auto *Gather = dyn_cast<MaskedGatherSDNode>(GorS)) {
SDValue Ops[] = { Gather->getChain(), Gather->getPassThru(),
Gather->getMask(), Base, Index, Scale } ;
SDValue Ops[] = {
Gather->getChain(), Gather->getPassThru(), Mask, Base, Index, Scale};
return DAG.getMaskedGather(Gather->getVTList(),
Gather->getMemoryVT(), DL, Ops,
Gather->getMemOperand(),
Gather->getIndexType(),
Gather->getExtensionType());
}
auto *Scatter = cast<MaskedScatterSDNode>(GorS);
SDValue Ops[] = { Scatter->getChain(), Scatter->getValue(),
Scatter->getMask(), Base, Index, Scale };
SDValue Ops[] = {
Scatter->getChain(), Scatter->getValue(), Mask, Base, Index, Scale};
return DAG.getMaskedScatter(Scatter->getVTList(),
Scatter->getMemoryVT(), DL,
Ops, Scatter->getMemOperand(),
Expand All @@ -57422,7 +57426,8 @@ static SDValue rebuildGatherScatter(MaskedGatherScatterSDNode *GorS,
}

static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI) {
TargetLowering::DAGCombinerInfo &DCI,
const X86Subtarget &Subtarget) {
SDLoc DL(N);
auto *GorS = cast<MaskedGatherScatterSDNode>(N);
SDValue Index = GorS->getIndex();
Expand Down Expand Up @@ -57460,7 +57465,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOperand(0), NewShAmt);
SDValue NewScale =
DAG.getConstant(ScaleAmt * 2, DL, Scale.getValueType());
return rebuildGatherScatter(GorS, NewIndex, Base, NewScale, DAG);
return rebuildGatherScatter(DAG, GorS, NewIndex, Base, NewScale);
}
}
}
Expand All @@ -57478,7 +57483,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
// a split.
if (SDValue TruncIndex =
DAG.FoldConstantArithmetic(ISD::TRUNCATE, DL, NewVT, Index))
return rebuildGatherScatter(GorS, TruncIndex, Base, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, TruncIndex, Base, Scale);

// Shrink any sign/zero extends from 32 or smaller to larger than 32 if
// there are sufficient sign bits. Only do this before legalize types to
Expand All @@ -57487,13 +57492,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
Index.getOpcode() == ISD::ZERO_EXTEND) &&
Index.getOperand(0).getScalarValueSizeInBits() <= 32) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}

// Shrink if we remove an illegal type.
if (!TLI.isTypeLegal(Index.getValueType()) && TLI.isTypeLegal(NewVT)) {
Index = DAG.getNode(ISD::TRUNCATE, DL, NewVT, Index);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}
}
}
Expand All @@ -57518,13 +57523,13 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base,
DAG.getConstant(Adder, DL, PtrVT));
SDValue NewIndex = Index.getOperand(1 - I);
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
// For non-constant cases, limit this to non-scaled cases.
if (ScaleAmt == 1) {
SDValue NewBase = DAG.getNode(ISD::ADD, DL, PtrVT, Base, Splat);
SDValue NewIndex = Index.getOperand(1 - I);
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
}
}
Expand All @@ -57539,7 +57544,7 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
SDValue NewIndex = DAG.getNode(ISD::ADD, DL, IndexVT,
Index.getOperand(1 - I), Splat);
SDValue NewBase = DAG.getConstant(0, DL, PtrVT);
return rebuildGatherScatter(GorS, NewIndex, NewBase, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, NewIndex, NewBase, Scale);
}
}
}
Expand All @@ -57550,12 +57555,67 @@ static SDValue combineGatherScatter(SDNode *N, SelectionDAG &DAG,
MVT EltVT = IndexWidth > 32 ? MVT::i64 : MVT::i32;
IndexVT = IndexVT.changeVectorElementType(*DAG.getContext(), EltVT);
Index = DAG.getSExtOrTrunc(Index, DL, IndexVT);
return rebuildGatherScatter(GorS, Index, Base, Scale, DAG);
return rebuildGatherScatter(DAG, GorS, Index, Base, Scale);
}
}

// With vector masks we only demand the upper bit of the mask.
SDValue Mask = GorS->getMask();

// When the target does not have avx512 (which has special mask registers),
// replace a mask that looks like:
//
// t9: v4i1 = bitcast t8
//
// With one that looks like:
//
// t25: i32 = zero_extend t8
// t26: v4i32 = X86ISD::VBROADCAST t25
// t32: v4i32 = and t26, t31
// t33: v4i32 = X86ISD::PCMPEQ t32, t31
//
// The t31 vector has the values 1 << 0, 1 << 1, 1 << 2, etc.
//
// The default expansion from an integer to a mask vector generates a lot more
// instructions.
if (DCI.isBeforeLegalize() && !Subtarget.hasAVX512()) {
EVT MaskVT = Mask.getValueType();

if (MaskVT.isVector() && MaskVT.getVectorElementType() == MVT::i1 &&
Mask.getOpcode() == ISD::BITCAST) {

SDValue Bits = Mask.getOperand(0);
if (Bits.getValueType().isScalarInteger()) {
unsigned NumElts = MaskVT.getVectorNumElements();
if (NumElts == 4 || NumElts == 8) {

EVT ValueVT = N->getValueType(0);
EVT IntMaskVT = ValueVT.changeVectorElementTypeToInteger();

MVT MaskVecVT = IntMaskVT.getSimpleVT();
MVT MaskEltVT = MaskVecVT.getVectorElementType();

SDValue BitsElt = DAG.getZExtOrTrunc(Bits, DL, MaskEltVT);
SDValue Bc = DAG.getNode(X86ISD::VBROADCAST, DL, MaskVecVT, BitsElt);

SmallVector<SDValue, 8> Lanes;
Lanes.reserve(NumElts);
for (unsigned i = 0; i < NumElts; ++i) {
uint64_t Bit = 1ull << i;
Lanes.push_back(DAG.getConstant(Bit, DL, MaskEltVT));
}

SDValue LaneBits = DAG.getBuildVector(MaskVecVT, DL, Lanes);
SDValue And = DAG.getNode(ISD::AND, DL, MaskVecVT, Bc, LaneBits);
SDValue NewMask =
DAG.getNode(X86ISD::PCMPEQ, DL, MaskVecVT, And, LaneBits);

return rebuildGatherScatter(DAG, GorS, Index, Base, Scale, NewMask);
}
}
}
}

if (Mask.getScalarValueSizeInBits() != 1) {
APInt DemandedMask(APInt::getSignMask(Mask.getScalarValueSizeInBits()));
if (TLI.SimplifyDemandedBits(Mask, DemandedMask, DCI)) {
Expand Down Expand Up @@ -61700,7 +61760,7 @@ SDValue X86TargetLowering::PerformDAGCombine(SDNode *N,
case X86ISD::MGATHER:
case X86ISD::MSCATTER: return combineX86GatherScatter(N, DAG, DCI);
case ISD::MGATHER:
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI);
case ISD::MSCATTER: return combineGatherScatter(N, DAG, DCI, Subtarget);
case X86ISD::PCMPEQ:
case X86ISD::PCMPGT: return combineVectorCompare(N, DAG, Subtarget);
case X86ISD::PMULDQ:
Expand Down
Loading
Loading