Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[RISCV] Combine vp_strided_load with zero stride to scalar load + splat #97798

Closed
wants to merge 2 commits into from

Conversation

lukel97
Copy link
Contributor

@lukel97 lukel97 commented Jul 5, 2024

This is another version of #97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.

The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.

RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.

This is another version of llvm#97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.

The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.

RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.
@llvmbot
Copy link
Collaborator

llvmbot commented Jul 5, 2024

@llvm/pr-subscribers-backend-risc-v

Author: Luke Lau (lukel97)

Changes

This is another version of #97394, but performs it as a DAGCombine instead of lowering so that we have a better chance of detecting non-zero EVLs before they are legalized.

The riscv_masked_strided_load already does this, but this combine also checks that the vector element type is legal. Currently a riscv_masked_strided_load with a zero stride of nxv1i64 will crash on rv32, but I'm hoping we can remove the masked_strided intrinsics and replace them with their VP counterparts.

RISCVISelDAGToDAG will lower splats of scalar loads back to zero strided loads anyway, so the test changes are to show how combining it to a scalar load can lead to some .vx patterns being matched.


Full diff: https://github.com/llvm/llvm-project/pull/97798.diff

2 Files Affected:

  • (modified) llvm/lib/Target/RISCV/RISCVISelLowering.cpp (+31-1)
  • (modified) llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll (+66)
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 022b8bcedda4d..24e384fa64f1a 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1502,7 +1502,8 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
   if (Subtarget.hasVInstructions())
     setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
-                         ISD::VP_GATHER, ISD::VP_SCATTER, ISD::SRA, ISD::SRL,
+                         ISD::VP_GATHER, ISD::VP_SCATTER,
+                         ISD::EXPERIMENTAL_VP_STRIDED_LOAD, ISD::SRA, ISD::SRL,
                          ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
                          ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
                          ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
@@ -17108,6 +17109,35 @@ SDValue RISCVTargetLowering::PerformDAGCombine(SDNode *N,
                               VPSN->getMemOperand(), IndexType);
     break;
   }
+  case ISD::EXPERIMENTAL_VP_STRIDED_LOAD: {
+    if (DCI.isBeforeLegalize())
+      break;
+    auto *Load = cast<VPStridedLoadSDNode>(N);
+    MVT VT = N->getSimpleValueType(0);
+
+    // Combine a zero strided load -> scalar load + splat
+    // The mask must be all ones and the EVL must be known to not be zero
+    if (!DAG.isKnownNeverZero(Load->getVectorLength()) ||
+        !Load->getOffset().isUndef() || !Load->isSimple() ||
+        !ISD::isConstantSplatVectorAllOnes(Load->getMask().getNode()) ||
+        !isNullConstant(Load->getStride()) ||
+        !isTypeLegal(VT.getVectorElementType()))
+      break;
+
+    SDValue ScalarLoad;
+    if (VT.isInteger())
+      ScalarLoad = DAG.getExtLoad(ISD::EXTLOAD, DL, XLenVT, Load->getChain(),
+                                  Load->getBasePtr(), VT.getVectorElementType(),
+                                  Load->getMemOperand());
+    else
+      ScalarLoad = DAG.getLoad(VT.getVectorElementType(), DL, Load->getChain(),
+                               Load->getBasePtr(), Load->getMemOperand());
+    SDValue Splat = VT.isFixedLengthVector()
+                        ? DAG.getSplatBuildVector(VT, DL, ScalarLoad)
+                        : DAG.getSplatVector(VT, DL, ScalarLoad);
+    return DAG.getMergeValues({Splat, SDValue(ScalarLoad.getNode(), 1)}, DL);
+    break;
+  }
   case RISCVISD::SHL_VL:
     if (SDValue V = combineBinOp_VLToVWBinOp_VL(N, DCI, Subtarget))
       return V;
