diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java index f1695304d009..c315c6a00003 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -65,7 +65,8 @@ /** * This rule decorrelates a correlated subquery of LEFT or INNER correlated join with: * - single global aggregation, or - * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments) + * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments), + * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator *

* In the case of single aggregation, it transforms: *

@@ -152,17 +153,23 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
 
         // if there is another aggregation below the AggregationNode, handle both
         PlanNode source = captures.get(SOURCE);
+
+        // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator
         AggregationNode distinct = null;
-        if (isDistinctOperator(source)) {
-            distinct = (AggregationNode) source;
-            source = distinct.getSource();
-        }
 
         // decorrelate nested plan
         PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(metadata, context.getSymbolAllocator(), context.getLookup());
         Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
         if (decorrelatedSource.isEmpty()) {
-            return Result.empty();
+            // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan
+            if (isDistinctOperator(source)) {
+                distinct = (AggregationNode) source;
+                source = distinct.getSource();
+                decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
+            }
+            if (decorrelatedSource.isEmpty()) {
+                return Result.empty();
+            }
         }
 
         source = decorrelatedSource.get().getNode();
diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
index edc205c7c33d..1a4b94786f86 100644
--- a/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
+++ b/presto-main/src/main/java/io/prestosql/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
@@ -62,7 +62,8 @@
 /**
  * This rule decorrelates a correlated subquery with:
  * - single global aggregation, or
- * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments)
+ * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments),
+ * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator
  * It is similar to TransformCorrelatedGlobalAggregationWithProjection rule, but does not support projection over aggregation in the subquery
  * 

* In the case of single aggregation, it transforms: @@ -143,19 +144,24 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co { checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType()); - // if there is another aggregation below the AggregationNode, handle both PlanNode source = captures.get(SOURCE); + + // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator AggregationNode distinct = null; - if (isDistinctOperator(source)) { - distinct = (AggregationNode) source; - source = distinct.getSource(); - } // decorrelate nested plan PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(metadata, context.getSymbolAllocator(), context.getLookup()); Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); if (decorrelatedSource.isEmpty()) { - return Result.empty(); + // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan + if (isDistinctOperator(source)) { + distinct = (AggregationNode) source; + source = distinct.getSource(); + decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation()); + } + if (decorrelatedSource.isEmpty()) { + return Result.empty(); + } } source = decorrelatedSource.get().getNode(); diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java index d17964527f21..a373fb735497 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/TestLogicalPlanner.java @@ -927,6 +927,36 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin() node(ValuesNode.class)))))))); } + @Test + public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin() + { + assertPlan( + "SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey), c.custkey FROM customer c", + output( + project(join( + INNER, + ImmutableList.of(), + join( + LEFT, + ImmutableList.of(equiJoinClause("c_custkey", "o_custkey")), + anyTree(tableScan("customer", ImmutableMap.of("c_custkey", "custkey"))), + anyTree(aggregation( + singleGroupingSet("o_custkey"), + ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))), + ImmutableList.of(), + ImmutableList.of("non_null"), + Optional.empty(), + SINGLE, + project(ImmutableMap.of("non_null", expression("true")), + aggregation( + singleGroupingSet("o_orderkey", "o_custkey"), + ImmutableMap.of(), + Optional.empty(), + FINAL, + anyTree(tableScan("orders", ImmutableMap.of("o_orderkey", "orderkey", "o_custkey", "custkey")))))))), + anyTree(node(ValuesNode.class)))))); + } + @Test public void testRemovesTrivialFilters() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java index e48347316e04..de224c5d73b0 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java @@ -184,6 +184,54 @@ public void rewritesOnSubqueryWithDistinct() values("a", "b")))))))); } + @Test + public void rewritesOnSubqueryWithDecorrelatableDistinct() + { + // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator + // because the correlated predicate is equality comparison + tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getMetadata())) + .on(p -> p.correlatedJoin( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.project( + Assignments.of(p.symbol("expr_sum"), PlanBuilder.expression("sum + 1"), p.symbol("expr_count"), PlanBuilder.expression("count - 1")), + p.aggregation(outerBuilder -> outerBuilder + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of()) + .globalGrouping() + .source(p.aggregation(innerBuilder -> innerBuilder + .singleGroupingSet(p.symbol("a")) + .source(p.filter( + PlanBuilder.expression("b = corr"), + p.values(p.symbol("a"), p.symbol("b")))))))))) + .matches( + project(ImmutableMap.of("corr", expression("corr"), "expr_sum", expression("(sum_agg + 1)"), "expr_count", expression("count_agg - 1")), + aggregation( + singleGroupingSet("corr", "unique"), + ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())), + ImmutableList.of(), + ImmutableList.of("non_null"), + Optional.empty(), + SINGLE, + join( + LEFT, + ImmutableList.of(), + Optional.of("b = corr"), + assignUniqueId( + "unique", + values("corr")), + project( + ImmutableMap.of("non_null", expression("true")), + aggregation( + singleGroupingSet("a", "b"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + filter( + "true", + values("a", "b")))))))); + } + @Test public void testWithPreexistingMask() { diff --git a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java index 0701772af169..201374ba39cf 100644 --- a/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java +++ b/presto-main/src/test/java/io/prestosql/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java @@ -207,6 +207,52 @@ public void rewritesOnSubqueryWithDistinct() values("a", "b")))))))); } + @Test + public void rewritesOnSubqueryWithDecorrelatableDistinct() + { + // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator + // because the correlated predicate is equality comparison + tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata())) + .on(p -> p.correlatedJoin( + ImmutableList.of(p.symbol("corr")), + p.values(p.symbol("corr")), + p.aggregation(outerBuilder -> outerBuilder + .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT)) + .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of()) + .globalGrouping() + .source(p.aggregation(innerBuilder -> innerBuilder + .singleGroupingSet(p.symbol("a")) + .source(p.filter( + PlanBuilder.expression("b = corr"), + p.values(p.symbol("a"), p.symbol("b"))))))))) + .matches( + project(ImmutableMap.of("corr", expression("corr"), "sum_agg", expression("sum_agg"), "count_agg", expression("count_agg")), + aggregation( + singleGroupingSet("corr", "unique"), + ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())), + ImmutableList.of(), + ImmutableList.of("non_null"), + Optional.empty(), + SINGLE, + join( + LEFT, + ImmutableList.of(), + Optional.of("b = corr"), + assignUniqueId( + "unique", + values("corr")), + project( + ImmutableMap.of("non_null", expression("true")), + aggregation( + singleGroupingSet("a", "b"), + ImmutableMap.of(), + Optional.empty(), + SINGLE, + filter( + "true", + values("a", "b")))))))); + } + @Test public void testWithPreexistingMask() {