Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@
/**
* This rule decorrelates a correlated subquery of LEFT or INNER correlated join with:
* - single global aggregation, or
* - global aggregation over distinct operator (grouped aggregation with no aggregation assignments)
* - global aggregation over distinct operator (grouped aggregation with no aggregation assignments),
* in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator
* <p>
* In the case of single aggregation, it transforms:
* <pre>
Expand Down Expand Up @@ -152,17 +153,23 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co

// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);

// if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}

// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(metadata, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
// we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
}
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
}

source = decorrelatedSource.get().getNode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@
/**
* This rule decorrelates a correlated subquery with:
* - single global aggregation, or
* - global aggregation over distinct operator (grouped aggregation with no aggregation assignments)
* - global aggregation over distinct operator (grouped aggregation with no aggregation assignments),
* in case when the distinct operator cannot be de-correlated by PlanNodeDecorrelator
* It is similar to TransformCorrelatedGlobalAggregationWithProjection rule, but does not support projection over aggregation in the subquery
* <p>
* In the case of single aggregation, it transforms:
Expand Down Expand Up @@ -143,19 +144,24 @@ public Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Co
{
checkArgument(correlatedJoinNode.getType() == INNER || correlatedJoinNode.getType() == LEFT, "unexpected correlated join type: " + correlatedJoinNode.getType());

// if there is another aggregation below the AggregationNode, handle both
PlanNode source = captures.get(SOURCE);

// if we fail to decorrelate the nested plan, and it contains a distinct operator, we can extract and special-handle the distinct operator
AggregationNode distinct = null;
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
}

// decorrelate nested plan
PlanNodeDecorrelator decorrelator = new PlanNodeDecorrelator(metadata, context.getSymbolAllocator(), context.getLookup());
Optional<PlanNodeDecorrelator.DecorrelatedNode> decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
if (decorrelatedSource.isEmpty()) {
return Result.empty();
// we failed to decorrelate the nested plan, so check if we can extract a distinct operator from the nested plan
if (isDistinctOperator(source)) {
distinct = (AggregationNode) source;
source = distinct.getSource();
decorrelatedSource = decorrelator.decorrelateFilters(source, correlatedJoinNode.getCorrelation());
}
if (decorrelatedSource.isEmpty()) {
return Result.empty();
}
}

source = decorrelatedSource.get().getNode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -927,6 +927,36 @@ public void testCorrelatedScalarAggregationRewriteToLeftOuterJoin()
node(ValuesNode.class))))))));
}

@Test
public void testCorrelatedDistinctAggregationRewriteToLeftOuterJoin()
{
assertPlan(
"SELECT (SELECT count(DISTINCT o.orderkey) FROM orders o WHERE c.custkey = o.custkey), c.custkey FROM customer c",
output(
project(join(
INNER,
ImmutableList.of(),
join(
LEFT,
ImmutableList.of(equiJoinClause("c_custkey", "o_custkey")),
anyTree(tableScan("customer", ImmutableMap.of("c_custkey", "custkey"))),
anyTree(aggregation(
singleGroupingSet("o_custkey"),
ImmutableMap.of(Optional.of("count"), functionCall("count", ImmutableList.of("o_orderkey"))),
ImmutableList.of(),
ImmutableList.of("non_null"),
Optional.empty(),
SINGLE,
project(ImmutableMap.of("non_null", expression("true")),
aggregation(
singleGroupingSet("o_orderkey", "o_custkey"),
ImmutableMap.of(),
Optional.empty(),
FINAL,
anyTree(tableScan("orders", ImmutableMap.of("o_orderkey", "orderkey", "o_custkey", "custkey")))))))),
anyTree(node(ValuesNode.class))))));
}

@Test
public void testRemovesTrivialFilters()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,54 @@ public void rewritesOnSubqueryWithDistinct()
values("a", "b"))))))));
}

@Test
public void rewritesOnSubqueryWithDecorrelatableDistinct()
{
// distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator
// because the correlated predicate is equality comparison
tester().assertThat(new TransformCorrelatedGlobalAggregationWithProjection(tester().getMetadata()))
.on(p -> p.correlatedJoin(
ImmutableList.of(p.symbol("corr")),
p.values(p.symbol("corr")),
p.project(
Assignments.of(p.symbol("expr_sum"), PlanBuilder.expression("sum + 1"), p.symbol("expr_count"), PlanBuilder.expression("count - 1")),
p.aggregation(outerBuilder -> outerBuilder
.addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT))
.addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of())
.globalGrouping()
.source(p.aggregation(innerBuilder -> innerBuilder
.singleGroupingSet(p.symbol("a"))
.source(p.filter(
PlanBuilder.expression("b = corr"),
p.values(p.symbol("a"), p.symbol("b"))))))))))
.matches(
project(ImmutableMap.of("corr", expression("corr"), "expr_sum", expression("(sum_agg + 1)"), "expr_count", expression("count_agg - 1")),
aggregation(
singleGroupingSet("corr", "unique"),
ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())),
ImmutableList.of(),
ImmutableList.of("non_null"),
Optional.empty(),
SINGLE,
join(
LEFT,
ImmutableList.of(),
Optional.of("b = corr"),
assignUniqueId(
"unique",
values("corr")),
project(
ImmutableMap.of("non_null", expression("true")),
aggregation(
singleGroupingSet("a", "b"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
filter(
"true",
values("a", "b"))))))));
}

@Test
public void testWithPreexistingMask()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,52 @@ public void rewritesOnSubqueryWithDistinct()
values("a", "b"))))))));
}

@Test
public void rewritesOnSubqueryWithDecorrelatableDistinct()
{
// distinct aggregation can be decorrelated in the subquery by PlanNodeDecorrelator
// because the correlated predicate is equality comparison
tester().assertThat(new TransformCorrelatedGlobalAggregationWithoutProjection(tester().getMetadata()))
.on(p -> p.correlatedJoin(
ImmutableList.of(p.symbol("corr")),
p.values(p.symbol("corr")),
p.aggregation(outerBuilder -> outerBuilder
.addAggregation(p.symbol("sum"), PlanBuilder.expression("sum(a)"), ImmutableList.of(BIGINT))
.addAggregation(p.symbol("count"), PlanBuilder.expression("count()"), ImmutableList.of())
.globalGrouping()
.source(p.aggregation(innerBuilder -> innerBuilder
.singleGroupingSet(p.symbol("a"))
.source(p.filter(
PlanBuilder.expression("b = corr"),
p.values(p.symbol("a"), p.symbol("b")))))))))
.matches(
project(ImmutableMap.of("corr", expression("corr"), "sum_agg", expression("sum_agg"), "count_agg", expression("count_agg")),
aggregation(
singleGroupingSet("corr", "unique"),
ImmutableMap.of(Optional.of("sum_agg"), functionCall("sum", ImmutableList.of("a")), Optional.of("count_agg"), functionCall("count", ImmutableList.of())),
ImmutableList.of(),
ImmutableList.of("non_null"),
Optional.empty(),
SINGLE,
join(
LEFT,
ImmutableList.of(),
Optional.of("b = corr"),
assignUniqueId(
"unique",
values("corr")),
project(
ImmutableMap.of("non_null", expression("true")),
aggregation(
singleGroupingSet("a", "b"),
ImmutableMap.of(),
Optional.empty(),
SINGLE,
filter(
"true",
values("a", "b"))))))));
}

@Test
public void testWithPreexistingMask()
{
Expand Down