diff --git a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java index 93012585cf60..ad5efc3f1e00 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/Metadata.java +++ b/core/trino-main/src/main/java/io/trino/metadata/Metadata.java @@ -840,6 +840,15 @@ default boolean isMaterializedView(Session session, QualifiedObjectName viewName */ OptionalInt getMaxWriterTasks(Session session, String catalogName); + /** + * Workaround to lack of statistics about IO and CPU operations performed by the connector. + * In the long term, this should be replaced by improvements in the cost model. + * + * @return true if the cumulative cost of splitting a read of the specified tableHandle into multiple reads, + * each of which projects a subset of the required columns, is not significantly more than the cost of reading the specified tableHandle + */ + boolean allowSplittingReadIntoMultipleSubQueries(Session session, TableHandle tableHandle); + /** * Returns writer scaling options for the specified table. This method is called when table handle is not available during CTAS. */ diff --git a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java index 15053fd557e1..cfbc4cf532a6 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/MetadataManager.java @@ -2792,6 +2792,15 @@ public OptionalInt getMaxWriterTasks(Session session, String catalogName) return catalogMetadata.getMetadata(session).getMaxWriterTasks(session.toConnectorSession(catalogHandle)); } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(Session session, TableHandle tableHandle) + { + CatalogHandle catalogHandle = tableHandle.catalogHandle(); + CatalogMetadata catalogMetadata = getCatalogMetadata(session, catalogHandle); + ConnectorSession connectorSession = session.toConnectorSession(catalogHandle); + return catalogMetadata.getMetadata(session).allowSplittingReadIntoMultipleSubQueries(connectorSession, tableHandle.connectorHandle()); + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(Session session, QualifiedObjectName tableName, Map tableProperties) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java index a45a19ed28a8..fcd9f25353cc 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/OptimizerConfig.java @@ -136,6 +136,7 @@ public enum DistinctAggregationsStrategy SINGLE_STEP, MARK_DISTINCT, PRE_AGGREGATE, + SPLIT_TO_SUBQUERIES, AUTOMATIC, } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java index 5b315b12fe37..079f64482063 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanCopier.java @@ -13,6 +13,8 @@ */ package io.trino.sql.planner; +import io.trino.sql.planner.iterative.GroupReference; +import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.optimizations.UnaliasSymbolReferences; import io.trino.sql.planner.plan.AggregationNode; import io.trino.sql.planner.plan.ApplyNode; @@ -56,7 +58,12 @@ private PlanCopier() {} public static NodeAndMappings copyPlan(PlanNode plan, List fields, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator) { - PlanNode copy = SimplePlanRewriter.rewriteWith(new Copier(idAllocator), plan, null); + return copyPlan(plan, fields, symbolAllocator, idAllocator, Lookup.noLookup()); + } + + public static NodeAndMappings copyPlan(PlanNode plan, List fields, SymbolAllocator symbolAllocator, PlanNodeIdAllocator idAllocator, Lookup lookup) + { + PlanNode copy = SimplePlanRewriter.rewriteWith(new Copier(idAllocator, lookup), plan, null); return new UnaliasSymbolReferences().reallocateSymbols(copy, fields, symbolAllocator); } @@ -64,10 +71,12 @@ private static class Copier extends SimplePlanRewriter { private final PlanNodeIdAllocator idAllocator; + private final Lookup lookup; - private Copier(PlanNodeIdAllocator idAllocator) + private Copier(PlanNodeIdAllocator idAllocator, Lookup lookup) { this.idAllocator = requireNonNull(idAllocator, "idAllocator is null"); + this.lookup = requireNonNull(lookup, "lookup is null"); } @Override @@ -76,6 +85,12 @@ protected PlanNode visitPlan(PlanNode node, RewriteContext context) throw new UnsupportedOperationException("plan copying not implemented for " + node.getClass().getSimpleName()); } + @Override + public PlanNode visitGroupReference(GroupReference node, RewriteContext context) + { + return context.rewrite(lookup.resolve(node)); + } + @Override public PlanNode visitAggregation(AggregationNode node, RewriteContext context) { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java index 6c1e514e6f05..a6232ca83919 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/PlanOptimizers.java @@ -76,6 +76,7 @@ import io.trino.sql.planner.iterative.rule.MergeProjectWithValues; import io.trino.sql.planner.iterative.rule.MergeUnion; import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct; +import io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries; import io.trino.sql.planner.iterative.rule.OptimizeDuplicateInsensitiveJoins; import io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations; import io.trino.sql.planner.iterative.rule.OptimizeRowPattern; @@ -682,10 +683,15 @@ public PlanOptimizers( new RemoveRedundantIdentityProjections(), new PushAggregationThroughOuterJoin(), new ReplaceRedundantJoinWithSource(), // Run this after PredicatePushDown optimizer as it inlines filter constants + // Run this after PredicatePushDown and PushProjectionIntoTableScan as it uses stats, and those two rules may reduce the number of partitions + // and columns we need stats for thus reducing the overhead of reading statistics from the metastore. + new MultipleDistinctAggregationsToSubqueries(taskCountEstimator, metadata), + // Run SingleDistinctAggregationToGroupBy after MultipleDistinctAggregationsToSubqueries to ensure the single column distinct is optimized + new SingleDistinctAggregationToGroupBy(), new OptimizeMixedDistinctAggregations(plannerContext, taskCountEstimator), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector - // It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence if enabled + // It also is run before MultipleDistinctAggregationToMarkDistinct to take precedence f enabled new ImplementFilteredAggregations(), // DistinctAggregationToGroupBy will add filters if fired - new MultipleDistinctAggregationToMarkDistinct(taskCountEstimator))), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector + new MultipleDistinctAggregationToMarkDistinct(taskCountEstimator, metadata))), // Run this after aggregation pushdown so that multiple distinct aggregations can be pushed into a connector inlineProjections, simplifyOptimizer, // Re-run the SimplifyExpressions to simplify any recomposed expressions from other optimizations pushProjectionIntoTableScanOptimizer, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationStrategyChooser.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationStrategyChooser.java index 368dd6cdc40d..5f34790157b2 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationStrategyChooser.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/DistinctAggregationStrategyChooser.java @@ -13,13 +13,40 @@ */ package io.trino.sql.planner.iterative.rule; +import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Sets; import io.trino.Session; import io.trino.cost.PlanNodeStatsEstimate; import io.trino.cost.StatsProvider; import io.trino.cost.TaskCountEstimator; +import io.trino.metadata.Metadata; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Lookup; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.FilterNode; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.sql.planner.plan.UnionNode; +import java.util.List; +import java.util.Set; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static io.trino.SystemSessionProperties.distinctAggregationsStrategy; import static io.trino.SystemSessionProperties.getTaskConcurrency; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.AUTOMATIC; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES; +import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationToMarkDistinct.canUseMarkDistinct; +import static io.trino.sql.planner.iterative.rule.MultipleDistinctAggregationsToSubqueries.isAggregationCandidateForSplittingToSubqueries; +import static io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations.canUsePreAggregate; +import static io.trino.sql.planner.iterative.rule.OptimizeMixedDistinctAggregations.distinctAggregationsUniqueArgumentCount; +import static io.trino.sql.planner.optimizations.PlanNodeSearcher.searchFrom; import static java.lang.Double.NaN; import static java.lang.Double.isNaN; import static java.util.Objects.requireNonNull; @@ -31,58 +58,86 @@ public class DistinctAggregationStrategyChooser { private static final int MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = 8; private static final int PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER = MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * 8; + private static final double MAX_JOIN_GROUPING_KEYS_SIZE = 100 * 1024 * 1024; // 100 MB private final TaskCountEstimator taskCountEstimator; + private final Metadata metadata; - private DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator) + public DistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) { this.taskCountEstimator = requireNonNull(taskCountEstimator, "taskCountEstimator is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); } - public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator) + public static DistinctAggregationStrategyChooser createDistinctAggregationStrategyChooser(TaskCountEstimator taskCountEstimator, Metadata metadata) { - return new DistinctAggregationStrategyChooser(taskCountEstimator); + return new DistinctAggregationStrategyChooser(taskCountEstimator, metadata); } - public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider) + public boolean shouldAddMarkDistinct(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) { - return !canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER); + return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == MARK_DISTINCT; } - public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider) + public boolean shouldUsePreAggregate(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) { - if (canParallelizeSingleStepDistinctAggregation(aggregationNode, session, statsProvider, PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER)) { - return false; - } + return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == PRE_AGGREGATE; + } - // mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2 - // because group-by keys are added to every grouping set and this makes partial aggregation behaves badly - return aggregationNode.getGroupingKeys().size() <= 2; + public boolean shouldSplitToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) + { + return chooseMarkDistinctStrategy(aggregationNode, session, statsProvider, lookup) == SPLIT_TO_SUBQUERIES; } - private boolean canParallelizeSingleStepDistinctAggregation(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, int maxOutputRowCountMultiplier) + private DistinctAggregationsStrategy chooseMarkDistinctStrategy(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) { - if (aggregationNode.getGroupingKeys().isEmpty()) { - // global distinct aggregation is computed using a single thread. MarkDistinct will help parallelize the execution. - return false; + DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(session); + if (distinctAggregationsStrategy != AUTOMATIC) { + if (distinctAggregationsStrategy == MARK_DISTINCT && canUseMarkDistinct(aggregationNode)) { + return MARK_DISTINCT; + } + if (distinctAggregationsStrategy == PRE_AGGREGATE && canUsePreAggregate(aggregationNode)) { + return PRE_AGGREGATE; + } + if (distinctAggregationsStrategy == SPLIT_TO_SUBQUERIES && isAggregationCandidateForSplittingToSubqueries(aggregationNode) && isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) { + return SPLIT_TO_SUBQUERIES; + } + // in case strategy is chosen by the session property, but we cannot use it, lets fallback to single-step + return SINGLE_STEP; } - double numberOfDistinctValues = getMinDistinctValueCountEstimate(aggregationNode, statsProvider); - if (isNaN(numberOfDistinctValues)) { - // if the estimate is unknown, use MarkDistinct to avoid query failure - return false; + int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(session); + + // use single_step if it can be parallelized + // small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed. + // this makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes. + // single_step alternatives should increase the parallelism at a cost of CPU. + if (!aggregationNode.getGroupingKeys().isEmpty() && // global distinct aggregation is computed using a single thread. Strategies other than single_step will help parallelize the execution. + !isNaN(numberOfDistinctValues) && // if the estimate is unknown, use alternatives to avoid query failure + (numberOfDistinctValues > PRE_AGGREGATE_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation || + (numberOfDistinctValues > MARK_DISTINCT_MAX_OUTPUT_ROW_COUNT_MULTIPLIER * maxNumberOfConcurrentThreadsForAggregation && + // if the NDV and the number of grouping keys is small, pre-aggregate is faster than single_step at a cost of CPU + aggregationNode.getGroupingKeys().size() > 2))) { + return SINGLE_STEP; } - int maxNumberOfConcurrentThreadsForAggregation = getMaxNumberOfConcurrentThreadsForAggregation(session); - if (numberOfDistinctValues <= maxOutputRowCountMultiplier * maxNumberOfConcurrentThreadsForAggregation) { - // small numberOfDistinctValues reduces the distinct aggregation parallelism, also because the partitioning may be skewed. - // This makes query to underutilize the cluster CPU but also to possibly concentrate memory on few nodes. - // MarkDistinct should increase the parallelism at a cost of CPU. - return false; + if (isAggregationCandidateForSplittingToSubqueries(aggregationNode) && shouldSplitAggregationToSubqueries(aggregationNode, session, statsProvider, lookup)) { + // for simple distinct aggregations on top of table scan it makes sense to split the aggregation into multiple subqueries, + // so they can be handled by the SingleDistinctAggregationToGroupBy and use other single column optimizations + return SPLIT_TO_SUBQUERIES; } - // can parallelize single-step, and single-step distinct is more efficient than alternatives - return true; + // mark-distinct is better than pre-aggregate if the number of group-by keys is bigger than 2 + // because group-by keys are added to every grouping set and this makes partial aggregation behaves badly + if (canUsePreAggregate(aggregationNode) && aggregationNode.getGroupingKeys().size() <= 2) { + return PRE_AGGREGATE; + } + else if (canUseMarkDistinct(aggregationNode)) { + return MARK_DISTINCT; + } + + // if no strategy found, use single_step by default + return SINGLE_STEP; } private int getMaxNumberOfConcurrentThreadsForAggregation(Session session) @@ -102,4 +157,103 @@ private double getMinDistinctValueCountEstimate(AggregationNode aggregationNode, .map(symbol -> sourceStats.getSymbolStatistics(symbol).getDistinctValuesCount()) .max(Double::compareTo).orElse(NaN); } + + // Since, to avoid degradation caused by multiple table scans, we want to split to sub-queries only if we are confident + // it brings big benefits, we are fairly conservative in the decision below. + private boolean shouldSplitAggregationToSubqueries(AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) + { + if (!isAggregationSourceSupportedForSubqueries(aggregationNode.getSource(), session, lookup)) { + // only table scan, union, filter and project are supported + return false; + } + + if (searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(UnionNode.class).findFirst().isPresent()) { + // supporting union with auto decision is complex + return false; + } + + // skip if the source has a filter with low selectivity, as the scan and filter can + // be the main bottleneck in this case, and we want to avoid duplicating this effort. + if (searchFrom(aggregationNode.getSource(), lookup) + .where(node -> node instanceof FilterNode filterNode && isSelective(filterNode, statsProvider)) + .matches()) { + return false; + } + + if (isAdditionalReadOverheadTooExpensive(aggregationNode, statsProvider, lookup)) { + return false; + } + + if (aggregationNode.hasSingleGlobalAggregation()) { + return true; + } + + PlanNodeStatsEstimate stats = statsProvider.getStats(aggregationNode); + double groupingKeysSizeInBytes = stats.getOutputSizeInBytes(aggregationNode.getGroupingKeys()); + + // estimated group by result size is big so that both calculating aggregation multiple times and join would be inefficient + return !(isNaN(groupingKeysSizeInBytes) || groupingKeysSizeInBytes > MAX_JOIN_GROUPING_KEYS_SIZE); + } + + private static boolean isAdditionalReadOverheadTooExpensive(AggregationNode aggregationNode, StatsProvider statsProvider, Lookup lookup) + { + Set distinctInputs = aggregationNode.getAggregations() + .values().stream() + .filter(AggregationNode.Aggregation::isDistinct) + .flatMap(aggregation -> aggregation.getArguments().stream()) + .filter(Reference.class::isInstance) + .map(Symbol::from) + .collect(toImmutableSet()); + + TableScanNode tableScanNode = (TableScanNode) searchFrom(aggregationNode.getSource(), lookup).whereIsInstanceOfAny(TableScanNode.class).findOnlyElement(); + Set additionalColumns = Sets.difference(ImmutableSet.copyOf(tableScanNode.getOutputSymbols()), distinctInputs); + + // Group by columns need to read N times, where N is number of sub-queries. + // Distinct columns are read once. + double singleTableScanDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(tableScanNode.getOutputSymbols()); + double additionalColumnsDataSize = statsProvider.getStats(tableScanNode).getOutputSizeInBytes(additionalColumns); + long subqueryCount = distinctAggregationsUniqueArgumentCount(aggregationNode); + double distinctInputDataSize = singleTableScanDataSize - additionalColumnsDataSize; + double subqueriesTotalDataSize = additionalColumnsDataSize * subqueryCount + distinctInputDataSize; + + return isNaN(subqueriesTotalDataSize) || + isNaN(singleTableScanDataSize) || + // we would read more than 50% more data + subqueriesTotalDataSize / singleTableScanDataSize > 1.5; + } + + private static boolean isSelective(FilterNode filterNode, StatsProvider statsProvider) + { + double filterOutputRowCount = statsProvider.getStats(filterNode).getOutputRowCount(); + double filterSourceRowCount = statsProvider.getStats(filterNode.getSource()).getOutputRowCount(); + return filterOutputRowCount / filterSourceRowCount < 0.5; + } + + // Only table scan, union, filter and project are supported. + // PlanCopier.copyPlan must support all supported nodes here. + // Additionally, we should split the table scan only if reading single columns is efficient in the given connector. + private boolean isAggregationSourceSupportedForSubqueries(PlanNode source, Session session, Lookup lookup) + { + if (searchFrom(source, lookup) + .where(node -> !(node instanceof TableScanNode + || node instanceof FilterNode + || node instanceof ProjectNode + || node instanceof UnionNode)) + .findFirst() + .isPresent()) { + return false; + } + + List tableScans = searchFrom(source, lookup) + .whereIsInstanceOfAny(TableScanNode.class) + .findAll(); + + if (tableScans.isEmpty()) { + // at least one table scan is expected + return false; + } + + return tableScans.stream() + .allMatch(tableScanNode -> metadata.allowSplittingReadIntoMultipleSubQueries(session, ((TableScanNode) tableScanNode).getTable())); + } } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java index 4c9fb61c99c2..bc41ef0a012b 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationToMarkDistinct.java @@ -13,13 +13,13 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import io.trino.cost.TaskCountEstimator; import io.trino.matching.Captures; import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; import io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.iterative.Rule; @@ -68,12 +68,13 @@ public class MultipleDistinctAggregationToMarkDistinct implements Rule { private static final Pattern PATTERN = aggregation() - .matching( - Predicates.and( - MultipleDistinctAggregationToMarkDistinct::hasNoDistinctWithFilterOrMask, - Predicates.or( - MultipleDistinctAggregationToMarkDistinct::hasMultipleDistincts, - MultipleDistinctAggregationToMarkDistinct::hasMixedDistinctAndNonDistincts))); + .matching(MultipleDistinctAggregationToMarkDistinct::canUseMarkDistinct); + + public static boolean canUseMarkDistinct(AggregationNode aggregationNode) + { + return hasNoDistinctWithFilterOrMask(aggregationNode) && + (hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode)); + } private static boolean hasNoDistinctWithFilterOrMask(AggregationNode aggregationNode) { @@ -105,9 +106,9 @@ private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregati private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser; - public MultipleDistinctAggregationToMarkDistinct(TaskCountEstimator taskCountEstimator) + public MultipleDistinctAggregationToMarkDistinct(TaskCountEstimator taskCountEstimator, Metadata metadata) { - this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator); + this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator, metadata); } @Override @@ -121,7 +122,7 @@ public Result apply(AggregationNode parent, Captures captures, Context context) { DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession()); if (!(distinctAggregationsStrategy.equals(MARK_DISTINCT) || - (distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldAddMarkDistinct(parent, context.getSession(), context.getStatsProvider())))) { + (distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldAddMarkDistinct(parent, context.getSession(), context.getStatsProvider(), context.getLookup())))) { return Result.empty(); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java new file mode 100644 index 000000000000..4e2ae1ed7f23 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/MultipleDistinctAggregationsToSubqueries.java @@ -0,0 +1,208 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.cost.TaskCountEstimator; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.metadata.Metadata; +import io.trino.sql.ir.Expression; +import io.trino.sql.planner.NodeAndMappings; +import io.trino.sql.planner.PlanCopier; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.AggregationNode.Aggregation; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.JoinNode; +import io.trino.sql.planner.plan.JoinNode.EquiJoinClause; +import io.trino.sql.planner.plan.PlanNode; +import io.trino.sql.planner.plan.ProjectNode; + +import java.util.Comparator; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Optional; +import java.util.Set; +import java.util.stream.IntStream; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser; +import static io.trino.sql.planner.plan.JoinType.INNER; +import static io.trino.sql.planner.plan.Patterns.aggregation; + +/** + * Transforms plans of the following shape: + *
+ * - Aggregation
+ *        GROUP BY (k)
+ *        F1(DISTINCT a0, a1, ...)
+ *        F2(DISTINCT b0, b1, ...)
+ *        F3(DISTINCT c0, c1, ...)
+ *     - X
+ * 
+ * into + *
+ * - Join
+ *     on left.k = right.k
+ *     - Aggregation
+ *         GROUP BY (k)
+ *         F1(DISTINCT a0, a1, ...)
+ *         F2(DISTINCT b0, b1, ...)
+ *       - X
+ *     - Aggregation
+ *         GROUP BY (k)
+ *         F3(DISTINCT c0, c1, ...)
+ *       - X
+ * 
+ *

+ * This improves plan parallelism and allows {@link SingleDistinctAggregationToGroupBy} to optimize the single input distinct aggregation further. + * The cost is we calculate X and GROUP BY (k) multiple times, so this rule is only beneficial if the calculations are cheap compared to + * other distinct aggregation strategies. + */ +public class MultipleDistinctAggregationsToSubqueries + implements Rule +{ + private static final Pattern PATTERN = aggregation() + .matching(MultipleDistinctAggregationsToSubqueries::isAggregationCandidateForSplittingToSubqueries); + + // In addition to this check, DistinctAggregationController.isAggregationSourceSupportedForSubqueries, that accesses Metadata, + // needs also pass, for the plan to be applicable for this rule, + public static boolean isAggregationCandidateForSplittingToSubqueries(AggregationNode aggregationNode) + { + // TODO: we could support non-distinct aggregations if SingleDistinctAggregationToGroupBy supports it + return SingleDistinctAggregationToGroupBy.allDistinctAggregates(aggregationNode) && + OptimizeMixedDistinctAggregations.hasMultipleDistincts(aggregationNode) && + // if we have more than one grouping set, we can have duplicated grouping sets and handling this is complex + aggregationNode.getGroupingSetCount() == 1 && + // hash symbol is added late in the planning, and handling it here would increase complexity + aggregationNode.getHashSymbol().isEmpty(); + } + + private final DistinctAggregationStrategyChooser distinctAggregationStrategyChooser; + + public MultipleDistinctAggregationsToSubqueries(TaskCountEstimator taskCountEstimator, Metadata metadata) + { + this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator, metadata); + } + + @Override + public Pattern getPattern() + { + return PATTERN; + } + + @Override + public Result apply(AggregationNode aggregationNode, Captures captures, Context context) + { + if (!distinctAggregationStrategyChooser.shouldSplitToSubqueries(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())) { + return Result.empty(); + } + // group aggregations by arguments + Map, Map> aggregationsByArguments = new LinkedHashMap<>(aggregationNode.getAggregations().size()); + // sort the aggregation by output symbol to have consistent join layout + List> sortedAggregations = aggregationNode.getAggregations().entrySet() + .stream() + .sorted(Comparator.comparing(entry -> entry.getKey().name())) + .collect(toImmutableList()); + for (Entry entry : sortedAggregations) { + aggregationsByArguments.compute(ImmutableSet.copyOf(entry.getValue().getArguments()), (_, current) -> { + if (current == null) { + current = new HashMap<>(); + } + current.put(entry.getKey(), entry.getValue()); + return current; + }); + } + + PlanNode right = null; + List rightJoinSymbols = null; + Assignments.Builder assignments = Assignments.builder(); + List> aggregationsByArgumentsList = ImmutableList.copyOf(aggregationsByArguments.values()); + for (int i = aggregationsByArgumentsList.size() - 1; i > 0; i--) { + // go from right to left and build the right side of the join + Map aggregations = aggregationsByArgumentsList.get(i); + AggregationNode subAggregationNode = buildSubAggregation(aggregationNode, aggregations, assignments, context); + + if (right == null) { + right = subAggregationNode; + rightJoinSymbols = subAggregationNode.getGroupingKeys(); + } + else { + right = buildJoin(subAggregationNode, subAggregationNode.getGroupingKeys(), right, rightJoinSymbols, context); + } + } + + // the first aggregation is the left side of the top join + AggregationNode left = buildSubAggregation(aggregationNode, aggregationsByArgumentsList.getFirst(), assignments, context); + + for (int i = 0; i < left.getGroupingKeys().size(); i++) { + assignments.put(aggregationNode.getGroupingKeys().get(i), left.getGroupingKeys().get(i).toSymbolReference()); + } + JoinNode topJoin = buildJoin(left, left.getGroupingKeys(), right, rightJoinSymbols, context); + ProjectNode result = new ProjectNode(aggregationNode.getId(), topJoin, assignments.build()); + return Result.ofPlanNode(result); + } + + private AggregationNode buildSubAggregation(AggregationNode aggregationNode, Map aggregations, Assignments.Builder assignments, Context context) + { + List originalAggregationOutputSymbols = ImmutableList.copyOf(aggregations.keySet()); + // copy the plan so that both plan node ids and symbols are not duplicated between sub aggregations + NodeAndMappings copied = PlanCopier.copyPlan( + AggregationNode.builderFrom(aggregationNode).setAggregations(aggregations).build(), + originalAggregationOutputSymbols, + context.getSymbolAllocator(), + context.getIdAllocator(), + context.getLookup()); + AggregationNode subAggregationNode = (AggregationNode) copied.getNode(); + // add the mapping from the new output symbols to original ones + for (int i = 0; i < originalAggregationOutputSymbols.size(); i++) { + assignments.put(originalAggregationOutputSymbols.get(i), copied.getFields().get(i).toSymbolReference()); + } + return subAggregationNode; + } + + private JoinNode buildJoin(PlanNode left, List leftJoinSymbols, PlanNode right, List rightJoinSymbols, Context context) + { + checkArgument(leftJoinSymbols.size() == rightJoinSymbols.size()); + List criteria = IntStream.range(0, leftJoinSymbols.size()) + .mapToObj(i -> new EquiJoinClause(leftJoinSymbols.get(i), rightJoinSymbols.get(i))) + .collect(toImmutableList()); + + // TODO: we dont need dynamic filters for this join at all. We could add skipDf field to the JoinNode and make use of it in PredicatePushDown + return new JoinNode( + context.getIdAllocator().getNextId(), + INNER, + left, + right, + criteria, + left.getOutputSymbols(), + right.getOutputSymbols(), + false, // since we only work on global aggregation or grouped rows, there are no duplicates, so we don't have to skip it + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + Optional.empty(), + ImmutableMap.of(), + Optional.empty()); + } +} diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeMixedDistinctAggregations.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeMixedDistinctAggregations.java index 32ee069f5556..f9f9fb427429 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeMixedDistinctAggregations.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/OptimizeMixedDistinctAggregations.java @@ -13,7 +13,6 @@ */ package io.trino.sql.planner.iterative.rule; -import com.google.common.base.Predicates; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.ImmutableSet; @@ -83,18 +82,25 @@ public class OptimizeMixedDistinctAggregations private static final CatalogSchemaFunctionName APPROX_DISTINCT_NAME = builtinFunctionName("approx_distinct"); private static final Pattern PATTERN = aggregation() - .matching(Predicates.and( - Predicates.or( - // single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id) - OptimizeMixedDistinctAggregations::hasMultipleDistincts, - OptimizeMixedDistinctAggregations::hasMixedDistinctAndNonDistincts), - OptimizeMixedDistinctAggregations::allDistinctAggregationsHaveSingleArgument, - OptimizeMixedDistinctAggregations::noFilters, - OptimizeMixedDistinctAggregations::noMasks, - aggregation -> !aggregation.hasOrderings(), - aggregation -> aggregation.getStep().equals(SINGLE))); - - private static boolean hasMultipleDistincts(AggregationNode aggregationNode) + .matching(OptimizeMixedDistinctAggregations::canUsePreAggregate); + + public static boolean canUsePreAggregate(AggregationNode aggregationNode) + { + // single distinct can be supported in this rule, but it is already supported by SingleDistinctAggregationToGroupBy, which produces simpler plans (without group-id) + return (hasMultipleDistincts(aggregationNode) || hasMixedDistinctAndNonDistincts(aggregationNode)) && + allDistinctAggregationsHaveSingleArgument(aggregationNode) && + noFilters(aggregationNode) && + noMasks(aggregationNode) && + !aggregationNode.hasOrderings() && + aggregationNode.getStep().equals(SINGLE); + } + + public static boolean hasMultipleDistincts(AggregationNode aggregationNode) + { + return distinctAggregationsUniqueArgumentCount(aggregationNode) > 1; + } + + public static long distinctAggregationsUniqueArgumentCount(AggregationNode aggregationNode) { return aggregationNode.getAggregations() .values().stream() @@ -102,7 +108,7 @@ private static boolean hasMultipleDistincts(AggregationNode aggregationNode) .map(Aggregation::getArguments) .map(HashSet::new) .distinct() - .count() > 1; + .count(); } private static boolean hasMixedDistinctAndNonDistincts(AggregationNode aggregationNode) @@ -143,7 +149,7 @@ private static boolean noMasks(AggregationNode aggregationNode) public OptimizeMixedDistinctAggregations(PlannerContext plannerContext, TaskCountEstimator taskCountEstimator) { this.functionResolver = plannerContext.getFunctionResolver(); - this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator); + this.distinctAggregationStrategyChooser = createDistinctAggregationStrategyChooser(taskCountEstimator, plannerContext.getMetadata()); } @Override @@ -158,7 +164,7 @@ public Result apply(AggregationNode node, Captures captures, Context context) DistinctAggregationsStrategy distinctAggregationsStrategy = distinctAggregationsStrategy(context.getSession()); if (!(distinctAggregationsStrategy.equals(PRE_AGGREGATE) || - (distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldUsePreAggregate(node, context.getSession(), context.getStatsProvider())))) { + (distinctAggregationsStrategy.equals(AUTOMATIC) && distinctAggregationStrategyChooser.shouldUsePreAggregate(node, context.getSession(), context.getStatsProvider(), context.getLookup())))) { return Result.empty(); } @@ -209,9 +215,9 @@ public Result apply(AggregationNode node, Captures captures, Context context) Aggregation originalAggregation = entry.getValue(); if (originalAggregation.isDistinct()) { // for the outer aggregation node, replace distinct aggregation with non-distinct aggregation with FILTER (WHERE group_id=X) - Symbol aggregationInput = Symbol.from(originalAggregation.getArguments().get(0)); + Symbol aggregationInput = Symbol.from(originalAggregation.getArguments().getFirst()); Integer groupId = distinctAggregationArgumentToGroupIdMap.get(aggregationInput); - Symbol groupIdFilterSymbol = groupIdFilterSymbolByGroupId.computeIfAbsent(groupId, id -> { + Symbol groupIdFilterSymbol = groupIdFilterSymbolByGroupId.computeIfAbsent(groupId, _ -> { Symbol filterSymbol = symbolAllocator.newSymbol("gid-filter-" + groupId, BOOLEAN); groupIdFilters.put(filterSymbol, new Comparison( EQUAL, diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java index c44b7d3094e2..08470087525c 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/SingleDistinctAggregationToGroupBy.java @@ -74,7 +74,7 @@ private static boolean hasSingleDistinctInput(AggregationNode aggregationNode) .count() == 1; } - private static boolean allDistinctAggregates(AggregationNode aggregationNode) + public static boolean allDistinctAggregates(AggregationNode aggregationNode) { return aggregationNode.getAggregations() .values().stream() diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java index 1115dd9403e3..eb03924196fb 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingConnectorMetadata.java @@ -1414,6 +1414,15 @@ public OptionalInt getMaxWriterTasks(ConnectorSession session) } } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Span span = startSpan("allowSplittingReadIntoMultipleSubQueries"); + try (var ignored = scopedSpan(span)) { + return delegate.allowSplittingReadIntoMultipleSubQueries(session, tableHandle); + } + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { diff --git a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java index cee335201efe..2441cee43568 100644 --- a/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/tracing/TracingMetadata.java @@ -1522,6 +1522,15 @@ public OptionalInt getMaxWriterTasks(Session session, String catalogName) } } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(Session session, TableHandle tableHandle) + { + Span span = startSpan("allowSplittingReadIntoMultipleSubQueries", tableHandle); + try (var ignored = scopedSpan(span)) { + return delegate.allowSplittingReadIntoMultipleSubQueries(session, tableHandle); + } + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(Session session, QualifiedObjectName tableName, Map tableProperties) { diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java index 08dd9abf0df8..ad14fb506016 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnector.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnector.java @@ -195,6 +195,7 @@ public class MockConnector private final BiFunction> getLayoutForTableExecute; private final WriterScalingOptions writerScalingOptions; private final Supplier> capabilities; + private final boolean allowSplittingReadIntoMultipleSubQueries; MockConnector( Function metadataWrapper, @@ -247,7 +248,8 @@ public class MockConnector OptionalInt maxWriterTasks, BiFunction> getLayoutForTableExecute, WriterScalingOptions writerScalingOptions, - Supplier> capabilities) + Supplier> capabilities, + boolean allowSplittingReadIntoMultipleSubQueries) { this.metadataWrapper = requireNonNull(metadataWrapper, "metadataWrapper is null"); this.sessionProperties = ImmutableList.copyOf(requireNonNull(sessionProperties, "sessionProperties is null")); @@ -300,6 +302,7 @@ public class MockConnector this.getLayoutForTableExecute = requireNonNull(getLayoutForTableExecute, "getLayoutForTableExecute is null"); this.writerScalingOptions = requireNonNull(writerScalingOptions, "writerScalingOptions is null"); this.capabilities = requireNonNull(capabilities, "capabilities is null"); + this.allowSplittingReadIntoMultipleSubQueries = allowSplittingReadIntoMultipleSubQueries; } @Override @@ -317,7 +320,7 @@ public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel @Override public ConnectorMetadata getMetadata(ConnectorSession session, ConnectorTransactionHandle transaction) { - return metadataWrapper.apply(new MockConnectorMetadata()); + return metadataWrapper.apply(new MockConnectorMetadata(allowSplittingReadIntoMultipleSubQueries)); } @Override @@ -444,6 +447,13 @@ public Set getCapabilities() private class MockConnectorMetadata implements ConnectorMetadata { + private final boolean allowSplittingReadIntoMultipleSubQueries; + + public MockConnectorMetadata(boolean allowSplittingReadIntoMultipleSubQueries) + { + this.allowSplittingReadIntoMultipleSubQueries = allowSplittingReadIntoMultipleSubQueries; + } + @Override public boolean schemaExists(ConnectorSession session, String schemaName) { @@ -980,6 +990,12 @@ public OptionalInt getMaxWriterTasks(ConnectorSession session) return maxWriterTasks; } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return allowSplittingReadIntoMultipleSubQueries; + } + @Override public BeginTableExecuteResult beginTableExecute(ConnectorSession session, ConnectorTableExecuteHandle tableExecuteHandle, ConnectorTableHandle updatedSourceTableHandle) { diff --git a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java index 666e91ca019a..7a0c8e090360 100644 --- a/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java +++ b/core/trino-main/src/test/java/io/trino/connector/MockConnectorFactory.java @@ -149,6 +149,7 @@ public class MockConnectorFactory private final WriterScalingOptions writerScalingOptions; private final Supplier> capabilities; + private final boolean allowSplittingReadIntoMultipleSubQueries; private MockConnectorFactory( String name, @@ -157,7 +158,7 @@ private MockConnectorFactory( Function> listSchemaNames, BiFunction> listTables, Optional>> streamTableColumns, - Optional streamRelationColumns, + Optional streamRelationColumns, BiFunction> getViews, Supplier>> getViewProperties, Supplier>> getMaterializedViewProperties, @@ -202,7 +203,8 @@ private MockConnectorFactory( OptionalInt maxWriterTasks, BiFunction> getLayoutForTableExecute, WriterScalingOptions writerScalingOptions, - Supplier> capabilities) + Supplier> capabilities, + boolean allowSplittingReadIntoMultipleSubQueries) { this.name = requireNonNull(name, "name is null"); this.sessionProperty = ImmutableList.copyOf(requireNonNull(sessionProperty, "sessionProperty is null")); @@ -256,6 +258,7 @@ private MockConnectorFactory( this.getLayoutForTableExecute = requireNonNull(getLayoutForTableExecute, "getLayoutForTableExecute is null"); this.writerScalingOptions = requireNonNull(writerScalingOptions, "writerScalingOptions is null"); this.capabilities = requireNonNull(capabilities, "capabilities is null"); + this.allowSplittingReadIntoMultipleSubQueries = allowSplittingReadIntoMultipleSubQueries; } @Override @@ -318,7 +321,8 @@ public Connector create(String catalogName, Map config, Connecto maxWriterTasks, getLayoutForTableExecute, writerScalingOptions, - capabilities); + capabilities, + allowSplittingReadIntoMultipleSubQueries); } public static MockConnectorFactory create() @@ -474,6 +478,7 @@ public static final class Builder private BiFunction> getLayoutForTableExecute = (session, handle) -> Optional.empty(); private WriterScalingOptions writerScalingOptions = WriterScalingOptions.DISABLED; private Supplier> capabilities = ImmutableSet::of; + private boolean allowSplittingReadIntoMultipleSubQueries; private Builder() {} @@ -833,6 +838,12 @@ public Builder withCapabilities(Supplier> capabilitie return this; } + public Builder withAllowSplittingReadIntoMultipleSubQueries(boolean allowSplittingReadIntoMultipleSubQueries) + { + this.allowSplittingReadIntoMultipleSubQueries = allowSplittingReadIntoMultipleSubQueries; + return this; + } + public MockConnectorFactory build() { Optional accessControl = Optional.empty(); @@ -891,7 +902,8 @@ public MockConnectorFactory build() maxWriterTasks, getLayoutForTableExecute, writerScalingOptions, - capabilities); + capabilities, + allowSplittingReadIntoMultipleSubQueries); } public static Function> defaultListSchemaNames() diff --git a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java index 34039068107d..6955c5dfccbf 100644 --- a/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java +++ b/core/trino-main/src/test/java/io/trino/metadata/AbstractMockMetadata.java @@ -1019,6 +1019,12 @@ public OptionalInt getMaxWriterTasks(Session session, String catalogName) throw new UnsupportedOperationException(); } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(Session session, TableHandle tableHandle) + { + throw new UnsupportedOperationException(); + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(Session session, QualifiedObjectName tableName, Map tableProperties) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 4006c1bca139..6029a9a7ac1b 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -30,6 +30,7 @@ import io.trino.spi.predicate.TupleDomain; import io.trino.spi.predicate.ValueSet; import io.trino.spi.type.RowType; +import io.trino.sql.ir.Between; import io.trino.sql.ir.Call; import io.trino.sql.ir.Cast; import io.trino.sql.ir.Coalesce; @@ -94,6 +95,7 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.MoreCollectors.toOptional; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.SystemSessionProperties.COST_ESTIMATION_WORKER_COUNT; import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; import static io.trino.SystemSessionProperties.DISTRIBUTED_SORT; import static io.trino.SystemSessionProperties.FILTERING_SEMI_JOIN_TO_INNER; @@ -142,6 +144,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.exchange; import static io.trino.sql.planner.assertions.PlanMatchPattern.expression; import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.groupId; import static io.trino.sql.planner.assertions.PlanMatchPattern.identityProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.join; import static io.trino.sql.planner.assertions.PlanMatchPattern.limit; @@ -159,6 +162,7 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.strictConstrainedTableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictProject; import static io.trino.sql.planner.assertions.PlanMatchPattern.strictTableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.symbol; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; import static io.trino.sql.planner.assertions.PlanMatchPattern.topN; import static io.trino.sql.planner.assertions.PlanMatchPattern.topNRanking; @@ -205,6 +209,7 @@ public class TestLogicalPlanner private static final ResolvedFunction LOWER = FUNCTIONS.resolveFunction("lower", fromTypes(VARCHAR)); private static final ResolvedFunction COMBINE_HASH = FUNCTIONS.resolveFunction("combine_hash", fromTypes(BIGINT, BIGINT)); private static final ResolvedFunction HASH_CODE = createTestMetadataManager().resolveOperator(OperatorType.HASH_CODE, ImmutableList.of(BIGINT)); + private static final ResolvedFunction CONCAT = FUNCTIONS.resolveFunction("concat", fromTypes(VARCHAR, VARCHAR)); private static final WindowNode.Frame ROWS_FROM_CURRENT = new WindowNode.Frame( ROWS, @@ -427,6 +432,102 @@ public void testDistinctOverConstants() tableScan("orders", ImmutableMap.of("orderstatus", "orderstatus")))))); } + @Test + public void testSingleDistinct() + { + assertPlan("SELECT custkey, orderstatus, COUNT(DISTINCT orderkey) FROM orders GROUP BY custkey, orderstatus", + anyTree( + aggregation( + singleGroupingSet("custkey", "orderstatus"), + ImmutableMap.of("count", aggregationFunction("count", ImmutableList.of("orderkey"))), + aggregation( + singleGroupingSet("custkey", "orderstatus", "orderkey"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + exchange(aggregation( + singleGroupingSet("custkey", "orderstatus", "orderkey"), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + tableScan( + "orders", + ImmutableMap.of("orderstatus", "orderstatus", "custkey", "custkey", "orderkey", "orderkey")))))))); + } + + @Test + public void testPreAggregateDistinct() + { + assertPlan("SELECT COUNT(DISTINCT orderkey), COUNT(DISTINCT custkey) FROM orders", + anyTree( + aggregation( + singleGroupingSet(), + ImmutableMap.of(Optional.of("count1"), aggregationFunction("count", false, ImmutableList.of(symbol("orderkey"))), + Optional.of("count2"), aggregationFunction("count", false, ImmutableList.of(symbol("custkey")))), + ImmutableList.of(), + ImmutableList.of("gid-filter-0", "gid-filter-1"), + Optional.empty(), + SINGLE, + project( + ImmutableMap.of( + "gid-filter-0", expression(new Comparison(EQUAL, new Reference(BIGINT, "groupId"), new Constant(BIGINT, 0L))), + "gid-filter-1", expression(new Comparison(EQUAL, new Reference(BIGINT, "groupId"), new Constant(BIGINT, 1L)))), + aggregation( + singleGroupingSet("custkey", "orderkey", "groupId"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + exchange(aggregation( + singleGroupingSet("orderkey", "custkey", "groupId"), + ImmutableMap.of(), + Optional.empty(), + PARTIAL, + filter( + new Between(new Reference(BIGINT, "groupId"), new Constant(BIGINT, 0L), new Constant(BIGINT, 1L)), + groupId( + ImmutableList.of(ImmutableList.of("orderkey"), ImmutableList.of("custkey")), + "groupId", + tableScan( + "orders", + ImmutableMap.of("custkey", "custkey", "orderkey", "orderkey"))))))))))); + } + + @Test + public void testMultipleDistinctUsingMarkDistinct() + { + assertPlan("SELECT orderstatus, orderstatus || '1', orderstatus || '2', COUNT(DISTINCT orderkey), COUNT(DISTINCT custkey) FROM orders GROUP BY 1, 2, 3", + Session.builder(getPlanTester().getDefaultSession()) + .setSystemProperty(COST_ESTIMATION_WORKER_COUNT, "6") + .build(), + anyTree( + aggregation( + singleGroupingSet("orderstatus", "orderstatus1", "orderstatus2"), + ImmutableMap.of(Optional.of("count1"), aggregationFunction("count", false, ImmutableList.of(symbol("custkey"))), + Optional.of("count2"), aggregationFunction("count", false, ImmutableList.of(symbol("orderkey")))), + ImmutableList.of(), + ImmutableList.of("custkey_mask", "orderkey_mask"), + Optional.empty(), + SINGLE, + markDistinct( + "custkey_mask", + ImmutableList.of("orderstatus", "orderstatus1", "orderstatus2", "custkey"), + markDistinct( + "orderkey_mask", + ImmutableList.of("orderstatus", "orderstatus1", "orderstatus2", "orderkey"), + exchange( + project( + ImmutableMap.of( + "orderstatus1", expression(new Call(CONCAT, ImmutableList.of( + new Cast(new Reference(createVarcharType(1), "orderstatus"), VARCHAR), + new Constant(VARCHAR, utf8Slice("1"))))), + "orderstatus2", expression(new Call(CONCAT, ImmutableList.of( + new Cast(new Reference(createVarcharType(1), "orderstatus"), VARCHAR), + new Constant(VARCHAR, utf8Slice("2")))))), + tableScan( + "orders", + ImmutableMap.of("custkey", "custkey", "orderkey", "orderkey", "orderstatus", "orderstatus"))))))))); + } + @Test public void testInnerInequalityJoinNoEquiJoinConjuncts() { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationStrategyChooser.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationStrategyChooser.java index 4b2af632574e..2a984d45f9ab 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationStrategyChooser.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestDistinctAggregationStrategyChooser.java @@ -22,29 +22,49 @@ import io.trino.cost.SymbolStatsEstimate; import io.trino.cost.TaskCountEstimator; import io.trino.execution.warnings.WarningCollector; +import io.trino.metadata.AbstractMockMetadata; +import io.trino.metadata.Metadata; +import io.trino.metadata.TableHandle; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.security.AllowAllAccessControl; +import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.PlanNodeIdAllocator; import io.trino.sql.planner.Symbol; import io.trino.sql.planner.SymbolAllocator; import io.trino.sql.planner.iterative.Lookup; -import io.trino.sql.planner.iterative.Rule; +import io.trino.sql.planner.iterative.Rule.Context; import io.trino.sql.planner.plan.AggregationNode; +import io.trino.sql.planner.plan.AggregationNode.Aggregation; import io.trino.sql.planner.plan.PlanNode; import io.trino.sql.planner.plan.PlanNodeId; -import io.trino.sql.planner.plan.ValuesNode; +import io.trino.sql.planner.plan.TableScanNode; +import io.trino.transaction.TestingTransactionManager; +import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.TestInstance; import org.junit.jupiter.api.parallel.Execution; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.Function; import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.trino.SessionTestUtils.TEST_SESSION; +import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; +import static io.trino.SystemSessionProperties.getTaskConcurrency; import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.sql.analyzer.TypeSignatureProvider.fromTypes; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.MARK_DISTINCT; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.PRE_AGGREGATE; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SINGLE_STEP; +import static io.trino.sql.planner.OptimizerConfig.DistinctAggregationsStrategy.SPLIT_TO_SUBQUERIES; import static io.trino.sql.planner.iterative.rule.DistinctAggregationStrategyChooser.createDistinctAggregationStrategyChooser; import static io.trino.sql.planner.plan.AggregationNode.singleAggregation; import static io.trino.sql.planner.plan.AggregationNode.singleGroupingSet; +import static io.trino.testing.TestingHandles.TEST_TABLE_HANDLE; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.TransactionBuilder.transaction; import static org.assertj.core.api.Assertions.assertThat; import static org.junit.jupiter.api.TestInstance.Lifecycle.PER_CLASS; import static org.junit.jupiter.api.parallel.ExecutionMode.CONCURRENT; @@ -55,83 +75,162 @@ public class TestDistinctAggregationStrategyChooser { private static final int NODE_COUNT = 6; private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> NODE_COUNT); + private static final TestingFunctionResolution functionResolution = new TestingFunctionResolution(); + private TestingTransactionManager transactionManager; + private Metadata metadata; + + @BeforeAll + public final void setUp() + { + this.transactionManager = new TestingTransactionManager(); + this.metadata = new AbstractMockMetadata() + { + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(Session session, TableHandle tableHandle) + { + return true; + } + }; + } @Test public void testSingleStepPreferredForHighCardinalitySingleGroupByKey() { - DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR); + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); - AggregationNode aggregationNode = singleAggregation( - new PlanNodeId("aggregation"), - source, - ImmutableMap.of(), - singleGroupingSet(ImmutableList.of(groupingKey))); - Rule.Context context = context( + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); + Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of( groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()))), symbolAllocator); - assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isFalse(); + assertShouldUseSingleStep(aggregationStrategyChooser, aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); } @Test public void testSingleStepPreferredForHighCardinalityMultipleGroupByKeys() { - DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR); + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); Symbol lowCardinalityGroupingKey = symbolAllocator.newSymbol("lowCardinalityGroupingKey", BIGINT); Symbol highCardinalityGroupingKey = symbolAllocator.newSymbol("highCardinalityGroupingKey", BIGINT); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); - AggregationNode aggregationNode = singleAggregation( - new PlanNodeId("aggregation"), - source, - ImmutableMap.of(), - singleGroupingSet(ImmutableList.of(lowCardinalityGroupingKey, highCardinalityGroupingKey))); - Rule.Context context = context( + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(lowCardinalityGroupingKey, highCardinalityGroupingKey), source, symbolAllocator); + Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of( lowCardinalityGroupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(10).build(), highCardinalityGroupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()))), symbolAllocator); - assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isFalse(); + assertShouldUseSingleStep(aggregationStrategyChooser, aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); } @Test public void testPreAggregatePreferredForLowCardinality2GroupByKeys() { - DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR); + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( symbolAllocator.newSymbol("key1", BIGINT), symbolAllocator.newSymbol("key2", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); - AggregationNode aggregationNode = singleAggregation( - new PlanNodeId("aggregation"), - source, - ImmutableMap.of(), - singleGroupingSet(groupingKeys)); - Rule.Context context = context( + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); + Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate( 1_000_000, groupingKeys.stream().collect(toImmutableMap( Function.identity(), _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), new SymbolAllocator()); - assertThat(aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue(); - assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue(); + + assertThat(aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isTrue(); + assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isFalse(); + } + + @Test + public void testPreAggregatePreferredForUnknownStatisticsAnd2GroupByKeys() + { + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + + List groupingKeys = ImmutableList.of( + symbolAllocator.newSymbol("key1", BIGINT), + symbolAllocator.newSymbol("key2", BIGINT)); + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); + Context context = context(ImmutableMap.of(), new SymbolAllocator()); + assertThat(aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isTrue(); + assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isFalse(); + } + + @Test + public void testPreAggregatePreferredForMediumCardinalitySingleGroupByKey() + { + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); + + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10).build()))), + symbolAllocator); + + assertThat(aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isTrue(); + } + + @Test + public void testSingleStepPreferredForMediumCardinality3GroupByKeys() + { + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + List groupingKeys = ImmutableList.of( + symbolAllocator.newSymbol("key1", BIGINT), + symbolAllocator.newSymbol("key2", BIGINT), + symbolAllocator.newSymbol("key3", BIGINT)); + + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10, + groupingKeys.stream().collect(toImmutableMap( + Function.identity(), + _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(NODE_COUNT * getTaskConcurrency(TEST_SESSION) * 10).build())))), + symbolAllocator); + + assertShouldUseSingleStep(aggregationStrategyChooser, aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + } + + @Test + public void testSplitToSubqueriesPreferredForGlobalAggregation() + { + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(), source, symbolAllocator); + assertThat((boolean) inTransaction(session -> { + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1_000_000, ImmutableMap.of())), + session, + symbolAllocator); + return aggregationStrategyChooser.shouldSplitToSubqueries(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + })) + .isTrue(); } @Test public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() { - DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR); + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); SymbolAllocator symbolAllocator = new SymbolAllocator(); List groupingKeys = ImmutableList.of( @@ -139,26 +238,162 @@ public void testMarkDistinctPreferredForLowCardinality3GroupByKeys() symbolAllocator.newSymbol("key2", BIGINT), symbolAllocator.newSymbol("key3", BIGINT)); - ValuesNode source = new ValuesNode(new PlanNodeId("source"), 1_000_000); - AggregationNode aggregationNode = singleAggregation( - new PlanNodeId("aggregation"), - source, - ImmutableMap.of(), - singleGroupingSet(groupingKeys)); - Rule.Context context = context( + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); + Context context = context( ImmutableMap.of(source, new PlanNodeStatsEstimate( 1_000_000, groupingKeys.stream().collect(toImmutableMap( Function.identity(), _ -> SymbolStatsEstimate.builder().setDistinctValuesCount(10).build())))), new SymbolAllocator()); - assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider())).isTrue(); + assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup())).isTrue(); + } + + @Test + public void testMarkDistinctPreferredForUnknownStatisticsAnd3GroupByKeys() + { + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + + List groupingKeys = ImmutableList.of( + symbolAllocator.newSymbol("key1", BIGINT), + symbolAllocator.newSymbol("key2", BIGINT), + symbolAllocator.newSymbol("key3", BIGINT)); + PlanNode source = tableScan(); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(groupingKeys, source, symbolAllocator); + assertThat((boolean) inTransaction(session -> { + Context context = context(ImmutableMap.of(), session, symbolAllocator); + return aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + })) + .isTrue(); + } + + @Test + public void testChoiceForcedByTheSessionProperty() + { + int clusterThreadCount = NODE_COUNT * getTaskConcurrency(TEST_SESSION); + DistinctAggregationStrategyChooser aggregationStrategyChooser = createDistinctAggregationStrategyChooser(TASK_COUNT_ESTIMATOR, metadata); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + Symbol groupingKey = symbolAllocator.newSymbol("groupingKey", BIGINT); + + TableScanNode source = new TableScanNode( + new PlanNodeId("source"), + TEST_TABLE_HANDLE, + ImmutableList.of(), + ImmutableMap.of(), + TupleDomain.all(), + Optional.empty(), + false, + Optional.empty()); + AggregationNode aggregationNode = aggregationWithTwoDistinctAggregations(ImmutableList.of(groupingKey), source, symbolAllocator); + + // big NDV, distinct_aggregations_strategy = mark_distinct + assertThat((boolean) inTransaction( + testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, MARK_DISTINCT.name()).build(), + session -> { + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator); + return aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + })) + .isTrue(); + + // big NDV, distinct_aggregations_strategy = pre-aggregate + assertThat((boolean) inTransaction( + testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, PRE_AGGREGATE.name()).build(), + session -> { + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator); + return aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + })) + .isTrue(); + + // small NDV, distinct_aggregations_strategy = single_step + Context smallNdvContext = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, SINGLE_STEP.name()).build(), + symbolAllocator); + assertShouldUseSingleStep(aggregationStrategyChooser, aggregationNode, smallNdvContext.getSession(), smallNdvContext.getStatsProvider(), smallNdvContext.getLookup()); + + // big NDV, distinct_aggregations_strategy = split_to_subqueries + assertThat((boolean) inTransaction( + testSessionBuilder().setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, SPLIT_TO_SUBQUERIES.name()).build(), + session -> { + Context context = context( + ImmutableMap.of(source, new PlanNodeStatsEstimate(1000 * clusterThreadCount, ImmutableMap.of( + groupingKey, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()))), + session, + symbolAllocator); + return aggregationStrategyChooser.shouldSplitToSubqueries(aggregationNode, context.getSession(), context.getStatsProvider(), context.getLookup()); + })) + .isTrue(); + } + + private T inTransaction(Function callback) + { + return inTransaction(TEST_SESSION, callback); + } + + private T inTransaction(Session session, Function callback) + { + return transaction(transactionManager, metadata, new AllowAllAccessControl()) + .execute(session, callback); + } + + private static PlanNode tableScan() + { + return new TableScanNode(new PlanNodeId("source"), TEST_TABLE_HANDLE, ImmutableList.of(), ImmutableMap.of(), TupleDomain.all(), Optional.empty(), false, Optional.empty()); + } + + private static AggregationNode aggregationWithTwoDistinctAggregations(List groupingKeys, PlanNode source, SymbolAllocator symbolAllocator) + { + return singleAggregation( + new PlanNodeId("aggregation"), + source, + twoDistinctAggregations(symbolAllocator), + singleGroupingSet(groupingKeys)); + } + + private static Map twoDistinctAggregations(SymbolAllocator symbolAllocator) + { + return ImmutableMap.of(symbolAllocator.newSymbol("output1", BIGINT), new Aggregation( + functionResolution.resolveFunction("sum", fromTypes(BIGINT)), + ImmutableList.of(symbolAllocator.newSymbol("input1", BIGINT).toSymbolReference()), + true, + Optional.empty(), + Optional.empty(), + Optional.empty()), + symbolAllocator.newSymbol("output2", BIGINT), new Aggregation( + functionResolution.resolveFunction("sum", fromTypes(BIGINT)), + ImmutableList.of(symbolAllocator.newSymbol("input2", BIGINT).toSymbolReference()), + true, + Optional.empty(), + Optional.empty(), + Optional.empty())); + } + + private static void assertShouldUseSingleStep(DistinctAggregationStrategyChooser aggregationStrategyChooser, AggregationNode aggregationNode, Session session, StatsProvider statsProvider, Lookup lookup) + { + assertThat(aggregationStrategyChooser.shouldAddMarkDistinct(aggregationNode, session, statsProvider, lookup)).isFalse(); + assertThat(aggregationStrategyChooser.shouldUsePreAggregate(aggregationNode, session, statsProvider, lookup)).isFalse(); + } + + private static Context context(Map stats, SymbolAllocator symbolAllocator) + { + return context(stats, TEST_SESSION, symbolAllocator); } - private static Rule.Context context(Map stats, final SymbolAllocator symbolAllocator) + private static Context context(Map stats, Session session, SymbolAllocator symbolAllocator) { PlanNodeIdAllocator planNodeIdAllocator = new PlanNodeIdAllocator(); - return new Rule.Context() + return new Context() { @Override public Lookup getLookup() @@ -181,13 +416,13 @@ public SymbolAllocator getSymbolAllocator() @Override public Session getSession() { - return TEST_SESSION; + return session; } @Override public StatsProvider getStatsProvider() { - return stats::get; + return node -> stats.getOrDefault(node, PlanNodeStatsEstimate.unknown()); } @Override diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java index 3c6cba9ce277..5187354b2ee0 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationToMarkDistinct.java @@ -15,26 +15,21 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import io.trino.cost.PlanNodeStatsEstimate; -import io.trino.cost.SymbolStatsEstimate; import io.trino.cost.TaskCountEstimator; +import io.trino.metadata.Metadata; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Reference; import io.trino.sql.planner.Symbol; -import io.trino.sql.planner.assertions.PlanMatchPattern; import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; import io.trino.sql.planner.iterative.rule.test.PlanBuilder; import io.trino.sql.planner.plan.Assignments; -import io.trino.sql.planner.plan.PlanNode; -import io.trino.sql.planner.plan.PlanNodeId; import org.junit.jupiter.api.Test; import java.util.Optional; -import java.util.function.Function; import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; -import static io.trino.SystemSessionProperties.TASK_CONCURRENCY; +import static io.trino.metadata.MetadataManager.createTestMetadataManager; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.IntegerType.INTEGER; @@ -43,7 +38,6 @@ import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; import static io.trino.sql.planner.assertions.PlanMatchPattern.globalAggregation; import static io.trino.sql.planner.assertions.PlanMatchPattern.markDistinct; -import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; import static io.trino.sql.planner.assertions.PlanMatchPattern.values; import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; import static io.trino.type.UnknownType.UNKNOWN; @@ -53,11 +47,13 @@ public class TestMultipleDistinctAggregationToMarkDistinct { private static final int NODES_COUNT = 4; private static final TaskCountEstimator TASK_COUNT_ESTIMATOR = new TaskCountEstimator(() -> NODES_COUNT); + private static final Metadata METADATA = createTestMetadataManager(); @Test public void testNoDistinct() { tester().assertThat(new SingleDistinctAggregationToGroupBy()) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) @@ -72,7 +68,8 @@ public void testNoDistinct() @Test public void testSingleDistinct() { - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) + tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) @@ -86,7 +83,8 @@ public void testSingleDistinct() @Test public void testMultipleAggregations() { - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) + tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) @@ -99,7 +97,8 @@ public void testMultipleAggregations() @Test public void testDistinctWithFilter() { - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) + tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation( @@ -121,7 +120,8 @@ public void testDistinctWithFilter() p.symbol("input2", BIGINT)))))) .doesNotFire(); - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) + tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1")), new Symbol(UNKNOWN, "filter1")), ImmutableList.of(BIGINT)) @@ -143,7 +143,8 @@ public void testDistinctWithFilter() @Test public void testGlobalAggregation() { - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) + tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR, METADATA)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") .on(p -> p.aggregation(builder -> builder .globalGrouping() .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) @@ -167,117 +168,4 @@ public void testGlobalAggregation() ImmutableList.of("input1"), values(ImmutableMap.of("input1", 0, "input2", 1)))))); } - - @Test - public void testAggregationNDV() - { - PlanNodeId aggregationSourceId = new PlanNodeId("aggregationSourceId"); - Symbol key = new Symbol(BIGINT, "key"); - Function plan = p -> p.aggregation(builder -> builder - .singleGroupingSet(p.symbol(key.name(), BIGINT)) - .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .source( - p.values(aggregationSourceId, p.symbol("input", BIGINT), p.symbol(key.name(), BIGINT)))); - PlanMatchPattern expectedMarkDistinct = aggregation( - singleGroupingSet("key"), - ImmutableMap.of( - Optional.of("output1"), aggregationFunction("count", ImmutableList.of("input")), - Optional.of("output2"), aggregationFunction("sum", ImmutableList.of("input"))), - ImmutableList.of(), - ImmutableList.of("mark_input"), - Optional.empty(), - SINGLE, - markDistinct( - "mark_input", - ImmutableList.of("input", "key"), - values(ImmutableMap.of("input", 0, "key", 1)))); - - int clusterThreadCount = NODES_COUNT * tester().getSession().getSystemProperty(TASK_CONCURRENCY, Integer.class); - - // small NDV - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics( - key, SymbolStatsEstimate.builder().setDistinctValuesCount(2 * clusterThreadCount).build()) - .build()) - .on(plan) - .matches(expectedMarkDistinct); - - // unknown estimate - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics( - key, SymbolStatsEstimate.builder().setDistinctValuesCount(Double.NaN).build()) - .build()) - .on(plan) - .matches(expectedMarkDistinct); - - // medium NDV, distinct_aggregations_strategy mark_distinct - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics( - key, SymbolStatsEstimate.builder().setDistinctValuesCount(50 * clusterThreadCount).build()) - .build()) - .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") - .on(plan) - .matches(expectedMarkDistinct); - - // medium NDV - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics( - key, SymbolStatsEstimate.builder().setDistinctValuesCount(50 * clusterThreadCount).build()) - .build()) - .on(plan) - .doesNotFire(); - - // medium NDV, distinct_aggregations_strategy pre_aggregate, but the plan has multiple distinct aggregations - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "pre_aggregate") - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics(key, SymbolStatsEstimate.builder().setDistinctValuesCount(50 * clusterThreadCount).build()).build()) - .on(p -> p.aggregation(builder -> builder - .singleGroupingSet(p.symbol(key.name(), BIGINT)) - .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2"))), ImmutableList.of(BIGINT)) - .source( - p.values(aggregationSourceId, p.symbol("input1", BIGINT), p.symbol("input2", BIGINT), p.symbol("key", BIGINT))))) - .doesNotFire(); - - // big NDV - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder() - .addSymbolStatistics(key, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()).build()).on(plan) - .doesNotFire(); - - // big NDV, distinct_aggregations_strategy = mark_distinct - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "mark_distinct") - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * clusterThreadCount).build()) - .on(plan) - .matches(expectedMarkDistinct); - // small NDV, distinct_aggregations_strategy != mark_distinct - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "single_step") - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(2 * clusterThreadCount).build()) - .on(plan) - .doesNotFire(); - - // big NDV but on multiple grouping keys - Symbol key1 = new Symbol(BIGINT, "key1"); - Symbol key2 = new Symbol(BIGINT, "key2"); - tester().assertThat(new MultipleDistinctAggregationToMarkDistinct(TASK_COUNT_ESTIMATOR)) - .overrideStats(aggregationSourceId.toString(), PlanNodeStatsEstimate.builder().setOutputRowCount(1000 * clusterThreadCount) - .addSymbolStatistics(key1, SymbolStatsEstimate.builder().setDistinctValuesCount(1000 * clusterThreadCount).build()) - .addSymbolStatistics(key2, SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) - .build()) - .on(p -> p.aggregation(builder -> builder - .singleGroupingSet(p.symbol(key1.name(), BIGINT), p.symbol(key2.name(), BIGINT)) - .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", ImmutableList.of(new Reference(BIGINT, "input"))), ImmutableList.of(BIGINT)) - .source( - p.values(aggregationSourceId, p.symbol("input", BIGINT), p.symbol("key1", BIGINT), p.symbol("key2", BIGINT))))) - .doesNotFire(); - } } diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java new file mode 100644 index 000000000000..02939d015353 --- /dev/null +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestMultipleDistinctAggregationsToSubqueries.java @@ -0,0 +1,1102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.sql.planner.iterative.rule; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableListMultimap; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import io.trino.Session; +import io.trino.connector.MockConnectorColumnHandle; +import io.trino.connector.MockConnectorFactory; +import io.trino.connector.MockConnectorTableHandle; +import io.trino.cost.PlanNodeStatsEstimate; +import io.trino.cost.SymbolStatsEstimate; +import io.trino.cost.TaskCountEstimator; +import io.trino.metadata.ResolvedFunction; +import io.trino.metadata.TableHandle; +import io.trino.metadata.TestingFunctionResolution; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.ColumnMetadata; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.function.OperatorType; +import io.trino.spi.predicate.TupleDomain; +import io.trino.sql.ir.Call; +import io.trino.sql.ir.Cast; +import io.trino.sql.ir.Comparison; +import io.trino.sql.ir.Constant; +import io.trino.sql.ir.IsNull; +import io.trino.sql.ir.Reference; +import io.trino.sql.planner.Symbol; +import io.trino.sql.planner.assertions.PlanMatchPattern; +import io.trino.sql.planner.assertions.SetOperationOutputMatcher; +import io.trino.sql.planner.iterative.rule.test.BaseRuleTest; +import io.trino.sql.planner.iterative.rule.test.PlanBuilder; +import io.trino.sql.planner.iterative.rule.test.RuleTester; +import io.trino.sql.planner.plan.Assignments; +import io.trino.sql.planner.plan.PlanNodeId; +import io.trino.testing.PlanTester; +import io.trino.testing.TestingTransactionHandle; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Stream; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.testing.Closeables.closeAllRuntimeException; +import static io.trino.SystemSessionProperties.DISTINCT_AGGREGATIONS_STRATEGY; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static io.trino.sql.ir.Booleans.TRUE; +import static io.trino.sql.ir.Comparison.Operator.GREATER_THAN; +import static io.trino.sql.ir.IrExpressions.not; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregation; +import static io.trino.sql.planner.assertions.PlanMatchPattern.aggregationFunction; +import static io.trino.sql.planner.assertions.PlanMatchPattern.filter; +import static io.trino.sql.planner.assertions.PlanMatchPattern.join; +import static io.trino.sql.planner.assertions.PlanMatchPattern.project; +import static io.trino.sql.planner.assertions.PlanMatchPattern.singleGroupingSet; +import static io.trino.sql.planner.assertions.PlanMatchPattern.symbol; +import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; +import static io.trino.sql.planner.assertions.PlanMatchPattern.union; +import static io.trino.sql.planner.plan.AggregationNode.Step.SINGLE; +import static io.trino.sql.planner.plan.AggregationNode.groupingSets; +import static io.trino.sql.planner.plan.JoinType.INNER; +import static io.trino.testing.TestingSession.testSessionBuilder; + +public class TestMultipleDistinctAggregationsToSubqueries + extends BaseRuleTest +{ + private static final String MOCK_CATALOG = "mock_catalog"; + private static final String TEST_SCHEMA = "test_schema"; + private static final String TEST_TABLE = "test_table"; + + private static final Session MOCK_SESSION = testSessionBuilder().setCatalog(MOCK_CATALOG).setSchema(TEST_SCHEMA).build(); + + private static final String COLUMN_1 = "orderkey"; + private static final ColumnHandle COLUMN_1_HANDLE = new MockConnectorColumnHandle(COLUMN_1, BIGINT); + private static final String COLUMN_2 = "partkey"; + private static final ColumnHandle COLUMN_2_HANDLE = new MockConnectorColumnHandle(COLUMN_2, BIGINT); + private static final String COLUMN_3 = "linenumber"; + private static final ColumnHandle COLUMN_3_HANDLE = new MockConnectorColumnHandle(COLUMN_3, BIGINT); + + private static final String COLUMN_4 = "shipdate"; + private static final ColumnHandle COLUMN_4_HANDLE = new MockConnectorColumnHandle(COLUMN_4, DATE); + private static final String GROUPING_KEY_COLUMN = "suppkey"; + private static final ColumnHandle GROUPING_KEY_COLUMN_HANDLE = new MockConnectorColumnHandle(GROUPING_KEY_COLUMN, BIGINT); + private static final String GROUPING_KEY2_COLUMN = "comment"; + private static final ColumnHandle GROUPING_KEY2_COLUMN_HANDLE = new MockConnectorColumnHandle(GROUPING_KEY2_COLUMN, VARCHAR); + + private static final SchemaTableName TABLE_SCHEMA = new SchemaTableName(TEST_SCHEMA, TEST_TABLE); + + private static final List ALL_COLUMNS = Stream.of(COLUMN_1_HANDLE, COLUMN_2_HANDLE, COLUMN_3_HANDLE, COLUMN_4_HANDLE, GROUPING_KEY_COLUMN_HANDLE, GROUPING_KEY2_COLUMN_HANDLE) + .map(columnHandle -> (MockConnectorColumnHandle) columnHandle) + .map(column -> new ColumnMetadata(column.getName(), column.getType())) + .collect(toImmutableList()); + + private static final TestingFunctionResolution FUNCTIONS = new TestingFunctionResolution(); + private static final ResolvedFunction ADD_BIGINT = FUNCTIONS.resolveOperator(OperatorType.ADD, ImmutableList.of(BIGINT, BIGINT)); + + private RuleTester ruleTester = tester(true); + + @AfterAll + public final void tearDownTester() + { + closeAllRuntimeException(ruleTester); + ruleTester = null; + } + + @Test + public void testDoesNotFire() + { + // no distinct aggregation + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol inputSymbol = p.symbol("inputSymbol", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(inputSymbol) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(inputSymbol), + ImmutableMap.of(inputSymbol, COLUMN_1_HANDLE)))); + }) + .doesNotFire(); + + // single distinct + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol inputSymbol = p.symbol("inputSymbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "inputSymbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(inputSymbol), + ImmutableMap.of(inputSymbol, COLUMN_1_HANDLE)))); + }) + .doesNotFire(); + + // two distinct on the same input + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE)))); + }) + .doesNotFire(); + + // hash symbol + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .hashSymbol(p.symbol("hashSymbol", BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .doesNotFire(); + + // non-distinct + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output3", BIGINT), PlanBuilder.aggregation("count", ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .doesNotFire(); + + // groupingSetCount > 1 + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .groupingSets(groupingSets(ImmutableList.of(), 2, ImmutableSet.of(0, 1))) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .doesNotFire(); + + // complex subquery (join) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.join( + INNER, + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(), + ImmutableMap.of()), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE))))); + }) + .doesNotFire(); + + // complex subquery (filter on top of join to test recursion) + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.filter( + TRUE, + p.join( + INNER, + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(), + ImmutableMap.of()), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))))); + }) + .doesNotFire(); + + // connector does not support efficient single column reads + RuleTester ruleTesterNotObjectStore = tester(false); + + ruleTesterNotObjectStore.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTesterNotObjectStore)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTesterNotObjectStore), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .doesNotFire(); + + // rule not enabled + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "single_step") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .doesNotFire(); + + // automatic but single_step is preferred + String aggregationSourceId = "aggregationSourceId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder().addSymbolStatistics( + new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(p.symbol("groupingKey", BIGINT)) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnTableScan() + { + // automatic but single_step is preferred + String aggregationSourceId = "aggregationSourceId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder().addSymbolStatistics( + new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(1_000_000).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(p.symbol("groupingKey", BIGINT)) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE))))); + }) + .doesNotFire(); + + // single_step is not preferred, the overhead of groupingKey is not big + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + return p.aggregation(builder -> builder + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE))))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1Symbol", COLUMN_1, + "left_groupingKey", GROUPING_KEY_COLUMN)))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input2Symbol", COLUMN_2, + "right_groupingKey", GROUPING_KEY_COLUMN))))))); + + // single_step is not preferred, the overhead of groupingKeys is bigger than 50% + String aggregationId = "aggregationId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "groupingKey2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol groupingKey2 = p.symbol("groupingKey2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey, groupingKey2) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, groupingKey2)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + groupingKey2, GROUPING_KEY2_COLUMN_HANDLE))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnProjectedTableScan() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + // the overhead of the projection is bigger than 50% + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol projectionInput1 = p.symbol("projectionInput1", BIGINT); + Symbol projectionInput2 = p.symbol("projectionInput2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.project( + Assignments.builder() + .putIdentity(input1Symbol) + .putIdentity(input2Symbol) + .put(groupingKey, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "projectionInput1"), new Cast(new Reference(BIGINT, "projectionInput2"), BIGINT)))) + .build(), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, projectionInput1, projectionInput2)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + projectionInput1, GROUPING_KEY_COLUMN_HANDLE, + projectionInput2, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + + // the big projection is used as distinct input. we could handle this case, but for simplicity sake, the rule won't fire here + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput1"), SymbolStatsEstimate.builder().setDistinctValuesCount(10).build()) + .addSymbolStatistics(new Symbol(BIGINT, "projectionInput2"), SymbolStatsEstimate.builder().setAverageRowSize(1_000_000).build()) + .build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(10).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol projectionInput1 = p.symbol("projectionInput1", BIGINT); + Symbol projectionInput2 = p.symbol("projectionInput2", VARCHAR); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.project( + Assignments.builder() + .put(input1Symbol, new Call(ADD_BIGINT, ImmutableList.of(new Reference(BIGINT, "projectionInput1"), new Cast(new Reference(BIGINT, "projectionInput2"), BIGINT)))) + .putIdentity(input2Symbol) + .putIdentity(groupingKey) + .build(), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(groupingKey, input2Symbol, projectionInput1, projectionInput2)) + .setAssignments(ImmutableMap.of( + groupingKey, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + projectionInput1, GROUPING_KEY_COLUMN_HANDLE, + projectionInput2, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + } + + @Test + public void testAutomaticDecisionForAggregationOnFilteredTableScan() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + String filterId = "filterId"; + // selective filter + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(1).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(1).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol filterInput = p.symbol("filterInput", VARCHAR); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.filter( + new PlanNodeId(filterId), + not(ruleTester.getMetadata(), new IsNull(new Reference(VARCHAR, "filterInput"))), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, filterInput)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + filterInput, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .doesNotFire(); + + // non-selective filter + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol filterInput = p.symbol("filterInput", VARCHAR); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.filter( + new PlanNodeId(filterId), + not(ruleTester.getMetadata(), new IsNull(new Reference(VARCHAR, "filterInput"))), + p.tableScan(tableScan -> tableScan + .setNodeId(new PlanNodeId(aggregationSourceId)) + .setTableHandle(testTableHandle(ruleTester)) + .setSymbols(ImmutableList.of(input1Symbol, input2Symbol, groupingKey, filterInput)) + .setAssignments(ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE, + filterInput, GROUPING_KEY2_COLUMN_HANDLE)))))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + Optional.empty(), + SINGLE, + filter( + not(ruleTester.getMetadata(), new IsNull(new Reference(BIGINT, "left_filterInput"))), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1Symbol", COLUMN_1, + "left_groupingKey", GROUPING_KEY_COLUMN, + "left_filterInput", GROUPING_KEY2_COLUMN))))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + Optional.empty(), + SINGLE, + filter( + not(ruleTester.getMetadata(), new IsNull(new Reference(BIGINT, "right_filterInput"))), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input2Symbol", COLUMN_2, + "right_groupingKey", GROUPING_KEY_COLUMN, + "right_filterInput", GROUPING_KEY2_COLUMN)))))))); + } + + @Test + public void testAutomaticDecisionForAggregationOnFilteredUnion() + { + String aggregationSourceId = "aggregationSourceId"; + String aggregationId = "aggregationId"; + String filterId = "filterId"; + // union with additional columns to read + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "automatic") + .overrideStats(aggregationSourceId, PlanNodeStatsEstimate.builder() + .setOutputRowCount(100) + .addSymbolStatistics(new Symbol(VARCHAR, "filterInput"), SymbolStatsEstimate.builder().setAverageRowSize(1).build()) + .build()) + .overrideStats(filterId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .overrideStats(aggregationId, PlanNodeStatsEstimate.builder().setOutputRowCount(100).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input11Symbol = p.symbol("input1_1Symbol", BIGINT); + Symbol input12Symbol = p.symbol("input1_2Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol input21Symbol = p.symbol("input2_1Symbol", BIGINT); + Symbol input22Symbol = p.symbol("input2_2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol groupingKey1 = p.symbol("groupingKey1", BIGINT); + Symbol groupingKey2 = p.symbol("groupingKey2", BIGINT); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.union( + ImmutableListMultimap.builder() + .put(input1Symbol, input11Symbol) + .put(input1Symbol, input12Symbol) + .put(input2Symbol, input21Symbol) + .put(input2Symbol, input22Symbol) + .put(groupingKey, groupingKey1) + .put(groupingKey, groupingKey2) + .build(), + ImmutableList.of( + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input1_1Symbol"), new Constant(BIGINT, 0L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input11Symbol, input21Symbol, groupingKey1), + ImmutableMap.of( + input11Symbol, COLUMN_1_HANDLE, + input21Symbol, COLUMN_2_HANDLE, + groupingKey1, GROUPING_KEY_COLUMN_HANDLE))), + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input2_2Symbol"), new Constant(BIGINT, 2L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input12Symbol, input22Symbol, groupingKey2), + ImmutableMap.of( + input12Symbol, COLUMN_1_HANDLE, + input22Symbol, COLUMN_2_HANDLE, + groupingKey2, GROUPING_KEY_COLUMN_HANDLE))))))); + }) + .doesNotFire(); + } + + @Test + public void testGlobalDistinctToSubqueries() + { + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2"))), + join( + INNER, + builder -> builder + .left(aggregation( + ImmutableMap.of("output1", aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))) + .right(aggregation( + ImmutableMap.of("output2", aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2))))))); + } + + @Test + public void testGlobalWith3DistinctToSubqueries() + { + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol input3Symbol = p.symbol("input3Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output3", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input3Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol, input3Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + input3Symbol, COLUMN_3_HANDLE)))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "final_output3", PlanMatchPattern.expression(new Reference(BIGINT, "output3"))), + join( + INNER, + join -> join + .left(aggregation( + ImmutableMap.of("output1", aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))) + .right(join( + INNER, + subJoin -> subJoin + .left(aggregation( + ImmutableMap.of("output2", aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2)))) + .right(aggregation( + ImmutableMap.of("output3", aggregationFunction("count", true, ImmutableList.of(symbol("input3Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input3Symbol", COLUMN_3))))))))); + } + + // tests right deep join hierarchy + @Test + public void testGlobalWith4DistinctToSubqueries() + { + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol input3Symbol = p.symbol("input3Symbol", BIGINT); + Symbol input4Symbol = p.symbol("input4Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output3", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input3Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output4", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input4Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol, input3Symbol, input4Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + input3Symbol, COLUMN_3_HANDLE, + input4Symbol, COLUMN_4_HANDLE)))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "final_output3", PlanMatchPattern.expression(new Reference(BIGINT, "output3")), + "final_output4", PlanMatchPattern.expression(new Reference(BIGINT, "output4"))), + join( + INNER, + join -> join + .left(aggregation( + ImmutableMap.of("output1", aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))) + .right(join( + INNER, + subJoin -> subJoin + .left(aggregation( + ImmutableMap.of("output2", aggregationFunction("count", true, ImmutableList.of(symbol("input2Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2)))) + .right(join( + INNER, + subJoin2 -> subJoin2 + .left(aggregation( + ImmutableMap.of("output3", aggregationFunction("count", true, ImmutableList.of(symbol("input3Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input3Symbol", COLUMN_3)))) + .right(aggregation( + ImmutableMap.of("output4", aggregationFunction("count", true, ImmutableList.of(symbol("input4Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input4Symbol", COLUMN_4))))))))))); + } + + @Test + public void testGlobal2DistinctOnTheSameInputToSubqueries() + { + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + return p.aggregation(builder -> builder + .globalGrouping() + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output3", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE)))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "final_output3", PlanMatchPattern.expression(new Reference(BIGINT, "output3"))), + join( + INNER, + builder -> builder + .left(aggregation( + ImmutableMap.of("output1", aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input1Symbol", COLUMN_1)))) + .right(aggregation( + ImmutableMap.of( + "output2", aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol"))), + "output3", aggregationFunction("count", true, ImmutableList.of(symbol("input2Symbol")))), + tableScan(TABLE_SCHEMA.getTableName(), ImmutableMap.of("input2Symbol", COLUMN_2))))))); + } + + @Test + public void testGroupByWithDistinctToSubqueries() + { + String aggregationNodeId = "aggregationNodeId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .overrideStats(aggregationNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(100_000).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationNodeId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input1Symbol, input2Symbol), + ImmutableMap.of( + input1Symbol, COLUMN_1_HANDLE, + input2Symbol, COLUMN_2_HANDLE, + groupingKey, GROUPING_KEY_COLUMN_HANDLE)))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1Symbol", COLUMN_1, + "left_groupingKey", GROUPING_KEY_COLUMN)))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol")))), + Optional.empty(), + SINGLE, + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input2Symbol", COLUMN_2, + "right_groupingKey", GROUPING_KEY_COLUMN))))))); + } + + @Test + public void testGroupByWithDistinctOverUnionToSubqueries() + { + String aggregationNodeId = "aggregationNodeId"; + ruleTester.assertThat(newMultipleDistinctAggregationsToSubqueries(ruleTester)) + .setSystemProperty(DISTINCT_AGGREGATIONS_STRATEGY, "split_to_subqueries") + .overrideStats(aggregationNodeId, PlanNodeStatsEstimate.builder().setOutputRowCount(100_000).build()) + .on(p -> { + Symbol input1Symbol = p.symbol("input1Symbol", BIGINT); + Symbol input11Symbol = p.symbol("input1_1Symbol", BIGINT); + Symbol input12Symbol = p.symbol("input1_2Symbol", BIGINT); + Symbol input2Symbol = p.symbol("input2Symbol", BIGINT); + Symbol input21Symbol = p.symbol("input2_1Symbol", BIGINT); + Symbol input22Symbol = p.symbol("input2_2Symbol", BIGINT); + Symbol groupingKey = p.symbol("groupingKey", BIGINT); + Symbol groupingKey1 = p.symbol("groupingKey1", BIGINT); + Symbol groupingKey2 = p.symbol("groupingKey2", BIGINT); + + return p.aggregation(builder -> builder + .nodeId(new PlanNodeId(aggregationNodeId)) + .singleGroupingSet(groupingKey) + .addAggregation(p.symbol("output1", BIGINT), PlanBuilder.aggregation("count", true, ImmutableList.of(new Reference(BIGINT, "input1Symbol"))), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("output2", BIGINT), PlanBuilder.aggregation("sum", true, ImmutableList.of(new Reference(BIGINT, "input2Symbol"))), ImmutableList.of(BIGINT)) + .source( + p.union( + ImmutableListMultimap.builder() + .put(input1Symbol, input11Symbol) + .put(input1Symbol, input12Symbol) + .put(input2Symbol, input21Symbol) + .put(input2Symbol, input22Symbol) + .put(groupingKey, groupingKey1) + .put(groupingKey, groupingKey2) + .build(), + ImmutableList.of( + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input1_1Symbol"), new Constant(BIGINT, 0L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input11Symbol, input21Symbol, groupingKey1), + ImmutableMap.of( + input11Symbol, COLUMN_1_HANDLE, + input21Symbol, COLUMN_2_HANDLE, + groupingKey1, GROUPING_KEY_COLUMN_HANDLE))), + p.filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input2_2Symbol"), new Constant(BIGINT, 2L)), + p.tableScan( + testTableHandle(ruleTester), + ImmutableList.of(input12Symbol, input22Symbol, groupingKey2), + ImmutableMap.of( + input12Symbol, COLUMN_1_HANDLE, + input22Symbol, COLUMN_2_HANDLE, + groupingKey2, GROUPING_KEY_COLUMN_HANDLE))))))); + }) + .matches(project( + ImmutableMap.of( + "final_output1", PlanMatchPattern.expression(new Reference(BIGINT, "output1")), + "final_output2", PlanMatchPattern.expression(new Reference(BIGINT, "output2")), + "group_by_key", PlanMatchPattern.expression(new Reference(BIGINT, "left_groupingKey"))), + join( + INNER, + builder -> builder + .equiCriteria("left_groupingKey", "right_groupingKey") + .left(aggregation( + singleGroupingSet("left_groupingKey"), + ImmutableMap.of(Optional.of("output1"), aggregationFunction("count", true, ImmutableList.of(symbol("input1Symbol1")))), + Optional.empty(), + SINGLE, + union( + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input1_1_1Symbol"), new Constant(BIGINT, 0L)), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1_1_1Symbol", COLUMN_1, + "input2_1_1Symbol", COLUMN_2, + "left_groupingKey1", GROUPING_KEY_COLUMN))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input2_2_1Symbol"), new Constant(BIGINT, 2L)), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1_2_1Symbol", COLUMN_1, + "input2_2_1Symbol", COLUMN_2, + "left_groupingKey2", GROUPING_KEY_COLUMN)))) + .withAlias("input1Symbol1", new SetOperationOutputMatcher(0)) + .withAlias("input2Symbol1", new SetOperationOutputMatcher(1)) + .withAlias("left_groupingKey", new SetOperationOutputMatcher(2)))) + .right(aggregation( + singleGroupingSet("right_groupingKey"), + ImmutableMap.of(Optional.of("output2"), aggregationFunction("sum", true, ImmutableList.of(symbol("input2Symbol2")))), + Optional.empty(), + SINGLE, + union( + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input1_1_2Symbol"), new Constant(BIGINT, 0L)), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1_1_2Symbol", COLUMN_1, + "input2_1_2Symbol", COLUMN_2, + "right_groupingKey1", GROUPING_KEY_COLUMN))), + filter( + new Comparison(GREATER_THAN, new Reference(BIGINT, "input2_2_2Symbol"), new Constant(BIGINT, 2L)), + tableScan( + TABLE_SCHEMA.getTableName(), + ImmutableMap.of( + "input1_2_2Symbol", COLUMN_1, + "input2_2_2Symbol", COLUMN_2, + "right_groupingKey2", GROUPING_KEY_COLUMN)))) + .withAlias("input1Symbol2", new SetOperationOutputMatcher(0)) + .withAlias("input2Symbol2", new SetOperationOutputMatcher(1)) + .withAlias("right_groupingKey", new SetOperationOutputMatcher(2))))))); + } + + private static MultipleDistinctAggregationsToSubqueries newMultipleDistinctAggregationsToSubqueries(RuleTester ruleTester) + { + return new MultipleDistinctAggregationsToSubqueries(new TaskCountEstimator(() -> Integer.MAX_VALUE), ruleTester.getMetadata()); + } + + private static TableHandle testTableHandle(RuleTester ruleTester) + { + return new TableHandle(ruleTester.getCurrentCatalogHandle(), new MockConnectorTableHandle(TABLE_SCHEMA, TupleDomain.all(), Optional.empty()), TestingTransactionHandle.create()); + } + + private static RuleTester tester(boolean allowSplittingReadIntoMultipleSubQueries) + { + PlanTester planTester = PlanTester.create(MOCK_SESSION); + MockConnectorFactory.Builder builder = MockConnectorFactory.builder() + .withAllowSplittingReadIntoMultipleSubQueries(allowSplittingReadIntoMultipleSubQueries) + .withGetTableHandle((_, schemaTableName) -> new MockConnectorTableHandle(schemaTableName)) + .withGetColumns(_ -> ALL_COLUMNS); + planTester.createCatalog( + MOCK_CATALOG, + builder.build(), + ImmutableMap.of()); + return new RuleTester(planTester); + } +} diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java index f1487e1283d2..aa01e3536ab2 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/test/PlanBuilder.java @@ -620,6 +620,7 @@ public static class TableScanBuilder private Optional statistics = Optional.empty(); private boolean updateTarget; private Optional useConnectorNodePartitioning = Optional.empty(); + private Optional nodeId = Optional.empty(); private TableScanBuilder(PlanNodeIdAllocator idAllocator) { @@ -667,6 +668,12 @@ public TableScanBuilder setUpdateTarget(boolean updateTarget) return this; } + public TableScanBuilder setNodeId(PlanNodeId id) + { + this.nodeId = Optional.of(id); + return this; + } + public TableScanBuilder setUseConnectorNodePartitioning(Optional useConnectorNodePartitioning) { this.useConnectorNodePartitioning = useConnectorNodePartitioning; @@ -676,7 +683,7 @@ public TableScanBuilder setUseConnectorNodePartitioning(Optional useCon public TableScanNode build() { return new TableScanNode( - idAllocator.getNextId(), + nodeId.orElseGet(idAllocator::getNextId), tableHandle, symbols, assignments, diff --git a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java index 2095fb167cd0..7235f6fd56a0 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java +++ b/core/trino-spi/src/main/java/io/trino/spi/connector/ConnectorMetadata.java @@ -1818,6 +1818,15 @@ default OptionalInt getMaxWriterTasks(ConnectorSession session) return OptionalInt.empty(); } + /** + * @return true if reading a subset of columns from a given table separately from reading a complement of the subset has similar or better + * performance as reading this table. + */ + default boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return false; + } + default WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { return WriterScalingOptions.DISABLED; diff --git a/docs/src/main/sphinx/admin/properties-optimizer.md b/docs/src/main/sphinx/admin/properties-optimizer.md index ffe3bfb7e599..f7db86222c7e 100644 --- a/docs/src/main/sphinx/admin/properties-optimizer.md +++ b/docs/src/main/sphinx/admin/properties-optimizer.md @@ -43,18 +43,17 @@ create them. ## `optimizer.distinct-aggregations-strategy` - **Type:** {ref}`prop-type-string` -- **Allowed values:** `AUTOMATIC`, `MARK_DISTINCT`, `SINGLE_STEP`, `PRE_AGGREGATE` +- **Allowed values:** `AUTOMATIC`, `MARK_DISTINCT`, `SINGLE_STEP`, `PRE_AGGREGATE`, `SPLIT_TO_SUBQUERIES` - **Default value:** `AUTOMATIC` - **Session property:** `distinct_aggregations_strategy` The strategy to use for multiple distinct aggregations. -`SINGLE_STEP` Computes distinct aggregations in single-step without any pre-aggregations. +- `SINGLE_STEP` Computes distinct aggregations in single-step without any pre-aggregations. This strategy will perform poorly if the number of distinct grouping keys is small. -`MARK_DISTINCT` uses `MarkDistinct` for multiple distinct aggregations -or for mix of distinct and non-distinct aggregations. -`PRE_AGGREGATE` Computes distinct aggregations using a combination of aggregation -and pre-aggregation steps. -`AUTOMATIC` chooses the strategy automatically. +- `MARK_DISTINCT` uses `MarkDistinct` for multiple distinct aggregations or for mix of distinct and non-distinct aggregations. +- `PRE_AGGREGATE` Computes distinct aggregations using a combination of aggregation and pre-aggregation steps. +- `SPLIT_TO_SUBQUERIES` Splits the aggregation input to independent sub-queries, where each subquery computes single distinct aggregation thus improving parallelism +- `AUTOMATIC` chooses the strategy automatically. Single-step strategy is preferred. However, for cases with limited concurrency due to a small number of distinct grouping keys, it will choose an alternative strategy diff --git a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java index 4501e6e87bc2..6b1a3b6b40b1 100644 --- a/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java +++ b/lib/trino-plugin-toolkit/src/main/java/io/trino/plugin/base/classloader/ClassLoaderSafeConnectorMetadata.java @@ -1285,6 +1285,14 @@ public OptionalInt getMaxWriterTasks(ConnectorSession session) } } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + return delegate.allowSplittingReadIntoMultipleSubQueries(session, tableHandle); + } + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 300f2cad7c0f..3cfa6cfbc1e2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -3803,6 +3803,13 @@ private Optional getRawSystemTable(ConnectorSession session, Schema }; } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + // delta lake supports only a columnar (parquet) storage format + return true; + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java index 99cb71580eb0..9b8ac1488af2 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveMetadata.java @@ -3981,6 +3981,24 @@ private static Optional redirectTableToHudi(Optional new TableNotFoundException(tableName)); + + try { + HiveStorageFormat hiveStorageFormat = extractHiveStorageFormat(table); + return hiveStorageFormat == HiveStorageFormat.ORC || hiveStorageFormat == HiveStorageFormat.PARQUET; + } + catch (TrinoException ignored) { + // unknown storage format + return false; + } + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { diff --git a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiMetadata.java b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiMetadata.java index 84743c1bcde9..f86182de78ad 100644 --- a/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiMetadata.java +++ b/plugin/trino-hudi/src/main/java/io/trino/plugin/hudi/HudiMetadata.java @@ -287,6 +287,13 @@ public void validateScan(ConnectorSession session, ConnectorTableHandle handle) } } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + // hudi supports only a columnar (parquet) storage format + return true; + } + HiveMetastore getMetastore() { return metastore; diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java index aa7fda6fc924..4c77ee80684f 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergMetadata.java @@ -3273,6 +3273,15 @@ public Optional redirectTable(ConnectorSession session, return catalog.redirectTable(session, tableName, targetCatalogName.get()); } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle connectorTableHandle) + { + IcebergTableHandle tableHandle = (IcebergTableHandle) connectorTableHandle; + IcebergFileFormat storageFormat = getFileFormat(tableHandle.getStorageProperties()); + + return storageFormat == IcebergFileFormat.ORC || storageFormat == IcebergFileFormat.PARQUET; + } + @Override public WriterScalingOptions getNewTableWriterScalingOptions(ConnectorSession session, SchemaTableName tableName, Map tableProperties) { diff --git a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java index 164bdc7fd1e4..285bfaafcbce 100644 --- a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java +++ b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java @@ -587,6 +587,12 @@ public Optional> applySample(Conne true)); } + @Override + public boolean allowSplittingReadIntoMultipleSubQueries(ConnectorSession session, ConnectorTableHandle tableHandle) + { + return true; + } + @Override public synchronized void setTableComment(ConnectorSession session, ConnectorTableHandle tableHandle, Optional comment) { diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java index b6584bb0c2a9..0de786afcfe9 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestAggregations.java @@ -334,6 +334,16 @@ public void testMultipleDifferentDistinct() assertQuery("SELECT COUNT(DISTINCT orderstatus), SUM(DISTINCT custkey) FROM orders"); } + @Test + public void testMultipleDifferentDistinctOverUnion() + { + assertQuery(""" + SELECT custkey, COUNT(DISTINCT orderkey), COUNT(DISTINCT orderstatus) + FROM (SELECT orderkey, orderstatus, custkey FROM orders WHERE orderstatus = 'O' + UNION ALL SELECT orderkey, orderstatus, custkey FROM orders WHERE orderstatus = 'F') + GROUP BY custkey"""); + } + @Test public void testMultipleDistinct() { diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestDistinctToSubqueriesAggregations.java b/testing/trino-tests/src/test/java/io/trino/tests/TestDistinctToSubqueriesAggregations.java new file mode 100644 index 000000000000..a75d5b2f071c --- /dev/null +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestDistinctToSubqueriesAggregations.java @@ -0,0 +1,35 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.tests; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.memory.MemoryQueryRunner; +import io.trino.testing.AbstractTestAggregations; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; + +public class TestDistinctToSubqueriesAggregations + extends AbstractTestAggregations +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + // using memory connector, because it enables ConnectorMetadata#allowSplittingReadIntoMultipleSubQueries + return MemoryQueryRunner.builder() + .setInitialTables(TpchTable.getTables()) + .setCoordinatorProperties(ImmutableMap.of("optimizer.distinct-aggregations-strategy", "split_to_subqueries")) + .build(); + } +}