diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp index 022b8bcedda4d2..24e384fa64f1a3 100644 --- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp +++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp @@ -1502,7 +1502,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM, ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT}); if (Subtarget.hasVInstructions()) setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER, - ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL, + ISD::VP_GATHER, ISD::VP_SCATTER, + ISD::EXPERIMENTAL_VP_STRIDED_LOAD, ISD::SRA, ISD::SRL, ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR, ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS, ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL, @@ -17108,6 +17109,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N, VPSN->getMemOperand(), IndexType); break; } + case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: { + if (DCI.isBeforeLegalize()) + break; + auto *Load = cast(N); + MVT VT = N->getSimpleValueType(0); + + // Combine a zero strided load -> scalar load + splat + // The mask must be all ones and the EVL must be known to not be zero + if (!DAG.isKnownNeverZero(Load->getVectorLength()) || + !Load->getOffset().isUndef() || !Load->isSimple() || + !ISD::isConstantSplatVectorAllOnes(Load->getMask().getNode()) || + !isNullConstant(Load->getStride()) || + !isTypeLegal(VT.getVectorElementType())) + break; + + SDValue ScalarLoad; + if (VT.isInteger()) + ScalarLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, XLenVT, Load->getChain(), + Load->getBasePtr(), VT.getVectorElementType(), + Load->getMemOperand()); + else + ScalarLoad = DAG.getLoad(VT.getVectorElementType(), DL, Load->getChain(), + Load->getBasePtr(), Load->getMemOperand()); + SDValue Splat = VT.isFixedLengthVector() + ? DAG.getSplatBuildVector(VT, DL, ScalarLoad) + : DAG.getSplatVector(VT, DL, ScalarLoad); + return DAG.getMergeValues({Splat, SDValue(ScalarLoad.getNode(), 1)}, DL); + break; + } case RISCVISD::SHL_VL: if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget)) return V; diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll index 4d3bced0bcb50f..c19ecbb75d8189 100644 --- a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll +++ b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll @@ -780,3 +780,69 @@ define @strided_load_nxv17f64(ptr %ptr, i64 %stride, @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, , i32) declare @llvm.experimental.vector.extract.nxv1f64( %vec, i64 %idx) declare @llvm.experimental.vector.extract.nxv16f64( %vec, i64 %idx) + +define @zero_strided_zero_evl(ptr %ptr, %v) { +; CHECK-LABEL: zero_strided_zero_evl: +; CHECK: # %bb.0: +; CHECK-NEXT: ret + %load = call @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, splat (i1 true), i32 0) + %res = add %v, %load + ret %res +} + +define @zero_strided_not_known_notzero_evl(ptr %ptr, %v, i32 zeroext %evl) { +; CHECK-LABEL: zero_strided_not_known_notzero_evl: +; CHECK: # %bb.0: +; CHECK-NEXT: vsetvli zero, a1, e64, m1, ta, ma +; CHECK-NEXT: vlse64.v v9, (a0), zero +; CHECK-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-NEXT: vadd.vv v8, v8, v9 +; CHECK-NEXT: ret + %load = call @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, splat (i1 true), i32 %evl) + %res = add %v, %load + ret %res +} + +define @zero_strided_known_notzero_avl(ptr %ptr, %v) { +; CHECK-RV32-LABEL: zero_strided_known_notzero_avl: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: vsetivli zero, 1, e64, m1, ta, ma +; CHECK-RV32-NEXT: vlse64.v v9, (a0), zero +; CHECK-RV32-NEXT: vsetvli a0, zero, e64, m1, ta, ma +; CHECK-RV32-NEXT: vadd.vv v8, v8, v9 +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: zero_strided_known_notzero_avl: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: ld a0, 0(a0) +; CHECK-RV64-NEXT: vsetvli a1, zero, e64, m1, ta, ma +; CHECK-RV64-NEXT: vadd.vx v8, v8, a0 +; CHECK-RV64-NEXT: ret + %load = call @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, splat (i1 true), i32 1) + %res = add %v, %load + ret %res +} + +define @zero_strided_vec_length_avl(ptr %ptr, %v) vscale_range(2, 1024) { +; CHECK-RV32-LABEL: zero_strided_vec_length_avl: +; CHECK-RV32: # %bb.0: +; CHECK-RV32-NEXT: csrr a1, vlenb +; CHECK-RV32-NEXT: srli a1, a1, 2 +; CHECK-RV32-NEXT: vsetvli zero, a1, e64, m2, ta, ma +; CHECK-RV32-NEXT: vlse64.v v10, (a0), zero +; CHECK-RV32-NEXT: vsetvli a0, zero, e64, m2, ta, ma +; CHECK-RV32-NEXT: vadd.vv v8, v8, v10 +; CHECK-RV32-NEXT: ret +; +; CHECK-RV64-LABEL: zero_strided_vec_length_avl: +; CHECK-RV64: # %bb.0: +; CHECK-RV64-NEXT: ld a0, 0(a0) +; CHECK-RV64-NEXT: vsetvli a1, zero, e64, m2, ta, ma +; CHECK-RV64-NEXT: vadd.vx v8, v8, a0 +; CHECK-RV64-NEXT: ret + %vscale = call i32 @llvm.vscale() + %veclen = mul i32 %vscale, 2 + %load = call @llvm.experimental.vp.strided.load.nxv2i64.p0.i32(ptr %ptr, i32 0, splat (i1 true), i32 %veclen) + %res = add %v, %load + ret %res +}