Rewrite AGG(IF()) to AGG() FILTER()#16534
Conversation
03ac1d7 to
b2a2c2a
Compare
The filter for the aggregation with mask was incorrect.
highker
left a comment
There was a problem hiding this comment.
still reviewing; but quick nit: let's change all "agg" into "aggregation" in this PR...
presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
Outdated
Show resolved
Hide resolved
presto-main/src/main/java/com/facebook/presto/SystemSessionProperties.java
Outdated
Show resolved
Hide resolved
presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
Outdated
Show resolved
Hide resolved
presto-main/src/main/java/com/facebook/presto/sql/analyzer/FeaturesConfig.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
...in/src/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggIfToAggFilter.java
Outdated
Show resolved
Hide resolved
|
The order of the aggregations in the Aggregation node seems not deterministic, so I made a few changes to generate the new expressions in the order of the VariableReferenceExpression names. |
|
Quick nit: could you add the "release note" section to this github page as well? Check existing merged PRs with release note as examples. This should be "general changes" with something like "introduce a new config and session property to blah blah blah". Note that a release note should be extremely user-facing, mean that we could use sentences like "optimize if expressions in aggregation functions to improve performance". |
presto-expressions/src/main/java/com/facebook/presto/expressions/LogicalRowExpressions.java
Outdated
Show resolved
Hide resolved
presto-main/src/test/java/com/facebook/presto/sql/query/TestFilteredAggregations.java
Outdated
Show resolved
Hide resolved
...c/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java
Outdated
Show resolved
Hide resolved
...c/main/java/com/facebook/presto/sql/planner/iterative/rule/RewriteAggregationIfToFilter.java
Outdated
Show resolved
Hide resolved
Add an rule to rewrite - AGG(IF(condition, expr)) to - AGG(expr) FILTER (WHERE condition). The latter plan is more efficient because - the filter can be pushed down to the scan node - the rows not matching the condition are not aggregated - the IF() expression wrapper is removed.
The rule rewriting AGG(IF()) to AGG() FILTER is enabled by default. To disable the rule, SET SESSION agg_if_to_filter_rewrite_enabled=false; or set the config optimizer.aggregation-if-to-filter-rewrite-enabled to false.
kaikalur
left a comment
There was a problem hiding this comment.
So SET_AGG works differently for these two cases:
presto:di> select set_agg(if(x=1,y)) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
_col0
--------
[2]
[null]
[null]
(3 rows)
Query 20210805_003013_03168_fbgd3, FINISHED, 195 nodes
Splits: 3,139 total, 3,139 done (100.00%)
0:02 [0 rows, 0B] [0 rows/s, 0B/s]
presto:di> select set_agg( y) filter (where x=1) from (select 1 x, 2 y union all select null x, 20 y union all select 1 x, null y) group by y;
_col0
--------
NULL
[null]
[2]
So you may want to see if there are aggregation properties needed for this to work properly.
| } | ||
| SpecialFormExpression expression = (SpecialFormExpression) sourceExpression; | ||
| // Only rewrite the aggregation if the else branch is not present. | ||
| return expression.getForm() == IF && Expressions.isNull(expression.getArguments().get(2)); |
There was a problem hiding this comment.
You could also add when the else part is the null literal like IF(x, y, null)
Thanks for catching this. I will make a fix and add the tests. |
|
In fact, I think we should do this only for numerical aggs (or an allowlist to start with SUM/COUNT/MIN/MAX) |
Right. I am planning to only do this rewrite for numerical aggs. In addition to SUM/COUNT/MIN/MAX, I see there are many use cases for approx_distinct/ variance/approx_percentile as well. It would be good to enable the rewrite for these too. |
|
Actually, there is more:
Also, when you are at it, keep applying the simplification logic iteratively so things like IF(p1, IF(p2, x)) are also handled (I have seen those in tool generated code). Use the RowExpressionInterpreter to simplify the IF expression before checking if the else part is null. It does some more const expr eval so good to be comprehensive. |
|
Just to be clear, if you do those things, there should be no need to special csase. |
| RowExpression predicate = TRUE_CONSTANT; | ||
| if (!aggregationNode.hasNonEmptyGroupingSet() && aggregationsToRewrite.size() == aggregationNode.getAggregations().size()) { | ||
| // All aggregations are rewritten by this rule. We can add a filter with all the masks to make the query more efficient. | ||
| predicate = or(masks.build()); |
There was a problem hiding this comment.
This or might actually cause more slowdown than help because all masks have to be evaluated for every row anyway. I don't think this helps.
There was a problem hiding this comment.
I think it helps because
- The filter could be pushed down to the scan node which might be able to evaluated very efficiently, e.g., if it is on some partition columns, the partition columns are using dictionary encoding. This only need to evaluated once per dictionary item.
- The filter can be used to prune the partitions/splits based on column stats.
Besides, the AGG() FILTER implementation also adds this predicate. It is better to keep the same behavior:
Great point on checking isCalledOnNullInput(). It is indeed the root cause of the behavior change as this rewrite filters out the NULL values.
Could you please elaborate why non-deterministic functions matter in this case? Do you have some examples when this could cause issues?
Yes, this is a nice additional optimization we can do. Basically, we can rewrite AGG(IF(p1, IF(p2, x))) to AGG(x) FILTER(WHERE p1 AND p2).
The SimplifyRowExpressions rules simplifies all expressions. Might be better to avoid doing the duplicate optimization here? |
It's one of those things for example random() is called once per row I think but if you put in a then part, that might work differently. Not sure. So let's be conservative (for corner cases).
I thought about it. Sure you can do that in a separate PR
That's fine. It's good to make sure it's simpified to get max benefit. |
|
Thanks Sreeni for all the suggestions. Made the changes in #16566. |
Add an rule to rewrite
to
The latter plan is more efficient because:
Test plan
Added unit tests for the rule in TestRewriteAggIfToAggFilter.java.
Added query plan tests in TestFilteredAggregations.java and covered all existing test cases for AGG() FILTER.