diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java index e6b3acce8bc9..6966153344fb 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithProjection.java @@ -66,7 +66,8 @@ /** * This rule decorrelates a correlated subquery of LEFT or INNER correlated join with: * - single global aggregation, or - * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments) + * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments), + * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator *
* In the case of single aggregation, it transforms: *
@@ -153,17 +154,23 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
+
+ // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator
AggregationNode distinct = null;
- if (isDistinctOperator(source)) {
- distinct = (AggregationNode) source;
- source = distinct.getSource();
- }
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
- return Result.empty();
+ // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan
+ if (isDistinctOperator(source)) {
+ distinct = (AggregationNode) source;
+ source = distinct.getSource();
+ decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
+ }
+ if (decorrelatedSource.isEmpty()) {
+ return Result.empty();
+ }
}
source = decorrelatedSource.get().getNode();
diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
index 4cb37fdb4685..b8420729e128 100644
--- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
+++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/TransformCorrelatedGlobalAggregationWithoutProjection.java
@@ -63,7 +63,8 @@
/**
* This rule decorrelates a correlated subquery with:
* - single global aggregation, or
- * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments)
+ * - global aggregation over distinct operator (grouped aggregation with no aggregation assignments),
+ * in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator
* It is similar to TransformCorrelatedGlobalAggregationWithProjection rule, but does not support projection over aggregation in the subquery
*
* In the case of single aggregation, it transforms:
@@ -144,19 +145,24 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
{
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());
- // if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);
+
+ // if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator
AggregationNode distinct = null;
- if (isDistinctOperator(source)) {
- distinct = (AggregationNode) source;
- source = distinct.getSource();
- }
// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(plannerContext, context.getSymbolAllocator(), context.getLookup());
Optional decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
- return Result.empty();
+ // we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan
+ if (isDistinctOperator(source)) {
+ distinct = (AggregationNode) source;
+ source = distinct.getSource();
+ decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
+ }
+ if (decorrelatedSource.isEmpty()) {
+ return Result.empty();
+ }
}
source = decorrelatedSource.get().getNode();
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java
index b523aa84e1ab..0766a04d4280 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java
@@ -965,6 +965,36 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
node(ValuesNode.class)))))));
}
+ @Test
+ public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin()
+ {
+ assertPlan(
+ "SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey), c.custkey FROM customer c",
+ output(
+ project(join(
+ INNER,
+ ImmutableList.of(),
+ join(
+ LEFT,
+ ImmutableList.of(equiJoinClause("c_custkey", "o_custkey")),
+ anyTree(tableScan("customer", ImmutableMap.of("c_custkey", "custkey"))),
+ anyTree(aggregation(
+ singleGroupingSet("o_custkey"),
+ ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))),
+ ImmutableList.of(),
+ ImmutableList.of("non_null"),
+ Optional.empty(),
+ SINGLE,
+ project(ImmutableMap.of("non_null", expression("true")),
+ aggregation(
+ singleGroupingSet("o_orderkey", "o_custkey"),
+ ImmutableMap.of(),
+ Optional.empty(),
+ FINAL,
+ anyTree(tableScan("orders", ImmutableMap.of("o_orderkey", "orderkey", "o_custkey", "custkey")))))))),
+ anyTree(node(ValuesNode.class))))));
+ }
+
@Test
public void testRemovesTrivialFilters()
{
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java
index 66187e6bb0e0..6463ec8aefae 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithProjection.java
@@ -184,6 +184,54 @@ public void rewritesOnSubqueryWithDistinct()
values("a", "b"))))))));
}
+ @Test
+ public void rewritesOnSubqueryWithDecorrelatableDistinct()
+ {
+ // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator
+ // because the correlated predicate is equality comparison
+ tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getPlannerContext()))
+ .on(p -> p.correlatedJoin(
+ ImmutableList.of(p.symbol("corr")),
+ p.values(p.symbol("corr")),
+ p.project(
+ Assignments.of(p.symbol("expr_sum"), PlanBuilder.expression("sum + 1"), p.symbol("expr_count"), PlanBuilder.expression("count - 1")),
+ p.aggregation(outerBuilder -> outerBuilder
+ .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT))
+ .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of())
+ .globalGrouping()
+ .source(p.aggregation(innerBuilder -> innerBuilder
+ .singleGroupingSet(p.symbol("a"))
+ .source(p.filter(
+ PlanBuilder.expression("b = corr"),
+ p.values(p.symbol("a"), p.symbol("b"))))))))))
+ .matches(
+ project(ImmutableMap.of("corr", expression("corr"), "expr_sum", expression("(sum_agg + 1)"), "expr_count", expression("count_agg - 1")),
+ aggregation(
+ singleGroupingSet("corr", "unique"),
+ ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())),
+ ImmutableList.of(),
+ ImmutableList.of("non_null"),
+ Optional.empty(),
+ SINGLE,
+ join(
+ LEFT,
+ ImmutableList.of(),
+ Optional.of("b = corr"),
+ assignUniqueId(
+ "unique",
+ values("corr")),
+ project(
+ ImmutableMap.of("non_null", expression("true")),
+ aggregation(
+ singleGroupingSet("a", "b"),
+ ImmutableMap.of(),
+ Optional.empty(),
+ SINGLE,
+ filter(
+ "true",
+ values("a", "b"))))))));
+ }
+
@Test
public void testWithPreexistingMask()
{
diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java
index 6306a38e5749..3e34517365e1 100644
--- a/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java
+++ b/core/trino-main/src/test/java/io/trino/sql/planner/iterative/rule/TestTransformCorrelatedGlobalAggregationWithoutProjection.java
@@ -207,6 +207,52 @@ public void rewritesOnSubqueryWithDistinct()
values("a", "b"))))))));
}
+ @Test
+ public void rewritesOnSubqueryWithDecorrelatableDistinct()
+ {
+ // distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator
+ // because the correlated predicate is equality comparison
+ tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getPlannerContext()))
+ .on(p -> p.correlatedJoin(
+ ImmutableList.of(p.symbol("corr")),
+ p.values(p.symbol("corr")),
+ p.aggregation(outerBuilder -> outerBuilder
+ .addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT))
+ .addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of())
+ .globalGrouping()
+ .source(p.aggregation(innerBuilder -> innerBuilder
+ .singleGroupingSet(p.symbol("a"))
+ .source(p.filter(
+ PlanBuilder.expression("b = corr"),
+ p.values(p.symbol("a"), p.symbol("b")))))))))
+ .matches(
+ project(ImmutableMap.of("corr", expression("corr"), "sum_agg", expression("sum_agg"), "count_agg", expression("count_agg")),
+ aggregation(
+ singleGroupingSet("corr", "unique"),
+ ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())),
+ ImmutableList.of(),
+ ImmutableList.of("non_null"),
+ Optional.empty(),
+ SINGLE,
+ join(
+ LEFT,
+ ImmutableList.of(),
+ Optional.of("b = corr"),
+ assignUniqueId(
+ "unique",
+ values("corr")),
+ project(
+ ImmutableMap.of("non_null", expression("true")),
+ aggregation(
+ singleGroupingSet("a", "b"),
+ ImmutableMap.of(),
+ Optional.empty(),
+ SINGLE,
+ filter(
+ "true",
+ values("a", "b"))))))));
+ }
+
@Test
public void testWithPreexistingMask()
{