diff --git a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java index 7709c593e8fb0..c1c5b6765e34c 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/analyzer/ExpressionTreeUtils.java @@ -199,10 +199,14 @@ public static boolean isConstant(Expression expression) tempExpression = ((Cast) tempExpression).getExpression(); } - if (tempExpression instanceof Literal || tempExpression instanceof ArrayConstructor) { + if (tempExpression instanceof Literal) { return true; } + if (tempExpression instanceof ArrayConstructor) { + return ((ArrayConstructor) tempExpression).getValues().stream().allMatch(ExpressionTreeUtils::isConstant); + } + // ROW an MAP are special so we explicitly do that here. if (tempExpression instanceof Row) { return (((Row) tempExpression).getItems().stream().allMatch(ExpressionTreeUtils::isConstant)); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java index 0c9bb875f66a0..e06cf4a67e31b 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/TranslationMap.java @@ -40,6 +40,7 @@ import static com.facebook.presto.common.type.TypeUtils.isEnumType; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.createSymbolReference; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.getNodeLocation; +import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.isConstant; import static com.facebook.presto.sql.analyzer.ExpressionTreeUtils.resolveEnumLiteral; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.base.Preconditions.checkState; @@ -119,19 +120,27 @@ public Expression rewrite(Expression expression) Expression mapped = translateNamesToSymbols(expression); // then rewrite subexpressions in terms of the current mappings - return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() + return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() { @Override - public Expression rewriteExpression(Expression node, Void context, ExpressionTreeRewriter treeRewriter) + public Expression rewriteExpression(Expression node, Boolean context, ExpressionTreeRewriter treeRewriter) { - if (expressionToVariables.containsKey(node)) { + // Do not rewrite if node is constant and within a lambda expression + if (expressionToVariables.containsKey(node) && !((context.equals(Boolean.TRUE) && isConstant(node)))) { return new SymbolReference(expression.getLocation(), expressionToVariables.get(node).getName()); } Expression translated = expressionToExpressions.getOrDefault(node, node); return treeRewriter.defaultRewrite(translated, context); } - }, mapped); + + @Override + public Expression rewriteLambdaExpression(LambdaExpression node, Boolean context, ExpressionTreeRewriter treeRewriter) + { + Expression result = super.rewriteLambdaExpression(node, true, treeRewriter); + return result; + } + }, mapped, false); } public void put(Expression expression, VariableReferenceExpression variable) diff --git a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java index bc2377b1e3717..f23cc31cdd349 100644 --- a/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java +++ b/presto-tests/src/main/java/com/facebook/presto/tests/AbstractTestQueries.java @@ -7514,4 +7514,14 @@ public void testGuardConstraintFramework() assertQuery("select orderkey from (select * from (select * from orders where 1=0)) group by rollup(orderkey)", "values (null)"); } + + @Test + public void testLambdaInAggregation() + { + assertQuery("SELECT id, reduce_agg(value, 0, (a, b) -> a + b+0, (a, b) -> a + b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", "values (1, 9), (2, 90)"); + assertQuery("SELECT id, reduce_agg(value, 's', (a, b) -> concat(a, b, 's'), (a, b) -> concat(a, b, 's')) FROM ( VALUES (1, '2'), (1, '3'), (1, '4'), (2, '20'), (2, '30'), (2, '40') ) AS t(id, value) GROUP BY id", + "values (1, 's2s3s4s'), (2, 's20s30s40s')"); + assertQueryFails("SELECT id, reduce_agg(value, array[id, value], (a, b) -> a || b, (a, b) -> a || b) FROM ( VALUES (1, 2), (1, 3), (1, 4), (2, 20), (2, 30), (2, 40) ) AS t(id, value) GROUP BY id", + ".*REDUCE_AGG only supports non-NULL literal as the initial value.*"); + } }