diff --git a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp index 1ee4c66a5bde5..0493bf0f408d7 100644 --- a/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp +++ b/llvm/lib/Target/RISCV/RISCVCodeGenPrepare.cpp @@ -22,6 +22,7 @@ #include "llvm/IR/IRBuilder.h" #include "llvm/IR/InstVisitor.h" #include "llvm/IR/Intrinsics.h" +#include "llvm/IR/IntrinsicsRISCV.h" #include "llvm/IR/PatternMatch.h" #include "llvm/InitializePasses.h" #include "llvm/Pass.h" @@ -48,6 +49,7 @@ class RISCVCodeGenPrepare : public InstVisitor { bool visitAnd(BinaryOperator &BO); bool visitIntrinsicInst(IntrinsicInst &I); bool expandVPStrideLoad(IntrinsicInst &I); + bool convertVFirstPattern(IntrinsicInst &I); bool widenVPMerge(IntrinsicInst &I); }; } // namespace @@ -213,6 +215,9 @@ bool RISCVCodeGenPrepare::widenVPMerge(IntrinsicInst &II) { // Which eliminates the scalar -> vector -> scalar crossing during instruction // selection. bool RISCVCodeGenPrepare::visitIntrinsicInst(IntrinsicInst &I) { + if (convertVFirstPattern(I)) + return true; + if (expandVPStrideLoad(I)) return true; @@ -281,6 +286,100 @@ bool RISCVCodeGenPrepare::expandVPStrideLoad(IntrinsicInst &II) { return true; } +// Convert vector.reduce.or + cttz.elts into riscv.vfirst. +// +// The RISC-V vfirst instruction natively provides the functionality of both +// vector.reduce.or (checking if any element is set) and cttz.elts (finding +// the first set element). This function matches the following pattern and +// replaces it with a single vfirst intrinsic: +// +// Before: +// block1: +// %ffload = call {, i32} @llvm.vp.load.ff(ptr, , i32) +// %evl = extractvalue %ffload, 1 +// %alm = call @llvm.get.active.lane.mask(0, %evl) +// %cond = ... +// %select = select %alm, %cond, zeroinitializer +// %reduce = call @llvm.vector.reduce.or(%select) +// br i1 %reduce, label %early.exit, label %continue +// early.exit: +// %idx = call @llvm.experimental.cttz.elts(%cond) +// ... +// +// After: +// block1: +// %vfirst = call @llvm.riscv.vfirst.mask(%cond, %mask, %evl) +// %found = icmp sge %vfirst, 0 +// br i1 %found, label %early.exit, label %continue +// early.exit: +// ; uses of cttz.elts replaced with %vfirst +// ... +bool RISCVCodeGenPrepare::convertVFirstPattern(IntrinsicInst &II) { + using namespace PatternMatch; + Value *Select, *ALM, *Cond, *EVL, *FFLoad, *Mask; + + // Match the reduce.or pattern with freeze, select, active lane mask, + // and vp.load.ff. + bool MatchReduceOr = + match(&II, m_Intrinsic( + m_Freeze(m_Value(Select)))) && + match(Select, m_Select(m_Value(ALM), m_Value(Cond), m_Zero())) && + Cond->getNumUses() == 2 && + match(ALM, m_Intrinsic( + m_Zero(), m_ZExtOrSelf(m_Value(EVL)))) && + match(EVL, m_ExtractValue<1>(m_Value(FFLoad))) && + match(FFLoad, m_Intrinsic(m_Value(), m_Value(Mask), + m_Value())); + if (!MatchReduceOr) + return false; + + // Find the cttz.elts user of Cond. + IntrinsicInst *CttzElts = nullptr; + for (User *U : Cond->users()) { + if (auto *Intr = dyn_cast(U)) { + if (Intr->getIntrinsicID() == Intrinsic::experimental_cttz_elts) { + CttzElts = Intr; + break; + } + } + } + if (!CttzElts) + return false; + + // Verify that cttz.elts is in a block whose single predecessor branches + // on the reduce.or result. + BasicBlock *CttzBB = CttzElts->getParent(); + BasicBlock *PredBB = CttzBB->getSinglePredecessor(); + if (!PredBB) + return false; + auto *BI = dyn_cast(PredBB->getTerminator()); + if (!BI || !BI->isConditional() || BI->getCondition() != &II) + return false; + + // Generate the vfirst intrinsic and replacement instructions. + IRBuilder<> Builder(&II); + Type *XLenTy = IntegerType::get(II.getContext(), ST->getXLen()); + if (EVL->getType() != XLenTy) + EVL = Builder.CreateZExt(EVL, XLenTy); + + Value *VFirst = + Builder.CreateIntrinsic(Intrinsic::riscv_vfirst_mask, + {Cond->getType(), XLenTy}, {Cond, Mask, EVL}); + + // Replace reduce.or with (icmp sge (vfirst), 0) + // vfirst returns -1 if no element is set. + Value *Found = Builder.CreateICmpSGE(VFirst, ConstantInt::get(XLenTy, 0)); + II.replaceAllUsesWith(Found); + + // Replace cttz.elts with the vfirst. + Value *VFirstCasted = Builder.CreateZExtOrTrunc(VFirst, CttzElts->getType()); + CttzElts->replaceAllUsesWith(VFirstCasted); + + II.eraseFromParent(); + CttzElts->eraseFromParent(); + return true; +} + bool RISCVCodeGenPrepare::run() { bool MadeChange = false; for (auto &BB : F) diff --git a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll index ffbcb65c40c33..722711117c8d8 100644 --- a/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll +++ b/llvm/test/CodeGen/RISCV/rvv/riscv-codegenprepare.ll @@ -3,7 +3,7 @@ define float @reduce_fadd(ptr %f) { ; CHECK-LABEL: define float @reduce_fadd( -; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR2:[0-9]+]] { +; CHECK-SAME: ptr [[F:%.*]]) #[[ATTR0:[0-9]+]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: [[VSCALE:%.*]] = tail call i64 @llvm.vscale.i64() ; CHECK-NEXT: [[VECSIZE:%.*]] = shl nuw nsw i64 [[VSCALE]], 2 @@ -44,7 +44,7 @@ exit: define i32 @vp_reduce_add(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_add( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -88,7 +88,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_and(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_and( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -132,7 +132,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_or(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_or( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -176,7 +176,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_xor(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_xor( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -220,7 +220,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_smax(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_smax( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -264,7 +264,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_smin(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_smin( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -308,7 +308,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_umax(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_umax( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -352,7 +352,7 @@ for.cond.cleanup: ; preds = %vector.body define i32 @vp_reduce_umin(ptr %a) { ; CHECK-LABEL: define i32 @vp_reduce_umin( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -396,7 +396,7 @@ for.cond.cleanup: ; preds = %vector.body define float @vp_reduce_fadd(ptr %a) { ; CHECK-LABEL: define float @vp_reduce_fadd( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -440,7 +440,7 @@ for.cond.cleanup: ; preds = %vector.body define float @vp_reduce_fmax(ptr %a) { ; CHECK-LABEL: define float @vp_reduce_fmax( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -484,7 +484,7 @@ for.cond.cleanup: ; preds = %vector.body define float @vp_reduce_fmin(ptr %a) { ; CHECK-LABEL: define float @vp_reduce_fmin( -; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR2]] { +; CHECK-SAME: ptr [[A:%.*]]) #[[ATTR0]] { ; CHECK-NEXT: entry: ; CHECK-NEXT: br label [[VECTOR_BODY:%.*]] ; CHECK: vector.body: @@ -525,3 +525,57 @@ vector.body: ; preds = %vector.body, %entry for.cond.cleanup: ; preds = %vector.body ret float %red } + +define i64 @vfirst_use1(ptr %src, %mask, i32 %avl, i8 %value, i64 %index) { +; CHECK-LABEL: define i64 @vfirst_use1( +; CHECK-SAME: ptr [[SRC:%.*]], [[MASK:%.*]], i32 [[AVL:%.*]], i8 [[VALUE:%.*]], i64 [[INDEX:%.*]]) #[[ATTR0]] { +; CHECK-NEXT: entry: +; CHECK-NEXT: [[BROADCAST_SPLATINSERT:%.*]] = insertelement poison, i8 [[VALUE]], i64 0 +; CHECK-NEXT: [[BROADCAST_SPLAT:%.*]] = shufflevector [[BROADCAST_SPLATINSERT]], poison, zeroinitializer +; CHECK-NEXT: [[FFLOAD:%.*]] = call { , i32 } @llvm.vp.load.ff.nxv16i8.p0(ptr [[SRC]], [[MASK]], i32 [[AVL]]) +; CHECK-NEXT: [[EVL:%.*]] = extractvalue { , i32 } [[FFLOAD]], 1 +; CHECK-NEXT: [[EVL64:%.*]] = zext i32 [[EVL]] to i64 +; CHECK-NEXT: [[ALM:%.*]] = call @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 [[EVL64]]) +; CHECK-NEXT: [[DATA:%.*]] = extractvalue { , i32 } [[FFLOAD]], 0 +; CHECK-NEXT: [[EEMASK:%.*]] = icmp eq [[DATA]], [[BROADCAST_SPLAT]] +; CHECK-NEXT: [[SEL:%.*]] = select [[ALM]], [[EEMASK]], zeroinitializer +; CHECK-NEXT: [[SEL2:%.*]] = freeze [[SEL]] +; CHECK-NEXT: [[TMP0:%.*]] = zext i32 [[EVL]] to i64 +; CHECK-NEXT: [[TMP1:%.*]] = call i64 @llvm.riscv.vfirst.mask.nxv16i1.i64( [[EEMASK]], [[MASK]], i64 [[TMP0]]) +; CHECK-NEXT: [[TMP2:%.*]] = icmp sge i64 [[TMP1]], 0 +; CHECK-NEXT: br i1 [[TMP2]], label [[VECTOR_EARLY_EXIT:%.*]], label [[VECTOR_BODY_INTERIM:%.*]] +; CHECK: vector.body.interim: +; CHECK-NEXT: br label [[EXIT:%.*]] +; CHECK: vector.early.exit: +; CHECK-NEXT: [[TMP4:%.*]] = add i64 [[INDEX]], [[TMP1]] +; CHECK-NEXT: br label [[EXIT]] +; CHECK: exit: +; CHECK-NEXT: [[RETVAL:%.*]] = phi i64 [ 0, [[VECTOR_BODY_INTERIM]] ], [ [[TMP4]], [[VECTOR_EARLY_EXIT]] ] +; CHECK-NEXT: ret i64 [[RETVAL]] +; +entry: + %broadcast.splatinsert = insertelement poison, i8 %value, i64 0 + %broadcast.splat = shufflevector %broadcast.splatinsert, poison, zeroinitializer + %ffload = call { , i32 } @llvm.vp.load.ff.nxv16i8.p0(ptr %src, %mask, i32 %avl) + %EVL = extractvalue { , i32 } %ffload, 1 + %EVL64 = zext i32 %EVL to i64 + %ALM = call @llvm.get.active.lane.mask.nxv16i1.i64(i64 0, i64 %EVL64) + %data = extractvalue { , i32 } %ffload, 0 + %eemask = icmp eq %data, %broadcast.splat + %sel = select %ALM, %eemask, zeroinitializer + %sel2 = freeze %sel + %early.exit = call i1 @llvm.vector.reduce.or.nxv16i1( %sel2) + br i1 %early.exit, label %vector.early.exit, label %vector.body.interim + +vector.body.interim: ; preds = %vector.body + br label %exit + +vector.early.exit: ; preds = %vector.body + %19 = call i64 @llvm.experimental.cttz.elts.i64.nxv16i1( %eemask, i1 false) + %20 = add i64 %index, %19 + br label %exit + +exit: ; preds = %vector.early.exit, %middle.block, %for.inc, %for.body + %retval = phi i64 [ 0, %vector.body.interim ], [ %20, %vector.early.exit ] + ret i64 %retval +}