Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Pattern<AggregationNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
if (node.getGroupingSetCount() != 1) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Pattern<AssignUniqueId> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(AssignUniqueId assignUniqueId, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> calculate(AssignUniqueId assignUniqueId, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(assignUniqueId.getSource());
return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,23 @@ public final class CachingStatsProvider
private final Lookup lookup;
private final Session session;
private final TypeProvider types;
private final TableStatsProvider tableStatsProvider;

private final Map<PlanNode, PlanNodeStatsEstimate> cache = new IdentityHashMap<>();

public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types)
public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
this(statsCalculator, Optional.empty(), noLookup(), session, types);
this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider);
}

public CachingStatsProvider(StatsCalculator statsCalculator, Optional<Memo> memo, Lookup lookup, Session session, TypeProvider types)
public CachingStatsProvider(StatsCalculator statsCalculator, Optional<Memo> memo, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null");
this.memo = requireNonNull(memo, "memo is null");
this.lookup = requireNonNull(lookup, "lookup is null");
this.session = requireNonNull(session, "session is null");
this.types = requireNonNull(types, "types is null");
this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null");
}

@Override
Expand All @@ -77,7 +79,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node)
return stats;
}

stats = statsCalculator.calculateStats(node, this, lookup, session, types);
stats = statsCalculator.calculateStats(node, this, lookup, session, types, tableStatsProvider);
verify(cache.put(node, stats) == null, "Stats already set");
return stats;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.cost;

import io.trino.Session;
import io.trino.metadata.Metadata;
import io.trino.metadata.TableHandle;
import io.trino.spi.statistics.TableStatistics;

import java.util.HashMap;
import java.util.Map;

import static java.util.Objects.requireNonNull;

