Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Memo;
import io.trino.sql.planner.plan.PlanNode;
Expand All @@ -38,22 +37,20 @@ public class CachingCostProvider
private final StatsProvider statsProvider;
private final Optional<Memo> memo;
private final Session session;
private final TypeProvider types;

private final Map<PlanNode, PlanCostEstimate> cache = new IdentityHashMap<>();

public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Session session, TypeProvider types)
public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Session session)
{
this(costCalculator, statsProvider, Optional.empty(), session, types);
this(costCalculator, statsProvider, Optional.empty(), session);
}

public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Optional<Memo> memo, Session session, TypeProvider types)
public CachingCostProvider(CostCalculator costCalculator, StatsProvider statsProvider, Optional<Memo> memo, Session session)
{
this.costCalculator = requireNonNull(costCalculator, "costCalculator is null");
this.statsProvider = requireNonNull(statsProvider, "statsProvider is null");
this.memo = requireNonNull(memo, "memo is null");
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
}

@Override
Expand Down Expand Up @@ -106,6 +103,6 @@ private PlanCostEstimate getGroupCost(GroupReference groupReference)

private PlanCostEstimate calculateCost(PlanNode node)
{
return costCalculator.calculateCost(node, statsProvider, this, session, types);
return costCalculator.calculateCost(node, statsProvider, this, session);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import io.airlift.log.Logger;
import io.trino.Session;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Memo;
Expand All @@ -40,31 +39,28 @@ public final class CachingStatsProvider
private final Optional<Memo> memo;
private final Lookup lookup;
private final Session session;
private final TypeProvider types;
private final TableStatsProvider tableStatsProvider;
private final RuntimeInfoProvider runtimeInfoProvider;

private final Map<PlanNode, PlanNodeStatsEstimate> cache = new IdentityHashMap<>();

public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TableStatsProvider tableStatsProvider)
{
this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider, RuntimeInfoProvider.noImplementation());
this(statsCalculator, Optional.empty(), noLookup(), session, tableStatsProvider, RuntimeInfoProvider.noImplementation());
}

public CachingStatsProvider(
StatsCalculator statsCalculator,
Optional<Memo> memo,
Lookup lookup,
Session session,
TypeProvider types,
TableStatsProvider tableStatsProvider,
RuntimeInfoProvider runtimeInfoProvider)
{
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.memo = requireNonNull(memo, "memo is null");
this.lookup = requireNonNull(lookup, "lookup is null");
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null");
this.runtimeInfoProvider = requireNonNull(runtimeInfoProvider, "runtimeInfoProvider is null");
}
Expand All @@ -88,7 +84,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node)
return stats;
}

stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, types, tableStatsProvider, runtimeInfoProvider));
stats = statsCalculator.calculateStats(node, new StatsCalculator.Context(this, lookup, session, tableStatsProvider, runtimeInfoProvider));
verify(cache.put(node, stats) == null, "Stats already set");
return stats;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import com.google.errorprone.annotations.ThreadSafe;
import com.google.inject.BindingAnnotation;
import io.trino.Session;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.plan.PlanNode;

import java.lang.annotation.Retention;
Expand All @@ -39,8 +38,7 @@ PlanCostEstimate calculateCost(
PlanNode node,
StatsProvider stats,
CostProvider sourcesCosts,
Session session,
TypeProvider types);
Session session);

@BindingAnnotation
@Target(PARAMETER)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import com.google.inject.Inject;
import io.trino.Session;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
Expand Down Expand Up @@ -74,9 +73,9 @@ public CostCalculatorUsingExchanges(TaskCountEstimator taskCountEstimator)
}

@Override
public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session, TypeProvider types)
public PlanCostEstimate calculateCost(PlanNode node, StatsProvider stats, CostProvider sourcesCosts, Session session)
{
CostEstimator costEstimator = new CostEstimator(stats, sourcesCosts, types, taskCountEstimator, session);
CostEstimator costEstimator = new CostEstimator(stats, sourcesCosts, taskCountEstimator, session);
return node.accept(costEstimator, null);
}

