Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@

import static com.google.common.base.Verify.verify;
import static com.google.common.collect.ImmutableList.toImmutableList;
import static io.trino.cost.CostCalculatorWithEstimatedExchanges.adjustReplicatedJoinLocalExchangeCost;
import static io.trino.cost.CostCalculatorWithEstimatedExchanges.calculateJoinInputCost;
import static io.trino.cost.CostCalculatorWithEstimatedExchanges.calculateLocalRepartitionCost;
import static io.trino.cost.CostCalculatorWithEstimatedExchanges.calculateRemoteGatherCost;
Expand Down Expand Up @@ -192,15 +193,24 @@ public PlanCostEstimate visitJoin(JoinNode node, Void context)

private LocalCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanNode build, boolean replicated)
{
int estimatedSourceDistributedTaskCount = taskCountEstimator.estimateSourceDistributedTaskCount(session);
LocalCostEstimate joinInputCost = calculateJoinInputCost(
probe,
build,
stats,
types,
replicated,
taskCountEstimator.estimateSourceDistributedTaskCount(session));
estimatedSourceDistributedTaskCount);
// TODO: Use traits (https://github.com/trinodb/trino/issues/4763) instead, to correctly estimate
// local exchange cost for replicated join in CostCalculatorUsingExchanges#visitExchange
LocalCostEstimate adjustedLocalExchangeCost = adjustReplicatedJoinLocalExchangeCost(
build,
stats,
types,
replicated,
estimatedSourceDistributedTaskCount);
LocalCostEstimate joinOutputCost = calculateJoinOutputCost(join);
return addPartialComponents(joinInputCost, joinOutputCost);
return addPartialComponents(joinInputCost, adjustedLocalExchangeCost, joinOutputCost);
}

private LocalCostEstimate calculateJoinOutputCost(PlanNode join)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,14 +214,53 @@ public static LocalCostEstimate calculateJoinCostWithoutOutput(
types,
replicated,
estimatedSourceDistributedTaskCount);
// TODO: Remove once traits (https://github.com/trinodb/trino/issues/4763) are used to correctly estimate
// local exchange cost for replicated join in CostCalculatorUsingExchanges#visitExchange
LocalCostEstimate adjustedLocalExchangeCost = adjustReplicatedJoinLocalExchangeCost(
build,
stats,
types,
replicated,
estimatedSourceDistributedTaskCount);
LocalCostEstimate inputCost = calculateJoinInputCost(
probe,
build,
stats,
types,
replicated,
estimatedSourceDistributedTaskCount);
return addPartialComponents(exchangesCost, inputCost);
return addPartialComponents(exchangesCost, adjustedLocalExchangeCost, inputCost);
}

public static LocalCostEstimate adjustReplicatedJoinLocalExchangeCost(
PlanNode build,
StatsProvider stats,
TypeProvider types,
boolean replicated,
int estimatedSourceDistributedTaskCount)
{
if (!replicated) {
return LocalCostEstimate.zero();
}

/*
* HACK!
*
* Stats model doesn't multiply the number of rows by the number of tasks for replicated
* exchange to avoid misestimation of the JOIN output.
*
* Thus the cost estimation for the operations that come after a replicated exchange is
* underestimated. And the cost of operations over the replicated copies must be explicitly
* added here.
*/

// Add the cost of a local repartitioning of build side copies.
// Cost of the repartitioning of a single data copy has been already added in
// CostCalculatorWithEstimatedExchanges#calculateJoinExchangeCost or in CostCalculatorUsingExchanges#visitExchange
PlanNodeStatsEstimate buildStats = stats.getStats(build);
double buildSideSize = buildStats.getOutputSizeInBytes(build.getOutputSymbols(), types);
double cpuCost = buildSideSize * (estimatedSourceDistributedTaskCount - 1);
return LocalCostEstimate.of(cpuCost, 0, 0);
}

private static LocalCostEstimate calculateJoinExchangeCost(
Expand All @@ -237,7 +276,7 @@ private static LocalCostEstimate calculateJoinExchangeCost(
if (replicated) {
// assuming the probe side of a replicated join is always source distributed
LocalCostEstimate replicateCost = calculateRemoteReplicateCost(buildSizeInBytes, estimatedSourceDistributedTaskCount);
// cost of the copies repartitioning is added in CostCalculatorUsingExchanges#calculateJoinCost
// cost of the copies repartitioning is added in CostCalculatorWithEstimatedExchanges#adjustReplicatedJoinLocalExchangeCost
LocalCostEstimate localRepartitionCost = calculateLocalRepartitionCost(buildSizeInBytes);
return addPartialComponents(replicateCost, localRepartitionCost);
}
Expand Down Expand Up @@ -266,23 +305,6 @@ public static LocalCostEstimate calculateJoinInputCost(
double probeSideSize = probeStats.getOutputSizeInBytes(probe.getOutputSymbols(), types);

double cpuCost = probeSideSize + buildSideSize * buildSizeMultiplier;

/*
* HACK!
*
* Stats model doesn't multiply the number of rows by the number of tasks for replicated
* exchange to avoid misestimation of the JOIN output.
*
* Thus the cost estimation for the operations that come after a replicated exchange is
* underestimated. And the cost of operations over the replicated copies must be explicitly
* added here.
*/
if (replicated) {
// add the cost of a local repartitioning of build side copies
// cost of the repartitioning of a single data copy has been already added in calculateExchangeCost
cpuCost += buildSideSize * (buildSizeMultiplier - 1);
}

double memoryCost = buildSideSize * buildSizeMultiplier;

return LocalCostEstimate.of(cpuCost, memoryCost, 0);
Expand Down