-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Fix analyzer for lambda in aggregation #22539
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -40,6 +40,7 @@ | |
| import static com.facebook.presto.common.type.TypeUtils.isEnumType; | ||
| import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; | ||
| import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; | ||
| import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isConstant; | ||
| import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; | ||
| import static com.google.common.base.Preconditions.checkArgument; | ||
| import static com.google.common.base.Preconditions.checkState; | ||
|
|
@@ -119,19 +120,27 @@ public Expression rewrite(Expression expression) | |
| Expression mapped = translateNamesToSymbols(expression); | ||
|
|
||
| // then rewrite subexpressions in terms of the current mappings | ||
| return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Void>() | ||
| return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter<Boolean>() | ||
| { | ||
| @Override | ||
| public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter<Void> treeRewriter) | ||
| public Expression rewriteExpression(Expression node, Boolean context, ExpressionTreeRewriter<Boolean> treeRewriter) | ||
| { | ||
| if (expressionToVariables.containsKey(node)) { | ||
| // Do not rewrite if node is constant and within a lambda expression | ||
| if (expressionToVariables.containsKey(node) && !((context.equals(Boolean.TRUE) && isConstant(node)))) { | ||
| return new SymbolReference(expression.getLocation(), expressionToVariables.get(node).getName()); | ||
| } | ||
|
|
||
| Expression translated = expressionToExpressions.getOrDefault(node, node); | ||
| return treeRewriter.defaultRewrite(translated, context); | ||
| } | ||
| }, mapped); | ||
|
|
||
| @Override | ||
| public Expression rewriteLambdaExpression(LambdaExpression node, Boolean context, ExpressionTreeRewriter<Boolean> treeRewriter) | ||
| { | ||
| Expression result = super.rewriteLambdaExpression(node, true, treeRewriter); | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Context set to true for lambda expression |
||
| return result; | ||
| } | ||
| }, mapped, false); | ||
| } | ||
|
|
||
| public void put(Expression expression, VariableReferenceExpression variable) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -7514,4 +7514,14 @@ public void testGuardConstraintFramework() | |
| assertQuery("select orderkey from (select * from (select * from orders where 1=0)) group by rollup(orderkey)", | ||
| "values (null)"); | ||
| } | ||
|
|
||
| @Test | ||
| public void testLambdaInAggregation() | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. These two will fail without fix here |
||
| { | ||
| assertQuery("SELECT id, reduce_agg(value, 0, (a, b) -> a + b+0, (a, b) -> a + b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", "values (1, 9), (2, 90)"); | ||
| assertQuery("SELECT id, reduce_agg(value, 's', (a, b) -> concat(a, b, 's'), (a, b) -> concat(a, b, 's')) FROM ( VALUES (1, '2'), (1, '3'), (1, '4'), (2, '20'), (2, '30'), (2, '40') ) AS t(id, value) GROUP BY id", | ||
| "values (1, 's2s3s4s'), (2, 's20s30s40s')"); | ||
| assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is for the change of isConstant(Expression expression) function. Before change "array[id, value]" is considered constant and this query pass. Now it will throw exception.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @feilong-liu explained to me that the reason we want this to fail is because semantically it doesn't make sense for the initial state of the lambda function to depend on the value of a column (because which row of the column is it even talking about? It wouldn't even be consistent within a query because on any given worker it would depend on which row it happened to read first) |
||
| ".*REDUCE_AGG only supports non-NULL literal as the initial value.*"); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The elements in array constructor should also be Literals
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add a test for this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added one more test for this in abstracttestqueries
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This util is used for the reduce_agg function, and it will change the behaviour for the reduce_agg function.
For example, currently the query
SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY idwill succeed as it considers expressionarray[id, value]as constant.However, after this fix, it will fail with error
REDUCE_AGG only supports non-NULL literal as the initial valueasarray[id, value]is not considered as constant now.cc @kaikalur