Expand All @@ -85,15 +84,13 @@ private static class CostEstimator
{
private final StatsProvider stats;
private final CostProvider sourcesCosts;
private final TypeProvider types;
private final TaskCountEstimator taskCountEstimator;
private final Session session;

CostEstimator(StatsProvider stats, CostProvider sourcesCosts, TypeProvider types, TaskCountEstimator taskCountEstimator, Session session)
CostEstimator(StatsProvider stats, CostProvider sourcesCosts, TaskCountEstimator taskCountEstimator, Session session)
{
this.stats = requireNonNull(stats, "stats is null");
this.sourcesCosts = requireNonNull(sourcesCosts, "sourcesCosts is null");
this.types = requireNonNull(types, "types is null");
this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null");
this.session = requireNonNull(session, "session is null");
}
Expand All @@ -114,7 +111,7 @@ public PlanCostEstimate visitGroupReference(GroupReference node, Void context)
@Override
public PlanCostEstimate visitAssignUniqueId(AssignUniqueId node, Void context)
{
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdColumn()), types));
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(ImmutableList.of(node.getIdColumn())));
return costForStreaming(node, localCost);
}

Expand All @@ -131,8 +128,8 @@ public PlanCostEstimate visitRowNumber(RowNumberNode node, Void context)
.build();
}
PlanNodeStatsEstimate stats = getStats(node);
double cpuCost = stats.getOutputSizeInBytes(symbols, types);
double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types);
double cpuCost = stats.getOutputSizeInBytes(symbols);
double memoryCost = node.getPartitionBy().isEmpty() ? 0 : stats.getOutputSizeInBytes(node.getSource().getOutputSymbols());
LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0);
return costForStreaming(node, localCost);
}
Expand All @@ -147,21 +144,21 @@ public PlanCostEstimate visitOutput(OutputNode node, Void context)
public PlanCostEstimate visitTableScan(TableScanNode node, Void context)
{
// TODO: add network cost, based on input size in bytes? Or let connector provide this cost?
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
return costForSource(node, localCost);
}

@Override
public PlanCostEstimate visitFilter(FilterNode node, Void context)
{
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols(), types));
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node.getSource()).getOutputSizeInBytes(node.getOutputSymbols()));
return costForStreaming(node, localCost);
}

@Override
public PlanCostEstimate visitProject(ProjectNode node, Void context)
{
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
return costForStreaming(node, localCost);
}

Expand All @@ -173,8 +170,8 @@ public PlanCostEstimate visitAggregation(AggregationNode node, Void context)
}
PlanNodeStatsEstimate aggregationStats = getStats(node);
PlanNodeStatsEstimate sourceStats = getStats(node.getSource());
double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols(), types);
double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols(), types);
double cpuCost = sourceStats.getOutputSizeInBytes(node.getSource().getOutputSymbols());
double memoryCost = aggregationStats.getOutputSizeInBytes(node.getOutputSymbols());
LocalCostEstimate localCost = LocalCostEstimate.of(cpuCost, memoryCost, 0);
return costForAccumulation(node, localCost);
}
Expand All @@ -197,15 +194,13 @@ private LocalCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanN
probe,
build,
stats,
types,
replicated,
estimatedSourceDistributedTaskCount);
// TODO: Use traits (https://github.com/trinodb/trino/issues/4763) instead, to correctly estimate
// local exchange cost for replicated join in CostCalculatorUsingExchanges#visitExchange
LocalCostEstimate adjustedLocalExchangeCost = adjustReplicatedJoinLocalExchangeCost(
build,
stats,
types,
replicated,
estimatedSourceDistributedTaskCount);
LocalCostEstimate joinOutputCost = calculateJoinOutputCost(join);
Expand All @@ -215,7 +210,7 @@ private LocalCostEstimate calculateJoinCost(PlanNode join, PlanNode probe, PlanN
private LocalCostEstimate calculateJoinOutputCost(PlanNode join)
{
PlanNodeStatsEstimate outputStats = getStats(join);
double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols(), types);
double joinOutputSize = outputStats.getOutputSizeInBytes(join.getOutputSymbols());
return LocalCostEstimate.ofCpu(joinOutputSize);
}

Expand All @@ -227,7 +222,7 @@ public PlanCostEstimate visitExchange(ExchangeNode node, Void context)

private LocalCostEstimate calculateExchangeCost(ExchangeNode node)
{
double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types);
double inputSizeInBytes = getStats(node).getOutputSizeInBytes(node.getOutputSymbols());
switch (node.getScope()) {
case LOCAL:
switch (node.getType()) {
Expand Down Expand Up @@ -297,7 +292,7 @@ public PlanCostEstimate visitLimit(LimitNode node, Void context)
// so proper cost estimation is not that important. Second, since LimitNode can lead to incomplete evaluation
// of the source, true cost estimation should be implemented as a "constraint" enforced on a sub-tree and
// evaluated in context of actual source node type (and their sources).
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols(), types));
LocalCostEstimate localCost = LocalCostEstimate.ofCpu(getStats(node).getOutputSizeInBytes(node.getOutputSymbols()));
return costForStreaming(node, localCost);
}

Expand Down
Loading