Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 113 additions & 4 deletions llvm/lib/Target/AArch64/AArch64ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24368,6 +24368,89 @@ static SDValue performZExtUZPCombine(SDNode *N, SelectionDAG &DAG) {
return DAG.getNode(ISD::AND, DL, VT, BC, DAG.getConstant(Mask, DL, VT));
}

// Convert (vXiY *ext(vXi1 bitcast(iX))) to extend_in_reg(broadcast(iX)).
static SDValue combineToExtendBoolVectorInReg(
unsigned Opcode, const SDLoc &DL, EVT VT, SDValue N0, SelectionDAG &DAG,
TargetLowering::DAGCombinerInfo &DCI, const AArch64Subtarget &Subtarget) {
if (Opcode != ISD::SIGN_EXTEND && Opcode != ISD::ZERO_EXTEND &&
Opcode != ISD::ANY_EXTEND)
return SDValue();
if (!DCI.isBeforeLegalizeOps())
return SDValue();
if (!Subtarget.hasNEON())
return SDValue();

EVT SVT = VT.getScalarType();
EVT InSVT = N0.getValueType().getScalarType();
unsigned EltSizeInBits = SVT.getSizeInBits();

// Input type must be extending a bool vector (bit-casted from a scalar
// integer) to legal integer types.
if (!VT.isVector())
return SDValue();
if (SVT != MVT::i64 && SVT != MVT::i32 && SVT != MVT::i16 && SVT != MVT::i8)
return SDValue();
if (InSVT != MVT::i1 || N0.getOpcode() != ISD::BITCAST)
return SDValue();

SDValue N00 = N0.getOperand(0);
EVT SclVT = N00.getValueType();
if (!SclVT.isScalarInteger())
return SDValue();

SDValue Vec;
SmallVector<int> ShuffleMask;
unsigned NumElts = VT.getVectorNumElements();
assert(NumElts == SclVT.getSizeInBits() && "Unexpected bool vector size");

// Broadcast the scalar integer to the vector elements.
bool IsBE = DAG.getDataLayout().isBigEndian();
if (NumElts > EltSizeInBits) {
// If the scalar integer is greater than the vector element size, then we
// must split it down into sub-sections for broadcasting. For example:
// i16 -> v16i8 (i16 -> v8i16 -> v16i8) with 2 sub-sections.
// i32 -> v32i8 (i32 -> v8i32 -> v32i8) with 4 sub-sections.
assert((NumElts % EltSizeInBits) == 0 && "Unexpected integer scale");
unsigned Scale = NumElts / EltSizeInBits;
EVT BroadcastVT = EVT::getVectorVT(*DAG.getContext(), SclVT, EltSizeInBits);
Vec = DAG.getNode(ISD::SCALAR_TO_VECTOR, DL, BroadcastVT, N00);
Vec = DAG.getBitcast(VT, Vec);

for (unsigned I = 0; I != Scale; ++I)
ShuffleMask.append(EltSizeInBits, (int)I);

Vec = DAG.getVectorShuffle(VT, DL, Vec, Vec, ShuffleMask);
} else {
// For smaller scalar integers, we can simply any-extend it to the vector
// element size (we don't care about the upper bits) and broadcast it to all
// elements.
Vec = DAG.getSplat(VT, DL, DAG.getAnyExtOrTrunc(N00, DL, SVT));
}

// Now, mask the relevant bit in each element.
SmallVector<SDValue, 32> Bits;
for (unsigned I = 0; I != NumElts; ++I) {
unsigned ScalarBit = IsBE ? (NumElts - 1 - I) : I;
int BitIdx = ScalarBit % EltSizeInBits;
APInt Bit = APInt::getBitsSet(EltSizeInBits, BitIdx, BitIdx + 1);
Bits.push_back(DAG.getConstant(Bit, DL, SVT));
}
SDValue BitMask = DAG.getBuildVector(VT, DL, Bits);
Vec = DAG.getNode(ISD::AND, DL, VT, Vec, BitMask);

// Compare against the bitmask and extend the result.
EVT CCVT = EVT::getVectorVT(*DAG.getContext(), MVT::i1, NumElts);
Vec = DAG.getSetCC(DL, CCVT, Vec, BitMask, ISD::SETEQ);
Vec = DAG.getSExtOrTrunc(Vec, DL, VT);

// For SEXT, this is now done, otherwise shift the result down for
// zero-extension.
if (Opcode == ISD::SIGN_EXTEND)
return Vec;
return DAG.getNode(ISD::SRL, DL, VT, Vec,
DAG.getConstant(EltSizeInBits - 1, DL, VT));
}

