@@ -3476,29 +3476,21 @@ tryToMatchAndCreateExtendedReduction(VPReductionRecipe *Red, VPCostContext &Ctx,
34763476 InstructionCost ExtCost =
34773477 cast<VPWidenCastRecipe>(VecOp)->computeCost (VF, Ctx);
34783478 InstructionCost RedCost = Red->computeCost (VF, Ctx);
3479- InstructionCost BaseCost = ExtCost + RedCost;
34803479
34813480 if (isa<VPPartialReductionRecipe>(Red)) {
34823481 TargetTransformInfo::PartialReductionExtendKind ExtKind =
34833482 TargetTransformInfo::getPartialReductionExtendKind (ExtOpc);
3484- // The VF ranges have already been clamped for a partial reduction
3485- // and its existence confirms that it's valid, so we don't need to
3486- // perform any cost checks or more clamping. Just assert that the
3487- // partial reduction is still profitable.
34883483 // FIXME: Move partial reduction creation, costing and clamping
3489- // here.
3490- InstructionCost Cost = Ctx.TTI .getPartialReductionCost (
3484+ // here from LoopVectorize.cpp .
3485+ ExtRedCost = Ctx.TTI .getPartialReductionCost (
34913486 Opcode, SrcTy, nullptr , RedTy, VF, ExtKind,
34923487 llvm::TargetTransformInfo::PR_None, std::nullopt , Ctx.CostKind );
3493- assert (Cost <= BaseCost &&
3494- " Cost of the partial reduction is more than the base cost" );
3495- return true ;
34963488 } else {
34973489 ExtRedCost = Ctx.TTI .getExtendedReductionCost (
34983490 Opcode, ExtOpc == Instruction::CastOps::ZExt, RedTy, SrcVecTy,
34993491 Red->getFastMathFlags (), CostKind);
35003492 }
3501- return ExtRedCost.isValid () && ExtRedCost < BaseCost ;
3493+ return ExtRedCost.isValid () && ExtRedCost < ExtCost + RedCost ;
35023494 },
35033495 Range);
35043496 };
@@ -3540,41 +3532,17 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35403532 VPWidenCastRecipe *OuterExt) -> bool {
35413533 return LoopVectorizationPlanner::getDecisionAndClampRange (
35423534 [&](ElementCount VF) {
3543- // Only partial reductions support mixed extends at the moment.
3544- if (!IsPartialReduction && Ext0 && Ext1 &&
3545- Ext0->getOpcode () != Ext1->getOpcode ())
3546- return false ;
3547-
3548- bool IsZExt =
3549- !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
35503535 TTI::TargetCostKind CostKind = TTI::TCK_RecipThroughput;
35513536 Type *SrcTy =
35523537 Ext0 ? Ctx.Types .inferScalarType (Ext0->getOperand (0 )) : RedTy;
3553- auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3554- InstructionCost MulAccCost = Ctx.TTI .getMulAccReductionCost (
3555- IsZExt, Opcode, RedTy, SrcVecTy, CostKind);
3556- InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3557- InstructionCost RedCost = Red->computeCost (VF, Ctx);
3558- InstructionCost ExtCost = 0 ;
3559- if (Ext0)
3560- ExtCost += Ext0->computeCost (VF, Ctx);
3561- if (Ext1)
3562- ExtCost += Ext1->computeCost (VF, Ctx);
3563- if (OuterExt)
3564- ExtCost += OuterExt->computeCost (VF, Ctx);
3565-
3566- InstructionCost BaseCost = ExtCost + MulCost + RedCost;
3538+ InstructionCost MulAccCost;
35673539
35683540 if (IsPartialReduction) {
35693541 Type *SrcTy2 =
35703542 Ext1 ? Ctx.Types .inferScalarType (Ext1->getOperand (0 )) : nullptr ;
3571- // The VF ranges have already been clamped for a partial reduction
3572- // and its existence confirms that it's valid, so we don't need to
3573- // perform any cost checks or more clamping. Just assert that the
3574- // partial reduction is still profitable.
35753543 // FIXME: Move partial reduction creation, costing and clamping
3576- // here.
3577- InstructionCost Cost = Ctx.TTI .getPartialReductionCost (
3544+ // here from LoopVectorize.cpp .
3545+ MulAccCost = Ctx.TTI .getPartialReductionCost (
35783546 Opcode, SrcTy, SrcTy2, RedTy, VF,
35793547 Ext0 ? TargetTransformInfo::getPartialReductionExtendKind (
35803548 Ext0->getOpcode ())
@@ -3583,12 +3551,30 @@ tryToMatchAndCreateMulAccumulateReduction(VPReductionRecipe *Red,
35833551 Ext1->getOpcode ())
35843552 : TargetTransformInfo::PR_None,
35853553 Mul->getOpcode (), CostKind);
3586- assert (Cost <= BaseCost &&
3587- " Cost of the partial reduction is more than the base cost" );
3588- return true ;
3554+ } else {
3555+ // Only partial reductions support mixed extends at the moment.
3556+ if (Ext0 && Ext1 && Ext0->getOpcode () != Ext1->getOpcode ())
3557+ return false ;
3558+
3559+ bool IsZExt =
3560+ !Ext0 || Ext0->getOpcode () == Instruction::CastOps::ZExt;
3561+ auto *SrcVecTy = cast<VectorType>(toVectorTy (SrcTy, VF));
3562+ MulAccCost = Ctx.TTI .getMulAccReductionCost (IsZExt, Opcode, RedTy,
3563+ SrcVecTy, CostKind);
35893564 }
35903565
3591- return MulAccCost.isValid () && MulAccCost < BaseCost;
3566+ InstructionCost MulCost = Mul->computeCost (VF, Ctx);
3567+ InstructionCost RedCost = Red->computeCost (VF, Ctx);
3568+ InstructionCost ExtCost = 0 ;
3569+ if (Ext0)
3570+ ExtCost += Ext0->computeCost (VF, Ctx);
3571+ if (Ext1)
3572+ ExtCost += Ext1->computeCost (VF, Ctx);
3573+ if (OuterExt)
3574+ ExtCost += OuterExt->computeCost (VF, Ctx);
3575+
3576+ return MulAccCost.isValid () &&
3577+ MulAccCost < ExtCost + MulCost + RedCost;
35923578 },
35933579 Range);
35943580 };
0 commit comments