diff --git a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
index 4d3bced0bcb50..c19ecbb75d818 100644
--- a/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/strided-vpload.ll
@@ -780,3 +780,69 @@ define <vscale x 16 x double> @strided_load_nxv17f64(ptr %ptr, i64 %stride, <vsc
 declare <vscale x 17 x double> @llvm.experimental.vp.strided.load.nxv17f64.p0.i64(ptr, i64, <vscale x 17 x i1>, i32)
 declare <vscale x 1 x double> @llvm.experimental.vector.extract.nxv1f64(<vscale x 17 x double> %vec, i64 %idx)
 declare <vscale x 16 x double> @llvm.experimental.vector.extract.nxv16f64(<vscale x 17 x double> %vec, i64 %idx)
+
+define <vscale x 1 x i64> @zero_strided_zero_evl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-LABEL: zero_strided_zero_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 0)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_not_known_notzero_evl(ptr %ptr, <vscale x 1 x i64> %v, i32 zeroext %evl) {
+; CHECK-LABEL: zero_strided_not_known_notzero_evl:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    vsetvli zero, a1, e64, m1, ta, ma
+; CHECK-NEXT:    vlse64.v v9, (a0), zero
+; CHECK-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-NEXT:    vadd.vv v8, v8, v9
+; CHECK-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 %evl)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 1 x i64> @zero_strided_known_notzero_avl(ptr %ptr, <vscale x 1 x i64> %v) {
+; CHECK-RV32-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    vsetivli zero, 1, e64, m1, ta, ma
+; CHECK-RV32-NEXT:    vlse64.v v9, (a0), zero
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e64, m1, ta, ma
+; CHECK-RV32-NEXT:    vadd.vv v8, v8, v9
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: zero_strided_known_notzero_avl:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    ld a0, 0(a0)
+; CHECK-RV64-NEXT:    vsetvli a1, zero, e64, m1, ta, ma
+; CHECK-RV64-NEXT:    vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT:    ret
+  %load = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i32(ptr %ptr, i32 0, <vscale x 1 x i1> splat (i1 true), i32 1)
+  %res = add <vscale x 1 x i64> %v, %load
+  ret <vscale x 1 x i64> %res
+}
+
+define <vscale x 2 x i64> @zero_strided_vec_length_avl(ptr %ptr, <vscale x 2 x i64> %v) vscale_range(2, 1024) {
+; CHECK-RV32-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV32:       # %bb.0:
+; CHECK-RV32-NEXT:    csrr a1, vlenb
+; CHECK-RV32-NEXT:    srli a1, a1, 2
+; CHECK-RV32-NEXT:    vsetvli zero, a1, e64, m2, ta, ma
+; CHECK-RV32-NEXT:    vlse64.v v10, (a0), zero
+; CHECK-RV32-NEXT:    vsetvli a0, zero, e64, m2, ta, ma
+; CHECK-RV32-NEXT:    vadd.vv v8, v8, v10
+; CHECK-RV32-NEXT:    ret
+;
+; CHECK-RV64-LABEL: zero_strided_vec_length_avl:
+; CHECK-RV64:       # %bb.0:
+; CHECK-RV64-NEXT:    ld a0, 0(a0)
+; CHECK-RV64-NEXT:    vsetvli a1, zero, e64, m2, ta, ma
+; CHECK-RV64-NEXT:    vadd.vx v8, v8, a0
+; CHECK-RV64-NEXT:    ret
+  %vscale = call i32 @llvm.vscale()
+  %veclen = mul i32 %vscale, 2
+  %load = call <vscale x 2 x i64> @llvm.experimental.vp.strided.load.nxv2i64.p0.i32(ptr %ptr, i32 0, <vscale x 2 x i1> splat (i1 true), i32 %veclen)
+  %res = add <vscale x 2 x i64> %v, %load
+  ret <vscale x 2 x i64> %res
+}

Copy link

github-actions bot commented Jul 5, 2024

⚠️ C/C++ code formatter, clang-format found issues in your code. ⚠️

