diff --git a/presto-main/src/main/java/io/prestosql/sql/ExpressionUtils.java b/presto-main/src/main/java/io/prestosql/sql/ExpressionUtils.java index 9fe7249993c8..1f82800f5406 100644 --- a/presto-main/src/main/java/io/prestosql/sql/ExpressionUtils.java +++ b/presto-main/src/main/java/io/prestosql/sql/ExpressionUtils.java @@ -19,7 +19,6 @@ import io.prestosql.sql.planner.DeterminismEvaluator; import io.prestosql.sql.planner.Symbol; import io.prestosql.sql.planner.SymbolsExtractor; -import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.ExpressionRewriter; import io.prestosql.sql.tree.ExpressionTreeRewriter; @@ -28,7 +27,6 @@ import io.prestosql.sql.tree.LambdaExpression; import io.prestosql.sql.tree.LogicalBinaryExpression; import io.prestosql.sql.tree.LogicalBinaryExpression.Operator; -import io.prestosql.sql.tree.NotExpression; import io.prestosql.sql.tree.SymbolReference; import java.util.ArrayDeque; @@ -45,7 +43,6 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static io.prestosql.sql.tree.BooleanLiteral.FALSE_LITERAL; import static io.prestosql.sql.tree.BooleanLiteral.TRUE_LITERAL; -import static io.prestosql.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; import static java.util.Objects.requireNonNull; import static java.util.stream.Collectors.toList; @@ -330,21 +327,6 @@ else if (!seen.contains(expression)) { return result.build(); } - public static Expression normalize(Expression expression) - { - if (expression instanceof NotExpression) { - NotExpression not = (NotExpression) expression; - if (not.getValue() instanceof ComparisonExpression && ((ComparisonExpression) not.getValue()).getOperator() != IS_DISTINCT_FROM) { - ComparisonExpression comparison = (ComparisonExpression) not.getValue(); - return new ComparisonExpression(comparison.getOperator().negate(), comparison.getLeft(), comparison.getRight()); - } - if (not.getValue() instanceof NotExpression) { - return normalize(((NotExpression) not.getValue()).getValue()); - } - } - return expression; - } - public static Expression rewriteIdentifiersToSymbolReferences(Expression expression) { return ExpressionTreeRewriter.rewriteWith(new ExpressionRewriter() diff --git a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java index 1fd08e6cf974..15a0acf10754 100644 --- a/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java +++ b/presto-main/src/main/java/io/prestosql/sql/planner/RelationPlanner.java @@ -333,8 +333,6 @@ protected RelationPlan visitJoin(Join node, Void context) List joinConditionComparisonOperators = new ArrayList<>(); for (Expression conjunct : ExpressionUtils.extractConjuncts(criteria)) { - conjunct = ExpressionUtils.normalize(conjunct); - if (!isEqualComparisonExpression(conjunct) && node.getType() != INNER) { complexJoinExpressions.add(conjunct); continue; diff --git a/presto-main/src/test/java/io/prestosql/sql/TestExpressionUtils.java b/presto-main/src/test/java/io/prestosql/sql/TestExpressionUtils.java index b5b103a53104..8a6b48d220a8 100644 --- a/presto-main/src/test/java/io/prestosql/sql/TestExpressionUtils.java +++ b/presto-main/src/test/java/io/prestosql/sql/TestExpressionUtils.java @@ -14,24 +14,12 @@ package io.prestosql.sql; import io.prestosql.metadata.Metadata; -import io.prestosql.sql.tree.ComparisonExpression; import io.prestosql.sql.tree.Expression; import io.prestosql.sql.tree.Identifier; -import io.prestosql.sql.tree.IsNullPredicate; -import io.prestosql.sql.tree.LikePredicate; import io.prestosql.sql.tree.LogicalBinaryExpression; -import io.prestosql.sql.tree.LongLiteral; -import io.prestosql.sql.tree.NotExpression; -import io.prestosql.sql.tree.StringLiteral; import org.testng.annotations.Test; -import java.util.Optional; - import static io.prestosql.metadata.MetadataManager.createTestMetadataManager; -import static io.prestosql.sql.ExpressionUtils.normalize; -import static io.prestosql.sql.tree.ComparisonExpression.Operator.EQUAL; -import static io.prestosql.sql.tree.ComparisonExpression.Operator.IS_DISTINCT_FROM; -import static io.prestosql.sql.tree.ComparisonExpression.Operator.NOT_EQUAL; import static org.testng.Assert.assertEquals; public class TestExpressionUtils @@ -56,32 +44,6 @@ public void testAnd() and(and(and(a, b), and(c, d)), e)); } - @Test - public void testNormalize() - { - assertNormalize(new ComparisonExpression(EQUAL, name("a"), new LongLiteral("1"))); - assertNormalize(new IsNullPredicate(name("a"))); - assertNormalize(new NotExpression(new LikePredicate(name("a"), new StringLiteral("x%"), Optional.empty()))); - assertNormalize( - new NotExpression(new ComparisonExpression(EQUAL, name("a"), new LongLiteral("1"))), - new ComparisonExpression(NOT_EQUAL, name("a"), new LongLiteral("1"))); - assertNormalize( - new NotExpression(new ComparisonExpression(NOT_EQUAL, name("a"), new LongLiteral("1"))), - new ComparisonExpression(EQUAL, name("a"), new LongLiteral("1"))); - // Cannot normalize IS DISTINCT FROM yet - assertNormalize(new NotExpression(new ComparisonExpression(IS_DISTINCT_FROM, name("a"), new LongLiteral("1")))); - } - - private static void assertNormalize(Expression expression) - { - assertNormalize(expression, expression); - } - - private static void assertNormalize(Expression expression, Expression normalized) - { - assertEquals(normalize(expression), normalized); - } - private static Identifier name(String name) { return new Identifier(name); diff --git a/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java index 606c8d5b5857..d1c8ce7ce8e8 100644 --- a/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java +++ b/presto-main/src/test/java/io/prestosql/sql/query/TestJoin.java @@ -52,4 +52,13 @@ public void testCrossJoinEliminationWithOuterJoin() "JOIN d ON d.id = a.id")) .matches("VALUES 1"); } + + @Test + public void testJoinOnNan() + { + assertThat(assertions.query( + "WITH t(x) AS (VALUES if(rand() > 0, nan())) " + // TODO: remove if(rand() > 0, ...) once https://github.com/prestosql/presto/issues/4119 is fixed + "SELECT * FROM t t1 JOIN t t2 ON NOT t1.x < t2.x")) + .matches("VALUES (nan(), nan())"); + } }