Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.presto.sql.planner.iterative.rule;

import com.facebook.presto.Session;
import com.facebook.presto.sql.planner.DeterminismEvaluator;
import com.facebook.presto.sql.planner.TypeProvider;
import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence;
import com.facebook.presto.sql.relational.RowExpression;
import com.facebook.presto.sql.tree.CoalesceExpression;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.ExpressionRewriter;
import com.facebook.presto.sql.tree.ExpressionTreeRewriter;
import com.google.common.collect.ImmutableList;

import java.util.HashMap;
import java.util.Map;

import static java.util.Objects.requireNonNull;

public class CoalesceExpressionRewriter
{
public static Expression simplifyCoalesceExpression(Session session, ExpressionEquivalence expressionEquivalence, Expression expression, TypeProvider typeProvider)
{
return ExpressionTreeRewriter.rewriteWith(new Visitor(session, expressionEquivalence, typeProvider), expression);
}

private CoalesceExpressionRewriter() {}

private static class Visitor
extends ExpressionRewriter<Void>
{
private final ExpressionEquivalence expressionEquivalence;
private final Session session;
private final TypeProvider typeProvider;

private Visitor(Session session, ExpressionEquivalence expressionEquivalence, TypeProvider typeProvider)
{
this.session = requireNonNull(session, "Session is null");
this.expressionEquivalence = requireNonNull(expressionEquivalence, "ExpressionEquivalence is null");
this.typeProvider = requireNonNull(typeProvider, "TypeProvider is null");
}

@Override
public Expression rewriteCoalesceExpression(CoalesceExpression node, Void context, ExpressionTreeRewriter<Void> treeRewriter)
{
ImmutableList.Builder<Expression> operandsBuilder = ImmutableList.builder();
Map<RowExpression, Expression> rowExpressionMap = new HashMap<>();
for (Expression operand : node.getOperands()) {
if (DeterminismEvaluator.isDeterministic(operand)) {
RowExpression rowExpression = expressionEquivalence.toCanonicalizedRowExpression(session, operand, typeProvider);
rowExpressionMap.putIfAbsent(rowExpression, operand);
operandsBuilder.add(rowExpressionMap.get(rowExpression));
}
else {
operandsBuilder.add(operand);
}
}
return new CoalesceExpression(operandsBuilder.build());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import com.facebook.presto.sql.planner.NoOpSymbolResolver;
import com.facebook.presto.sql.planner.SymbolAllocator;
import com.facebook.presto.sql.planner.iterative.Rule;
import com.facebook.presto.sql.planner.optimizations.ExpressionEquivalence;
import com.facebook.presto.sql.tree.Expression;
import com.facebook.presto.sql.tree.NodeRef;
import com.facebook.presto.sql.tree.SymbolReference;
Expand All @@ -33,6 +34,7 @@
import java.util.Set;

import static com.facebook.presto.sql.analyzer.ExpressionAnalyzer.getExpressionTypes;
import static com.facebook.presto.sql.planner.iterative.rule.CoalesceExpressionRewriter.simplifyCoalesceExpression;
import static com.facebook.presto.sql.planner.iterative.rule.ExtractCommonPredicatesExpressionRewriter.extractCommonPredicates;
import static com.facebook.presto.sql.planner.iterative.rule.PushDownNegationsExpressionRewriter.pushDownNegations;
import static java.util.Collections.emptyList;
Expand All @@ -49,6 +51,7 @@ static Expression rewrite(Expression expression, Session session, SymbolAllocato
if (expression instanceof SymbolReference) {
return expression;
}
expression = simplifyCoalesceExpression(session, new ExpressionEquivalence(metadata, sqlParser), expression, symbolAllocator.getTypes());
expression = pushDownNegations(expression);
expression = extractCommonPredicates(expression);
Map<NodeRef<Expression>, Type> expressionTypes = getExpressionTypes(session, metadata, sqlParser, symbolAllocator.getTypes(), expression, emptyList(), WarningCollector.NOOP);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,21 @@ public boolean areExpressionsEquivalent(Session session, Expression leftExpressi
return canonicalizedLeft.equals(canonicalizedRight);
}

public RowExpression toCanonicalizedRowExpression(Session session, Expression expression, TypeProvider types)
{
Map<Symbol, Integer> symbolInput = new HashMap<>();
Map<Integer, Type> inputTypes = new HashMap<>();
int inputId = 0;
for (Entry<Symbol, Type> entry : types.allTypes().entrySet()) {
symbolInput.put(entry.getKey(), inputId);
inputTypes.put(inputId, entry.getValue());
inputId++;
}
RowExpression rowExpression = toRowExpression(session, expression, symbolInput, inputTypes);

return rowExpression.accept(CANONICALIZATION_VISITOR, null);
}

private RowExpression toRowExpression(Session session, Expression expression, Map<Symbol, Integer> symbolInput, Map<Integer, Type> inputTypes)
{
// replace qualified names with input references since row expressions do not support these
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@

import static com.facebook.presto.SessionTestUtils.TEST_SESSION;
import static com.facebook.presto.metadata.MetadataManager.createTestMetadataManager;
import static com.facebook.presto.spi.type.BigintType.BIGINT;
import static com.facebook.presto.spi.type.BooleanType.BOOLEAN;
import static com.facebook.presto.sql.ExpressionUtils.binaryExpression;
import static com.facebook.presto.sql.ExpressionUtils.extractPredicates;
Expand Down Expand Up @@ -114,6 +115,32 @@ public void testExtractCommonPredicates()
" OR (A51 AND A52) OR (A53 AND A54) OR (A55 AND A56) OR (A57 AND A58) OR (A59 AND A60)");
}

@Test
public void testSimplifyCoalesceExpression()
{
Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(unbound_long > 6, 6 < unbound_long)"));
Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(unbound_long > 6, unbound_long > 6)"));
SymbolAllocator symbolAllocator = new SymbolAllocator();
symbolAllocator.newSymbol("unbound_long", BIGINT);
Expression rewritten = rewrite(actualExpression, TEST_SESSION, symbolAllocator, METADATA, LITERAL_ENCODER, SQL_PARSER);
assertEquals(
normalize(rewritten),
normalize(expectedExpression));
}

@Test
public void testSimplifyCoalesceExpressionForNonDeterministic()
{
Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(random() > DOUBLE '0.6', DOUBLE '0.6' < random())"));
Expression expectedExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression("coalesce(random() > 6E-1, 6E-1 < random())"));
SymbolAllocator symbolAllocator = new SymbolAllocator();
symbolAllocator.newSymbol("unbound_long", BIGINT);
Expression rewritten = rewrite(actualExpression, TEST_SESSION, symbolAllocator, METADATA, LITERAL_ENCODER, SQL_PARSER);
assertEquals(
normalize(rewritten),
normalize(expectedExpression));
}

private static void assertSimplifies(String expression, String expected)
{
Expression actualExpression = rewriteIdentifiersToSymbolReferences(SQL_PARSER.createExpression(expression));
Expand Down