Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions llvm/include/llvm/Analysis/TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -1647,12 +1647,12 @@ class TargetTransformInfo {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;

/// Calculate the cost of an extended reduction pattern, similar to
/// getArithmeticReductionCost of an Add reduction with multiply and optional
/// extensions. This is the cost of as:
/// ResTy vecreduce.add(mul (A, B)).
/// ResTy vecreduce.add(mul(ext(Ty A), ext(Ty B)).
/// getArithmeticReductionCost of an Add/Sub reduction with multiply and
/// optional extensions. This is the cost of as:
/// ResTy vecreduce.add/sub(mul (A, B)).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

Suggested change
/// ResTy vecreduce.add/sub(mul (A, B)).
/// ResTy vecreduce.add/sub(mul(A, B)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// ResTy vecreduce.add/sub(mul (A, B)).
/// ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).
/// * ResTy vecreduce.add/sub(mul (A, B)) or,
/// * ResTy vecreduce.add/sub(mul(ext(Ty A), ext(Ty B)).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

LLVM_ABI InstructionCost getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const;

/// Calculate the cost of an extended reduction pattern, similar to
Expand Down
4 changes: 2 additions & 2 deletions llvm/include/llvm/Analysis/TargetTransformInfoImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -971,8 +971,8 @@ class TargetTransformInfoImplBase {
}

virtual InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind) const {
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
VectorType *Ty, TTI::TargetCostKind CostKind) const {
return 1;
}

Expand Down
7 changes: 5 additions & 2 deletions llvm/include/llvm/CodeGen/BasicTTIImpl.h
Original file line number Diff line number Diff line change
Expand Up @@ -3260,14 +3260,17 @@ class BasicTTIImplBase : public TargetTransformInfoImplCRTPBase<T> {
}

InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *Ty,
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
VectorType *Ty,
TTI::TargetCostKind CostKind) const override {
// Without any native support, this is equivalent to the cost of
// vecreduce.add(mul(ext(Ty A), ext(Ty B))) or
// vecreduce.add(mul(A, B)).
assert((RedOpcode == Instruction::Add || RedOpcode == Instruction::Sub) &&
"The reduction opcode is expected to be Add or Sub.");
VectorType *ExtTy = VectorType::get(ResTy, Ty);
InstructionCost RedCost = thisT()->getArithmeticReductionCost(
Instruction::Add, ExtTy, std::nullopt, CostKind);
RedOpcode, ExtTy, std::nullopt, CostKind);
InstructionCost ExtCost = thisT()->getCastInstrCost(
IsUnsigned ? Instruction::ZExt : Instruction::SExt, ExtTy, Ty,
TTI::CastContextHint::None, CostKind);
Expand Down
5 changes: 3 additions & 2 deletions llvm/lib/Analysis/TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1283,9 +1283,10 @@ InstructionCost TargetTransformInfo::getExtendedReductionCost(
}

InstructionCost TargetTransformInfo::getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind) const {
return TTIImpl->getMulAccReductionCost(IsUnsigned, ResTy, Ty, CostKind);
return TTIImpl->getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, Ty,
CostKind);
}

InstructionCost
Expand Down
10 changes: 6 additions & 4 deletions llvm/lib/Target/AArch64/AArch64TargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5486,13 +5486,14 @@ InstructionCost AArch64TTIImpl::getExtendedReductionCost(
}

InstructionCost
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
VectorType *VecTy,
AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
Type *ResTy, VectorType *VecTy,
TTI::TargetCostKind CostKind) const {
EVT VecVT = TLI->getValueType(DL, VecTy);
EVT ResVT = TLI->getValueType(DL, ResTy);

if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple()) {
if (ST->hasDotProd() && VecVT.isSimple() && ResVT.isSimple() &&
RedOpcode == Instruction::Add) {
std::pair<InstructionCost, MVT> LT = getTypeLegalizationCost(VecTy);

// The legal cases with dotprod are
Expand All @@ -5503,7 +5504,8 @@ AArch64TTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return LT.first + 2;
}

return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, VecTy, CostKind);
return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, VecTy,
CostKind);
}

InstructionCost
Expand Down
2 changes: 1 addition & 1 deletion llvm/lib/Target/AArch64/AArch64TargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ class AArch64TTIImpl final : public BasicTTIImplBase<AArch64TTIImpl> {
TTI::TargetCostKind CostKind) const override;

InstructionCost getMulAccReductionCost(
bool IsUnsigned, Type *ResTy, VectorType *Ty,
bool IsUnsigned, unsigned RedOpcode, Type *ResTy, VectorType *Ty,
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput) const override;

InstructionCost
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Target/ARM/ARMTargetTransformInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1916,9 +1916,11 @@ InstructionCost ARMTTIImpl::getExtendedReductionCost(
}

InstructionCost
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
VectorType *ValTy,
ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode,
Type *ResTy, VectorType *ValTy,
TTI::TargetCostKind CostKind) const {
if (RedOpcode != Instruction::Add)
return InstructionCost::getInvalid(CostKind);
EVT ValVT = TLI->getValueType(DL, ValTy);
EVT ResVT = TLI->getValueType(DL, ResTy);

Expand All @@ -1939,7 +1941,8 @@ ARMTTIImpl::getMulAccReductionCost(bool IsUnsigned, Type *ResTy,
return ST->getMVEVectorCostFactor(CostKind) * LT.first;
}

return BaseT::getMulAccReductionCost(IsUnsigned, ResTy, ValTy, CostKind);
return BaseT::getMulAccReductionCost(IsUnsigned, RedOpcode, ResTy, ValTy,
CostKind);
}

InstructionCost
Expand Down
3 changes: 2 additions & 1 deletion llvm/lib/Target/ARM/ARMTargetTransformInfo.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,8 @@ class ARMTTIImpl final : public BasicTTIImplBase<ARMTTIImpl> {
VectorType *ValTy, std::optional<FastMathFlags> FMF,
TTI::TargetCostKind CostKind) const override;
InstructionCost
getMulAccReductionCost(bool IsUnsigned, Type *ResTy, VectorType *ValTy,
getMulAccReductionCost(bool IsUnsigned, unsigned RedOpcode, Type *ResTy,
VectorType *ValTy,
TTI::TargetCostKind CostKind) const override;

InstructionCost
Expand Down
9 changes: 6 additions & 3 deletions llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5414,7 +5414,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI::CastContextHint::None, CostKind, RedOp);

InstructionCost RedCost = TTI.getMulAccReductionCost(
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
CostKind);

if (RedCost.isValid() &&
RedCost < ExtCost * 2 + MulCost + Ext2Cost + BaseCost)
Expand Down Expand Up @@ -5459,7 +5460,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);

InstructionCost RedCost = TTI.getMulAccReductionCost(
IsUnsigned, RdxDesc.getRecurrenceType(), ExtType, CostKind);
IsUnsigned, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), ExtType,
CostKind);
InstructionCost ExtraExtCost = 0;
if (Op0Ty != LargestOpTy || Op1Ty != LargestOpTy) {
Instruction *ExtraExtOp = (Op0Ty != LargestOpTy) ? Op0 : Op1;
Expand All @@ -5478,7 +5480,8 @@ LoopVectorizationCostModel::getReductionPatternCost(Instruction *I,
TTI.getArithmeticInstrCost(Instruction::Mul, VectorTy, CostKind);

InstructionCost RedCost = TTI.getMulAccReductionCost(
true, RdxDesc.getRecurrenceType(), VectorTy, CostKind);
true, RdxDesc.getOpcode(), RdxDesc.getRecurrenceType(), VectorTy,
CostKind);

if (RedCost.isValid() && RedCost < MulCost + BaseCost)
return I == RetI ? RedCost : 0;
Expand Down
9 changes: 5 additions & 4 deletions llvm/lib/Transforms/Vectorize/VPlanRecipes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2803,24 +2803,25 @@ InstructionCost VPExpressionRecipe::computeCost(ElementCount VF,
toVectorTy(Ctx.Types.inferScalarType(getOperand(0)), VF));
assert(RedTy->isIntegerTy() &&
"VPExpressionRecipe only supports integer types currently.");
unsigned Opcode = RecurrenceDescriptor::getOpcode(
cast<VPReductionRecipe>(ExpressionRecipes.back())->getRecurrenceKind());
switch (ExpressionType) {
case ExpressionTypes::ExtendedReduction: {
unsigned Opcode = RecurrenceDescriptor::getOpcode(
cast<VPReductionRecipe>(ExpressionRecipes[1])->getRecurrenceKind());
return Ctx.TTI.getExtendedReductionCost(
Opcode,
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
RedTy, SrcVecTy, std::nullopt, Ctx.CostKind);
}
case ExpressionTypes::MulAccReduction:
return Ctx.TTI.getMulAccReductionCost(false, RedTy, SrcVecTy, Ctx.CostKind);
return Ctx.TTI.getMulAccReductionCost(false, Opcode, RedTy, SrcVecTy,
Ctx.CostKind);

case ExpressionTypes::ExtMulAccReduction:
return Ctx.TTI.getMulAccReductionCost(
cast<VPWidenCastRecipe>(ExpressionRecipes.front())->getOpcode() ==
Instruction::ZExt,
RedTy, SrcVecTy, Ctx.CostKind);
Opcode, RedTy, SrcVecTy, Ctx.CostKind);
}
llvm_unreachable("Unknown VPExpressionRecipe::ExpressionTypes enum");
}
Expand Down
24 changes: 13 additions & 11 deletions llvm/lib/Transforms/Vectorize/VPlanTransforms.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3151,23 +3151,24 @@ static VPExpressionRecipe *
tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
VPCostContext &Ctx, VFRange &Range) {
unsigned Opcode = RecurrenceDescriptor::getOpcode(Red->getRecurrenceKind());
if (Opcode != Instruction::Add)
if (Opcode != Instruction::Add && Opcode != Instruction::Sub)
return nullptr;

Type *RedTy = Ctx.Types.inferScalarType(Red);

// Clamp the range if using multiply-accumulate-reduction is profitable.
auto IsMulAccValidAndClampRange =
[&](bool isZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt) -> bool {
[&](bool IsZExt, VPWidenRecipe *Mul, VPWidenCastRecipe *Ext0,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this rename is NFC, maybe remove it from this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
unsigned Opcode) -> bool {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt,
unsigned Opcode) -> bool {
VPWidenCastRecipe *Ext1, VPWidenCastRecipe *OuterExt
) -> bool {

Can we just use the captured Opcode?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

return LoopVectorizationPlanner::getDecisionAndClampRange(
[&](ElementCount VF) {
TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
Type *SrcTy =
Ext0 ? Ctx.Types.inferScalarType(Ext0->getOperand(0)) : RedTy;
auto *SrcVecTy = cast<VectorType>(toVectorTy(SrcTy, VF));
InstructionCost MulAccCost =
Ctx.TTI.getMulAccReductionCost(isZExt, RedTy, SrcVecTy, CostKind);
InstructionCost MulAccCost = Ctx.TTI.getMulAccReductionCost(
IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
InstructionCost MulCost = Mul->computeCost(VF, Ctx);
InstructionCost RedCost = Red->computeCost(VF, Ctx);
InstructionCost ExtCost = 0;
Expand All @@ -3192,7 +3193,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
dyn_cast_if_present<VPWidenCastRecipe>(A->getDefiningRecipe());
auto *RecipeB =
dyn_cast_if_present<VPWidenCastRecipe>(B->getDefiningRecipe());
auto *Mul = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
auto *MulR = cast<VPWidenRecipe>(VecOp->getDefiningRecipe());
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This rename is NFC, maybe remove it from this PR?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


// Match reduce.add(mul(ext, ext)).
if (RecipeA && RecipeB &&
Expand All @@ -3201,12 +3202,13 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
match(RecipeB, m_ZExtOrSExt(m_VPValue())) &&
IsMulAccValidAndClampRange(RecipeA->getOpcode() ==
Instruction::CastOps::ZExt,
Mul, RecipeA, RecipeB, nullptr)) {
return new VPExpressionRecipe(RecipeA, RecipeB, Mul, Red);
MulR, RecipeA, RecipeB, nullptr, Opcode)) {
return new VPExpressionRecipe(RecipeA, RecipeB, MulR, Red);
}
// Match reduce.add(mul).
if (IsMulAccValidAndClampRange(true, Mul, nullptr, nullptr, nullptr))
return new VPExpressionRecipe(Mul, Red);
if (IsMulAccValidAndClampRange(true, MulR, nullptr, nullptr, nullptr,
Opcode))
return new VPExpressionRecipe(MulR, Red);
}
// Match reduce.add(ext(mul(ext(A), ext(B)))).
// All extend recipes must have same opcode or A == B
Expand All @@ -3223,7 +3225,7 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
Ext0->getOpcode() == Ext1->getOpcode() &&
IsMulAccValidAndClampRange(Ext0->getOpcode() ==
Instruction::CastOps::ZExt,
Mul, Ext0, Ext1, Ext)) {
Mul, Ext0, Ext1, Ext, Opcode)) {
auto *NewExt0 = new VPWidenCastRecipe(
Ext0->getOpcode(), Ext0->getOperand(0), Ext->getResultType(), *Ext0,
Ext0->getDebugLoc());
Expand Down
4 changes: 2 additions & 2 deletions llvm/lib/Transforms/Vectorize/VectorCombine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1468,8 +1468,8 @@ static void analyzeCostOfVecReduction(const IntrinsicInst &II,
TTI::CastContextHint::None, CostKind, RedOp);

CostBeforeReduction = ExtCost * 2 + MulCost + Ext2Cost;
CostAfterReduction =
TTI.getMulAccReductionCost(IsUnsigned, II.getType(), ExtType, CostKind);
CostAfterReduction = TTI.getMulAccReductionCost(
IsUnsigned, ReductionOpc, II.getType(), ExtType, CostKind);
Comment on lines +1471 to +1472
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it would be nice to have a test for this, but not sure if that's possible.

Copy link
Collaborator Author

@SamTebbs33 SamTebbs33 Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've been trying to make a test but I don't think this code is ever reached. The RedOp && match(RedOp, m_ZExtOrSExt(m_Value())) check above fully (AFAIK) encompasses this check so that code path is always followed instead. If I move this if statement block above that one above then the compiler fails the assertion at Type.cpp:805. This happens on main as well.

return;
}
CostAfterReduction = TTI.getArithmeticReductionCost(ReductionOpc, VecRedTy,
Expand Down
Loading
Loading