// Combine:
// ext(duplane(insert_subvector(undef, trunc(X), 0), idx))
// Into:
Expand Down Expand Up @@ -24432,7 +24515,8 @@ static SDValue performExtendDuplaneTruncCombine(SDNode *N, SelectionDAG &DAG) {

static SDValue performExtendCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
SelectionDAG &DAG) {
SelectionDAG &DAG,
const AArch64Subtarget *Subtarget) {
// If we see something like (zext (sabd (extract_high ...), (DUP ...))) then
// we can convert that DUP into another extract_high (of a bigger DUP), which
// helps the backend to decide that an sabdl2 would be useful, saving a real
Expand All @@ -24455,6 +24539,13 @@ static SDValue performExtendCombine(SDNode *N,
if (SDValue R = performZExtUZPCombine(N, DAG))
return R;

SDLoc dl(N);
SDValue N0 = N->getOperand(0);
EVT VT = N->getValueType(0);
if (SDValue V = combineToExtendBoolVectorInReg(N->getOpcode(), dl, VT, N0,
DAG, DCI, *Subtarget))
return V;

if (N->getValueType(0).isFixedLengthVector() &&
N->getOpcode() == ISD::SIGN_EXTEND &&
N->getOperand(0)->getOpcode() == ISD::SETCC)
Expand Down Expand Up @@ -27712,7 +27803,11 @@ static SDValue trySwapVSelectOperands(SDNode *N, SelectionDAG &DAG) {
// FIXME: Currently the type legalizer can't handle VSELECT having v1i1 as
// condition. If it can legalize "VSELECT v1i1" correctly, no need to combine
// such VSELECT.
static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
static SDValue performVSelectCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI,
const AArch64Subtarget *Subtarget) {
SelectionDAG &DAG = DCI.DAG;

if (auto SwapResult = trySwapVSelectOperands(N, DAG))
return SwapResult;

Expand Down Expand Up @@ -27776,6 +27871,20 @@ static SDValue performVSelectCombine(SDNode *N, SelectionDAG &DAG) {
}
}

// Attempt to convert a (vXi1 bitcast(iX N0)) selection mask before it might
// get split by legalization.
if (N0.getOpcode() == ISD::BITCAST && CCVT.isVector() &&
CCVT.getVectorElementType() == MVT::i1) {
SDLoc DL(N);
EVT ExtCondVT = ResVT.changeVectorElementTypeToInteger();

if (SDValue ExtCond = combineToExtendBoolVectorInReg(
ISD::SIGN_EXTEND, DL, ExtCondVT, N0, DAG, DCI, *Subtarget)) {
ExtCond = DAG.getNode(ISD::TRUNCATE, DL, CCVT, ExtCond);
return DAG.getSelect(DL, ResVT, ExtCond, IfTrue, IfFalse);
}
}

EVT CmpVT = N0.getOperand(0).getValueType();
if (N0.getOpcode() != ISD::SETCC ||
CCVT.getVectorElementCount() != ElementCount::getFixed(1) ||
Expand Down Expand Up @@ -29188,7 +29297,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::ANY_EXTEND:
case ISD::ZERO_EXTEND:
case ISD::SIGN_EXTEND:
return performExtendCombine(N, DCI, DAG);
return performExtendCombine(N, DCI, DAG, Subtarget);
case ISD::SIGN_EXTEND_INREG:
return performSignExtendInRegCombine(N, DCI, DAG);
case ISD::CONCAT_VECTORS:
Expand All @@ -29200,7 +29309,7 @@ SDValue AArch64TargetLowering::PerformDAGCombine(SDNode *N,
case ISD::SELECT:
return performSelectCombine(N, DCI);
case ISD::VSELECT:
return performVSelectCombine(N, DCI.DAG);
return performVSelectCombine(N, DCI, Subtarget);
case ISD::SETCC:
return performSETCCCombine(N, DCI, DAG);
case ISD::LOAD:
Expand Down
Loading
Loading