diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp index e7bae17dd2ceb..682b2f949bdd1 100644 --- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp +++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp @@ -964,9 +964,8 @@ class LoopVectorizationCostModel { /// user options, for the given register kind. bool useMaxBandwidth(TargetTransformInfo::RegisterKind RegKind); - /// \return True if maximizing vector bandwidth is enabled by the target or - /// user options, for the given vector factor. - bool useMaxBandwidth(ElementCount VF); + /// \return True if register pressure should be calculated for the given VF. + bool shouldCalculateRegPressureForVF(ElementCount VF); /// \return The size (in bits) of the smallest and widest types in the code /// that needs to be vectorized. We ignore values that remain scalar such as @@ -1753,6 +1752,9 @@ class LoopVectorizationCostModel { /// Whether this loop should be optimized for size based on function attribute /// or profile information. bool OptForSize; + + /// The highest VF possible for this loop, without using MaxBandwidth. + FixedScalableVFPair MaxPermissibleVFWithoutMaxBW; }; } // end namespace llvm @@ -3943,10 +3945,16 @@ LoopVectorizationCostModel::computeMaxVF(ElementCount UserVF, unsigned UserIC) { return FixedScalableVFPair::getNone(); } -bool LoopVectorizationCostModel::useMaxBandwidth(ElementCount VF) { - return useMaxBandwidth(VF.isScalable() - ? TargetTransformInfo::RGK_ScalableVector - : TargetTransformInfo::RGK_FixedWidthVector); +bool LoopVectorizationCostModel::shouldCalculateRegPressureForVF( + ElementCount VF) { + if (!useMaxBandwidth(VF.isScalable() + ? TargetTransformInfo::RGK_ScalableVector + : TargetTransformInfo::RGK_FixedWidthVector)) + return false; + // Only calculate register pressure for VFs enabled by MaxBandwidth. + return ElementCount::isKnownGT( + VF, VF.isScalable() ? MaxPermissibleVFWithoutMaxBW.ScalableVF + : MaxPermissibleVFWithoutMaxBW.FixedVF); } bool LoopVectorizationCostModel::useMaxBandwidth( @@ -4022,6 +4030,12 @@ ElementCount LoopVectorizationCostModel::getMaximizedVFForTarget( ComputeScalableMaxVF ? TargetTransformInfo::RGK_ScalableVector : TargetTransformInfo::RGK_FixedWidthVector; ElementCount MaxVF = MaxVectorElementCount; + + if (MaxVF.isScalable()) + MaxPermissibleVFWithoutMaxBW.ScalableVF = MaxVF; + else + MaxPermissibleVFWithoutMaxBW.FixedVF = MaxVF; + if (useMaxBandwidth(RegKind)) { auto MaxVectorElementCountMaxBW = ElementCount::get( llvm::bit_floor(WidestRegister.getKnownMinValue() / SmallestType), @@ -4375,9 +4389,10 @@ VectorizationFactor LoopVectorizationPlanner::selectVectorizationFactor() { if (VF.isScalar()) continue; - /// Don't consider the VF if it exceeds the number of registers for the - /// target. - if (CM.useMaxBandwidth(VF) && RUs[I].exceedsMaxNumRegs(TTI)) + /// If the VF was proposed due to MaxBandwidth, don't consider the VF if + /// it exceeds the number of registers for the target. + if (CM.shouldCalculateRegPressureForVF(VF) && + RUs[I].exceedsMaxNumRegs(TTI, ForceTargetNumVectorRegs)) continue; InstructionCost C = CM.expectedCost(VF); @@ -7155,7 +7170,8 @@ VectorizationFactor LoopVectorizationPlanner::computeBestVF() { InstructionCost Cost = cost(*P, VF); VectorizationFactor CurrentFactor(VF, Cost, ScalarCost); - if (CM.useMaxBandwidth(VF) && RUs[I].exceedsMaxNumRegs(TTI)) { + if (CM.shouldCalculateRegPressureForVF(VF) && + RUs[I].exceedsMaxNumRegs(TTI, ForceTargetNumVectorRegs)) { LLVM_DEBUG(dbgs() << "LV(REG): Not considering vector loop of width " << VF << " because it uses too many registers\n"); continue; diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp index 92db9674ef42b..3ab18b2fe6438 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.cpp @@ -405,9 +405,12 @@ static unsigned getVFScaleFactor(VPRecipeBase *R) { return 1; } -bool VPRegisterUsage::exceedsMaxNumRegs(const TargetTransformInfo &TTI) const { - return any_of(MaxLocalUsers, [&TTI](auto &LU) { - return LU.second > TTI.getNumberOfRegisters(LU.first); +bool VPRegisterUsage::exceedsMaxNumRegs(const TargetTransformInfo &TTI, + unsigned OverrideMaxNumRegs) const { + return any_of(MaxLocalUsers, [&TTI, &OverrideMaxNumRegs](auto &LU) { + return LU.second > (OverrideMaxNumRegs > 0 + ? OverrideMaxNumRegs + : TTI.getNumberOfRegisters(LU.first)); }); } diff --git a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h index 7bcf9dba8c311..cd86d27cf9122 100644 --- a/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h +++ b/llvm/lib/Transforms/Vectorize/VPlanAnalysis.h @@ -85,8 +85,10 @@ struct VPRegisterUsage { SmallMapVector MaxLocalUsers; /// Check if any of the tracked live intervals exceeds the number of - /// available registers for the target. - bool exceedsMaxNumRegs(const TargetTransformInfo &TTI) const; + /// available registers for the target. If non-zero, OverrideMaxNumRegs + /// is used in place of the target's number of registers. + bool exceedsMaxNumRegs(const TargetTransformInfo &TTI, + unsigned OverrideMaxNumRegs = 0) const; }; /// Estimate the register usage for \p Plan and vectorization factors in \p VFs diff --git a/llvm/test/Transforms/LoopVectorize/AArch64/maxbandwidth-regpressure.ll b/llvm/test/Transforms/LoopVectorize/AArch64/maxbandwidth-regpressure.ll new file mode 100644 index 0000000000000..ce639f9150078 --- /dev/null +++ b/llvm/test/Transforms/LoopVectorize/AArch64/maxbandwidth-regpressure.ll @@ -0,0 +1,37 @@ +; RUN: opt -passes=loop-vectorize -vectorizer-maximize-bandwidth -debug-only=loop-vectorize -disable-output -force-vector-interleave=1 -enable-epilogue-vectorization=false -S < %s 2>&1 | FileCheck %s --check-prefixes=CHECK-REGS-VP +; RUN: opt -passes=loop-vectorize -vectorizer-maximize-bandwidth -debug-only=loop-vectorize -disable-output -force-target-num-vector-regs=1 -force-vector-interleave=1 -enable-epilogue-vectorization=false -S < %s 2>&1 | FileCheck %s --check-prefixes=CHECK-NOREGS-VP + +target datalayout = "e-m:e-i8:8:32-i16:16:32-i64:64-i128:128-n32:64-S128" +target triple = "aarch64-none-unknown-elf" + +define i32 @dotp(ptr %a, ptr %b) #0 { +; CHECK-REGS-VP-NOT: LV(REG): Not considering vector loop of width vscale x 16 because it uses too many registers +; CHECK-REGS-VP: LV: Selecting VF: vscale x 8. +; +; CHECK-NOREGS-VP: LV(REG): Not considering vector loop of width vscale x 8 because it uses too many registers +; CHECK-NOREGS-VP: LV(REG): Not considering vector loop of width vscale x 16 because it uses too many registers +; CHECK-NOREGS-VP: LV: Selecting VF: vscale x 4. +entry: + br label %for.body + +for.body: ; preds = %for.body, %entry + %iv = phi i64 [ 0, %entry ], [ %iv.next, %for.body ] + %accum = phi i32 [ 0, %entry ], [ %add, %for.body ] + %gep.a = getelementptr i8, ptr %a, i64 %iv + %load.a = load i8, ptr %gep.a, align 1 + %ext.a = zext i8 %load.a to i32 + %gep.b = getelementptr i8, ptr %b, i64 %iv + %load.b = load i8, ptr %gep.b, align 1 + %ext.b = zext i8 %load.b to i32 + %mul = mul i32 %ext.b, %ext.a + %sub = sub i32 0, %mul + %add = add i32 %accum, %sub + %iv.next = add i64 %iv, 1 + %exitcond.not = icmp eq i64 %iv.next, 1024 + br i1 %exitcond.not, label %for.exit, label %for.body + +for.exit: ; preds = %for.body + ret i32 %add +} + +attributes #0 = { vscale_range(1,16) "target-features"="+sve" }