You can test this locally with the following command:
git-clang-format --diff e4b28420f677207cbb81683396d1aba00fb9ab80 af35a30f9bc0f938782c1b430b86db2747f09871 -- llvm/lib/Target/RISCV/RISCVISelLowering.cpp
View the diff from clang-format here.
diff --git a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
index 24e384fa64..207bb71f36 100644
--- a/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
+++ b/llvm/lib/Target/RISCV/RISCVISelLowering.cpp
@@ -1501,14 +1501,27 @@ RISCVTargetLowering::RISCVTargetLowering(const TargetMachine &TM,
     setTargetDAGCombine({ISD::ZERO_EXTEND, ISD::FP_TO_SINT, ISD::FP_TO_UINT,
                          ISD::FP_TO_SINT_SAT, ISD::FP_TO_UINT_SAT});
   if (Subtarget.hasVInstructions())
-    setTargetDAGCombine({ISD::FCOPYSIGN, ISD::MGATHER, ISD::MSCATTER,
-                         ISD::VP_GATHER, ISD::VP_SCATTER,
-                         ISD::EXPERIMENTAL_VP_STRIDED_LOAD, ISD::SRA, ISD::SRL,
-                         ISD::SHL, ISD::STORE, ISD::SPLAT_VECTOR,
-                         ISD::BUILD_VECTOR, ISD::CONCAT_VECTORS,
-                         ISD::EXPERIMENTAL_VP_REVERSE, ISD::MUL,
-                         ISD::SDIV, ISD::UDIV, ISD::SREM, ISD::UREM,
-                         ISD::INSERT_VECTOR_ELT, ISD::ABS});
+    setTargetDAGCombine({ISD::FCOPYSIGN,
+                         ISD::MGATHER,
+                         ISD::MSCATTER,
+                         ISD::VP_GATHER,
+                         ISD::VP_SCATTER,
+                         ISD::EXPERIMENTAL_VP_STRIDED_LOAD,
+                         ISD::SRA,
+                         ISD::SRL,
+                         ISD::SHL,
+                         ISD::STORE,
+                         ISD::SPLAT_VECTOR,
+                         ISD::BUILD_VECTOR,
+                         ISD::CONCAT_VECTORS,
+                         ISD::EXPERIMENTAL_VP_REVERSE,
+                         ISD::MUL,
+                         ISD::SDIV,
+                         ISD::UDIV,
+                         ISD::SREM,
+                         ISD::UREM,
+                         ISD::INSERT_VECTOR_ELT,
+                         ISD::ABS});
   if (Subtarget.hasVendorXTHeadMemPair())
     setTargetDAGCombine({ISD::LOAD, ISD::STORE});
   if (Subtarget.useRVVForFixedLengthVectors())

Copy link
Contributor

@wangpc-pp wangpc-pp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM in general, but I think we should wait for your PR for removing riscv_strided_load/store?

@lukel97
Copy link
Contributor Author

lukel97 commented Jul 5, 2024

LGTM in general, but I think we should wait for your PR for removing riscv_strided_load/store?

This is a dependency to remove riscv_strided_load/store, since we get some small regressions if we switch to vp.strided.load/store without handling this case


// Combine a zero strided load -> scalar load + splat
// The mask must be all ones and the EVL must be known to not be zero
if (!DAG.isKnownNeverZero(Load->getVectorLength()) ||
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do the cheaper checks before calling DAG.isKnownNeverZero since that's recursive.

@topperc
Copy link
Collaborator

topperc commented Jul 5, 2024

I think @yetingk is looking at doing this in RISCVCodeGenPrepare where we have access to IR's isKnownNonZero. We need that to handle the loop vectorized case where we only know the EVL is non-zero because of a branch in the loop preheader.

@lukel97
Copy link
Contributor Author

lukel97 commented Jul 6, 2024

I think @BeMg is looking at doing this in RISCVCodeGenPrepare where we have access to IR's isKnownNonZero. We need that to handle the loop vectorized case where we only know the EVL is non-zero because of a branch in the loop preheader.

That seems like a better place. And it looks like RISCVCodeGenPrepare runs after RISCVGatherScatterLowering, so I think it would still work if we emitted experimental.vp.strided.load from the latter. I'll leave this for PR for now then

lukel97 added a commit to lukel97/llvm-project that referenced this pull request Jul 9, 2024
RISCVGatherScatterLowering is the main user of riscv_masked_strided_{load,store}, which we can remove if we replace them with their VP equivalents.

Submitting early as a draft to show the regressions in the test diff that llvm#97800 and llvm#97798 (or the CGP version) are needed to fix.
lukel97 added a commit to lukel97/llvm-project that referenced this pull request Jul 9, 2024
This combine is a duplication of the transform in RISCVGatherScatterLowering but at the SelectionDAG level, so similarly to llvm#98111 we can replace the use of riscv_masked_strided_load with a VP strided load.

Unlike llvm#98111 we don't require llvm#97800 or llvm#97798 since it only operates on fixed vectors with a non-zero stride.
lukel97 added a commit that referenced this pull request Jul 9, 2024
This combine is a duplication of the transform in
RISCVGatherScatterLowering but at the SelectionDAG level, so similarly
to #98111 we can replace the use of riscv_masked_strided_load with a VP
strided load.

Unlike #98111 we don't require #97800 or #97798 since it only operates
on fixed vectors with a non-zero stride.
@yetingk
Copy link
Contributor

yetingk commented Jul 9, 2024

I think @yetingk is looking at doing this in RISCVCodeGenPrepare where we have access to IR's isKnownNonZero. We need that to handle the loop vectorized case where we only know the EVL is non-zero because of a branch in the loop preheader.

I had created a draft PR #98140. It's just a draft, since it's hard to create a splat with specific vl in IR level. I used riscv intrinsics vmv_v_x/vmv_v_f in this PR, but the implementation made fixed vector not benefit this optimization.

@lukel97
Copy link
Contributor Author

lukel97 commented Jul 10, 2024

Superseded by #98140

@lukel97 lukel97 closed this Jul 10, 2024
yetingk added a commit that referenced this pull request Jul 11, 2024
It's a similar patch as a214c52 for
vp.stride.load. Some targets prefer pattern (vmv.v.x (load)) instead of
vlse with zero stride.

It's IR version of #97798.
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
This combine is a duplication of the transform in
RISCVGatherScatterLowering but at the SelectionDAG level, so similarly
to llvm#98111 we can replace the use of riscv_masked_strided_load with a VP
strided load.

Unlike llvm#98111 we don't require llvm#97800 or llvm#97798 since it only operates
on fixed vectors with a non-zero stride.
aaryanshukla pushed a commit to aaryanshukla/llvm-project that referenced this pull request Jul 14, 2024
It's a similar patch as a214c52 for
vp.stride.load. Some targets prefer pattern (vmv.v.x (load)) instead of
vlse with zero stride.

It's IR version of llvm#97798.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants