@@ -15510,6 +15510,79 @@ static const SCEV *getNextSCEVDivisibleByDivisor(const SCEV *Expr,
1551015510 return SE.getConstant(*ExprVal + DivisorVal - Rem);
1551115511}
1551215512
15513+ static bool collectDivisibilityInformation(
15514+ ICmpInst::Predicate Predicate, const SCEV *LHS, const SCEV *RHS,
15515+ DenseMap<const SCEV *, const SCEV *> &DivInfo,
15516+ DenseMap<const SCEV *, APInt> &Multiples, ScalarEvolution &SE) {
15517+ // If we have LHS == 0, check if LHS is computing a property of some unknown
15518+ // SCEV %v which we can rewrite %v to express explicitly.
15519+ if (Predicate != CmpInst::ICMP_EQ || !match(RHS, m_scev_Zero()))
15520+ return false;
15521+ // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15522+ // explicitly express that.
15523+ const SCEVUnknown *URemLHS = nullptr;
15524+ const SCEV *URemRHS = nullptr;
15525+ if (!match(LHS, m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE)))
15526+ return false;
15527+
15528+ const SCEV *Multiple =
15529+ SE.getMulExpr(SE.getUDivExpr(URemLHS, URemRHS), URemRHS);
15530+ DivInfo[URemLHS] = Multiple;
15531+ Multiples[URemLHS] = cast<SCEVConstant>(URemRHS)->getAPInt();
15532+ return true;
15533+ }
15534+
15535+ // Check if the condition is a divisibility guard (A % B == 0).
15536+ static bool isDivisibilityGuard(const SCEV *LHS, const SCEV *RHS,
15537+ ScalarEvolution &SE) {
15538+ const SCEV *X, *Y;
15539+ return match(LHS, m_scev_URem(m_SCEV(X), m_SCEV(Y), SE)) && RHS->isZero();
15540+ }
15541+
15542+ // Apply divisibility by \p Divisor on MinMaxExpr with constant values,
15543+ // recursively. This is done by aligning up/down the constant value to the
15544+ // Divisor.
15545+ static const SCEV *applyDivisibilityOnMinMaxExpr(const SCEV *MinMaxExpr,
15546+ const SCEV *Divisor,
15547+ ScalarEvolution &SE) {
15548+ // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15549+ // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15550+ // the non-constant operand and in \p LHS the constant operand.
15551+ auto IsMinMaxSCEVWithNonNegativeConstant =
15552+ [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15553+ const SCEV *&RHS) {
15554+ if (auto *MinMax = dyn_cast<SCEVMinMaxExpr>(Expr)) {
15555+ if (MinMax->getNumOperands() != 2)
15556+ return false;
15557+ if (auto *C = dyn_cast<SCEVConstant>(MinMax->getOperand(0))) {
15558+ if (C->getAPInt().isNegative())
15559+ return false;
15560+ SCTy = MinMax->getSCEVType();
15561+ LHS = MinMax->getOperand(0);
15562+ RHS = MinMax->getOperand(1);
15563+ return true;
15564+ }
15565+ }
15566+ return false;
15567+ };
15568+
15569+ const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15570+ SCEVTypes SCTy;
15571+ if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15572+ MinMaxRHS))
15573+ return MinMaxExpr;
15574+ auto IsMin = isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15575+ assert(SE.isKnownNonNegative(MinMaxLHS) && "Expected non-negative operand!");
15576+ auto *DivisibleExpr =
15577+ IsMin ? getPreviousSCEVDivisibleByDivisor(
15578+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE)
15579+ : getNextSCEVDivisibleByDivisor(
15580+ MinMaxLHS, cast<SCEVConstant>(Divisor)->getAPInt(), SE);
15581+ SmallVector<const SCEV *> Ops = {
15582+ applyDivisibilityOnMinMaxExpr(MinMaxRHS, Divisor, SE), DivisibleExpr};
15583+ return SE.getMinMaxExpr(SCTy, Ops);
15584+ }
15585+
1551315586void ScalarEvolution::LoopGuards::collectFromBlock(
1551415587 ScalarEvolution &SE, ScalarEvolution::LoopGuards &Guards,
1551515588 const BasicBlock *Block, const BasicBlock *Pred,
@@ -15520,19 +15593,14 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1552015593 SmallVector<const SCEV *> ExprsToRewrite;
1552115594 auto CollectCondition = [&](ICmpInst::Predicate Predicate, const SCEV *LHS,
1552215595 const SCEV *RHS,
15523- DenseMap<const SCEV *, const SCEV *>
15524- &RewriteMap) {
15596+ DenseMap<const SCEV *, const SCEV *> &RewriteMap,
15597+ const DenseMap<const SCEV *, const SCEV *>
15598+ &DivInfo) {
1552515599 // WARNING: It is generally unsound to apply any wrap flags to the proposed
1552615600 // replacement SCEV which isn't directly implied by the structure of that
1552715601 // SCEV. In particular, using contextual facts to imply flags is *NOT*
1552815602 // legal. See the scoping rules for flags in the header to understand why.
1552915603
15530- // If LHS is a constant, apply information to the other expression.
15531- if (isa<SCEVConstant>(LHS)) {
15532- std::swap(LHS, RHS);
15533- Predicate = CmpInst::getSwappedPredicate(Predicate);
15534- }
15535-
1553615604 // Check for a condition of the form (-C1 + X < C2). InstCombine will
1553715605 // create this form when combining two checks of the form (X u< C2 + C1) and
1553815606 // (X >=u C1).
@@ -15565,67 +15633,6 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1556515633 if (MatchRangeCheckIdiom())
1556615634 return;
1556715635
15568- // Return true if \p Expr is a MinMax SCEV expression with a non-negative
15569- // constant operand. If so, return in \p SCTy the SCEV type and in \p RHS
15570- // the non-constant operand and in \p LHS the constant operand.
15571- auto IsMinMaxSCEVWithNonNegativeConstant =
15572- [&](const SCEV *Expr, SCEVTypes &SCTy, const SCEV *&LHS,
15573- const SCEV *&RHS) {
15574- const APInt *C;
15575- SCTy = Expr->getSCEVType();
15576- return match(Expr, m_scev_MinMax(m_SCEV(LHS), m_SCEV(RHS))) &&
15577- match(LHS, m_scev_APInt(C)) && C->isNonNegative();
15578- };
15579-
15580- // Apply divisibilty by \p Divisor on MinMaxExpr with constant values,
15581- // recursively. This is done by aligning up/down the constant value to the
15582- // Divisor.
15583- std::function<const SCEV *(const SCEV *, const SCEV *)>
15584- ApplyDivisibiltyOnMinMaxExpr = [&](const SCEV *MinMaxExpr,
15585- const SCEV *Divisor) {
15586- auto *ConstDivisor = dyn_cast<SCEVConstant>(Divisor);
15587- if (!ConstDivisor)
15588- return MinMaxExpr;
15589- const APInt &DivisorVal = ConstDivisor->getAPInt();
15590-
15591- const SCEV *MinMaxLHS = nullptr, *MinMaxRHS = nullptr;
15592- SCEVTypes SCTy;
15593- if (!IsMinMaxSCEVWithNonNegativeConstant(MinMaxExpr, SCTy, MinMaxLHS,
15594- MinMaxRHS))
15595- return MinMaxExpr;
15596- auto IsMin =
15597- isa<SCEVSMinExpr>(MinMaxExpr) || isa<SCEVUMinExpr>(MinMaxExpr);
15598- assert(SE.isKnownNonNegative(MinMaxLHS) &&
15599- "Expected non-negative operand!");
15600- auto *DivisibleExpr =
15601- IsMin
15602- ? getPreviousSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE)
15603- : getNextSCEVDivisibleByDivisor(MinMaxLHS, DivisorVal, SE);
15604- SmallVector<const SCEV *> Ops = {
15605- ApplyDivisibiltyOnMinMaxExpr(MinMaxRHS, Divisor), DivisibleExpr};
15606- return SE.getMinMaxExpr(SCTy, Ops);
15607- };
15608-
15609- // If we have LHS == 0, check if LHS is computing a property of some unknown
15610- // SCEV %v which we can rewrite %v to express explicitly.
15611- if (Predicate == CmpInst::ICMP_EQ && match(RHS, m_scev_Zero())) {
15612- // If LHS is A % B, i.e. A % B == 0, rewrite A to (A /u B) * B to
15613- // explicitly express that.
15614- const SCEVUnknown *URemLHS = nullptr;
15615- const SCEV *URemRHS = nullptr;
15616- if (match(LHS,
15617- m_scev_URem(m_SCEVUnknown(URemLHS), m_SCEV(URemRHS), SE))) {
15618- auto I = RewriteMap.find(URemLHS);
15619- const SCEV *RewrittenLHS = I != RewriteMap.end() ? I->second : URemLHS;
15620- RewrittenLHS = ApplyDivisibiltyOnMinMaxExpr(RewrittenLHS, URemRHS);
15621- const auto *Multiple =
15622- SE.getMulExpr(SE.getUDivExpr(RewrittenLHS, URemRHS), URemRHS);
15623- RewriteMap[URemLHS] = Multiple;
15624- ExprsToRewrite.push_back(URemLHS);
15625- return;
15626- }
15627- }
15628-
1562915636 // Do not apply information for constants or if RHS contains an AddRec.
1563015637 if (isa<SCEVConstant>(LHS) || SE.containsAddRecurrence(RHS))
1563115638 return;
@@ -15655,7 +15662,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1565515662 };
1565615663
1565715664 const SCEV *RewrittenLHS = GetMaybeRewritten(LHS);
15658- const APInt &DividesBy = SE.getConstantMultiple(RewrittenLHS);
15665+ // Apply divisibility information when computing the constant multiple.
15666+ LoopGuards DivGuards(SE);
15667+ DivGuards.RewriteMap = DivInfo;
15668+ const APInt &DividesBy =
15669+ SE.getConstantMultiple(DivGuards.rewrite(RewrittenLHS));
1565915670
1566015671 // Collect rewrites for LHS and its transitive operands based on the
1566115672 // condition.
@@ -15840,8 +15851,11 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1584015851
1584115852 // Now apply the information from the collected conditions to
1584215853 // Guards.RewriteMap. Conditions are processed in reverse order, so the
15843- // earliest conditions is processed first. This ensures the SCEVs with the
15854+ // earliest conditions is processed first, except guards with divisibility
15855+ // information, which are moved to the back. This ensures the SCEVs with the
1584415856 // shortest dependency chains are constructed first.
15857+ SmallVector<std::tuple<CmpInst::Predicate, const SCEV *, const SCEV *>>
15858+ GuardsToProcess;
1584515859 for (auto [Term, EnterIfTrue] : reverse(Terms)) {
1584615860 SmallVector<Value *, 8> Worklist;
1584715861 SmallPtrSet<Value *, 8> Visited;
@@ -15856,7 +15870,12 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1585615870 EnterIfTrue ? Cmp->getPredicate() : Cmp->getInversePredicate();
1585715871 const auto *LHS = SE.getSCEV(Cmp->getOperand(0));
1585815872 const auto *RHS = SE.getSCEV(Cmp->getOperand(1));
15859- CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap);
15873+ // If LHS is a constant, apply information to the other expression.
15874+ if (isa<SCEVConstant>(LHS)) {
15875+ std::swap(LHS, RHS);
15876+ Predicate = CmpInst::getSwappedPredicate(Predicate);
15877+ }
15878+ GuardsToProcess.emplace_back(Predicate, LHS, RHS);
1586015879 continue;
1586115880 }
1586215881
@@ -15869,6 +15888,30 @@ void ScalarEvolution::LoopGuards::collectFromBlock(
1586915888 }
1587015889 }
1587115890
15891+ // Process divisibility guards in reverse order to populate DivInfo early.
15892+ DenseMap<const SCEV *, APInt> Multiples;
15893+ DenseMap<const SCEV *, const SCEV *> DivInfo;
15894+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess) {
15895+ if (!isDivisibilityGuard(LHS, RHS, SE))
15896+ continue;
15897+ collectDivisibilityInformation(Predicate, LHS, RHS, DivInfo, Multiples, SE);
15898+ }
15899+
15900+ for (const auto &[Predicate, LHS, RHS] : GuardsToProcess)
15901+ CollectCondition(Predicate, LHS, RHS, Guards.RewriteMap, DivInfo);
15902+
15903+ // Apply divisibility information last. This ensures it is applied to the
15904+ // outermost expression after other rewrites for the given value.
15905+ for (const auto &[K, V] : Multiples) {
15906+ const SCEV *DivisorSCEV = SE.getConstant(V);
15907+ Guards.RewriteMap[K] =
15908+ SE.getMulExpr(SE.getUDivExpr(applyDivisibilityOnMinMaxExpr(
15909+ Guards.rewrite(K), DivisorSCEV, SE),
15910+ DivisorSCEV),
15911+ DivisorSCEV);
15912+ ExprsToRewrite.push_back(K);
15913+ }
15914+
1587215915 // Let the rewriter preserve NUW/NSW flags if the unsigned/signed ranges of
1587315916 // the replacement expressions are contained in the ranges of the replaced
1587415917 // expressions.
0 commit comments