diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index 3cf5e5fb3f818..3259200011616 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -7424,6 +7424,8 @@ DenseMap LoopVectorizationPlanner::executePlan( VPlanTransforms::removeDeadRecipes(BestVPlan); VPlanTransforms::convertToConcreteRecipes(BestVPlan); + // Convert the exit condition to AVLNext == 0 for EVL tail folded loops. + VPlanTransforms::convertEVLExitCond(BestVPlan); // Regions are dissolved after optimizing for VF and UF, which completely // removes unneeded loop regions first. VPlanTransforms::dissolveLoopRegions(BestVPlan); diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp index d1d5870d78a03..a05c70fb05ff7 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp @@ -3266,22 +3266,6 @@ void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) { VPBasicBlock *HeaderVPBB = EVLPhi->getParent(); VPValue *EVLIncrement = EVLPhi->getBackedgeValue(); - VPValue *AVL; - [[maybe_unused]] bool FoundAVL = - match(EVLIncrement, - m_c_Add(m_ZExtOrSelf(m_EVL(m_VPValue(AVL))), m_Specific(EVLPhi))); - assert(FoundAVL && "Didn't find AVL?"); - - // The AVL may be capped to a safe distance. - VPValue *SafeAVL; - if (match(AVL, m_Select(m_VPValue(), m_VPValue(SafeAVL), m_VPValue()))) - AVL = SafeAVL; - - VPValue *AVLNext; - [[maybe_unused]] bool FoundAVLNext = - match(AVL, m_VPInstruction( - m_Specific(Plan.getTripCount()), m_VPValue(AVLNext))); - assert(FoundAVLNext && "Didn't find AVL backedge?"); // Convert EVLPhi to concrete recipe. auto *ScalarR = @@ -3302,27 +3286,57 @@ void VPlanTransforms::canonicalizeEVLLoops(VPlan &Plan) { VPRecipeBase *CanonicalIVIncrement = Backedge->getDefiningRecipe(); CanonicalIVIncrement->eraseFromParent(); CanonicalIV->eraseFromParent(); +} + +void VPlanTransforms::convertEVLExitCond(VPlan &Plan) { + VPRegionBlock *LoopRegion = Plan.getVectorLoopRegion(); + // The canonical IV may not exist at this stage. + if (!LoopRegion || + !isa(LoopRegion->getEntryBasicBlock()->front())) + return; + VPCanonicalIVPHIRecipe *CanIV = LoopRegion->getCanonicalIV(); + if (std::next(CanIV->getIterator()) == CanIV->getParent()->end()) + return; + // The EVL IV is always immediately after the canonical IV. + auto *EVLPhi = + dyn_cast_or_null(std::next(CanIV->getIterator())); + if (!EVLPhi) + return; + + // Bail if not an EVL tail folded loop. + VPValue *AVL; + if (!match(EVLPhi->getBackedgeValue(), + m_c_Add(m_ZExtOrSelf(m_EVL(m_VPValue(AVL))), m_Specific(EVLPhi)))) + return; + + // The AVL may be capped to a safe distance. + VPValue *SafeAVL; + if (match(AVL, m_Select(m_VPValue(), m_VPValue(SafeAVL), m_VPValue()))) + AVL = SafeAVL; + + VPValue *AVLNext; + [[maybe_unused]] bool FoundAVLNext = + match(AVL, m_VPInstruction( + m_Specific(Plan.getTripCount()), m_VPValue(AVLNext))); + assert(FoundAVLNext && "Didn't find AVL backedge?"); - // Replace the use of VectorTripCount in the latch-exiting block. - // Before: (branch-on-cond (icmp eq EVLIVInc, VectorTripCount)) - // After: (branch-on-cond icmp eq AVLNext, 0) - VPBasicBlock *LatchExiting = - HeaderVPBB->getPredecessors()[1]->getEntryBasicBlock(); - auto *LatchExitingBr = cast(LatchExiting->getTerminator()); - if (match(LatchExitingBr, m_BranchOnCond(m_True()))) + VPBasicBlock *Latch = LoopRegion->getExitingBasicBlock(); + auto *LatchBr = cast(Latch->getTerminator()); + if (match(LatchBr, m_BranchOnCond(m_True()))) return; - assert(match(LatchExitingBr, m_BranchOnCond(m_SpecificCmp( - CmpInst::ICMP_EQ, m_VPValue(EVLIncrement), - m_Specific(&Plan.getVectorTripCount())))) && - "Expected BranchOnCond with ICmp comparing EVL increment with vector " - "trip count"); + assert( + match(LatchBr, + m_BranchOnCond(m_SpecificCmp( + CmpInst::ICMP_EQ, m_Specific(CanIV->getIncomingValue(1)), + m_Specific(&Plan.getVectorTripCount())))) && + "Expected BranchOnCond with ICmp comparing CanIV increment with vector " + "trip count"); Type *AVLTy = VPTypeAnalysis(Plan).inferScalarType(AVLNext); - VPBuilder Builder(LatchExitingBr); - LatchExitingBr->setOperand(0, - Builder.createICmp(CmpInst::ICMP_EQ, AVLNext, - Plan.getConstantInt(AVLTy, 0))); + VPBuilder Builder(LatchBr); + LatchBr->setOperand(0, Builder.createICmp(CmpInst::ICMP_EQ, AVLNext, + Plan.getConstantInt(AVLTy, 0))); } void VPlanTransforms::replaceSymbolicStrides( diff --git a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h index e0d09a099647a..b0248faa20afe 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanTransforms.h +++ b/llvm/lib/Transforms/Vectorize/VPlanTransforms.h @@ -301,6 +301,12 @@ struct VPlanTransforms { VPlan &Plan, VPBasicBlock *HeaderVPBB, VPBasicBlock *LatchVPBB); + /// Replaces the exit condition from + /// (branch-on-cond eq CanonicalIVInc, VectorTripCount) + /// to + /// (branch-on-cond eq AVLNext, 0) + static void convertEVLExitCond(VPlan &Plan); + /// Replace loop regions with explicit CFG. static void dissolveLoopRegions(VPlan &Plan); @@ -315,10 +321,6 @@ struct VPlanTransforms { /// variable vector lengths instead of fixed lengths. This transformation: /// * Makes EVL-Phi concrete. // * Removes CanonicalIV and increment. - /// * Replaces the exit condition from - /// (branch-on-count CanonicalIVInc, VectorTripCount) - /// to - /// (branch-on-cond eq AVLNext, 0) static void canonicalizeEVLLoops(VPlan &Plan); /// Lower abstract recipes to concrete ones, that can be codegen'd. diff --git a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp index 2d6809d6f344e..5a260cf27416f 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanVerifier.cpp @@ -103,6 +103,12 @@ bool VPlanVerifier::verifyPhiRecipes(const VPBasicBlock *VPBB) { return false; } + if (isa(RecipeI) && + !isa_and_nonnull(std::prev(RecipeI))) { + errs() << "EVL based IV is not immediately after canonical IV\n"; + return false; + } + // Check if the recipe operands match the number of predecessors. // TODO Extend to other phi-like recipes. if (auto *PhiIRI = dyn_cast(&*RecipeI)) {