-
Notifications
You must be signed in to change notification settings - Fork 16.1k
[RISCV] Convert vector.reduce.or + cttz.elts to vfirst #175952
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
@llvm/pr-subscribers-backend-risc-v Author: Shih-Po Hung (arcbbb) ChangesAdd a pattern match in RISCVCodeGenPrepare to recognize the combination of vector.reduce.or and cttz.elts intrinsics and replace them with the native riscv.vfirst.mask intrinsic. The vfirst instruction returns the index of the first set element (or -1 if none), which natively provides both:
This depends on #151300 Full diff: https://github.com/llvm/llvm-project/pull/175952.diff 2 Files Affected:
diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
index 1ee4c66a5bde5..9163339ce03f6 100644
--- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
+++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp
@@ -22,6 +22,7 @@
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/InstVisitor.h"
#include "llvm/IR/Intrinsics.h"
+#include "llvm/IR/IntrinsicsRISCV.h"
#include "llvm/IR/PatternMatch.h"
#include "llvm/InitializePasses.h"
#include "llvm/Pass.h"
@@ -48,6 +49,7 @@ class RISCVCodeGenPrepare : public InstVisitor<RISCVCodeGenPrepare, bool> {
bool visitAnd(BinaryOperator &BO);
bool visitIntrinsicInst(IntrinsicInst &I);
bool expandVPStrideLoad(IntrinsicInst &I);
+ bool convertVFirstPattern(IntrinsicInst &I);
bool widenVPMerge(IntrinsicInst &I);
};
} // namespace
@@ -213,6 +215,9 @@ bool RISCVCodeGenPrepare::widenVPMerge(IntrinsicInst &II) {
// Which eliminates the scalar -> vector -> scalar crossing during instruction
// selection.
bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) {
+ if (convertVFirstPattern(I))
+ return true;
+
if (expandVPStrideLoad(I))
return true;
@@ -281,6 +286,99 @@ bool RISCVCodeGenPrepare::expandVPStrideLoad(IntrinsicInst &II) {
return true;
}
+// Convert vector.reduce.or + cttz.elts into riscv.vfirst.
+//
+// The RISC-V vfirst instruction natively provides the functionality of both
+// vector.reduce.or (checking if any element is set) and cttz.elts (finding
+// the first set element). This function matches the following pattern and
+// replaces it with a single vfirst intrinsic:
+//
+// Before:
+// block1:
+// %ffload = call {<vTy>, i32} @llvm.vp.load.ff(ptr, <mask>, i32)
+// %evl = extractvalue %ffload, 1
+// %alm = call @llvm.get.active.lane.mask(0, %evl)
+// %cond = ...
+// %select = select %alm, %cond, zeroinitializer
+// %reduce = call @llvm.vector.reduce.or(%select)
+// br i1 %reduce, label %early.exit, label %continue
+// early.exit:
+// %idx = call @llvm.experimental.cttz.elts(%cond)
+// ...
+//
+// After:
+// block1:
+// %vfirst = call @llvm.riscv.vfirst.mask(%cond, %mask, %evl)
+// %found = icmp sge %vfirst, 0
+// br i1 %found, label %early.exit, label %continue
+// early.exit:
+// ; uses of cttz.elts replaced with %vfirst
+// ...
+bool RISCVCodeGenPrepare::convertVFirstPattern(IntrinsicInst &II) {
+ using namespace PatternMatch;
+ Value *Select, *ALM, *Cond, *EVL, *FFLoad, *Mask;
+
+ // Match the reduce.or pattern with freeze, select, active lane mask,
+ // and vp.load.ff.
+ bool MatchReduceOr =
+ match(&II, m_Intrinsic<Intrinsic::vector_reduce_or>(
+ m_Freeze(m_Value(Select)))) &&
+ match(Select, m_Select(m_Value(ALM), m_Value(Cond), m_Zero())) &&
+ Cond->getNumUses() == 2 &&
+ match(ALM, m_Intrinsic<Intrinsic::get_active_lane_mask>(
+ m_Zero(), m_ZExtOrSelf(m_Value(EVL)))) &&
+ match(EVL, m_ExtractValue<1>(m_Value(FFLoad))) &&
+ match(FFLoad, m_Intrinsic<Intrinsic::vp_load_ff>(m_Value(), m_Value(Mask),
+ m_Value()));
+ if (!MatchReduceOr)
+ return false;
+
+ // Find the cttz.elts user of Cond.
+ IntrinsicInst *CttzElts = nullptr;
+ for (User *U : Cond->users()) {
+ if (auto *Intr = dyn_cast<IntrinsicInst>(U)) {
+ if (Intr->getIntrinsicID() == Intrinsic::experimental_cttz_elts) {
+ CttzElts = Intr;
+ break;
+ }
+ }
+ }
+ if (!CttzElts)
+ return false;
+
+ // Verify that cttz.elts is in a block whose single predecessor branches
+ // on the reduce.or result.
+ BasicBlock *CttzBB = CttzElts->getParent();
+ BasicBlock *PredBB = CttzBB->getSinglePredecessor();
+ if (!PredBB)
+ return false;
+ auto *BI = dyn_cast<BranchInst>(PredBB->getTerminator());
+ if (!BI || !BI->isConditional() || BI->getCondition() != &II)
+ return false;
+
+ // Generate the vfirst intrinsic and replacement instructions.
+ IRBuilder<> Builder(&II);
+ Type *XLenTy = IntegerType::get(II.getContext(), ST->getXLen());
+ if (EVL->getType() != XLenTy)
+ EVL = Builder.CreateZExt(EVL, XLenTy);
+
+ Value *VFirst = Builder.CreateIntrinsic(
+ Intrinsic::riscv_vfirst_mask, {Cond->getType(), XLenTy}, {Cond, Mask, EVL});
+
+ // Replace reduce.or with (icmp sge (vfirst), 0)
+ // vfirst returns -1 if no element is set.
+ Value *Found = Builder.CreateICmpSGE(VFirst, ConstantInt::get(XLenTy, 0));
+ II.replaceAllUsesWith(Found);
+
+ // Replace cttz.elts with the vfirst result (with appropriate type adjustment).
+ Value *VFirstCasted = Builder.CreateZExtOrTrunc(VFirst, CttzElts->getType());
+ CttzElts->replaceAllUsesWith(VFirstCasted);
+
+ II.eraseFromParent();
+ CttzElts->eraseFromParent();
+ return true;
+}
+
bool RISCVCodeGenPrepare::run() {
bool MadeChange = false;
for (auto &BB : F)
diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
index ffbcb65c40c33..722711117c8d8 100644
--- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
+++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll
@@ -3,7 +3,7 @@
define float @reduce_fadd(ptr %f) {
; CHECK-LABEL: define float @reduce_fadd(
-; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR2:[0-9]+]] {
+; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR0:[0-9]+]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64()
; CHECK-NEXT: [[VECSIZE:%.*]] = shl nuw nsw i64 [[VSCALE]], 2
@@ -44,7 +44,7 @@ exit:
define i32 @vp_reduce_add(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_add(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -88,7 +88,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_and(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_and(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -132,7 +132,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_or(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_or(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -176,7 +176,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_xor(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_xor(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -220,7 +220,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_smax(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_smax(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -264,7 +264,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_smin(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_smin(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -308,7 +308,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_umax(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_umax(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -352,7 +352,7 @@ for.cond.cleanup: ; preds = %vector.body
define i32 @vp_reduce_umin(ptr %a) {
; CHECK-LABEL: define i32 @vp_reduce_umin(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -396,7 +396,7 @@ for.cond.cleanup: ; preds = %vector.body
define float @vp_reduce_fadd(ptr %a) {
; CHECK-LABEL: define float @vp_reduce_fadd(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -440,7 +440,7 @@ for.cond.cleanup: ; preds = %vector.body
define float @vp_reduce_fmax(ptr %a) {
; CHECK-LABEL: define float @vp_reduce_fmax(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -484,7 +484,7 @@ for.cond.cleanup: ; preds = %vector.body
define float @vp_reduce_fmin(ptr %a) {
; CHECK-LABEL: define float @vp_reduce_fmin(
-; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] {
+; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] {
; CHECK-NEXT: entry:
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
; CHECK: vector.body:
@@ -525,3 +525,57 @@ vector.body: ; preds = %vector.body, %entry
for.cond.cleanup: ; preds = %vector.body
ret float %red
}
+
+define i64 @vfirst_use1(ptr %src, <vscale x 16 x i1> %mask, i32 %avl, i8 %value, i64 %index) {
+; CHECK-LABEL: define i64 @vfirst_use1(
+; CHECK-SAME: ptr [[SRC:%.*]], <vscale x 16 x i1> [[MASK:%.*]], i32 [[AVL:%.*]], i8 [[VALUE:%.*]], i64 [[INDEX:%.*]]) #[[ATTR0]] {
+; CHECK-NEXT: entry:
+; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement <vscale x 16 x i8> poison, i8 [[VALUE]], i64 0
+; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector <vscale x 16 x i8> [[BROADCAST_SPLATINSERT]], <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+; CHECK-NEXT: [[FFLOAD:%.*]] = call { <vscale x 16 x i8>, i32 } @llvm.vp.load.ff.nxv16i8.p0(ptr [[SRC]], <vscale x 16 x i1> [[MASK]], i32 [[AVL]])
+; CHECK-NEXT: [[EVL:%.*]] = extractvalue { <vscale x 16 x i8>, i32 } [[FFLOAD]], 1
+; CHECK-NEXT: [[EVL64:%.*]] = zext i32 [[EVL]] to i64
+; CHECK-NEXT: [[ALM:%.*]] = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 [[EVL64]])
+; CHECK-NEXT: [[DATA:%.*]] = extractvalue { <vscale x 16 x i8>, i32 } [[FFLOAD]], 0
+; CHECK-NEXT: [[EEMASK:%.*]] = icmp eq <vscale x 16 x i8> [[DATA]], [[BROADCAST_SPLAT]]
+; CHECK-NEXT: [[SEL:%.*]] = select <vscale x 16 x i1> [[ALM]], <vscale x 16 x i1> [[EEMASK]], <vscale x 16 x i1> zeroinitializer
+; CHECK-NEXT: [[SEL2:%.*]] = freeze <vscale x 16 x i1> [[SEL]]
+; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[EVL]] to i64
+; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.riscv.vfirst.mask.nxv16i1.i64(<vscale x 16 x i1> [[EEMASK]], <vscale x 16 x i1> [[MASK]], i64 [[TMP0]])
+; CHECK-NEXT: [[TMP2:%.*]] = icmp sge i64 [[TMP1]], 0
+; CHECK-NEXT: br i1 [[TMP2]], label [[VECTOR_EARLY_EXIT:%.*]], label [[VECTOR_BODY_INTERIM:%.*]]
+; CHECK: vector.body.interim:
+; CHECK-NEXT: br label [[EXIT:%.*]]
+; CHECK: vector.early.exit:
+; CHECK-NEXT: [[TMP4:%.*]] = add i64 [[INDEX]], [[TMP1]]
+; CHECK-NEXT: br label [[EXIT]]
+; CHECK: exit:
+; CHECK-NEXT: [[RETVAL:%.*]] = phi i64 [ 0, [[VECTOR_BODY_INTERIM]] ], [ [[TMP4]], [[VECTOR_EARLY_EXIT]] ]
+; CHECK-NEXT: ret i64 [[RETVAL]]
+;
+entry:
+ %broadcast.splatinsert = insertelement <vscale x 16 x i8> poison, i8 %value, i64 0
+ %broadcast.splat = shufflevector <vscale x 16 x i8> %broadcast.splatinsert, <vscale x 16 x i8> poison, <vscale x 16 x i32> zeroinitializer
+ %ffload = call { <vscale x 16 x i8>, i32 } @llvm.vp.load.ff.nxv16i8.p0(ptr %src, <vscale x 16 x i1> %mask, i32 %avl)
+ %EVL = extractvalue { <vscale x 16 x i8>, i32 } %ffload, 1
+ %EVL64 = zext i32 %EVL to i64
+ %ALM = call <vscale x 16 x i1> @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %EVL64)
+ %data = extractvalue { <vscale x 16 x i8>, i32 } %ffload, 0
+ %eemask = icmp eq <vscale x 16 x i8> %data, %broadcast.splat
+ %sel = select <vscale x 16 x i1> %ALM, <vscale x 16 x i1> %eemask, <vscale x 16 x i1> zeroinitializer
+ %sel2 = freeze <vscale x 16 x i1> %sel
+ %early.exit = call i1 @llvm.vector.reduce.or.nxv16i1(<vscale x 16 x i1> %sel2)
+ br i1 %early.exit, label %vector.early.exit, label %vector.body.interim
+
+vector.body.interim: ; preds = %vector.body
+ br label %exit
+
+vector.early.exit: ; preds = %vector.body
+ %19 = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1(<vscale x 16 x i1> %eemask, i1 false)
+ %20 = add i64 %index, %19
+ br label %exit
+
+exit: ; preds = %vector.early.exit, %middle.block, %for.inc, %for.body
+ %retval = phi i64 [ 0, %vector.body.interim ], [ %20, %vector.early.exit ]
+ ret i64 %retval
+}
|
Add a pattern match in RISCVCodeGenPrepare to recognize the combination of vector.reduce.or and cttz.elts intrinsics and replace them with the native riscv.vfirst.mask intrinsic. The vfirst instruction returns the index of the first set element (or -1 if none), which natively provides both: - The boolean result of reduce.or (vfirst >= 0 means found) - The index computed by cttz.elts This eliminates redundant scalar-vector-scalar conversions during instruction selection.
db4c8a5 to
d947310
Compare
| // and vp.load.ff. | ||
| bool MatchReduceOr = | ||
| match(&II, m_Intrinsic<Intrinsic::vector_reduce_or>( | ||
| m_Freeze(m_Value(Select)))) && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's no freeze in the IR in the comment. Can you add it?
| %data = extractvalue { <vscale x 16 x i8>, i32 } %ffload, 0 | ||
| %eemask = icmp eq <vscale x 16 x i8> %data, %broadcast.splat | ||
| %sel = select <vscale x 16 x i1> %ALM, <vscale x 16 x i1> %eemask, <vscale x 16 x i1> zeroinitializer | ||
| %sel2 = freeze <vscale x 16 x i1> %sel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did the freeze exist before, and why don't we need it in the final output?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Freeze is always generated to freeze the input of vector.reduce.or in #154156.
For this case with vp.load.ff, the freeze is redundant since the active-lane-mask already gates the input.
Similarly, the riscv.vfirst replacement doesn't need freeze because %evl masks out inactive lanes.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like I still need the freeze for consistency.
|
Why does the vectorizer use get.active.lane.mask instead of using vp.reduce.or and llvm.vp.cttz.elts? |
Good point. I was trying to avoid VP intrinsics for non-tail-folded cases, but |
Add a pattern match in RISCVCodeGenPrepare to recognize the combination of vector.reduce.or and cttz.elts intrinsics and replace them with the native riscv.vfirst.mask intrinsic.
The vfirst instruction returns the index of the first set element (or -1 if none), which natively provides both:
This depends on #151300