@@ -3002,13 +3002,6 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30023002 // instruction cost.
30033003 return 0 ;
30043004 case Instruction::Call: {
3005- if (!isSingleScalar ()) {
3006- // TODO: Handle remaining call costs here as well.
3007- if (VF.isScalable ())
3008- return InstructionCost::getInvalid ();
3009- break ;
3010- }
3011-
30123005 auto *CalledFn =
30133006 cast<Function>(getOperand (getNumOperands () - 1 )->getLiveInIRValue ());
30143007 if (CalledFn->isIntrinsic ())
@@ -3017,8 +3010,43 @@ InstructionCost VPReplicateRecipe::computeCost(ElementCount VF,
30173010 SmallVector<Type *, 4 > Tys;
30183011 for (VPValue *ArgOp : drop_end (operands ()))
30193012 Tys.push_back (Ctx.Types .inferScalarType (ArgOp));
3013+
30203014 Type *ResultTy = Ctx.Types .inferScalarType (this );
3021- return Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3015+ InstructionCost ScalarCallCost =
3016+ Ctx.TTI .getCallInstrCost (CalledFn, ResultTy, Tys, Ctx.CostKind );
3017+ if (isSingleScalar ())
3018+ return ScalarCallCost;
3019+
3020+ if (VF.isScalable ())
3021+ return InstructionCost::getInvalid ();
3022+
3023+ // Compute the cost of scalarizing the result and operands if needed.
3024+ InstructionCost ScalarizationCost = 0 ;
3025+ if (VF.isVector ()) {
3026+ if (!ResultTy->isVoidTy ()) {
3027+ for (Type *VectorTy : getContainedTypes (toVectorizedTy (ResultTy, VF))) {
3028+ ScalarizationCost += Ctx.TTI .getScalarizationOverhead (
3029+ cast<VectorType>(VectorTy), APInt::getAllOnes (VF.getFixedValue ()),
3030+ /* Insert=*/ true ,
3031+ /* Extract=*/ false , Ctx.CostKind );
3032+ }
3033+ }
3034+ // Skip operands that do not require extraction/scalarization and do not
3035+ // incur any overhead.
3036+ SmallVector<Type *> Tys;
3037+ SmallPtrSet<const VPValue *, 4 > UniqueOperands;
3038+ for (auto *Op : drop_end (operands ())) {
3039+ if (Op->isLiveIn () || isa<VPReplicateRecipe, VPPredInstPHIRecipe>(Op) ||
3040+ !UniqueOperands.insert (Op).second )
3041+ continue ;
3042+ Tys.push_back (toVectorizedTy (Ctx.Types .inferScalarType (Op), VF));
3043+ }
3044+ ScalarizationCost +=
3045+ Ctx.TTI .getOperandsScalarizationOverhead (Tys, Ctx.CostKind );
3046+ }
3047+
3048+ return ScalarCallCost * (isSingleScalar () ? 1 : VF.getFixedValue ()) +
3049+ ScalarizationCost;
30223050 }
30233051 case Instruction::Add:
30243052 case Instruction::Sub:
0 commit comments