Skip to content

Commit

Permalink
[AArch64] Prevent change shl 2 to And for Load target
Browse files Browse the repository at this point in the history
Currently, process of replacing bitwise operations consisting of
`(shl (srl x, c1), c2)` with `And` is performed by `DAGCombiner`.

However, in certain case like `(shl (srl, x, c1) 2)` is do not
need to transform to `AND` if it was used to `Load` Target.

Consider following case:
```
        lsr x8, x8, llvm#56
        and x8, x8, #0xfc
        ldr w0, [x2, x8]
        ret
```

In this case, we can remove the `AND` by changing the target of `LDR`
to `[X2, X8, LSL llvm#2]` and right-shifting amount change to 56 to 58.

after changed:
```
        lsr x8, x8, llvm#58
        ldr w0, [x2, x8, lsl llvm#2]
        ret
```

This patch checks to see if the `(shl (srl x, c1) 2)` operation on
`load` target can be prevent transform to `And`.
  • Loading branch information
ParkHanbum committed Apr 24, 2024
1 parent b5b34db commit 1343794
Showing 1 changed file with 102 additions and 6 deletions.
108 changes: 102 additions & 6 deletions llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -563,11 +563,11 @@ namespace {
SDValue visitFMULForFMADistributiveCombine(SDNode *N);

SDValue XformToShuffleWithZero(SDNode *N);
bool isAddressingModePattern(unsigned Opc, const SDLoc &DL, SDNode *N,
SDValue N0, SDValue N1);
bool reassociationCanBreakAddressingModePattern(unsigned Opc,
const SDLoc &DL,
SDNode *N,
SDValue N0,
SDValue N1);
const SDLoc &DL, SDNode *N,
SDValue N0, SDValue N1);
SDValue reassociateOpsCommutative(unsigned Opc, const SDLoc &DL, SDValue N0,
SDValue N1, SDNodeFlags Flags);
SDValue reassociateOps(unsigned Opc, const SDLoc &DL, SDValue N0,
Expand Down Expand Up @@ -1068,6 +1068,93 @@ static bool canSplitIdx(LoadSDNode *LD) {
!cast<ConstantSDNode>(LD->getOperand(2))->isOpaque());
}

static bool isAddressingModePatternSHL(unsigned Opc, const SDLoc &DL, SDNode *N,
SDValue Op0, SDValue Op1,
const TargetLowering &TLI,
const SelectionDAG &DAG) {
// handle (shl (srl x, c1) 2)
if (!N->hasOneUse())
return false;

APInt SrlAmt;
if (sd_match(N, m_Shl(m_Srl(m_Value(), m_ConstInt(SrlAmt)), m_SpecificInt(2)))) {
// Srl knownbits
SDValue ShlV = SDValue(N, 0);
unsigned RegSize = ShlV.getValueType().getScalarSizeInBits();
KnownBits Known = DAG.computeKnownBits(ShlV);

LLVM_DEBUG(dbgs() << "RegSize" << RegSize
<< "knownbit : " << Known.getBitWidth()
<< "AMT : " << SrlAmt << " knownbits : " << Known
<< "MAX : " << Known.getMaxValue() << "\n");

if (Known.getBitWidth() != RegSize)
return false;

// check load (ldr x, (add x, (shl (srl x, c1) 2)))
SDNode *User = N->use_begin().getUse().getUser();
LLVM_DEBUG(dbgs() << "N : "; N->dump(); User->dump());
if (!User || User->getOpcode() != ISD::ADD)
return false;

SDNode *Load = User->use_begin().getUse().getUser();
LLVM_DEBUG(dbgs() << "LOAD : "; Load->dump(); );
if (!Load || Load->getOpcode() != ISD::LOAD)
return false;

auto LoadM = dyn_cast<MemSDNode>(Load);
if (!LoadM)
return false;

TargetLoweringBase::AddrMode AM;
AM.HasBaseReg = true;
AM.BaseOffs = Known.getMaxValue().getZExtValue();
LLVM_DEBUG(dbgs() << "BaseMax : " << AM.BaseOffs << "\n");
EVT VT = LoadM->getMemoryVT();
unsigned AS = LoadM->getAddressSpace();
Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
return false;


LLVM_DEBUG(dbgs() << "Success : \n"; );
return true;
}


// const APInt &C2APIntVal = Op1C->getAPIntValue();
// for (SDNode *Node : N->uses()) {
// if (auto *LoadStore = dyn_cast<MemSDNode>(Node)) {
// TargetLoweringBase::AddrMode AM;
// AM.HasBaseReg = true;
// AM.BaseOffs = C2APIntVal.getSExtValue();
// EVT VT = LoadStore->getMemoryVT();
// unsigned AS = LoadStore->getAddressSpace();
// Type *AccessTy = VT.getTypeForEVT(*DAG.getContext());
// if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
// continue;

// // Would x[offset1+offset2] still be a legal addressing mode?
// if (!TLI.isLegalAddressingMode(DAG.getDataLayout(), AM, AccessTy, AS))
// return true;
// }
// }

return false;
}

bool DAGCombiner::isAddressingModePattern(unsigned Opc, const SDLoc &DL,
SDNode *N, SDValue Op0, SDValue Op1) {
switch (Opc) {
default:
false;
case ISD::SHL:
return isAddressingModePatternSHL(Opc, DL, N, Op0, Op1, TLI, DAG);
}

return false;
}

bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
const SDLoc &DL,
SDNode *N,
Expand All @@ -1085,6 +1172,8 @@ bool DAGCombiner::reassociationCanBreakAddressingModePattern(unsigned Opc,
if (Opc != ISD::ADD || N0.getOpcode() != ISD::ADD)
return false;

LLVM_DEBUG(dbgs() << "ADD\t";N->dump();N0.dump();N1.dump(););
LLVM_DEBUG(dbgs() << "=================================\n";);
auto *C2 = dyn_cast<ConstantSDNode>(N1);
if (!C2)
return false;
Expand Down Expand Up @@ -2610,6 +2699,8 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {
EVT VT = N0.getValueType();
SDLoc DL(N);


LLVM_DEBUG(dbgs() << "visitADDLike "; N->dump(););
// fold (add x, undef) -> undef
if (N0.isUndef())
return N0;
Expand Down Expand Up @@ -2687,6 +2778,7 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {

// reassociate add
if (!reassociationCanBreakAddressingModePattern(ISD::ADD, DL, N, N0, N1)) {
LLVM_DEBUG(dbgs() << "reassociation \n");
if (SDValue RADD = reassociateOps(ISD::ADD, DL, N0, N1, N->getFlags()))
return RADD;

Expand Down Expand Up @@ -2840,7 +2932,7 @@ SDValue DAGCombiner::visitADDLike(SDNode *N) {

if (SDValue Combined = visitADDLikeCommutative(N1, N0, N))
return Combined;

return SDValue();
}

Expand Down Expand Up @@ -2872,6 +2964,8 @@ SDValue DAGCombiner::visitADD(SDNode *N) {
EVT VT = N0.getValueType();
SDLoc DL(N);


LLVM_DEBUG(dbgs() << "visitADD "; N->dump();N0->dump();N1->dump(););
if (SDValue Combined = visitADDLike(N))
return Combined;

Expand Down Expand Up @@ -9887,13 +9981,15 @@ SDValue DAGCombiner::visitSHL(SDNode *N) {
}
}

LLVM_DEBUG(dbgs() << "FOLD !! \n");
// fold (shl (srl x, c1), c2) -> (and (shl x, (sub c2, c1), MASK) or
// (and (srl x, (sub c1, c2), MASK)
// Only fold this if the inner shift has no other uses -- if it does,
// folding this will increase the total number of instructions.
if (N0.getOpcode() == ISD::SRL &&
(N0.getOperand(1) == N1 || N0.hasOneUse()) &&
TLI.shouldFoldConstantShiftPairToMask(N, Level)) {
TLI.shouldFoldConstantShiftPairToMask(N, Level) &&
!isAddressingModePattern(N->getOpcode(), DL, N, N0, N1)) {
if (ISD::matchBinaryPredicate(N1, N0.getOperand(1), MatchShiftAmount,
/*AllowUndefs*/ false,
/*AllowTypeMismatch*/ true)) {
Expand Down

0 comments on commit 1343794

Please sign in to comment.