diff --git a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h index 72fda911962ad..3c5cf1ebe6ba2 100644 --- a/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h +++ b/llvm/include/llvm/Transforms/Vectorize/LoopVectorizationLegality.h @@ -389,27 +389,21 @@ class LoopVectorizationLegality { return LAI->getDepChecker().getMaxSafeVectorWidthInBits(); } - /// Returns true if the loop has an uncountable early exit, i.e. an + /// Returns true if the loop has exactly one uncountable early exit, i.e. an /// uncountable exit that isn't the latch block. - bool hasUncountableEarlyExit() const { return HasUncountableEarlyExit; } + bool hasUncountableEarlyExit() const { + return getUncountableEdge().has_value(); + } - /// Returns the uncountable early exiting block. + /// Returns the uncountable early exiting block, if there is exactly one. BasicBlock *getUncountableEarlyExitingBlock() const { - if (!HasUncountableEarlyExit) { - assert(getUncountableExitingBlocks().empty() && - "Expected no uncountable exiting blocks"); - return nullptr; - } - assert(getUncountableExitingBlocks().size() == 1 && - "Expected only a single uncountable exiting block"); - return getUncountableExitingBlocks()[0]; + return hasUncountableEarlyExit() ? getUncountableEdge()->first : nullptr; } - /// Returns the destination of an uncountable early exiting block. + /// Returns the destination of the uncountable early exiting block, if there + /// is exactly one. BasicBlock *getUncountableEarlyExitBlock() const { - assert(getUncountableExitBlocks().size() == 1 && - "Expected only a single uncountable exit block"); - return getUncountableExitBlocks()[0]; + return hasUncountableEarlyExit() ? getUncountableEdge()->second : nullptr; } /// Returns true if vector representation of the instruction \p I @@ -463,14 +457,11 @@ class LoopVectorizationLegality { return CountableExitingBlocks; } - /// Returns all the exiting blocks with an uncountable exit. - const SmallVector &getUncountableExitingBlocks() const { - return UncountableExitingBlocks; - } - - /// Returns all the exit blocks from uncountable exiting blocks. - SmallVector getUncountableExitBlocks() const { - return UncountableExitBlocks; + /// Returns the loop edge to an uncountable exit, or std::nullopt if there + /// isn't a single such edge. + std::optional> + getUncountableEdge() const { + return UncountableEdge; } private: @@ -654,18 +645,13 @@ class LoopVectorizationLegality { /// supported. bool StructVecCallFound = false; - /// Indicates whether this loop has an uncountable early exit, i.e. an - /// uncountable exiting block that is not the latch. - bool HasUncountableEarlyExit = false; - /// Keep track of all the countable and uncountable exiting blocks if /// the exact backedge taken count is not computable. SmallVector CountableExitingBlocks; - SmallVector UncountableExitingBlocks; - /// Keep track of the destinations of all uncountable exits if the - /// exact backedge taken count is not computable. - SmallVector UncountableExitBlocks; + /// Keep track of the loop edge to an uncountable exit, comprising a pair + /// of (Exiting, Exit) blocks, if there is exactly one early exit. + std::optional> UncountableEdge; }; } // namespace llvm diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp index 406864a6793dc..e3599315e224f 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorizationLegality.cpp @@ -1631,12 +1631,11 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // Keep a record of all the exiting blocks. SmallVector Predicates; + std::optional> SingleUncountableEdge; for (BasicBlock *BB : ExitingBlocks) { const SCEV *EC = PSE.getSE()->getPredicatedExitCount(TheLoop, BB, &Predicates); if (isa(EC)) { - UncountableExitingBlocks.push_back(BB); - SmallVector Succs(successors(BB)); if (Succs.size() != 2) { reportVectorizationFailure( @@ -1653,7 +1652,16 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { assert(!TheLoop->contains(Succs[1])); ExitBlock = Succs[1]; } - UncountableExitBlocks.push_back(ExitBlock); + + if (SingleUncountableEdge) { + reportVectorizationFailure( + "Loop has too many uncountable exits", + "Cannot vectorize early exit loop with more than one early exit", + "TooManyUncountableEarlyExits", ORE, TheLoop); + return false; + } + + SingleUncountableEdge = {BB, ExitBlock}; } else CountableExitingBlocks.push_back(BB); } @@ -1663,19 +1671,15 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { // PSE.getSymbolicMaxBackedgeTakenCount() below. Predicates.clear(); - // We only support one uncountable early exit. - if (getUncountableExitingBlocks().size() != 1) { - reportVectorizationFailure( - "Loop has too many uncountable exits", - "Cannot vectorize early exit loop with more than one early exit", - "TooManyUncountableEarlyExits", ORE, TheLoop); + if (!SingleUncountableEdge) { + LLVM_DEBUG(dbgs() << "LV: Cound not find any uncountable exits"); return false; } // The only supported early exit loops so far are ones where the early // exiting block is a unique predecessor of the latch block. BasicBlock *LatchPredBB = LatchBB->getUniquePredecessor(); - if (LatchPredBB != getUncountableEarlyExitingBlock()) { + if (LatchPredBB != SingleUncountableEdge->first) { reportVectorizationFailure("Early exit is not the latch predecessor", "Cannot vectorize early exit loop", "EarlyExitNotLatchPredecessor", ORE, TheLoop); @@ -1728,7 +1732,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { } // The vectoriser cannot handle loads that occur after the early exit block. - assert(LatchBB->getUniquePredecessor() == getUncountableEarlyExitingBlock() && + assert(LatchBB->getUniquePredecessor() == SingleUncountableEdge->first && "Expected latch predecessor to be the early exiting block"); // TODO: Handle loops that may fault. @@ -1751,6 +1755,7 @@ bool LoopVectorizationLegality::isVectorizableEarlyExitLoop() { LLVM_DEBUG(dbgs() << "LV: Found an early exit loop with symbolic max " "backedge taken count: " << *SymbolicMaxBTC << '\n'); + UncountableEdge = SingleUncountableEdge; return true; } @@ -1812,7 +1817,6 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { return false; } - HasUncountableEarlyExit = false; if (isa(PSE.getBackedgeTakenCount())) { if (TheLoop->getExitingBlock()) { reportVectorizationFailure("Cannot vectorize uncountable loop", @@ -1822,10 +1826,8 @@ bool LoopVectorizationLegality::canVectorize(bool UseVPlanNativePath) { else return false; } else { - HasUncountableEarlyExit = true; if (!isVectorizableEarlyExitLoop()) { - UncountableExitingBlocks.clear(); - HasUncountableEarlyExit = false; + UncountableEdge = std::nullopt; if (DoExtraAnalysis) Result = false; else