diff --git a/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java index f387b2c7ed0a..f0c2c46c3e59 100644 --- a/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/AggregationStatsRule.java @@ -47,7 +47,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(AggregationNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { if (node.getGroupingSetCount() != 1) { return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/cost/AssignUniqueIdStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/AssignUniqueIdStatsRule.java index 1c8352afb4f4..99b8dfd60c15 100644 --- a/core/trino-main/src/main/java/io/trino/cost/AssignUniqueIdStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/AssignUniqueIdStatsRule.java @@ -36,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional calculate(AssignUniqueId assignUniqueId, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + public Optional calculate(AssignUniqueId assignUniqueId, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(assignUniqueId.getSource()); return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats) diff --git a/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java b/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java index 703b0f652774..bce3a946d38e 100644 --- a/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java +++ b/core/trino-main/src/main/java/io/trino/cost/CachingStatsProvider.java @@ -41,21 +41,23 @@ public final class CachingStatsProvider private final Lookup lookup; private final Session session; private final TypeProvider types; + private final TableStatsProvider tableStatsProvider; private final Map cache = new IdentityHashMap<>(); - public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types) + public CachingStatsProvider(StatsCalculator statsCalculator, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { - this(statsCalculator, Optional.empty(), noLookup(), session, types); + this(statsCalculator, Optional.empty(), noLookup(), session, types, tableStatsProvider); } - public CachingStatsProvider(StatsCalculator statsCalculator, Optional memo, Lookup lookup, Session session, TypeProvider types) + public CachingStatsProvider(StatsCalculator statsCalculator, Optional memo, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); this.memo = requireNonNull(memo, "memo is null"); this.lookup = requireNonNull(lookup, "lookup is null"); this.session = requireNonNull(session, "session is null"); this.types = requireNonNull(types, "types is null"); + this.tableStatsProvider = requireNonNull(tableStatsProvider, "tableStatsProvider is null"); } @Override @@ -77,7 +79,7 @@ public PlanNodeStatsEstimate getStats(PlanNode node) return stats; } - stats = statsCalculator.calculateStats(node, this, lookup, session, types); + stats = statsCalculator.calculateStats(node, this, lookup, session, types, tableStatsProvider); verify(cache.put(node, stats) == null, "Stats already set"); return stats; } diff --git a/core/trino-main/src/main/java/io/trino/cost/CachingTableStatsProvider.java b/core/trino-main/src/main/java/io/trino/cost/CachingTableStatsProvider.java new file mode 100644 index 000000000000..8049ee1759c1 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/CachingTableStatsProvider.java @@ -0,0 +1,50 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import io.trino.Session; +import io.trino.metadata.Metadata; +import io.trino.metadata.TableHandle; +import io.trino.spi.statistics.TableStatistics; + +import java.util.Map; +import java.util.WeakHashMap; + +import static java.util.Objects.requireNonNull; + +public class CachingTableStatsProvider + implements TableStatsProvider +{ + private final Metadata metadata; + private final Session session; + + private final Map cache = new WeakHashMap<>(); + + public CachingTableStatsProvider(Metadata metadata, Session session) + { + this.metadata = requireNonNull(metadata, "metadata is null"); + this.session = requireNonNull(session, "session is null"); + } + + @Override + public TableStatistics getTableStatistics(TableHandle tableHandle) + { + TableStatistics stats = cache.get(tableHandle); + if (stats == null) { + stats = metadata.getTableStatistics(session, tableHandle); + cache.put(tableHandle, stats); + } + return stats; + } +} diff --git a/core/trino-main/src/main/java/io/trino/cost/ComposableStatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/ComposableStatsCalculator.java index ab63f4045c12..151cdc41494f 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ComposableStatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/ComposableStatsCalculator.java @@ -65,12 +65,12 @@ private Stream> getCandidates(PlanNode node) } @Override - public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { Iterator> ruleIterator = getCandidates(node).iterator(); while (ruleIterator.hasNext()) { Rule rule = ruleIterator.next(); - Optional calculatedStats = calculateStats(rule, node, sourceStats, lookup, session, types); + Optional calculatedStats = calculateStats(rule, node, sourceStats, lookup, session, types, tableStatsProvider); if (calculatedStats.isPresent()) { return calculatedStats.get(); } @@ -78,17 +78,17 @@ public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceS return PlanNodeStatsEstimate.unknown(); } - private static Optional calculateStats(Rule rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + private static Optional calculateStats(Rule rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { @SuppressWarnings("unchecked") T typedNode = (T) node; - return rule.calculate(typedNode, sourceStats, lookup, session, types); + return rule.calculate(typedNode, sourceStats, lookup, session, types, tableStatsProvider); } public interface Rule { Pattern getPattern(); - Optional calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types); + Optional calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider); } } diff --git a/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java index 1c6a3eec65ab..14f5df5f22a1 100644 --- a/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/DistinctLimitStatsRule.java @@ -42,7 +42,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(DistinctLimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(DistinctLimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { if (node.isPartial()) { return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/cost/EnforceSingleRowStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/EnforceSingleRowStatsRule.java index 3e461c76f283..7d23775c10ba 100644 --- a/core/trino-main/src/main/java/io/trino/cost/EnforceSingleRowStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/EnforceSingleRowStatsRule.java @@ -40,7 +40,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(EnforceSingleRowNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(EnforceSingleRowNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { return Optional.of(PlanNodeStatsEstimate.buildFrom(sourceStats.getStats(node.getSource())) .setOutputRowCount(1) diff --git a/core/trino-main/src/main/java/io/trino/cost/ExchangeStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/ExchangeStatsRule.java index 95bc658b0358..d423d48b307a 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ExchangeStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/ExchangeStatsRule.java @@ -46,7 +46,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(ExchangeNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(ExchangeNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { Optional estimate = Optional.empty(); for (int i = 0; i < node.getSources().size(); i++) { diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java index e4e565407c44..15090721f42a 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterProjectAggregationStatsRule.java @@ -56,7 +56,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { if (!isNonEstimatablePredicateApproximationEnabled(session)) { return Optional.empty(); diff --git a/core/trino-main/src/main/java/io/trino/cost/FilterStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/FilterStatsRule.java index 17b2d7e14027..c12ac1878565 100644 --- a/core/trino-main/src/main/java/io/trino/cost/FilterStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/FilterStatsRule.java @@ -45,7 +45,7 @@ public Pattern getPattern() } @Override - public Optional doCalculate(FilterNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + public Optional doCalculate(FilterNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); PlanNodeStatsEstimate estimate = filterStatsCalculator.filterStats(sourceStats, node.getPredicate(), session, types); diff --git a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java index 19495492d9b8..70e182b3e453 100644 --- a/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/JoinStatsRule.java @@ -78,7 +78,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(JoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft()); PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight()); diff --git a/core/trino-main/src/main/java/io/trino/cost/LimitStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/LimitStatsRule.java index 135a061a2712..fd3e46591293 100644 --- a/core/trino-main/src/main/java/io/trino/cost/LimitStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/LimitStatsRule.java @@ -40,7 +40,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(LimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(LimitNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); if (sourceStats.getOutputRowCount() <= node.getCount()) { diff --git a/core/trino-main/src/main/java/io/trino/cost/OutputStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/OutputStatsRule.java index 39473f04a304..7997910b95e2 100644 --- a/core/trino-main/src/main/java/io/trino/cost/OutputStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/OutputStatsRule.java @@ -36,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional calculate(OutputNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public Optional calculate(OutputNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { return Optional.of(sourceStats.getStats(node.getSource())); } diff --git a/core/trino-main/src/main/java/io/trino/cost/ProjectStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/ProjectStatsRule.java index 27177bfd78af..3433978da3ab 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ProjectStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/ProjectStatsRule.java @@ -47,7 +47,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(ProjectNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(ProjectNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); PlanNodeStatsEstimate.Builder calculatedStats = PlanNodeStatsEstimate.builder() diff --git a/core/trino-main/src/main/java/io/trino/cost/RowNumberStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/RowNumberStatsRule.java index 5999a59213a6..e7207498d21a 100644 --- a/core/trino-main/src/main/java/io/trino/cost/RowNumberStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/RowNumberStatsRule.java @@ -44,7 +44,7 @@ public Pattern getPattern() } @Override - public Optional doCalculate(RowNumberNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + public Optional doCalculate(RowNumberNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); if (sourceStats.isOutputRowCountUnknown()) { diff --git a/core/trino-main/src/main/java/io/trino/cost/SampleStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SampleStatsRule.java index 9c7a30467578..72c43d010f4c 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SampleStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SampleStatsRule.java @@ -40,7 +40,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(SampleNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(SampleNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); PlanNodeStatsEstimate calculatedStats = sourceStats.mapOutputRowCount(outputRowCount -> outputRowCount * node.getSampleRatio()); diff --git a/core/trino-main/src/main/java/io/trino/cost/SemiJoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SemiJoinStatsRule.java index 606a5cf80a6d..18dcf8cec7a0 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SemiJoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SemiJoinStatsRule.java @@ -36,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional calculate(SemiJoinNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + public Optional calculate(SemiJoinNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); diff --git a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java index e06a9f192046..dd48bf2e2b0f 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SimpleFilterProjectSemiJoinStatsRule.java @@ -66,7 +66,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(FilterNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNode nodeSource = lookup.resolve(node.getSource()); SemiJoinNode semiJoinNode; diff --git a/core/trino-main/src/main/java/io/trino/cost/SimpleStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SimpleStatsRule.java index 2183e6a41453..d10ac0fb1890 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SimpleStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SimpleStatsRule.java @@ -34,11 +34,11 @@ protected SimpleStatsRule(StatsNormalizer normalizer) } @Override - public final Optional calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public final Optional calculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { - return doCalculate(node, sourceStats, lookup, session, types) + return doCalculate(node, sourceStats, lookup, session, types, tableStatsProvider) .map(estimate -> normalizer.normalize(estimate, node.getOutputSymbols(), types)); } - protected abstract Optional doCalculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types); + protected abstract Optional doCalculate(T node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider); } diff --git a/core/trino-main/src/main/java/io/trino/cost/SortStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SortStatsRule.java index d291c350b6bf..d379cc8f8ce6 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SortStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SortStatsRule.java @@ -36,7 +36,7 @@ public Pattern getPattern() } @Override - public Optional calculate(SortNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public Optional calculate(SortNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { return Optional.of(sourceStats.getStats(node.getSource())); } diff --git a/core/trino-main/src/main/java/io/trino/cost/SpatialJoinStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/SpatialJoinStatsRule.java index 75d5897e53f7..ecbfd1e34ee6 100644 --- a/core/trino-main/src/main/java/io/trino/cost/SpatialJoinStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/SpatialJoinStatsRule.java @@ -38,7 +38,7 @@ public SpatialJoinStatsRule(FilterStatsCalculator statsCalculator, StatsNormaliz } @Override - protected Optional doCalculate(SpatialJoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(SpatialJoinNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate leftStats = sourceStats.getStats(node.getLeft()); PlanNodeStatsEstimate rightStats = sourceStats.getStats(node.getRight()); diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java index a5bbfd9c8add..4fb14f6fc205 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculator.java @@ -28,16 +28,18 @@ public interface StatsCalculator * @param sourceStats The stats provider for any child nodes' stats, if needed to compute stats for the {@code node} * @param lookup Lookup to be used when resolving source nodes, allowing stats calculation to work within {@link IterativeOptimizer} * @param types The type provider for all symbols in the scope. + * @param tableStatsProvider The table stats provider. */ PlanNodeStatsEstimate calculateStats( PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, - TypeProvider types); + TypeProvider types, + TableStatsProvider tableStatsProvider); static StatsCalculator noopStatsCalculator() { - return (node, sourceStats, lookup, ignore, types) -> PlanNodeStatsEstimate.unknown(); + return (node, sourceStats, lookup, ignore, types, tableStatsProvider) -> PlanNodeStatsEstimate.unknown(); } } diff --git a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java index bc67c68bbcc1..3440ee204902 100644 --- a/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java +++ b/core/trino-main/src/main/java/io/trino/cost/StatsCalculatorModule.java @@ -65,7 +65,7 @@ public List> get() ImmutableList.Builder> rules = ImmutableList.builder(); rules.add(new OutputStatsRule()); - rules.add(new TableScanStatsRule(plannerContext.getMetadata(), normalizer)); + rules.add(new TableScanStatsRule(normalizer)); rules.add(new SimpleFilterProjectSemiJoinStatsRule(plannerContext.getMetadata(), normalizer, filterStatsCalculator)); // this must be before FilterStatsRule rules.add(new FilterProjectAggregationStatsRule(normalizer, filterStatsCalculator)); // this must be before FilterStatsRule rules.add(new FilterStatsRule(normalizer, filterStatsCalculator)); diff --git a/core/trino-main/src/main/java/io/trino/cost/TableScanStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/TableScanStatsRule.java index 3601dd2ac21d..ce1191085e2f 100644 --- a/core/trino-main/src/main/java/io/trino/cost/TableScanStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/TableScanStatsRule.java @@ -15,7 +15,6 @@ import io.trino.Session; import io.trino.matching.Pattern; -import io.trino.metadata.Metadata; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.statistics.ColumnStatistics; import io.trino.spi.statistics.TableStatistics; @@ -40,12 +39,9 @@ public class TableScanStatsRule { private static final Pattern PATTERN = tableScan(); - private final Metadata metadata; - - public TableScanStatsRule(Metadata metadata, StatsNormalizer normalizer) + public TableScanStatsRule(StatsNormalizer normalizer) { super(normalizer); // Use stats normalization since connector can return inconsistent stats values - this.metadata = requireNonNull(metadata, "metadata is null"); } @Override @@ -55,13 +51,13 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(TableScanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { if (isStatisticsPrecalculationForPushdownEnabled(session) && node.getStatistics().isPresent()) { return node.getStatistics(); } - TableStatistics tableStatistics = metadata.getTableStatistics(session, node.getTable()); + TableStatistics tableStatistics = tableStatsProvider.getTableStatistics(node.getTable()); Map outputSymbolStats = new HashMap<>(); diff --git a/core/trino-main/src/main/java/io/trino/cost/TableStatsProvider.java b/core/trino-main/src/main/java/io/trino/cost/TableStatsProvider.java new file mode 100644 index 000000000000..13ee3b700e6c --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/cost/TableStatsProvider.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.cost; + +import io.trino.metadata.TableHandle; +import io.trino.spi.statistics.TableStatistics; + +public interface TableStatsProvider +{ + TableStatistics getTableStatistics(TableHandle tableHandle); +} diff --git a/core/trino-main/src/main/java/io/trino/cost/TopNStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/TopNStatsRule.java index 52280607e75a..487ba43fd920 100644 --- a/core/trino-main/src/main/java/io/trino/cost/TopNStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/TopNStatsRule.java @@ -42,7 +42,7 @@ public Pattern getPattern() } @Override - protected Optional doCalculate(TopNNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected Optional doCalculate(TopNNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate sourceStats = statsProvider.getStats(node.getSource()); double rowCount = sourceStats.getOutputRowCount(); diff --git a/core/trino-main/src/main/java/io/trino/cost/UnionStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/UnionStatsRule.java index 8892ad2aa733..522efa23d3dd 100644 --- a/core/trino-main/src/main/java/io/trino/cost/UnionStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/UnionStatsRule.java @@ -46,7 +46,7 @@ public Pattern getPattern() } @Override - protected final Optional doCalculate(UnionNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types) + protected final Optional doCalculate(UnionNode node, StatsProvider statsProvider, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { checkArgument(!node.getSources().isEmpty(), "Empty Union is not supported"); diff --git a/core/trino-main/src/main/java/io/trino/cost/ValuesStatsRule.java b/core/trino-main/src/main/java/io/trino/cost/ValuesStatsRule.java index 513542754e15..65b024eefdce 100644 --- a/core/trino-main/src/main/java/io/trino/cost/ValuesStatsRule.java +++ b/core/trino-main/src/main/java/io/trino/cost/ValuesStatsRule.java @@ -63,7 +63,7 @@ public Pattern getPattern() } @Override - public Optional calculate(ValuesNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public Optional calculate(ValuesNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { PlanNodeStatsEstimate.Builder statsBuilder = PlanNodeStatsEstimate.builder(); statsBuilder.setOutputRowCount(node.getRowCount()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java index cdf7c463d851..22a8a05e8cc3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/LogicalPlanner.java @@ -19,11 +19,13 @@ import io.trino.Session; import io.trino.cost.CachingCostProvider; import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.CostProvider; import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AnalyzeMetadata; import io.trino.metadata.Metadata; @@ -235,9 +237,11 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) planSanityChecker.validateIntermediatePlan(root, session, plannerContext, typeAnalyzer, symbolAllocator.getTypes(), warningCollector); + TableStatsProvider tableStatsProvider = new CachingTableStatsProvider(metadata, session); + if (stage.ordinal() >= OPTIMIZED.ordinal()) { for (PlanOptimizer optimizer : planOptimizers) { - root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector); + root = optimizer.optimize(root, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector, tableStatsProvider); requireNonNull(root, format("%s returned a null plan", optimizer.getClass().getName())); if (LOG.isDebugEnabled()) { @@ -263,7 +267,7 @@ public Plan plan(Analysis analysis, Stage stage, boolean collectPlanStatistics) StatsAndCosts statsAndCosts = StatsAndCosts.empty(); if (collectPlanStatistics) { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types, tableStatsProvider); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session, types); statsAndCosts = StatsAndCosts.create(root, statsProvider, costProvider); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java index 2d9bb1d0ba7e..c0fd2111506e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/IterativeOptimizer.java @@ -25,6 +25,7 @@ import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.matching.Capture; import io.trino.matching.Match; @@ -98,12 +99,12 @@ public IterativeOptimizer(PlannerContext plannerContext, RuleStatsRecorder stats } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { // only disable new rules if we have legacy rules to fall back to if (useLegacyRules.test(session) && !legacyRules.isEmpty()) { for (PlanOptimizer optimizer : legacyRules) { - plan = optimizer.optimize(plan, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector); + plan = optimizer.optimize(plan, session, symbolAllocator.getTypes(), symbolAllocator, idAllocator, warningCollector, tableStatsProvider); } return plan; @@ -113,7 +114,7 @@ public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, Sym Lookup lookup = Lookup.from(planNode -> Stream.of(memo.resolve(planNode))); Duration timeout = SystemSessionProperties.getOptimizerTimeout(session); - Context context = new Context(memo, lookup, idAllocator, symbolAllocator, nanoTime(), timeout.toMillis(), session, warningCollector); + Context context = new Context(memo, lookup, idAllocator, symbolAllocator, nanoTime(), timeout.toMillis(), session, warningCollector, tableStatsProvider); exploreGroup(memo.getRootGroup(), context); return memo.extract(); @@ -258,7 +259,7 @@ private boolean exploreChildren(int group, Context context) private Rule.Context ruleContext(Context context) { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(context.memo), context.lookup, context.session, context.symbolAllocator.getTypes()); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(context.memo), context.lookup, context.session, context.symbolAllocator.getTypes(), context.tableStatsProvider); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(context.memo), context.session, context.symbolAllocator.getTypes()); return new Rule.Context() @@ -323,6 +324,7 @@ private static class Context private final long timeoutInMilliseconds; private final Session session; private final WarningCollector warningCollector; + private final TableStatsProvider tableStatsProvider; private final Map, RuleInvocationStats> ruleStats = new HashMap<>(); @@ -334,7 +336,8 @@ public Context( long startTimeInNanos, long timeoutInMilliseconds, Session session, - WarningCollector warningCollector) + WarningCollector warningCollector, + TableStatsProvider tableStatsProvider) { checkArgument(timeoutInMilliseconds >= 0, "Timeout has to be a non-negative number [milliseconds]"); @@ -346,6 +349,7 @@ public Context( this.timeoutInMilliseconds = timeoutInMilliseconds; this.session = session; this.warningCollector = warningCollector; + this.tableStatsProvider = tableStatsProvider; } public void checkTimeoutNotExhausted() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java index 060e813cfce4..3c020d6c512e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/RemoveUnsupportedDynamicFilters.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import io.trino.Session; import io.trino.connector.CatalogServiceProvider; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AnalyzePropertyManager; import io.trino.metadata.OperatorNotFoundException; @@ -109,7 +110,7 @@ public RemoveUnsupportedDynamicFilters(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { PlanWithConsumedDynamicFilters result = plan.accept(new RemoveUnsupportedDynamicFilters.Rewriter(session, types), ImmutableSet.of()); return result.getNode(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java index 0e796dd23fe6..aa4c8fff0e0a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddExchanges.java @@ -24,6 +24,7 @@ import io.trino.cost.CachingStatsProvider; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.spi.connector.GroupingProperty; import io.trino.spi.connector.LocalProperty; @@ -136,9 +137,9 @@ public AddExchanges(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer, St } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { - PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator, session), PreferredProperties.any()); + PlanWithProperties result = plan.accept(new Rewriter(idAllocator, symbolAllocator, session, tableStatsProvider), PreferredProperties.any()); return result.getNode(); } @@ -156,12 +157,12 @@ private class Rewriter private final boolean redistributeWrites; private final boolean scaleWriters; - public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session) + public Rewriter(PlanNodeIdAllocator idAllocator, SymbolAllocator symbolAllocator, Session session, TableStatsProvider tableStatsProvider) { this.idAllocator = idAllocator; this.symbolAllocator = symbolAllocator; this.types = symbolAllocator.getTypes(); - this.statsProvider = new CachingStatsProvider(statsCalculator, session, types); + this.statsProvider = new CachingStatsProvider(statsCalculator, session, types, tableStatsProvider); this.session = session; this.domainTranslator = new DomainTranslator(plannerContext); this.distributedIndexJoins = SystemSessionProperties.isDistributedIndexJoinEnabled(session); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java index 522efb8d28fd..45067db0775a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/AddLocalExchanges.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Lists; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.spi.connector.ConstantProperty; import io.trino.spi.connector.GroupingProperty; @@ -113,7 +114,7 @@ public AddLocalExchanges(PlannerContext plannerContext, TypeAnalyzer typeAnalyze } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { PlanWithProperties result = plan.accept(new Rewriter(symbolAllocator, idAllocator, session), any()); return result.getNode(); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java index ec9623510dc6..d4c74933ce1a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/BeginTableWrite.java @@ -15,6 +15,7 @@ import io.trino.Session; import io.trino.cost.StatsAndCosts; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionManager; import io.trino.metadata.Metadata; @@ -83,7 +84,7 @@ public BeginTableWrite(Metadata metadata, FunctionManager functionManager) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { try { return SimplePlanRewriter.rewriteWith(new Rewriter(session), plan, Optional.empty()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java index 59c4c1697e49..66f7d3afe080 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/CheckSubqueryNodesAreRewritten.java @@ -15,6 +15,7 @@ package io.trino.sql.planner.optimizations; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.spi.TrinoException; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -37,7 +38,7 @@ public class CheckSubqueryNodesAreRewritten implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { searchFrom(plan).where(ApplyNode.class::isInstance) .findFirst() diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java index c9262e2958db..e86a201cb869 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/HashGenerationOptimizer.java @@ -25,6 +25,7 @@ import com.google.common.collect.Multimap; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.function.OperatorType; @@ -106,7 +107,7 @@ public HashGenerationOptimizer(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java index ec72bf630aa6..f0e5c981bc53 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/IndexJoinOptimizer.java @@ -21,6 +21,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Maps; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.BoundSignature; import io.trino.metadata.ResolvedFunction; @@ -80,7 +81,7 @@ public IndexJoinOptimizer(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider type, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java index 10da7a2f3ef9..deacac4239f9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/LimitPushDown.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; @@ -43,7 +44,7 @@ public class LimitPushDown implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java index ce64684bad76..7b82987c7b3e 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/MetadataQueryOptimizer.java @@ -19,6 +19,7 @@ import com.google.common.collect.Iterables; import io.trino.Session; import io.trino.SystemSessionProperties; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.TableProperties; import io.trino.spi.connector.ColumnHandle; @@ -73,7 +74,7 @@ public MetadataQueryOptimizer(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { if (!SystemSessionProperties.isOptimizeMetadataQueries(session)) { return plan; diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java index 35a2a9ce67e4..35a0d08492c1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/OptimizeMixedDistinctAggregations.java @@ -18,6 +18,7 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.Type; @@ -83,7 +84,7 @@ public OptimizeMixedDistinctAggregations(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { if (isOptimizeDistinctAggregationEnabled(session)) { return SimplePlanRewriter.rewriteWith(new Optimizer(session, idAllocator, symbolAllocator, metadata), plan, Optional.empty()); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java index 1e018beb8866..a0a33292f0d3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PlanOptimizer.java @@ -14,6 +14,7 @@ package io.trino.sql.planner.optimizations; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; @@ -28,5 +29,6 @@ PlanNode optimize( TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - WarningCollector warningCollector); + WarningCollector warningCollector, + TableStatsProvider tableStatsProvider); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java index 4ab16f218e16..f6235918cc76 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/PredicatePushDown.java @@ -21,6 +21,7 @@ import com.google.common.collect.Sets; import com.google.common.collect.Streams; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.Type; @@ -143,7 +144,7 @@ public PredicatePushDown( } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java index ae9fe64fecc4..79a41f75ecc9 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/ReplicateSemiJoinInDelete.java @@ -14,6 +14,7 @@ package io.trino.sql.planner.optimizations; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.SymbolAllocator; @@ -30,7 +31,7 @@ public class ReplicateSemiJoinInDelete implements PlanOptimizer { @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); return SimplePlanRewriter.rewriteWith(new Rewriter(), plan); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java index 3cc822fa3b87..dcb4019c24e1 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/StatsRecordingPlanOptimizer.java @@ -14,6 +14,7 @@ package io.trino.sql.planner.optimizations; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.sql.planner.OptimizerStatsRecorder; import io.trino.sql.planner.PlanNodeIdAllocator; @@ -43,13 +44,14 @@ public PlanNode optimize( TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, - WarningCollector warningCollector) + WarningCollector warningCollector, + TableStatsProvider tableStatsProvider) { PlanNode result; long duration; try { long start = System.nanoTime(); - result = delegate.optimize(plan, session, types, symbolAllocator, idAllocator, warningCollector); + result = delegate.optimize(plan, session, types, symbolAllocator, idAllocator, warningCollector, tableStatsProvider); duration = System.nanoTime() - start; } catch (RuntimeException e) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java index f5d086a0d820..b13bfbcfe6e2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/TransformQuantifiedComparisonApplyToCorrelatedJoin.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.type.BigintType; @@ -81,7 +82,7 @@ public TransformQuantifiedComparisonApplyToCorrelatedJoin(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { return rewriteWith(new Rewriter(idAllocator, types, symbolAllocator, metadata, session), plan, null); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index 13bc34bef4b1..f6c0880cc022 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -21,6 +21,7 @@ import com.google.common.collect.Sets; import io.trino.Session; import io.trino.cost.PlanNodeStatsEstimate; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; import io.trino.spi.connector.ColumnHandle; @@ -135,7 +136,7 @@ public UnaliasSymbolReferences(Metadata metadata) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java index 43189b43b5a0..36465e6f50ed 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/WindowFilterPushDown.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.FunctionId; import io.trino.spi.predicate.Domain; @@ -67,7 +68,7 @@ public WindowFilterPushDown(PlannerContext plannerContext) } @Override - public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector) + public PlanNode optimize(PlanNode plan, Session session, TypeProvider types, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, WarningCollector warningCollector, TableStatsProvider tableStatsProvider) { requireNonNull(plan, "plan is null"); requireNonNull(session, "session is null"); diff --git a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java index 2d0b1525abca..f0183d8cf63c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java +++ b/core/trino-main/src/main/java/io/trino/sql/rewrite/ShowStatsRewrite.java @@ -16,10 +16,12 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.StatsCalculator; import io.trino.cost.SymbolStatsEstimate; import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.Metadata; import io.trino.operator.scalar.timestamp.TimestampToVarcharCast; import io.trino.operator.scalar.timestamptz.TimestampWithTimeZoneToVarcharCast; import io.trino.spi.type.BigintType; @@ -87,12 +89,14 @@ public class ShowStatsRewrite private static final Expression NULL_DOUBLE = new Cast(new NullLiteral(), toSqlType(DOUBLE)); private static final Expression NULL_VARCHAR = new Cast(new NullLiteral(), toSqlType(VARCHAR)); + private final Metadata metadata; private final QueryExplainerFactory queryExplainerFactory; private final StatsCalculator statsCalculator; @Inject - public ShowStatsRewrite(QueryExplainerFactory queryExplainerFactory, StatsCalculator statsCalculator) + public ShowStatsRewrite(Metadata metadata, QueryExplainerFactory queryExplainerFactory, StatsCalculator statsCalculator) { + this.metadata = requireNonNull(metadata, "metadata is null"); this.queryExplainerFactory = requireNonNull(queryExplainerFactory, "queryExplainerFactory is null"); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); } @@ -106,7 +110,7 @@ public Statement rewrite( Map, Expression> parameterLookup, WarningCollector warningCollector) { - return (Statement) new Visitor(session, parameters, queryExplainerFactory.createQueryExplainer(analyzerFactory), warningCollector, statsCalculator).process(node, null); + return (Statement) new Visitor(session, parameters, metadata, queryExplainerFactory.createQueryExplainer(analyzerFactory), warningCollector, statsCalculator).process(node, null); } private static class Visitor @@ -114,14 +118,16 @@ private static class Visitor { private final Session session; private final List parameters; + private final Metadata metadata; private final QueryExplainer queryExplainer; private final WarningCollector warningCollector; private final StatsCalculator statsCalculator; - private Visitor(Session session, List parameters, QueryExplainer queryExplainer, WarningCollector warningCollector, StatsCalculator statsCalculator) + private Visitor(Session session, List parameters, Metadata metadata, QueryExplainer queryExplainer, WarningCollector warningCollector, StatsCalculator statsCalculator) { this.session = requireNonNull(session, "session is null"); this.parameters = requireNonNull(parameters, "parameters is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); this.queryExplainer = requireNonNull(queryExplainer, "queryExplainer is null"); this.warningCollector = requireNonNull(warningCollector, "warningCollector is null"); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator is null"); @@ -132,7 +138,7 @@ protected Node visitShowStats(ShowStats node, Void context) { Query query = getRelation(node); Plan plan = queryExplainer.getLogicalPlan(session, query, parameters, warningCollector); - CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(statsCalculator, session, plan.getTypes()); + CachingStatsProvider cachingStatsProvider = new CachingStatsProvider(statsCalculator, session, plan.getTypes(), new CachingTableStatsProvider(metadata, session)); PlanNodeStatsEstimate stats = cachingStatsProvider.getStats(plan.getRoot()); return rewriteShowStats(plan, stats); } diff --git a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java index 7a45c14a20ed..b51899fc0280 100644 --- a/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java +++ b/core/trino-main/src/main/java/io/trino/testing/LocalQueryRunner.java @@ -1119,7 +1119,7 @@ private AnalyzerFactory createAnalyzerFactory(QueryExplainerFactory queryExplain columnPropertyManager, tablePropertyManager, materializedViewPropertyManager), - new ShowStatsRewrite(queryExplainerFactory, statsCalculator), + new ShowStatsRewrite(plannerContext.getMetadata(), queryExplainerFactory, statsCalculator), new ExplainRewrite(queryExplainerFactory, new QueryPreparer(sqlParser))))); } diff --git a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java index e87b7502126d..d04085ad0e71 100644 --- a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java +++ b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorAssertion.java @@ -15,6 +15,7 @@ import io.trino.Session; import io.trino.cost.ComposableStatsCalculator.Rule; +import io.trino.metadata.Metadata; import io.trino.security.AllowAllAccessControl; import io.trino.sql.planner.TypeProvider; import io.trino.sql.planner.iterative.Lookup; @@ -36,6 +37,7 @@ public class StatsCalculatorAssertion { + private final Metadata metadata; private final StatsCalculator statsCalculator; private final Session session; private final PlanNode planNode; @@ -43,8 +45,9 @@ public class StatsCalculatorAssertion private final Map sourcesStats; - public StatsCalculatorAssertion(StatsCalculator statsCalculator, Session session, PlanNode planNode, TypeProvider types) + public StatsCalculatorAssertion(Metadata metadata, StatsCalculator statsCalculator, Session session, PlanNode planNode, TypeProvider types) { + this.metadata = requireNonNull(metadata, "metadata cannot be null"); this.statsCalculator = requireNonNull(statsCalculator, "statsCalculator cannot be null"); this.session = requireNonNull(session, "session cannot be null"); this.planNode = requireNonNull(planNode, "planNode is null"); @@ -84,7 +87,7 @@ public StatsCalculatorAssertion check(Consumer statistic { PlanNodeStatsEstimate statsEstimate = transaction(new TestingTransactionManager(), new AllowAllAccessControl()) .execute(session, transactionSession -> { - return statsCalculator.calculateStats(planNode, this::getSourceStats, noLookup(), transactionSession, types); + return statsCalculator.calculateStats(planNode, this::getSourceStats, noLookup(), transactionSession, types, new CachingTableStatsProvider(metadata, session)); }); statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate)); return this; @@ -92,16 +95,16 @@ public StatsCalculatorAssertion check(Consumer statistic public StatsCalculatorAssertion check(Rule rule, Consumer statisticsAssertionConsumer) { - Optional statsEstimate = calculatedStats(rule, planNode, this::getSourceStats, noLookup(), session, types); + Optional statsEstimate = calculatedStats(rule, planNode, this::getSourceStats, noLookup(), session, types, new CachingTableStatsProvider(metadata, session)); checkState(statsEstimate.isPresent(), "Expected stats estimates to be present"); statisticsAssertionConsumer.accept(PlanNodeStatsAssertion.assertThat(statsEstimate.get())); return this; } @SuppressWarnings("unchecked") - private static Optional calculatedStats(Rule rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + private static Optional calculatedStats(Rule rule, PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { - return rule.calculate((T) node, sourceStats, lookup, session, types); + return rule.calculate((T) node, sourceStats, lookup, session, types, tableStatsProvider); } private PlanNodeStatsEstimate getSourceStats(PlanNode sourceNode) diff --git a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorTester.java b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorTester.java index 48cdcda1fdad..13fc22d32421 100644 --- a/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorTester.java +++ b/core/trino-main/src/test/java/io/trino/cost/StatsCalculatorTester.java @@ -75,7 +75,7 @@ public StatsCalculatorAssertion assertStatsFor(Session session, Function new Symbol(entry.getKey()), Map.Entry::getValue))); - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator(stats), session, typeProvider, new CachingTableStatsProvider(localQueryRunner.getMetadata(), session)); CostProvider costProvider = new TestingCostProvider(costs, costCalculatorUsingExchanges, statsProvider, session, typeProvider); SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); return new CostAssertionBuilder(subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown())); @@ -681,7 +681,7 @@ private void assertFragmentedEqualsUnfragmented(PlanNode node, Map stats) { - return (node, sourceStats, lookup, session, types) -> + return (node, sourceStats, lookup, session, types, tableStatsProvider) -> requireNonNull(stats.get(node.getId().toString()), "no stats for node"); } @@ -705,7 +705,7 @@ private PlanCostEstimate calculateCost(PlanNode node, CostCalculator costCalcula { TypeProvider typeProvider = TypeProvider.copyOf(types.entrySet().stream() .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))); - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider, new CachingTableStatsProvider(localQueryRunner.getMetadata(), session)); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.empty(), session, typeProvider); return costProvider.getCost(node); } @@ -714,7 +714,7 @@ private PlanCostEstimate calculateCostFragmentedPlan(PlanNode node, StatsCalcula { TypeProvider typeProvider = TypeProvider.copyOf(types.entrySet().stream() .collect(toImmutableMap(entry -> new Symbol(entry.getKey()), Map.Entry::getValue))); - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, typeProvider, new CachingTableStatsProvider(localQueryRunner.getMetadata(), session)); CostProvider costProvider = new CachingCostProvider(costCalculatorUsingExchanges, statsProvider, Optional.empty(), session, typeProvider); SubPlan subPlan = fragment(new Plan(node, typeProvider, StatsAndCosts.create(node, statsProvider, costProvider))); return subPlan.getFragment().getStatsAndCosts().getCosts().getOrDefault(node.getId(), PlanCostEstimate.unknown()); diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanAssert.java b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanAssert.java index 2695b4008616..efc5f15f6d35 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanAssert.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/assertions/PlanAssert.java @@ -15,6 +15,7 @@ import io.trino.Session; import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; @@ -44,7 +45,7 @@ public static void assertPlan(Session session, Metadata metadata, FunctionManage public static void assertPlan(Session session, Metadata metadata, FunctionManager functionManager, StatsCalculator statsCalculator, Plan actual, Lookup lookup, PlanMatchPattern pattern) { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, actual.getTypes()); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, actual.getTypes(), new CachingTableStatsProvider(metadata, session)); assertPlan(session, metadata, functionManager, statsProvider, actual, lookup, pattern); } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java index 7a42a90b894d..23703efc7520 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestJoinEnumerator.java @@ -19,6 +19,7 @@ import io.trino.Session; import io.trino.cost.CachingCostProvider; import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostComparator; import io.trino.cost.CostProvider; import io.trino.cost.PlanCostEstimate; @@ -117,7 +118,8 @@ private Rule.Context createContext() Optional.empty(), noLookup(), queryRunner.getDefaultSession(), - symbolAllocator.getTypes()); + symbolAllocator.getTypes(), + new CachingTableStatsProvider(queryRunner.getMetadata(), queryRunner.getDefaultSession())); CachingCostProvider costProvider = new CachingCostProvider( queryRunner.getCostCalculator(), statsProvider, diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java index 0614e668ee9e..0d9422eb5a8e 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/RuleAssert.java @@ -17,12 +17,14 @@ import io.trino.Session; import io.trino.cost.CachingCostProvider; import io.trino.cost.CachingStatsProvider; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.CostCalculator; import io.trino.cost.CostProvider; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.StatsAndCosts; import io.trino.cost.StatsCalculator; import io.trino.cost.StatsProvider; +import io.trino.cost.TableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.matching.Capture; import io.trino.matching.Match; @@ -204,7 +206,7 @@ private static RuleApplication applyRule(Rule rule, PlanNode planNode, Ru private String formatPlan(PlanNode plan, TypeProvider types) { return inTransaction(session -> { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, session, types, new CachingTableStatsProvider(metadata, session)); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, session, types); return textLogicalPlan(plan, types, metadata, functionManager, StatsAndCosts.create(plan, statsProvider, costProvider), session, 2, false); }); @@ -223,7 +225,7 @@ private T inTransaction(Function transactionSessionConsumer) private Rule.Context ruleContext(StatsCalculator statsCalculator, CostCalculator costCalculator, SymbolAllocator symbolAllocator, Memo memo, Lookup lookup, Session session) { - StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes()); + StatsProvider statsProvider = new CachingStatsProvider(statsCalculator, Optional.of(memo), lookup, session, symbolAllocator.getTypes(), new CachingTableStatsProvider(metadata, session)); CostProvider costProvider = new CachingCostProvider(costCalculator, statsProvider, Optional.of(memo), session, symbolAllocator.getTypes()); return new Rule.Context() @@ -313,12 +315,12 @@ private static class TestingStatsCalculator } @Override - public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types) + public PlanNodeStatsEstimate calculateStats(PlanNode node, StatsProvider sourceStats, Lookup lookup, Session session, TypeProvider types, TableStatsProvider tableStatsProvider) { if (stats.containsKey(node.getId())) { return stats.get(node.getId()); } - return delegate.calculateStats(node, sourceStats, lookup, session, types); + return delegate.calculateStats(node, sourceStats, lookup, session, types, tableStatsProvider); } public void setNodeStats(PlanNodeId nodeId, PlanNodeStatsEstimate nodeStats) diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java index da6a06f8ce36..771ab6ea4e60 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestBeginTableWrite.java @@ -15,6 +15,7 @@ import com.google.common.collect.ImmutableList; import io.trino.Session; +import io.trino.cost.CachingTableStatsProvider; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.AbstractMockMetadata; import io.trino.metadata.Metadata; @@ -134,14 +135,16 @@ public void testUpdateWithInvalidNode() private void applyOptimization(Function planProvider) { Metadata metadata = new MockMetadata(); + Session session = testSessionBuilder().build(); new BeginTableWrite(metadata, createTestingFunctionManager()) .optimize( - planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata, testSessionBuilder().build())), - testSessionBuilder().build(), + planProvider.apply(new PlanBuilder(new PlanNodeIdAllocator(), metadata, session)), + session, empty(), new SymbolAllocator(), new PlanNodeIdAllocator(), - WarningCollector.NOOP); + WarningCollector.NOOP, + new CachingTableStatsProvider(metadata, session)); } private static class MockMetadata diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java index eda5b4cc2e0c..d27e10202799 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestRemoveUnsupportedDynamicFilters.java @@ -16,6 +16,7 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.trino.connector.CatalogName; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.StatsAndCosts; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; @@ -477,7 +478,13 @@ private PlanNode removeUnsupportedDynamicFilters(PlanNode root) return getQueryRunner().inTransaction(session -> { // metadata.getCatalogHandle() registers the catalog for the transaction session.getCatalog().ifPresent(catalog -> metadata.getCatalogHandle(session, catalog)); - PlanNode rewrittenPlan = new RemoveUnsupportedDynamicFilters(plannerContext).optimize(root, session, builder.getTypes(), new SymbolAllocator(), new PlanNodeIdAllocator(), WarningCollector.NOOP); + PlanNode rewrittenPlan = new RemoveUnsupportedDynamicFilters(plannerContext).optimize(root, + session, + builder.getTypes(), + new SymbolAllocator(), + new PlanNodeIdAllocator(), + WarningCollector.NOOP, + new CachingTableStatsProvider(metadata, session)); new DynamicFiltersChecker().validate(rewrittenPlan, session, plannerContext, createTestingTypeAnalyzer(plannerContext), diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java index 65f05d325480..fe757ebd5b91 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/optimizations/TestUnaliasSymbolReferences.java @@ -17,6 +17,7 @@ import com.google.common.collect.ImmutableMap; import io.trino.Session; import io.trino.connector.CatalogName; +import io.trino.cost.CachingTableStatsProvider; import io.trino.cost.StatsAndCosts; import io.trino.execution.warnings.WarningCollector; import io.trino.metadata.Metadata; @@ -131,7 +132,8 @@ private void assertOptimizedPlan(PlanOptimizer optimizer, PlanCreator planCreato planBuilder.getTypes(), symbolAllocator, idAllocator, - WarningCollector.NOOP); + WarningCollector.NOOP, + new CachingTableStatsProvider(metadata, session)); Plan actual = new Plan(optimized, planBuilder.getTypes(), StatsAndCosts.empty()); PlanAssert.assertPlan(session, queryRunner.getMetadata(), queryRunner.getFunctionManager(), queryRunner.getStatsCalculator(), actual, pattern); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergGetTableStatisticsOperations.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergGetTableStatisticsOperations.java index 1689fba443e4..8b4437f718e8 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergGetTableStatisticsOperations.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergGetTableStatisticsOperations.java @@ -107,7 +107,7 @@ public void testTwoWayJoin() "WHERE o.orderkey = l.orderkey"); assertThat(metadata.getMethodInvocations()).containsExactlyInAnyOrderElementsOf( ImmutableMultiset.builder() - .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 5) + .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 2) .build()); } @@ -119,7 +119,7 @@ public void testThreeWayJoin() "WHERE o.orderkey = l.orderkey AND c.custkey = o.custkey"); assertThat(metadata.getMethodInvocations()).containsExactlyInAnyOrderElementsOf( ImmutableMultiset.builder() - .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 9) + .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 3) .build()); } diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java index ebc1e0db3c07..913a1eaeaf75 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java @@ -159,8 +159,8 @@ public void testJoin() assertFileSystemAccesses("SELECT name, age FROM test_join_t1 JOIN test_join_t2 ON test_join_t2.id = test_join_t1.id", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 10) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 10) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -177,8 +177,8 @@ public void testJoinWithPartitionedTable() assertFileSystemAccesses("SELECT count(*) FROM test_join_partitioned_t1 t1 join test_join_partitioned_t2 t2 on t1.a = t2.foo", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 10) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 10) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestGetTableStatisticsOperations.java b/testing/trino-tests/src/test/java/io/trino/tests/TestGetTableStatisticsOperations.java index 7ba3ab29ccb9..52f75e043011 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestGetTableStatisticsOperations.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestGetTableStatisticsOperations.java @@ -66,7 +66,7 @@ public void testTwoWayJoin() "WHERE o.orderkey = l.orderkey"); assertThat(metadata.getMethodInvocations()).containsExactlyInAnyOrderElementsOf( ImmutableMultiset.builder() - .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 3) + .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 2) .build()); } @@ -78,7 +78,7 @@ public void testThreeWayJoin() "WHERE o.orderkey = l.orderkey AND c.custkey = o.custkey"); assertThat(metadata.getMethodInvocations()).containsExactlyInAnyOrderElementsOf( ImmutableMultiset.builder() - .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 5) + .addCopies(CountingAccessMetadata.Methods.GET_TABLE_STATISTICS, 3) .build()); }