public class CachingTableStatsProvider
implements TableStatsProvider
{
private final Metadata metadata;
private final Map<TableHandle, TableStatistics> cache = new HashMap<>();
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A connector creates new TableHandles during various ConnectorMetadata.apply* calls.
A table handle may become "seen" in the plan and then discarded, make obsolete.

I think we should use weak keys here. Otherwise we need to size this as a regular cache.
(WeakHashMap provides weak keys with equality-based lookup, so i'd recommend that)

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't fully understand this comment.. My understanding is that CachingTableStatsProvider is created per query and its lifecycle is within query planning. So the cache here can be GCed after the query planning. What issue did you see with it?


public CachingTableStatsProvider(Metadata metadata)
{
this.metadata = requireNonNull(metadata, "metadata is null");
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Squash "Address comments" with respective commits.

(You did rebase anyway, and split the commit into two in this PR, so keeping some changes as a fixup commit doesn't make review easier)

}

@Override
public TableStatistics getTableStatistics(Session session, TableHandle tableHandle)
{
TableStatistics stats = cache.get(tableHandle);
if (stats == null) {
stats = metadata.getTableStatistics(session, tableHandle);
cache.put(tableHandle, stats);
}
return stats;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,30 +65,30 @@ private Stream<Rule<?>> getCandidates(PlanNode node)
}

@Override
public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
Iterator<Rule<?>> ruleIterator = getCandidates(node).iterator();
while (ruleIterator.hasNext()) {
Rule<?> rule = ruleIterator.next();
Optional<PlanNodeStatsEstimate> calculatedStats = calculateStats(rule, node, sourceStats, lookup, session, types);
Optional<PlanNodeStatsEstimate> calculatedStats = calculateStats(rule, node, sourceStats, lookup, session, types, tableStatsProvider);
if (calculatedStats.isPresent()) {
return calculatedStats.get();
}
}
return PlanNodeStatsEstimate.unknown();
}

private static <T extends PlanNode> Optional<PlanNodeStatsEstimate> calculateStats(Rule<T> rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
private static <T extends PlanNode> Optional<PlanNodeStatsEstimate> calculateStats(Rule<T> rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
@SuppressWarnings("unchecked")
T typedNode = (T) node;
return rule.calculate(typedNode, sourceStats, lookup, session, types);
return rule.calculate(typedNode, sourceStats, lookup, session, types, tableStatsProvider);
}

public interface Rule<T extends PlanNode>
{
Pattern<T> getPattern();

Optional<PlanNodeStatsEstimate> calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types);
Optional<PlanNodeStatsEstimate> calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ public Pattern<DistinctLimitNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(DistinctLimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(DistinctLimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
if (node.isPartial()) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public Pattern<EnforceSingleRowNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(EnforceSingleRowNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(EnforceSingleRowNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats.getStats(node.getSource()))
.setOutputRowCount(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public Pattern<ExchangeNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(ExchangeNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
Optional<PlanNodeStatsEstimate> estimate = Optional.empty();
for (int i = 0; i < node.getSources().size(); i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public Pattern<FilterNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
if (!isNonEstimatablePredicateApproximationEnabled(session)) {
return Optional.empty();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ public Pattern<FilterNode> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
PlanNodeStatsEstimate estimate = filterStatsCalculator.filterStats(sourceStats, node.getPredicate(), session, types);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ public Pattern<JoinNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public Pattern<LimitNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(LimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(LimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
if (sourceStats.getOutputRowCount() <= node.getCount()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Pattern<OutputNode> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(OutputNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> calculate(OutputNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
return Optional.of(sourceStats.getStats(node.getSource()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ public Pattern<ProjectNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(ProjectNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ public Pattern<RowNumberNode> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> doCalculate(RowNumberNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> doCalculate(RowNumberNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
if (sourceStats.isOutputRowCountUnknown()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public Pattern<SampleNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(SampleNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(SampleNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());
PlanNodeStatsEstimate calculatedStats = sourceStats.mapOutputRowCount(outputRowCount -> outputRowCount * node.getSampleRatio());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Pattern<SemiJoinNode> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(SemiJoinNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> calculate(SemiJoinNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public Pattern<FilterNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNode nodeSource = lookup.resolve(node.getSource());
SemiJoinNode semiJoinNode;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,11 @@ protected SimpleStatsRule(StatsNormalizer normalizer)
}

@Override
public final Optional<PlanNodeStatsEstimate> calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
public final Optional<PlanNodeStatsEstimate> calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
return doCalculate(node, sourceStats, lookup, session, types)
return doCalculate(node, sourceStats, lookup, session, types, tableStatsProvider)
.map(estimate -> normalizer.normalize(estimate, node.getOutputSymbols(), types));
}

protected abstract Optional<PlanNodeStatsEstimate> doCalculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types);
protected abstract Optional<PlanNodeStatsEstimate> doCalculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider);
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public Pattern<SortNode> getPattern()
}

@Override
public Optional<PlanNodeStatsEstimate> calculate(SortNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
public Optional<PlanNodeStatsEstimate> calculate(SortNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
return Optional.of(sourceStats.getStats(node.getSource()));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public SpatialJoinStatsRule(FilterStatsCalculator statsCalculator, StatsNormaliz
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(SpatialJoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(SpatialJoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft());
PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,18 @@ public interface StatsCalculator
* @param sourceStats The stats provider for any child nodes' stats, if needed to compute stats for the {@code node}
* @param lookup Lookup to be used when resolving source nodes, allowing stats calculation to work within {@link IterativeOptimizer}
* @param types The type provider for all symbols in the scope.
* @param tableStatsProvider The table stats provider.
*/
PlanNodeStatsEstimate calculateStats(
PlanNode node,
StatsProvider sourceStats,
Lookup lookup,
Session session,
TypeProvider types);
TypeProvider types,
TableStatsProvider tableStatsProvider);
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need a new entity here, the tableStatsProvider?
Could StatsProvider sourceStats do the job?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

StatsProvider sourceStats is called to get children plan nodes' stats.

TableStatsProvider tableStatsProvider is used when calculating the "ultimate" source stats: table scan stats. In fact, all this PR does is to wire in a CachingTableStatsProvider for TableScanStatsRule to use.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can easily imagine CachingStatsProvider that caches stats on per-PlanNode basis. This would be a generally useful implementation, potentially allowing us to cut down some stats calculation cost.

The currently existing CachingStatsProvider caches stats in Memo's group. This is useful as well -- the group contains alternative plans (currently always exactly one at a time) that produce same relation. Same relation implies that once the stats are calculated, they are applicable to all group members.

Now, these two concepts are different, but not exclusionary:

  • we could have "L2" cache that's on per-PlanNode basis
  • then "L1" cache that's based on Memo group

Now the question is, whether per-PlanNode basis and per-TableHandle basis are significantly different.
I think there are not. I don't think we ever have a case where we have two TableScanNode that have same TableHandle.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now the question is, whether per-PlanNode basis and per-TableHandle basis are significantly different.

Actually they would be different.

Plan nodes have no equality, they compare by identity. Upon each exit from IterativeOptimizer (exit from Memo), a fully new plan structure is created. This can be improved a bit (eg produce new plan only if anything changed), but still would break identity-based caching, except for root nodes (table scans). Thus, the solution would look more generic, but would not actually be.


static StatsCalculator noopStatsCalculator()
{
return (node, sourceStats, lookup, ignore, types) -> PlanNodeStatsEstimate.unknown();
return (node, sourceStats, lookup, ignore, types, tableStatsProvider) -> PlanNodeStatsEstimate.unknown();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ public List<ComposableStatsCalculator.Rule<?>> get()
ImmutableList.Builder<ComposableStatsCalculator.Rule<?>> rules = ImmutableList.builder();

rules.add(new OutputStatsRule());
rules.add(new TableScanStatsRule(plannerContext.getMetadata(), normalizer));
rules.add(new TableScanStatsRule(normalizer));
rules.add(new SimpleFilterProjectSemiJoinStatsRule(plannerContext.getMetadata(), normalizer, filterStatsCalculator)); // this must be before FilterStatsRule
rules.add(new FilterProjectAggregationStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule
rules.add(new FilterStatsRule(normalizer, filterStatsCalculator));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

import io.trino.Session;
import io.trino.matching.Pattern;
import io.trino.metadata.Metadata;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.statistics.ColumnStatistics;
import io.trino.spi.statistics.TableStatistics;
Expand All @@ -40,12 +39,9 @@ public class TableScanStatsRule
{
private static final Pattern<TableScanNode> PATTERN = tableScan();

private final Metadata metadata;

public TableScanStatsRule(Metadata metadata, StatsNormalizer normalizer)
public TableScanStatsRule(StatsNormalizer normalizer)
{
super(normalizer); // Use stats normalization since connector can return inconsistent stats values
this.metadata = requireNonNull(metadata, "metadata is null");
}

@Override
Expand All @@ -55,13 +51,13 @@ public Pattern<TableScanNode> getPattern()
}

@Override
protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types)
protected Optional<PlanNodeStatsEstimate> doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider)
{
if (isStatisticsPrecalculationForPushdownEnabled(session) && node.getStatistics().isPresent()) {
return node.getStatistics();
}

TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable());
TableStatistics tableStatistics = tableStatsProvider.getTableStatistics(session, node.getTable());

Map<Symbol, SymbolStatsEstimate> outputSymbolStats = new HashMap<>();

Expand Down
Loading