diff --git a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp index 0f948b22759fe..cfec46d23d65b 100644 --- a/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelDAGToDAG.cpp @@ -3058,17 +3058,28 @@ bool RISCVDAGToDAGISel::SelectAddrRegRegScale(SDValue Addr, }; if (auto *C1 = dyn_cast(RHS)) { + // (add (add (shl A C2) B) C1) -> (add (add B C1) (shl A C2)) if (LHS.getOpcode() == ISD::ADD && - SelectShl(LHS.getOperand(0), Index, Scale) && !isa(LHS.getOperand(1)) && isInt<12>(C1->getSExtValue())) { - // (add (add (shl A C2) B) C1) -> (add (add B C1) (shl A C2)) - SDValue C1Val = CurDAG->getTargetConstant(*C1->getConstantIntValue(), - SDLoc(Addr), VT); - Base = SDValue(CurDAG->getMachineNode(RISCV::ADDI, SDLoc(Addr), VT, - LHS.getOperand(1), C1Val), - 0); - return true; + if (SelectShl(LHS.getOperand(1), Index, Scale)) { + SDValue C1Val = CurDAG->getTargetConstant(*C1->getConstantIntValue(), + SDLoc(Addr), VT); + Base = SDValue(CurDAG->getMachineNode(RISCV::ADDI, SDLoc(Addr), VT, + LHS.getOperand(0), C1Val), + 0); + return true; + } + + // Add is commutative so we need to check both operands. + if (SelectShl(LHS.getOperand(0), Index, Scale)) { + SDValue C1Val = CurDAG->getTargetConstant(*C1->getConstantIntValue(), + SDLoc(Addr), VT); + Base = SDValue(CurDAG->getMachineNode(RISCV::ADDI, SDLoc(Addr), VT, + LHS.getOperand(1), C1Val), + 0); + return true; + } } // Don't match add with constants. diff --git a/llvm/test/CodeGen/RISCV/xqcisls.ll b/llvm/test/CodeGen/RISCV/xqcisls.ll index 828a0760044aa..709dc4ce074dc 100644 --- a/llvm/test/CodeGen/RISCV/xqcisls.ll +++ b/llvm/test/CodeGen/RISCV/xqcisls.ll @@ -309,8 +309,8 @@ define i64 @lrd(ptr %a, i32 %b) { ; RV32IZBAXQCISLS-LABEL: lrd: ; RV32IZBAXQCISLS: # %bb.0: ; RV32IZBAXQCISLS-NEXT: qc.lrw a2, a0, a1, 3 -; RV32IZBAXQCISLS-NEXT: sh3add a0, a1, a0 -; RV32IZBAXQCISLS-NEXT: lw a1, 4(a0) +; RV32IZBAXQCISLS-NEXT: addi a0, a0, 4 +; RV32IZBAXQCISLS-NEXT: qc.lrw a1, a0, a1, 3 ; RV32IZBAXQCISLS-NEXT: add a0, a2, a2 ; RV32IZBAXQCISLS-NEXT: sltu a2, a0, a2 ; RV32IZBAXQCISLS-NEXT: add a1, a1, a1 @@ -473,10 +473,10 @@ define void @srd(ptr %a, i32 %b, i64 %c) { ; RV32IZBAXQCISLS-NEXT: add a4, a2, a2 ; RV32IZBAXQCISLS-NEXT: add a3, a3, a3 ; RV32IZBAXQCISLS-NEXT: sltu a2, a4, a2 -; RV32IZBAXQCISLS-NEXT: add a2, a3, a2 -; RV32IZBAXQCISLS-NEXT: sh3add a3, a1, a0 ; RV32IZBAXQCISLS-NEXT: qc.srw a4, a0, a1, 3 -; RV32IZBAXQCISLS-NEXT: sw a2, 4(a3) +; RV32IZBAXQCISLS-NEXT: add a2, a3, a2 +; RV32IZBAXQCISLS-NEXT: addi a0, a0, 4 +; RV32IZBAXQCISLS-NEXT: qc.srw a2, a0, a1, 3 ; RV32IZBAXQCISLS-NEXT: ret %1 = add i64 %c, %c %2 = getelementptr i64, ptr %a, i32 %b diff --git a/llvm/test/CodeGen/RISCV/xtheadmemidx.ll b/llvm/test/CodeGen/RISCV/xtheadmemidx.ll index 578f51a957a75..fc20fcb371179 100644 --- a/llvm/test/CodeGen/RISCV/xtheadmemidx.ll +++ b/llvm/test/CodeGen/RISCV/xtheadmemidx.ll @@ -858,14 +858,13 @@ define i64 @lurwu(ptr %a, i32 %b) { define i64 @lrd(ptr %a, i64 %b) { ; RV32XTHEADMEMIDX-LABEL: lrd: ; RV32XTHEADMEMIDX: # %bb.0: -; RV32XTHEADMEMIDX-NEXT: slli a2, a1, 3 +; RV32XTHEADMEMIDX-NEXT: th.lrw a2, a0, a1, 3 +; RV32XTHEADMEMIDX-NEXT: addi a0, a0, 4 ; RV32XTHEADMEMIDX-NEXT: th.lrw a1, a0, a1, 3 -; RV32XTHEADMEMIDX-NEXT: add a0, a0, a2 -; RV32XTHEADMEMIDX-NEXT: lw a2, 4(a0) -; RV32XTHEADMEMIDX-NEXT: add a0, a1, a1 -; RV32XTHEADMEMIDX-NEXT: sltu a1, a0, a1 -; RV32XTHEADMEMIDX-NEXT: add a2, a2, a2 -; RV32XTHEADMEMIDX-NEXT: add a1, a2, a1 +; RV32XTHEADMEMIDX-NEXT: add a0, a2, a2 +; RV32XTHEADMEMIDX-NEXT: sltu a2, a0, a2 +; RV32XTHEADMEMIDX-NEXT: add a1, a1, a1 +; RV32XTHEADMEMIDX-NEXT: add a1, a1, a2 ; RV32XTHEADMEMIDX-NEXT: ret ; ; RV64XTHEADMEMIDX-LABEL: lrd: @@ -908,14 +907,13 @@ define i64 @lrd_2(ptr %a, i64 %b) { define i64 @lurd(ptr %a, i32 %b) { ; RV32XTHEADMEMIDX-LABEL: lurd: ; RV32XTHEADMEMIDX: # %bb.0: -; RV32XTHEADMEMIDX-NEXT: slli a2, a1, 3 +; RV32XTHEADMEMIDX-NEXT: th.lrw a2, a0, a1, 3 +; RV32XTHEADMEMIDX-NEXT: addi a0, a0, 4 ; RV32XTHEADMEMIDX-NEXT: th.lrw a1, a0, a1, 3 -; RV32XTHEADMEMIDX-NEXT: add a0, a0, a2 -; RV32XTHEADMEMIDX-NEXT: lw a2, 4(a0) -; RV32XTHEADMEMIDX-NEXT: add a0, a1, a1 -; RV32XTHEADMEMIDX-NEXT: sltu a1, a0, a1 -; RV32XTHEADMEMIDX-NEXT: add a2, a2, a2 -; RV32XTHEADMEMIDX-NEXT: add a1, a2, a1 +; RV32XTHEADMEMIDX-NEXT: add a0, a2, a2 +; RV32XTHEADMEMIDX-NEXT: sltu a2, a0, a2 +; RV32XTHEADMEMIDX-NEXT: add a1, a1, a1 +; RV32XTHEADMEMIDX-NEXT: add a1, a1, a2 ; RV32XTHEADMEMIDX-NEXT: ret ; ; RV64XTHEADMEMIDX-LABEL: lurd: @@ -1047,11 +1045,10 @@ define void @srd(ptr %a, i64 %b, i64 %c) { ; RV32XTHEADMEMIDX-NEXT: add a2, a3, a3 ; RV32XTHEADMEMIDX-NEXT: add a4, a4, a4 ; RV32XTHEADMEMIDX-NEXT: sltu a3, a2, a3 -; RV32XTHEADMEMIDX-NEXT: add a3, a4, a3 -; RV32XTHEADMEMIDX-NEXT: slli a4, a1, 3 -; RV32XTHEADMEMIDX-NEXT: add a4, a0, a4 ; RV32XTHEADMEMIDX-NEXT: th.srw a2, a0, a1, 3 -; RV32XTHEADMEMIDX-NEXT: sw a3, 4(a4) +; RV32XTHEADMEMIDX-NEXT: add a3, a4, a3 +; RV32XTHEADMEMIDX-NEXT: addi a0, a0, 4 +; RV32XTHEADMEMIDX-NEXT: th.srw a3, a0, a1, 3 ; RV32XTHEADMEMIDX-NEXT: ret ; ; RV64XTHEADMEMIDX-LABEL: srd: @@ -1071,11 +1068,10 @@ define void @surd(ptr %a, i32 %b, i64 %c) { ; RV32XTHEADMEMIDX-NEXT: add a4, a2, a2 ; RV32XTHEADMEMIDX-NEXT: add a3, a3, a3 ; RV32XTHEADMEMIDX-NEXT: sltu a2, a4, a2 -; RV32XTHEADMEMIDX-NEXT: add a2, a3, a2 -; RV32XTHEADMEMIDX-NEXT: slli a3, a1, 3 -; RV32XTHEADMEMIDX-NEXT: add a3, a0, a3 ; RV32XTHEADMEMIDX-NEXT: th.srw a4, a0, a1, 3 -; RV32XTHEADMEMIDX-NEXT: sw a2, 4(a3) +; RV32XTHEADMEMIDX-NEXT: add a2, a3, a2 +; RV32XTHEADMEMIDX-NEXT: addi a0, a0, 4 +; RV32XTHEADMEMIDX-NEXT: th.srw a2, a0, a1, 3 ; RV32XTHEADMEMIDX-NEXT: ret ; ; RV64XTHEADMEMIDX-LABEL: surd: