diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CoalesceExpressionRewriter.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CoalesceExpressionRewriter.java new file mode 100644 index 0000000000000..2ec9a59be95b7 --- /dev/null +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/CoalesceExpressionRewriter.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.sql.planner.iterative.rule; + +import com.facebook.presto.Session; +import com.facebook.presto.sql.planner.DeterminismEvaluator; +import com.facebook.presto.sql.planner.TypeProvider; +import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence; +import com.facebook.presto.sql.relational.RowExpression; +import com.facebook.presto.sql.tree.CoalesceExpression; +import com.facebook.presto.sql.tree.Expression; +import com.facebook.presto.sql.tree.ExpressionRewriter; +import com.facebook.presto.sql.tree.ExpressionTreeRewriter; +import com.google.common.collect.ImmutableList; + +import java.util.HashMap; +import java.util.Map; + +import static java.util.Objects.requireNonNull; + +public class CoalesceExpressionRewriter +{ + public static Expression simplifyCoalesceExpression(Session session, ExpressionEquivalence expressionEquivalence, Expression expression, TypeProvider typeProvider) + { + return ExpressionTreeRewriter.rewriteWith(new Visitor(session, expressionEquivalence, typeProvider), expression); + } + + private CoalesceExpressionRewriter() {} + + private static class Visitor + extends ExpressionRewriter + { + private final ExpressionEquivalence expressionEquivalence; + private final Session session; + private final TypeProvider typeProvider; + + private Visitor(Session session, ExpressionEquivalence expressionEquivalence, TypeProvider typeProvider) + { + this.session = requireNonNull(session, "Session is null"); + this.expressionEquivalence = requireNonNull(expressionEquivalence, "ExpressionEquivalence is null"); + this.typeProvider = requireNonNull(typeProvider, "TypeProvider is null"); + } + + @Override + public Expression rewriteCoalesceExpression(CoalesceExpression node, Void context, ExpressionTreeRewriter treeRewriter) + { + ImmutableList.Builder operandsBuilder = ImmutableList.builder(); + Map rowExpressionMap = new HashMap<>(); + for (Expression operand : node.getOperands()) { + if (DeterminismEvaluator.isDeterministic(operand)) { + RowExpression rowExpression = expressionEquivalence.toCanonicalizedRowExpression(session, operand, typeProvider); + rowExpressionMap.putIfAbsent(rowExpression, operand); + operandsBuilder.add(rowExpressionMap.get(rowExpression)); + } + else { + operandsBuilder.add(operand); + } + } + return new CoalesceExpression(operandsBuilder.build()); + } + } +} diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyExpressions.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyExpressions.java index a70b3a35dd4b3..dde895145cefb 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyExpressions.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/iterative/rule/SimplifyExpressions.java @@ -23,6 +23,7 @@ import com.facebook.presto.sql.planner.NoOpSymbolResolver; import com.facebook.presto.sql.planner.SymbolAllocator; import com.facebook.presto.sql.planner.iterative.Rule; +import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence; import com.facebook.presto.sql.tree.Expression; import com.facebook.presto.sql.tree.NodeRef; import com.facebook.presto.sql.tree.SymbolReference; @@ -33,6 +34,7 @@ import java.util.Set; import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes; +import static com.facebook.presto.sql.planner.iterative.rule.CoalesceExpressionRewriter.simplifyCoalesceExpression; import static com.facebook.presto.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates; import static com.facebook.presto.sql.planner.iterative.rule.PushDownNegationsExpressionRewriter.pushDownNegations; import static java.util.Collections.emptyList; @@ -49,6 +51,7 @@ static Expression rewrite(Expression expression, Session session, SymbolAllocato if (expression instanceof SymbolReference) { return expression; } + expression = simplifyCoalesceExpression(session, new ExpressionEquivalence(metadata, sqlParser), expression, symbolAllocator.getTypes()); expression = pushDownNegations(expression); expression = extractCommonPredicates(expression); Map, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP); diff --git a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java index 0e67c84a4a7f1..7747810964896 100644 --- a/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java +++ b/presto-main/src/main/java/com/facebook/presto/sql/planner/optimizations/ExpressionEquivalence.java @@ -96,6 +96,21 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi return canonicalizedLeft.equals(canonicalizedRight); } + public RowExpression toCanonicalizedRowExpression(Session session, Expression expression, TypeProvider types) + { + Map symbolInput = new HashMap<>(); + Map inputTypes = new HashMap<>(); + int inputId = 0; + for (Entry entry : types.allTypes().entrySet()) { + symbolInput.put(entry.getKey(), inputId); + inputTypes.put(inputId, entry.getValue()); + inputId++; + } + RowExpression rowExpression = toRowExpression(session, expression, symbolInput, inputTypes); + + return rowExpression.accept(CANONICALIZATION_VISITOR, null); + } + private RowExpression toRowExpression(Session session, Expression expression, Map symbolInput, Map inputTypes) { // replace qualified names with input references since row expressions do not support these diff --git a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java index cca5a51a7af4f..b88c6161315d4 100644 --- a/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java +++ b/presto-main/src/test/java/com/facebook/presto/sql/planner/iterative/rule/TestSimplifyExpressions.java @@ -33,6 +33,7 @@ import static com.facebook.presto.SessionTestUtils.TEST_SESSION; import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager; +import static com.facebook.presto.spi.type.BigintType.BIGINT; import static com.facebook.presto.spi.type.BooleanType.BOOLEAN; import static com.facebook.presto.sql.ExpressionUtils.binaryExpression; import static com.facebook.presto.sql.ExpressionUtils.extractPredicates; @@ -114,6 +115,32 @@ public void testExtractCommonPredicates() " OR (A51 AND A52) OR (A53 AND A54) OR (A55 AND A56) OR (A57 AND A58) OR (A59 AND A60)"); } + @Test + public void testSimplifyCoalesceExpression() + { + Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(unbound_long > 6, 6 < unbound_long)")); + Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(unbound_long > 6, unbound_long > 6)")); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + symbolAllocator.newSymbol("unbound_long", BIGINT); + Expression rewritten = rewrite(actualExpression, TEST_SESSION, symbolAllocator, METADATA, LITERAL_ENCODER, SQL_PARSER); + assertEquals( + normalize(rewritten), + normalize(expectedExpression)); + } + + @Test + public void testSimplifyCoalesceExpressionForNonDeterministic() + { + Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(random() > DOUBLE '0.6', DOUBLE '0.6' < random())")); + Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(random() > 6E-1, 6E-1 < random())")); + SymbolAllocator symbolAllocator = new SymbolAllocator(); + symbolAllocator.newSymbol("unbound_long", BIGINT); + Expression rewritten = rewrite(actualExpression, TEST_SESSION, symbolAllocator, METADATA, LITERAL_ENCODER, SQL_PARSER); + assertEquals( + normalize(rewritten), + normalize(expectedExpression)); + } + private static void assertSimplifies(String expression, String expected) { Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression));