diff --git a/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst b/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst index cf47613188040..6394964ecfaaf 100644 --- a/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst +++ b/presto-docs/src/main/sphinx/presto_cpp/properties-session.rst @@ -588,3 +588,17 @@ with StringView type during global aggregation. Native Execution only. Ratio of unused (evicted) bytes to total bytes that triggers compaction. The value is in the range of [0, 1). Currently only applies to approx_most_frequent aggregate with StringView type during global aggregation. + +``optimizer.optimize_top_n_rank`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +* **Type:** ``boolean`` +* **Default value:** ``false`` + +If this is true, then filter and limit queries for ``n`` rows of +``rank()`` and ``dense_rank()`` window function values are executed +with a special TopNRowNumber operator instead of the +WindowFunction operator. + +The TopNRowNumber operator is more efficient than window as +it has a streaming behavior and does not need to buffer all input rows. diff --git a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java index f10e74930b2e3..5cda9d4108f2d 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java +++ b/presto-main-base/src/main/java/com/facebook/presto/SystemSessionProperties.java @@ -187,6 +187,7 @@ public final class SystemSessionProperties public static final String ADAPTIVE_PARTIAL_AGGREGATION = "adaptive_partial_aggregation"; public static final String ADAPTIVE_PARTIAL_AGGREGATION_ROWS_REDUCTION_RATIO_THRESHOLD = "adaptive_partial_aggregation_unique_rows_ratio_threshold"; public static final String OPTIMIZE_TOP_N_ROW_NUMBER = "optimize_top_n_row_number"; + public static final String OPTIMIZE_TOP_N_RANK = "optimize_top_n_rank"; public static final String OPTIMIZE_CASE_EXPRESSION_PREDICATE = "optimize_case_expression_predicate"; public static final String MAX_GROUPING_SETS = "max_grouping_sets"; public static final String LEGACY_UNNEST = "legacy_unnest"; @@ -1004,6 +1005,11 @@ public SystemSessionProperties( "Use top N row number optimization", featuresConfig.isOptimizeTopNRowNumber(), false), + booleanProperty( + OPTIMIZE_TOP_N_RANK, + "Use top N rank and dense_rank optimization", + featuresConfig.isOptimizeTopNRank(), + false), booleanProperty( OPTIMIZE_CASE_EXPRESSION_PREDICATE, "Optimize case expression predicates", @@ -2694,6 +2700,11 @@ public static boolean isOptimizeTopNRowNumber(Session session) return session.getSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.class); } + public static boolean isOptimizeTopNRank(Session session) + { + return session.getSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.class); + } + public static boolean isOptimizeCaseExpressionPredicate(Session session) { return session.getSystemProperty(OPTIMIZE_CASE_EXPRESSION_PREDICATE, Boolean.class); diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java index e4c73d4b16a14..97aeb7e08c210 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java @@ -166,6 +166,8 @@ public class FeaturesConfig private boolean adaptivePartialAggregationEnabled; private double adaptivePartialAggregationRowsReductionRatioThreshold = 0.8; private boolean optimizeTopNRowNumber = true; + + private boolean optimizeTopNRank; private boolean pushLimitThroughOuterJoin = true; private boolean optimizeConstantGroupingKeys = true; @@ -1195,6 +1197,11 @@ public boolean isOptimizeTopNRowNumber() return optimizeTopNRowNumber; } + public boolean isOptimizeTopNRank() + { + return optimizeTopNRank; + } + @Config("optimizer.optimize-top-n-row-number") public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber) { @@ -1202,6 +1209,13 @@ public FeaturesConfig setOptimizeTopNRowNumber(boolean optimizeTopNRowNumber) return this; } + @Config("optimizer.optimize-top-n-rank") + public FeaturesConfig setOptimizeTopNRank(boolean optimizeTopNRank) + { + this.optimizeTopNRank = optimizeTopNRank; + return this; + } + public boolean isOptimizeCaseExpressionPredicate() { return optimizeCaseExpressionPredicate; diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java index 6f5274eb36925..98f8d4f696599 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/CanonicalPlanGenerator.java @@ -699,6 +699,7 @@ public Optional visitTopNRowNumber(TopNRowNumberNode node, Context con new DataOrganizationSpecification( partitionBy, node.getSpecification().getOrderingScheme().map(scheme -> getCanonicalOrderingScheme(scheme, context.getExpressions()))), + node.getRankingFunction(), rowNumberVariable, node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java index 93447d712844d..7a74946525608 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/iterative/rule/MinMaxByToWindowFunction.java @@ -147,6 +147,7 @@ else if (!maxByAggregations.isEmpty() && minByAggregations.isEmpty()) { node.getStatsEquivalentPlanNode(), node.getSource(), dataOrganizationSpecification, + TopNRowNumberNode.RankingFunction.ROW_NUMBER, rowNumberVariable, 1, false, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java index 083b88040db61..1a788cf6b3e2b 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/AddExchanges.java @@ -608,6 +608,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, PreferredPr idAllocator.getNextId(), child.getNode(), node.getSpecification(), + node.getRankingFunction(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), true, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java index 84f42ef283d6b..58850d0195326 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/HashGenerationOptimizer.java @@ -330,6 +330,7 @@ public PlanWithProperties visitTopNRowNumber(TopNRowNumberNode node, HashComputa node.getId(), child.getNode(), node.getSpecification(), + node.getRankingFunction(), node.getRowNumberVariable(), node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java index 9915c495762be..97bb384a588ce 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PlanNodeDecorrelator.java @@ -310,6 +310,7 @@ public Optional visitTopN(TopNNode node, Void context) new DataOrganizationSpecification( ImmutableList.copyOf(childDecorrelationResult.variablesToPropagate), Optional.of(orderingScheme)), + TopNRowNumberNode.RankingFunction.ROW_NUMBER, variableAllocator.newVariable("row_number", BIGINT), toIntExact(node.getCount()), false, diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java index 201ec219823af..6ea24ac7b046e 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/PruneUnreferencedOutputs.java @@ -785,6 +785,7 @@ public PlanNode visitTopNRowNumber(TopNRowNumberNode node, RewriteContext node.getId(), context.rewrite(node.getSource()), canonicalizeAndDistinct(node.getSpecification()), + node.getRankingFunction(), canonicalize(node.getRowNumberVariable()), node.getMaxRowCountPerPartition(), node.isPartial(), diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java index 337e9eb39df03..3164106d1c3f3 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/optimizations/WindowFilterPushDown.java @@ -41,11 +41,14 @@ import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; import com.facebook.presto.sql.relational.RowExpressionDomainTranslator; import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; import java.util.Map; import java.util.Optional; import java.util.OptionalInt; +import static com.facebook.presto.SystemSessionProperties.isNativeExecutionEnabled; +import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRank; import static com.facebook.presto.SystemSessionProperties.isOptimizeTopNRowNumber; import static com.facebook.presto.common.predicate.Marker.Bound.BELOW; import static com.facebook.presto.common.type.BigintType.BIGINT; @@ -134,6 +137,12 @@ public PlanNode visitWindow(WindowNode node, RewriteContext context) return replaceChildren(node, ImmutableList.of(rewrittenSource)); } + private boolean canReplaceWithTopNRowNumber(WindowNode node) + { + return (canOptimizeRowNumberFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) || + (isNativeExecutionEnabled(session) && canOptimizeRankFunction(node, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRank(session)); + } + @Override public PlanNode visitLimit(LimitNode node, RewriteContext context) { @@ -152,16 +161,22 @@ public PlanNode visitLimit(LimitNode node, RewriteContext context) planChanged = true; source = rowNumberNode; } - else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) { + else if (source instanceof WindowNode) { WindowNode windowNode = (WindowNode) source; - // verify that unordered row_number window functions are replaced by RowNumberNode - verify(windowNode.getOrderingScheme().isPresent()); - TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit); - if (windowNode.getPartitionBy().isEmpty()) { - return topNRowNumberNode; + if (canReplaceWithTopNRowNumber(windowNode)) { + // Unordered row_number window functions are replaced by RowNumberNode and + // only rank/dense_rank with ordering schema are optimized. + verify(windowNode.getOrderingScheme().isPresent()); + + TopNRowNumberNode topNRowNumberNode = convertToTopNRowNumber(windowNode, limit); + planChanged = true; + // Limit can be entirely skipped for row_number without partitioning (not for rank/dense_rank). + if (windowNode.getPartitionBy().isEmpty() && + canOptimizeRowNumberFunction(windowNode, metadata.getFunctionAndTypeManager())) { + return topNRowNumberNode; + } + source = topNRowNumberNode; } - planChanged = true; - source = topNRowNumberNode; } return replaceChildren(node, ImmutableList.of(source)); } @@ -183,15 +198,17 @@ public PlanNode visitFilter(FilterNode node, RewriteContext context) return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); } } - else if (source instanceof WindowNode && canOptimizeWindowFunction((WindowNode) source, metadata.getFunctionAndTypeManager()) && isOptimizeTopNRowNumber(session)) { + else if (source instanceof WindowNode) { WindowNode windowNode = (WindowNode) source; - VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable()); - OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable); - - if (upperBound.isPresent()) { - source = convertToTopNRowNumber(windowNode, upperBound.getAsInt()); - planChanged = true; - return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); + if (canReplaceWithTopNRowNumber(windowNode)) { + VariableReferenceExpression rowNumberVariable = getOnlyElement(windowNode.getCreatedVariable()); + OptionalInt upperBound = extractUpperBound(tupleDomain, rowNumberVariable); + + if (upperBound.isPresent()) { + source = convertToTopNRowNumber(windowNode, upperBound.getAsInt()); + planChanged = true; + return rewriteFilterSource(node, source, rowNumberVariable, upperBound.getAsInt()); + } } } return replaceChildren(node, ImmutableList.of(source)); @@ -275,11 +292,30 @@ private static RowNumberNode mergeLimit(RowNumberNode node, int newRowCountPerPa private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limit) { + String windowFunction = Iterables.getOnlyElement(windowNode.getWindowFunctions().values()).getFunctionCall().getFunctionHandle().getName(); + String[] parts = windowFunction.split("\\."); + String windowFunctionName = parts[parts.length - 1]; + TopNRowNumberNode.RankingFunction rankingFunction; + switch (windowFunctionName) { + case "row_number": + rankingFunction = TopNRowNumberNode.RankingFunction.ROW_NUMBER; + break; + case "rank": + rankingFunction = TopNRowNumberNode.RankingFunction.RANK; + break; + case "dense_rank": + rankingFunction = TopNRowNumberNode.RankingFunction.DENSE_RANK; + break; + default: + throw new IllegalArgumentException("Unsupported window function for TopNRowNumberNode: " + windowFunctionName); + } + return new TopNRowNumberNode( windowNode.getSourceLocation(), idAllocator.getNextId(), windowNode.getSource(), windowNode.getSpecification(), + rankingFunction, getOnlyElement(windowNode.getCreatedVariable()), limit, false, @@ -288,16 +324,29 @@ private TopNRowNumberNode convertToTopNRowNumber(WindowNode windowNode, int limi private static boolean canReplaceWithRowNumber(WindowNode node, FunctionAndTypeManager functionAndTypeManager) { - return canOptimizeWindowFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent(); + return canOptimizeRowNumberFunction(node, functionAndTypeManager) && !node.getOrderingScheme().isPresent(); } - private static boolean canOptimizeWindowFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager) + private static boolean canOptimizeRowNumberFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager) { if (node.getWindowFunctions().size() != 1) { return false; } - VariableReferenceExpression rowNumberVariable = getOnlyElement(node.getWindowFunctions().keySet()); - return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(node.getWindowFunctions().get(rowNumberVariable).getFunctionHandle())); + return isRowNumberMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(getOnlyElement(node.getWindowFunctions().values()).getFunctionHandle())); + } + + private static boolean canOptimizeRankFunction(WindowNode node, FunctionAndTypeManager functionAndTypeManager) + { + if (node.getWindowFunctions().size() != 1) { + return false; + } + + // This optimization requires an ordering scheme for the rank functions. + if (!node.getOrderingScheme().isPresent()) { + return false; + } + + return isRankMetadata(functionAndTypeManager, functionAndTypeManager.getFunctionMetadata(getOnlyElement(node.getWindowFunctions().values()).getFunctionHandle())); } private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata) @@ -305,5 +354,13 @@ private static boolean isRowNumberMetadata(FunctionAndTypeManager functionAndTyp FunctionHandle rowNumberFunction = functionAndTypeManager.lookupFunction("row_number", ImmutableList.of()); return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rowNumberFunction)); } + + private static boolean isRankMetadata(FunctionAndTypeManager functionAndTypeManager, FunctionMetadata functionMetadata) + { + FunctionHandle rankFunction = functionAndTypeManager.lookupFunction("rank", ImmutableList.of()); + FunctionHandle denseRankFunction = functionAndTypeManager.lookupFunction("dense_rank", ImmutableList.of()); + return functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(rankFunction)) || + functionMetadata.equals(functionAndTypeManager.getFunctionMetadata(denseRankFunction)); + } } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java index f5fc3bde590a6..9969ad1a950ef 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java +++ b/presto-main-base/src/main/java/com/facebook/presto/sql/planner/plan/TopNRowNumberNode.java @@ -35,8 +35,16 @@ public final class TopNRowNumberNode extends InternalPlanNode { + public enum RankingFunction + { + ROW_NUMBER, + RANK, + DENSE_RANK + } + private final PlanNode source; private final DataOrganizationSpecification specification; + private final RankingFunction rankingFunction; private final VariableReferenceExpression rowNumberVariable; private final int maxRowCountPerPartition; private final boolean partial; @@ -48,12 +56,13 @@ public TopNRowNumberNode( @JsonProperty("id") PlanNodeId id, @JsonProperty("source") PlanNode source, @JsonProperty("specification") DataOrganizationSpecification specification, + @JsonProperty("rankingType") RankingFunction rankingFunction, @JsonProperty("rowNumberVariable") VariableReferenceExpression rowNumberVariable, @JsonProperty("maxRowCountPerPartition") int maxRowCountPerPartition, @JsonProperty("partial") boolean partial, @JsonProperty("hashVariable") Optional hashVariable) { - this(sourceLocation, id, Optional.empty(), source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + this(sourceLocation, id, Optional.empty(), source, specification, rankingFunction, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); } public TopNRowNumberNode( @@ -62,6 +71,7 @@ public TopNRowNumberNode( Optional statsEquivalentPlanNode, PlanNode source, DataOrganizationSpecification specification, + RankingFunction rankingFunction, VariableReferenceExpression rowNumberVariable, int maxRowCountPerPartition, boolean partial, @@ -75,9 +85,11 @@ public TopNRowNumberNode( requireNonNull(rowNumberVariable, "rowNumberVariable is null"); checkArgument(maxRowCountPerPartition > 0, "maxRowCountPerPartition must be > 0"); requireNonNull(hashVariable, "hashVariable is null"); + requireNonNull(rankingFunction, "rankingFunction is null"); this.source = source; this.specification = specification; + this.rankingFunction = rankingFunction; this.rowNumberVariable = rowNumberVariable; this.maxRowCountPerPartition = maxRowCountPerPartition; this.partial = partial; @@ -113,6 +125,12 @@ public DataOrganizationSpecification getSpecification() return specification; } + @JsonProperty + public RankingFunction getRankingFunction() + { + return rankingFunction; + } + public List getPartitionBy() { return specification.getPartitionBy(); @@ -156,12 +174,12 @@ public R accept(InternalPlanVisitor visitor, C context) @Override public PlanNode replaceChildren(List newChildren) { - return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + return new TopNRowNumberNode(getSourceLocation(), getId(), getStatsEquivalentPlanNode(), Iterables.getOnlyElement(newChildren), specification, rankingFunction, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); } @Override public PlanNode assignStatsEquivalentPlanNode(Optional statsEquivalentPlanNode) { - return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); + return new TopNRowNumberNode(getSourceLocation(), getId(), statsEquivalentPlanNode, source, specification, rankingFunction, rowNumberVariable, maxRowCountPerPartition, partial, hashVariable); } } diff --git a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java index 5368f400f4646..1f2a47c8d4083 100644 --- a/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java +++ b/presto-main-base/src/main/java/com/facebook/presto/util/GraphvizPrinter.java @@ -448,9 +448,11 @@ public Void visitTopNRowNumber(TopNRowNumberNode node, Void context) { printNode(node, "TopNRowNumber", - format("partition by = %s|order by = %s|n = %s", + format("function = %s|partition by = %s|order by = %s|n = %s", + node.getRankingFunction(), Joiner.on(", ").join(node.getPartitionBy()), - Joiner.on(", ").join(node.getOrderingScheme().getOrderByVariables()), node.getMaxRowCountPerPartition()), + Joiner.on(", ").join(node.getOrderingScheme().getOrderByVariables()), + node.getMaxRowCountPerPartition()), NODE_COLORS.get(NodeType.WINDOW)); return node.getSource().accept(this, context); } diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java index 6114e53c95d8f..c07757589c89d 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/analyzer/TestFeaturesConfig.java @@ -149,6 +149,7 @@ public void testDefaults() .setAdaptivePartialAggregationRowsReductionRatioThreshold(0.8) .setLocalExchangeParentPreferenceStrategy(LocalExchangeParentPreferenceStrategy.ALWAYS) .setOptimizeTopNRowNumber(true) + .setOptimizeTopNRank(false) .setOptimizeCaseExpressionPredicate(false) .setDistributedSortEnabled(true) .setMaxGroupingSets(2048) @@ -384,6 +385,7 @@ public void testExplicitPropertyMappings() .put("experimental.adaptive-partial-aggregation", "true") .put("experimental.adaptive-partial-aggregation-rows-reduction-ratio-threshold", "0.9") .put("optimizer.optimize-top-n-row-number", "false") + .put("optimizer.optimize-top-n-rank", "true") .put("optimizer.optimize-case-expression-predicate", "true") .put("distributed-sort", "false") .put("analyzer.max-grouping-sets", "2047") @@ -614,6 +616,7 @@ public void testExplicitPropertyMappings() .setAdaptivePartialAggregationRowsReductionRatioThreshold(0.9) .setLocalExchangeParentPreferenceStrategy(LocalExchangeParentPreferenceStrategy.AUTOMATIC) .setOptimizeTopNRowNumber(false) + .setOptimizeTopNRank(true) .setOptimizeCaseExpressionPredicate(true) .setDistributedSortEnabled(false) .setMaxGroupingSets(2047) diff --git a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestWindowFilterPushDown.java b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestWindowFilterPushDown.java index 3d0b32fceeb34..fefc10a4d78a9 100644 --- a/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestWindowFilterPushDown.java +++ b/presto-main-base/src/test/java/com/facebook/presto/sql/planner/optimizations/TestWindowFilterPushDown.java @@ -21,6 +21,8 @@ import org.intellij.lang.annotations.Language; import org.testng.annotations.Test; +import static com.facebook.presto.SystemSessionProperties.NATIVE_EXECUTION_ENABLED; +import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_TOP_N_RANK; import static com.facebook.presto.SystemSessionProperties.OPTIMIZE_TOP_N_ROW_NUMBER; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; @@ -31,15 +33,11 @@ public class TestWindowFilterPushDown extends BasePlanTest { - @Test - public void testLimitAboveWindow() + private void testLimitSql(String sql, boolean rowNumber) { - @Language("SQL") String sql = "SELECT " + - "row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10"; - assertPlanWithSession( sql, - optimizeTopNRowNumber(true), + rowNumber ? optimizeTopNRowNumber(true) : optimizeTopNRank(true), true, anyTree( limit(10, anyTree( @@ -49,25 +47,47 @@ public void testLimitAboveWindow() assertPlanWithSession( sql, - optimizeTopNRowNumber(false), + rowNumber ? optimizeTopNRowNumber(false) : optimizeTopNRank(false), true, anyTree( limit(10, anyTree( node(WindowNode.class, anyTree( tableScan("lineitem"))))))); - } + if (!rowNumber) { + assertPlanWithSession( + sql, + optimizeTopNRankWithoutNative(true), + true, + anyTree( + limit(10, anyTree( + node(WindowNode.class, + anyTree( + tableScan("lineitem"))))))); + } + } @Test - public void testFilterAboveWindow() + public void testLimitAboveWindow() { - @Language("SQL") String sql = "SELECT * FROM " + - "(SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) " + - "WHERE partition_row_number < 10"; + @Language("SQL") String sql = "SELECT " + + "row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10"; + testLimitSql(sql, true); + + sql = "SELECT " + + "rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10"; + testLimitSql(sql, false); + sql = "SELECT " + + "dense_rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem LIMIT 10"; + testLimitSql(sql, false); + } + + private void testFilterSql(String sql, boolean rowNumber) + { assertPlanWithSession( sql, - optimizeTopNRowNumber(true), + rowNumber ? optimizeTopNRowNumber(true) : optimizeTopNRank(true), true, anyTree( anyNot(FilterNode.class, @@ -77,7 +97,7 @@ public void testFilterAboveWindow() assertPlanWithSession( sql, - optimizeTopNRowNumber(false), + rowNumber ? optimizeTopNRowNumber(false) : optimizeTopNRank(false), true, anyTree( node(FilterNode.class, @@ -85,6 +105,38 @@ public void testFilterAboveWindow() node(WindowNode.class, anyTree( tableScan("lineitem"))))))); + + if (!rowNumber) { + assertPlanWithSession( + sql, + optimizeTopNRankWithoutNative(true), + true, + anyTree( + node(FilterNode.class, + anyTree( + node(WindowNode.class, + anyTree( + tableScan("lineitem"))))))); + } + } + @Test + public void testFilterAboveWindow() + { + @Language("SQL") String sql = "SELECT * FROM " + + "(SELECT row_number() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_row_number FROM lineitem) " + + "WHERE partition_row_number < 10"; + + testFilterSql(sql, true); + + sql = "SELECT * FROM " + + "(SELECT rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_rank FROM lineitem) " + + "WHERE partition_rank < 10"; + testFilterSql(sql, false); + + sql = "SELECT * FROM " + + "(SELECT dense_rank() OVER (PARTITION BY suppkey ORDER BY orderkey) partition_dense_rank FROM lineitem) " + + "WHERE partition_dense_rank < 10"; + testFilterSql(sql, false); } private Session optimizeTopNRowNumber(boolean enabled) @@ -93,4 +145,20 @@ private Session optimizeTopNRowNumber(boolean enabled) .setSystemProperty(OPTIMIZE_TOP_N_ROW_NUMBER, Boolean.toString(enabled)) .build(); } + + private Session optimizeTopNRank(boolean enabled) + { + return Session.builder(this.getQueryRunner().getDefaultSession()) + .setSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.toString(enabled)) + .setSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.toString(enabled)) + .build(); + } + + private Session optimizeTopNRankWithoutNative(boolean enabled) + { + return Session.builder(this.getQueryRunner().getDefaultSession()) + .setSystemProperty(NATIVE_EXECUTION_ENABLED, Boolean.toString(false)) + .setSystemProperty(OPTIMIZE_TOP_N_RANK, Boolean.toString(enabled)) + .build(); + } } diff --git a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp index 519e8743cd3d1..0b4dcd31cffb2 100644 --- a/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp +++ b/presto-native-execution/presto_cpp/main/types/PrestoToVeloxQueryPlan.cpp @@ -1836,6 +1836,22 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( toVeloxQueryPlan(node->source, tableWriteInfo, taskId)); } +namespace { +core::TopNRowNumberNode::RankFunction prestoToVeloxRankFunction( + protocol::RankingFunction rankingFunction) { + switch (rankingFunction) { + case protocol::RankingFunction::ROW_NUMBER: + return core::TopNRowNumberNode::RankFunction::kRowNumber; + case protocol::RankingFunction::RANK: + return core::TopNRowNumberNode::RankFunction::kRank; + case protocol::RankingFunction::DENSE_RANK: + return core::TopNRowNumberNode::RankFunction::kDenseRank; + default: + VELOX_UNREACHABLE(); + } +} +}; // namespace + std::shared_ptr VeloxQueryPlanConverterBase::toVeloxQueryPlan( const std::shared_ptr& node, @@ -1871,7 +1887,7 @@ VeloxQueryPlanConverterBase::toVeloxQueryPlan( return std::make_shared( node->id, - core::TopNRowNumberNode::RankFunction::kRowNumber, + prestoToVeloxRankFunction(node->rankingType), partitionFields, sortFields, sortOrders, diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp index 2444c3617a002..6df1d0c478daf 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.cpp @@ -12177,6 +12177,44 @@ void from_json(const json& j, TopNNode& p) { } } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +// Loosly copied this here from NLOHMANN_JSON_SERIALIZE_ENUM() + +// NOLINTNEXTLINE: cppcoreguidelines-avoid-c-arrays +static const std::pair RankingFunction_enum_table[] = + { // NOLINT: cert-err58-cpp + {RankingFunction::ROW_NUMBER, "ROW_NUMBER"}, + {RankingFunction::RANK, "RANK"}, + {RankingFunction::DENSE_RANK, "DENSE_RANK"}}; +void to_json(json& j, const RankingFunction& e) { + static_assert( + std::is_enum::value, "RankingFunction must be an enum!"); + const auto* it = std::find_if( + std::begin(RankingFunction_enum_table), + std::end(RankingFunction_enum_table), + [e](const std::pair& ej_pair) -> bool { + return ej_pair.first == e; + }); + j = ((it != std::end(RankingFunction_enum_table)) + ? it + : std::begin(RankingFunction_enum_table)) + ->second; +} +void from_json(const json& j, RankingFunction& e) { + static_assert( + std::is_enum::value, "RankingFunction must be an enum!"); + const auto* it = std::find_if( + std::begin(RankingFunction_enum_table), + std::end(RankingFunction_enum_table), + [&j](const std::pair& ej_pair) -> bool { + return ej_pair.second == j; + }); + e = ((it != std::end(RankingFunction_enum_table)) + ? it + : std::begin(RankingFunction_enum_table)) + ->first; +} +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { TopNRowNumberNode::TopNRowNumberNode() noexcept { _type = "com.facebook.presto.sql.planner.plan.TopNRowNumberNode"; } @@ -12193,6 +12231,13 @@ void to_json(json& j, const TopNRowNumberNode& p) { "TopNRowNumberNode", "DataOrganizationSpecification", "specification"); + to_json_key( + j, + "rankingType", + p.rankingType, + "TopNRowNumberNode", + "RankingFunction", + "rankingType"); to_json_key( j, "rowNumberVariable", @@ -12229,6 +12274,13 @@ void from_json(const json& j, TopNRowNumberNode& p) { "TopNRowNumberNode", "DataOrganizationSpecification", "specification"); + from_json_key( + j, + "rankingType", + p.rankingType, + "TopNRowNumberNode", + "RankingFunction", + "rankingType"); from_json_key( j, "rowNumberVariable", diff --git a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h index 07ccf8bdd09a7..82e2a33963827 100644 --- a/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h +++ b/presto-native-execution/presto_cpp/presto_protocol/core/presto_protocol_core.h @@ -2627,9 +2627,15 @@ void to_json(json& j, const TopNNode& p); void from_json(const json& j, TopNNode& p); } // namespace facebook::presto::protocol namespace facebook::presto::protocol { +enum class RankingFunction { ROW_NUMBER, RANK, DENSE_RANK }; +extern void to_json(json& j, const RankingFunction& e); +extern void from_json(const json& j, RankingFunction& e); +} // namespace facebook::presto::protocol +namespace facebook::presto::protocol { struct TopNRowNumberNode : public PlanNode { std::shared_ptr source = {}; DataOrganizationSpecification specification = {}; + RankingFunction rankingType = {}; VariableReferenceExpression rowNumberVariable = {}; int maxRowCountPerPartition = {}; bool partial = {}; diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java index 781b9088c519c..b3e45f8c2a16a 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/AbstractTestNativeWindowQueries.java @@ -13,6 +13,10 @@ */ package com.facebook.presto.nativeworker; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.spi.plan.FilterNode; +import com.facebook.presto.sql.planner.assertions.PlanMatchPattern; +import com.facebook.presto.sql.planner.plan.TopNRowNumberNode; import com.facebook.presto.testing.QueryRunner; import com.facebook.presto.tests.AbstractTestQueryFramework; import com.google.common.collect.ImmutableList; @@ -24,6 +28,11 @@ import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createLineitem; import static com.facebook.presto.nativeworker.NativeQueryRunnerUtils.createOrders; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyNot; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.anyTree; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.limit; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.node; +import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.tableScan; public abstract class AbstractTestNativeWindowQueries extends AbstractTestQueryFramework @@ -179,6 +188,61 @@ public void testRowNumberWithFilter_2() assertQuery("SELECT * FROM (SELECT row_number() over(partition by orderstatus order by orderkey) rn, * from orders) WHERE rn = 1"); } + private static final PlanMatchPattern topNForFilter = anyTree( + anyNot(FilterNode.class, + node(TopNRowNumberNode.class, + anyTree( + tableScan("orders"))))); + + private static final PlanMatchPattern topNForLimit = anyTree( + limit(10, + anyTree( + node(TopNRowNumberNode.class, + anyTree( + tableScan("orders")))))); + @Test + public void testTopNRowNumber() + { + String sql = "SELECT sum(rn) FROM (SELECT row_number() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + assertPlan(sql, topNForFilter); + + // Cannot test results for this query as they are not guaranteed to be the same due to lack of ORDER BY in LIMIT. + // But adding an ORDER BY would prevent the TopNRowNumber optimization from being applied. + sql = "SELECT sum(rn) FROM (SELECT row_number() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10)"; + assertPlan(sql, topNForLimit); + } + + @Test + public void testTopNRank() + { + String sql = "SELECT sum(rn) FROM (SELECT rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + + if (SystemSessionProperties.isOptimizeTopNRank(getSession())) { + assertPlan(sql, topNForFilter); + // Cannot test results for this query as they are not guaranteed to be the same due to lack of ORDER BY in LIMIT. + // But adding an ORDER BY would prevent the TopNRowNumber optimization from being applied. + sql = "SELECT sum(rn) FROM (SELECT rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10)"; + assertPlan(sql, topNForLimit); + } + } + + @Test + public void testTopNDenseRank() + { + String sql = "SELECT sum(rn) FROM (SELECT dense_rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders) WHERE rn <= 10"; + assertQuery(sql); + if (SystemSessionProperties.isOptimizeTopNRank(getSession())) { + assertPlan(sql, topNForFilter); + + // Cannot test results for this query as they are not guaranteed to be the same due to lack of ORDER BY in LIMIT. + // But adding an ORDER BY would prevent the TopNRowNumber optimization from being applied. + sql = "SELECT dense_rank() over(PARTITION BY orderdate ORDER BY totalprice) rn, * from orders limit 10"; + assertPlan(sql, topNForLimit); + } + } + @Test public void testFirstValueOrderKey() { diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTpcdsQueriesOrcUsingThrift.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTpcdsQueriesOrcUsingThrift.java index dfa8a91db721b..25dd8616d3ce9 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTpcdsQueriesOrcUsingThrift.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeTpcdsQueriesOrcUsingThrift.java @@ -15,6 +15,7 @@ import com.facebook.presto.testing.ExpectedQueryRunner; import com.facebook.presto.testing.QueryRunner; +import com.google.common.collect.ImmutableMap; import org.testng.annotations.Test; @Test(groups = {"orc"}) @@ -29,6 +30,7 @@ protected QueryRunner createQueryRunner() .setStorageFormat("ORC") .setAddStorageFormatToPath(true) .setUseThrift(true) + .setExtraCoordinatorProperties(ImmutableMap.of("optimizer.optimize-top-n-rank", "true")) .build(); } diff --git a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWindowQueries.java b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWindowQueries.java index 908fe3ea0c605..91018ebd46ca4 100644 --- a/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWindowQueries.java +++ b/presto-native-execution/src/test/java/com/facebook/presto/nativeworker/TestPrestoNativeWindowQueries.java @@ -15,6 +15,7 @@ import com.facebook.presto.testing.ExpectedQueryRunner; import com.facebook.presto.testing.QueryRunner; +import com.google.common.collect.ImmutableMap; public class TestPrestoNativeWindowQueries extends AbstractTestNativeWindowQueries @@ -25,6 +26,7 @@ protected QueryRunner createQueryRunner() throws Exception return PrestoNativeQueryRunnerUtils.nativeHiveQueryRunnerBuilder() .setAddStorageFormatToPath(true) .setUseThrift(true) + .setExtraCoordinatorProperties(ImmutableMap.of("optimizer.optimize-top-n-rank", "true")) .build(); } diff --git a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/TestTextReaderWithTpcdsQueriesUsingThrift.java b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/TestTextReaderWithTpcdsQueriesUsingThrift.java index 25686e175581c..b317387ffdd94 100644 --- a/presto-native-tests/src/test/java/com/facebook/presto/nativetests/TestTextReaderWithTpcdsQueriesUsingThrift.java +++ b/presto-native-tests/src/test/java/com/facebook/presto/nativetests/TestTextReaderWithTpcdsQueriesUsingThrift.java @@ -17,6 +17,7 @@ import com.facebook.presto.nativeworker.PrestoNativeQueryRunnerUtils; import com.facebook.presto.testing.ExpectedQueryRunner; import com.facebook.presto.testing.QueryRunner; +import com.google.common.collect.ImmutableMap; public class TestTextReaderWithTpcdsQueriesUsingThrift extends AbstractTestNativeTpcdsQueries @@ -31,6 +32,7 @@ protected QueryRunner createQueryRunner() .setStorageFormat(TEXTFILE) .setAddStorageFormatToPath(true) .setUseThrift(true) + .setExtraCoordinatorProperties(ImmutableMap.of("optimizer.optimize-top-n-rank", "true")) .build(); }