Skip to content

Rewrite AGG(IF()) to AGG() FILTER()#16534

Merged
highker merged 4 commits intoprestodb:masterfrom
yuanzhanhku:master
Aug 3, 2021
Merged

Rewrite AGG(IF()) to AGG() FILTER()#16534
highker merged 4 commits intoprestodb:masterfrom
yuanzhanhku:master

Conversation

@yuanzhanhku
Copy link
Contributor

@yuanzhanhku yuanzhanhku commented Jul 29, 2021

Add an rule to rewrite

  • AGG(IF(condition, expr))
    to
  • AGG(expr) FILTER (WHERE condition).

The latter plan is more efficient because:

  • the filter can be pushed down to the scan node
  • the rows not matching the condition are not aggregated
  • the IF() expression wrapper is removed.

Test plan
Added unit tests for the rule in TestRewriteAggIfToAggFilter.java.
Added query plan tests in TestFilteredAggregations.java and covered all existing test cases for AGG() FILTER.

== RELEASE NOTES ==

General Changes
* Introduce a new config ``optimizer.aggregation-if-to-filter-rewrite-enabled`` and its corresponding session property ``aggregation_if_to_filter_rewrite_enabled`` to enable or disable an optimizer rule to improve the query performance of ``IF`` expressions inside aggregation functions.

@yuanzhanhku yuanzhanhku requested a review from highker July 29, 2021 01:40
@yuanzhanhku yuanzhanhku force-pushed the master branch 2 times, most recently from 03ac1d7 to b2a2c2a Compare July 29, 2021 01:50
@yuanzhanhku yuanzhanhku linked an issue Jul 29, 2021 that may be closed by this pull request
@highker highker requested review from kaikalur and shixuan-fan July 29, 2021 17:09
The filter for the aggregation with mask was incorrect.
Copy link

@highker highker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

still reviewing; but quick nit: let's change all "agg" into "aggregation" in this PR...

Copy link

@highker highker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

more comments on logic

@yuanzhanhku
Copy link
Contributor Author

The order of the aggregations in the Aggregation node seems not deterministic, so I made a few changes to generate the new expressions in the order of the VariableReferenceExpression names.

@highker
Copy link

highker commented Aug 2, 2021

Quick nit: could you add the "release note" section to this github page as well? Check existing merged PRs with release note as examples. This should be "general changes" with something like "introduce a new config and session property to blah blah blah". Note that a release note should be extremely user-facing, mean that we could use sentences like "optimize if expressions in aggregation functions to improve performance".

@highker highker self-assigned this Aug 2, 2021
Add an rule to rewrite
 - AGG(IF(condition, expr))
to
 - AGG(expr) FILTER (WHERE condition).

The latter plan is more efficient because
 - the filter can be pushed down to the scan node
 - the rows not matching the condition are not aggregated
 - the IF() expression wrapper is removed.
The rule rewriting AGG(IF()) to AGG() FILTER is enabled by default.
To disable the rule,
 SET SESSION agg_if_to_filter_rewrite_enabled=false;
or set the config optimizer.aggregation-if-to-filter-rewrite-enabled to
false.
@yuanzhanhku yuanzhanhku requested a review from highker August 2, 2021 16:35
@highker highker merged commit 39cc942 into prestodb:master Aug 3, 2021
Copy link
Contributor

@kaikalur kaikalur left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So SET_AGG works differently for these two cases:


presto:di> select set_agg(if(x=1,y)) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
 _col0
--------
 [2]
 [null]
 [null]
(3 rows)

Query 20210805_003013_03168_fbgd3, FINISHED, 195 nodes
Splits: 3,139 total, 3,139 done (100.00%)
0:02 [0 rows, 0B] [0 rows/s, 0B/s]

presto:di> select set_agg( y) filter (where x=1) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
 _col0
--------
 NULL
 [null]
 [2]

So you may want to see if there are aggregation properties needed for this to work properly.

}
SpecialFormExpression expression = (SpecialFormExpression) sourceExpression;
// Only rewrite the aggregation if the else branch is not present.
return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You could also add when the else part is the null literal like IF(x, y, null)

@yuanzhanhku
Copy link
Contributor Author

So SET_AGG works differently for these two cases:


presto:di> select set_agg(if(x=1,y)) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
 _col0
--------
 [2]
 [null]
 [null]
(3 rows)

Query 20210805_003013_03168_fbgd3, FINISHED, 195 nodes
Splits: 3,139 total, 3,139 done (100.00%)
0:02 [0 rows, 0B] [0 rows/s, 0B/s]

presto:di> select set_agg( y) filter (where x=1) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
 _col0
--------
 NULL
 [null]
 [2]

So you may want to see if there are aggregation properties needed for this to work properly.

Thanks for catching this. I will make a fix and add the tests.

@kaikalur
Copy link
Contributor

kaikalur commented Aug 5, 2021

In fact, I think we should do this only for numerical aggs (or an allowlist to start with SUM/COUNT/MIN/MAX)

@yuanzhanhku
Copy link
Contributor Author

yuanzhanhku commented Aug 5, 2021

In fact, I think we should do this only for numerical aggs (or an allowlist to start with SUM/COUNT/MIN/MAX)

Right. I am planning to only do this rewrite for numerical aggs. In addition to SUM/COUNT/MIN/MAX, I see there are many use cases for approx_distinct/ variance/approx_percentile as well. It would be good to enable the rewrite for these too.

@kaikalur
Copy link
Contributor

kaikalur commented Aug 5, 2021

Actually, there is more:

  • Make sure the NULL behavior of the agg is "not called on null"
  • Make sure the condition, then part of the if expression are both deterministic

Also, when you are at it, keep applying the simplification logic iteratively so things like IF(p1, IF(p2, x)) are also handled (I have seen those in tool generated code).

Use the RowExpressionInterpreter to simplify the IF expression before checking if the else part is null. It does some more const expr eval so good to be comprehensive.

@kaikalur
Copy link
Contributor

kaikalur commented Aug 5, 2021

Just to be clear, if you do those things, there should be no need to special csase.

RowExpression predicate = TRUE_CONSTANT;
if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) {
// All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient.
predicate = or(masks.build());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This or might actually cause more slowdown than help because all masks have to be evaluated for every row anyway. I don't think this helps.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it helps because

  • The filter could be pushed down to the scan node which might be able to evaluated very efficiently, e.g., if it is on some partition columns, the partition columns are using dictionary encoding. This only need to evaluated once per dictionary item.
  • The filter can be used to prune the partitions/splits based on column stats.

Besides, the AGG() FILTER implementation also adds this predicate. It is better to keep the same behavior:

predicate = combineDisjunctsWithDefault(maskSymbols.build(), TRUE_LITERAL);

@yuanzhanhku
Copy link
Contributor Author

yuanzhanhku commented Aug 5, 2021

Actually, there is more:

  • Make sure the NULL behavior of the agg is "not called on null"

Great point on checking isCalledOnNullInput(). It is indeed the root cause of the behavior change as this rewrite filters out the NULL values.

  • Make sure the condition, then part of the if expression are both deterministic

Could you please elaborate why non-deterministic functions matter in this case? Do you have some examples when this could cause issues?

Also, when you are at it, keep applying the simplification logic iteratively so things like IF(p1, IF(p2, x)) are also handled (I have seen those in tool generated code).

Yes, this is a nice additional optimization we can do. Basically, we can rewrite AGG(IF(p1, IF(p2, x))) to AGG(x) FILTER(WHERE p1 AND p2).
Actually, it might be better to have a separate rewrite to inline the IF expressions in this case. i.e., rewrite IF(p1, IF(p2, x)) to IF(p1 AND p2, x). This rewrite can be applied for non-agg functions as well. It might improve performance because it simplifies the function?

Use the RowExpressionInterpreter to simplify the IF expression before checking if the else part is null. It does some more const expr eval so good to be comprehensive.

The SimplifyRowExpressions rules simplifies all expressions. Might be better to avoid doing the duplicate optimization here?

@kaikalur
Copy link
Contributor

kaikalur commented Aug 5, 2021

Actually, there is more:

  • Make sure the NULL behavior of the agg is "not called on null"

Great point on checking isCalledOnNullInput(). It is indeed the root cause of the behavior change as this rewrite filters out the NULL values.

  • Make sure the condition, then part of the if expression are both deterministic

Could you please elaborate why non-deterministic functions matter in this case? Do you have some examples when this could cause issues?

It's one of those things for example random() is called once per row I think but if you put in a then part, that might work differently. Not sure. So let's be conservative (for corner cases).

Also, when you are at it, keep applying the simplification logic iteratively so things like IF(p1, IF(p2, x)) are also handled (I have seen those in tool generated code).

Yes, this is a nice additional optimization we can do. Basically, we can rewrite AGG(IF(p1, IF(p2, x))) to AGG(x) FILTER(WHERE p1 AND p2).
Actually, it might be better to have a separate rewrite to inline the IF expressions in this case. i.e., rewrite IF(p1, IF(p2, x)) to IF(p1 AND p2, x). This rewrite can be applied for non-agg functions as well. It might improve performance because it simplifies the function?

I thought about it. Sure you can do that in a separate PR

Use the RowExpressionInterpreter to simplify the IF expression before checking if the else part is null. It does some more const expr eval so good to be comprehensive.

The SimplifyRowExpressions rules simplifies all expressions. Might be better to avoid doing the duplicate optimization here?

That's fine. It's good to make sure it's simpified to get max benefit.

@yuanzhanhku
Copy link
Contributor Author

Thanks Sreeni for all the suggestions. Made the changes in #16566.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add a rule to rewrite AGG(IF()) to AGG() WITH FILTER

3 participants