diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java index 48aa5e681d7f..ee71080d7db4 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java @@ -45,6 +45,7 @@ import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.Rule.Context; import static io.trino.sql.planner.iterative.Rule.Result; +import static io.trino.sql.planner.plan.JoinNode.Type.INNER; import static io.trino.sql.planner.plan.Patterns.filter; import static io.trino.sql.planner.plan.Patterns.join; import static io.trino.sql.planner.plan.Patterns.source; @@ -118,7 +119,21 @@ public Rule pushJoinInequalityFilterExpressionBelowJoinRule() private Result pushInequalityFilterExpressionBelowJoin(Context context, JoinNode joinNode, Optional filterNode) { JoinNodeContext joinNodeContext = new JoinNodeContext(joinNode); - Map> parentFilterCandidates = extractPushDownCandidates(joinNodeContext, filterNode.map(FilterNode::getPredicate).orElse(TRUE_LITERAL)); + + Expression parentFilterPredicate = filterNode.map(FilterNode::getPredicate).orElse(TRUE_LITERAL); + Map> parentFilterCandidates; + if (joinNode.getType() == INNER) { + parentFilterCandidates = extractPushDownCandidates(joinNodeContext, parentFilterPredicate); + } + else { + // Do not push parent filter predicate for outer joins. Pushing it below join changes + // filter semantics because such filter depends on null output from outer join side + // (otherwise outer join would be converted to inner join by predicate push down). + parentFilterCandidates = ImmutableMap.of( + true, ImmutableList.of(), + false, extractConjuncts(parentFilterPredicate)); + } + Map> joinFilterCandidates = extractPushDownCandidates(joinNodeContext, joinNode.getFilter().orElse(TRUE_LITERAL)); if (parentFilterCandidates.get(true).isEmpty() && joinFilterCandidates.get(true).isEmpty()) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java index 521bfc5c0598..13c6c2334d9a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestLogicalPlanner.java @@ -451,6 +451,39 @@ public void testJoinWithOrderBySameKey() tableScan("lineitem", ImmutableMap.of("LINEITEM_OK", "orderkey")))))); } + @Test + public void testInequalityPredicatePushdownWithOuterJoin() + { + assertPlan("" + + "SELECT o.orderkey " + + "FROM orders o LEFT JOIN lineitem l " + + "ON o.orderkey = l.orderkey AND o.custkey + 42 < l.partkey + 42 " + + "WHERE o.custkey - 24 < COALESCE(l.partkey - 24, 0)", + anyTree( + // predicate above outer join is not pushed to build side + filter( + "O_CUSTKEY - BIGINT '24' < COALESCE(L_PARTKEY - BIGINT '24', BIGINT '0')", + join( + LEFT, + ImmutableList.of(equiJoinClause("O_ORDERKEY", "L_ORDERKEY")), + // part of inequality predicate within outer join is pushed down to build side + Optional.of("O_CUSTKEY + BIGINT '42' < EXPR"), + anyTree( + tableScan( + "orders", + ImmutableMap.of( + "O_ORDERKEY", "orderkey", + "O_CUSTKEY", "custkey"))), + anyTree( + project( + ImmutableMap.of("EXPR", expression("L_PARTKEY + BIGINT '42'")), + tableScan( + "lineitem", + ImmutableMap.of( + "L_ORDERKEY", "orderkey", + "L_PARTKEY", "partkey")))))))); + } + @Test public void testTopNPushdownToJoinSource() { @@ -1655,8 +1688,8 @@ public void testRedundantHashRemovalForMarkDistinct() node(MarkDistinctNode.class, anyTree( project(ImmutableMap.of( - "hash_1", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(suppkey), 0))"), - "hash_2", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(partkey), 0))")), + "hash_1", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(suppkey), 0))"), + "hash_2", expression("combine_hash(bigint '0', coalesce(\"$operator$hash_code\"(partkey), 0))")), node(MarkDistinctNode.class, tableScan("lineitem", ImmutableMap.of("suppkey", "suppkey", "partkey", "partkey")))))))))); } diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java index de5eaba2ecfa..ab7703176c0c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java @@ -225,4 +225,24 @@ public void testOutputDuplicatesInsensitiveJoin() values()) .with(JoinNode.class, JoinNode::isMaySkipOutputDuplicates))))); } + + @Test + public void testPredicateOverOuterJoin() + { + assertThat(assertions.query( + "SELECT 5 " + + "FROM (VALUES (1,'foo')) l(l1, l2) " + + "LEFT JOIN (VALUES (2,'bar')) r(r1, r2) " + + "ON l2 = r2 " + + "WHERE l1 >= COALESCE(r1, 0)")) + .matches("VALUES 5"); + + assertThat(assertions.query( + "SELECT 5 " + + "FROM (VALUES (2,'foo')) l(l1, l2) " + + "RIGHT JOIN (VALUES (1,'bar')) r(r1, r2) " + + "ON l2 = r2 " + + "WHERE r1 >= COALESCE(l1, 0)")) + .matches("VALUES 5"); + } }