Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ public InternalAggregationFunction specialize(BoundVariables boundVariables, int
return generateAggregation(type, outputType);
}

@Override
public boolean isCalledOnNullInput()
{
return true;
Comment thread
yuanzhanhku marked this conversation as resolved.
Outdated
}

private static InternalAggregationFunction generateAggregation(Type type, ArrayType outputType)
{
DynamicClassLoader classLoader = new DynamicClassLoader(SetAggregationFunction.class.getClassLoader());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ else if (state.get() == null) {
}
}

@Override
public boolean isCalledOnNullInput()
{
return true;
}

public static void output(SetAggregationState state, BlockBuilder out)
{
SetOfValues set = state.get();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -404,7 +404,7 @@ public PlanOptimizers(
ruleStats,
statsCalculator,
estimatedExchangesCostCalculator,
ImmutableSet.of(new RewriteAggregationIfToFilter())),
ImmutableSet.of(new RewriteAggregationIfToFilter(metadata.getFunctionAndTypeManager()))),
predicatePushDown,
new IterativeOptimizer(
ruleStats,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import com.facebook.presto.matching.Capture;
import com.facebook.presto.matching.Captures;
import com.facebook.presto.matching.Pattern;
import com.facebook.presto.metadata.FunctionAndTypeManager;
import com.facebook.presto.spi.plan.AggregationNode;
import com.facebook.presto.spi.plan.AggregationNode.Aggregation;
import com.facebook.presto.spi.plan.Assignments;
Expand All @@ -28,6 +29,7 @@
import com.facebook.presto.spi.relation.VariableReferenceExpression;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.relational.Expressions;
import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSortedSet;
Expand All @@ -47,6 +49,7 @@
import static com.facebook.presto.sql.planner.plan.Patterns.source;
import static com.google.common.collect.ImmutableSet.toImmutableSet;
import static com.google.common.collect.ImmutableSortedMap.toImmutableSortedMap;
import static java.util.Objects.requireNonNull;
import static java.util.function.Function.identity;

/**
Expand All @@ -68,6 +71,15 @@ public class RewriteAggregationIfToFilter
private static final Pattern<AggregationNode> PATTERN = aggregation()
.with(source().matching(project().capturedAs(CHILD)));

private final FunctionAndTypeManager functionAndTypeManager;
private final RowExpressionDeterminismEvaluator rowExpressionDeterminismEvaluator;

public RewriteAggregationIfToFilter(FunctionAndTypeManager functionAndTypeManager)
{
this.functionAndTypeManager = requireNonNull(functionAndTypeManager, "functionManager is null");
rowExpressionDeterminismEvaluator = new RowExpressionDeterminismEvaluator(functionAndTypeManager);
}

@Override
public boolean isEnabled(Session session)
{
Expand Down Expand Up @@ -178,6 +190,10 @@ public Result apply(AggregationNode aggregationNode, Captures captures, Context

private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode sourceProject)
{
if (functionAndTypeManager.getFunctionMetadata(aggregation.getFunctionHandle()).isCalledOnNullInput()) {
// This rewrite will filter out the null values. It could change the behavior if the aggregation is also applied on NULLs.
return false;
}
if (!(aggregation.getArguments().size() == 1 && aggregation.getArguments().get(0) instanceof VariableReferenceExpression)) {
// Currently we only handle aggregation with a single VariableReferenceExpression. The detailed expressions are in a project node below this aggregation.
return false;
Expand All @@ -187,11 +203,11 @@ private boolean shouldRewriteAggregation(Aggregation aggregation, ProjectNode so
return false;
}
RowExpression sourceExpression = sourceProject.getAssignments().get((VariableReferenceExpression) aggregation.getArguments().get(0));
if (!(sourceExpression instanceof SpecialFormExpression)) {
if (!(sourceExpression instanceof SpecialFormExpression) || !rowExpressionDeterminismEvaluator.isDeterministic(sourceExpression)) {
return false;
}
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
// Only rewrite the aggregation if the else branch is not present.
// Only rewrite the aggregation if the else branch is not present or the else result is NULL.
return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.Optional;

import static com.facebook.presto.common.type.BigintType.BIGINT;
import static com.facebook.presto.common.type.DoubleType.DOUBLE;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.aggregation;
import static com.facebook.presto.sql.planner.assertions.PlanMatchPattern.expression;
Expand All @@ -44,7 +45,7 @@ public class TestRewriteAggregationIfToFilter
public void testDoesNotFireForNonIf()
{
// The aggregation expression is not an if expression.
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a", BooleanType.BOOLEAN);
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand All @@ -60,7 +61,7 @@ public void testDoesNotFireForNonIf()
public void testDoesNotFireForIfWithElse()
{
// The if expression has an else branch. We cannot rewrite it.
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand All @@ -72,10 +73,35 @@ public void testDoesNotFireForIfWithElse()
}).doesNotFire();
}

@Test
public void testDoesNotFireForNonDeterministicFunction()
{
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a", DOUBLE);
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
.addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
.source(p.project(
assignment(a, p.rowExpression("IF(ds > '2021-07-01', random())")),
p.values(ds))));
}).doesNotFire();
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a", BIGINT);
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
return p.aggregation(ap -> ap.globalGrouping().step(AggregationNode.Step.FINAL)
.addAggregation(p.variable("expr"), p.rowExpression("sum(a)"))
.source(p.project(
assignment(a, p.rowExpression("IF(random() > DOUBLE '0.1', 1)")),
p.values(ds))));
}).doesNotFire();
}

@Test
public void testFireOneAggregation()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand Down Expand Up @@ -104,7 +130,7 @@ public void testFireOneAggregation()
@Test
public void testFireTwoAggregations()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression b = p.variable("b");
Expand Down Expand Up @@ -145,7 +171,7 @@ public void testFireTwoAggregations()
@Test
public void testFireTwoAggregationsWithSharedInput()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression ds = p.variable("ds", VARCHAR);
Expand Down Expand Up @@ -181,7 +207,7 @@ public void testFireTwoAggregationsWithSharedInput()
@Test
public void testFireForOneOfTwoAggregations()
{
tester().assertThat(new RewriteAggregationIfToFilter())
tester().assertThat(new RewriteAggregationIfToFilter(getFunctionManager()))
.on(p -> {
VariableReferenceExpression a = p.variable("a");
VariableReferenceExpression b = p.variable("b");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,61 @@ public void testGroupAll()
"VALUES (BIGINT '3', BIGINT '6')");
}

@Test
public void testSetAggWithNulls()
{
assertions.assertQuery(
"SELECT y, set_agg(y) FILTER (WHERE x = 1) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', CAST(NULL AS ARRAY<INTEGER>)), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
assertions.assertQuery(
"SELECT y, set_agg(IF(x = 1,y)) FROM (SELECT 1 x, 2 y UNION ALL SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '2', ARRAY[INTEGER '2']), (INTEGER '20', ARRAY[CAST(NULL AS INTEGER)]), (CAST(NULL AS INTEGER), ARRAY[CAST(NULL AS INTEGER)])");
Comment thread
yuanzhanhku marked this conversation as resolved.
Outdated
}

@Test
public void testApproxSet()
{
assertions.assertQuery(
"SELECT y, approx_set(y) FILTER (WHERE x = 1) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
assertions.assertQuery(
"SELECT y, approx_set(IF(x = 1,y)) FROM (SELECT NULL x, 20 y UNION ALL SELECT 1 x, NULL y) GROUP BY y ORDER BY y",
"VALUES (INTEGER '20', CAST(NULL AS HyperLogLog)), (CAST(NULL AS INTEGER), CAST(NULL AS HyperLogLog))");
}

@Test
public void testSetUnion()
{
assertions.assertQuery(
"SELECT set_union(x) FILTER (WHERE y > 1) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS ARRAY<INTEGER>))");
assertions.assertQuery(
"SELECT set_union(IF(y > 1, x)) FROM (SELECT ARRAY[1] x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (ARRAY[] AS ARRAY<INTEGER>))");
}

@Test
public void testMapUnion()
{
assertions.assertQuery(
"SELECT map_union(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
assertions.assertQuery(
"SELECT map_union(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
}

@Test
public void testMapUnionSum()
{
assertions.assertQuery(
"SELECT map_union_sum(x) FILTER (WHERE y > 1) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
assertions.assertQuery(
"SELECT map_union_sum(IF(y > 1, x)) FROM (SELECT MAP(ARRAY[1], ARRAY[1]) x, 1 y UNION ALL SELECT NULL x, 1 y)",
"VALUES (CAST (NULL AS MAP<INTEGER, INTEGER>))");
}

@Test
public void testGroupingSets()
{
Expand Down