diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java index 48aa5e681d7f..52bf25da4bf3 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/iterative/rule/PushInequalityFilterExpressionBelowJoinRuleSet.java @@ -45,6 +45,7 @@ import static io.trino.sql.planner.SymbolsExtractor.extractUnique; import static io.trino.sql.planner.iterative.Rule.Context; import static io.trino.sql.planner.iterative.Rule.Result; +import static io.trino.sql.planner.plan.Patterns.Join.type; import static io.trino.sql.planner.plan.Patterns.filter; import static io.trino.sql.planner.plan.Patterns.join; import static io.trino.sql.planner.plan.Patterns.source; @@ -84,10 +85,10 @@ public class PushInequalityFilterExpressionBelowJoinRuleSet GREATER_THAN_OR_EQUAL, LESS_THAN, LESS_THAN_OR_EQUAL); - private static final Pattern JOIN_PATTERN = join(); + private static final Pattern JOIN_PATTERN = join().with(type().equalTo(JoinNode.Type.INNER)); private static final Capture JOIN_CAPTURE = newCapture(); private static final Pattern FILTER_PATTERN = filter().with(source().matching( - join().capturedAs(JOIN_CAPTURE))); + join().with(type().equalTo(JoinNode.Type.INNER)).capturedAs(JOIN_CAPTURE))); private final Metadata metadata; private final TypeAnalyzer typeAnalyzer; diff --git a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java index de5eaba2ecfa..ab7703176c0c 100644 --- a/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java +++ b/core/trino-main/src/test/java/io/trino/sql/query/TestJoin.java @@ -225,4 +225,24 @@ public void testOutputDuplicatesInsensitiveJoin() values()) .with(JoinNode.class, JoinNode::isMaySkipOutputDuplicates))))); } + + @Test + public void testPredicateOverOuterJoin() + { + assertThat(assertions.query( + "SELECT 5 " + + "FROM (VALUES (1,'foo')) l(l1, l2) " + + "LEFT JOIN (VALUES (2,'bar')) r(r1, r2) " + + "ON l2 = r2 " + + "WHERE l1 >= COALESCE(r1, 0)")) + .matches("VALUES 5"); + + assertThat(assertions.query( + "SELECT 5 " + + "FROM (VALUES (2,'foo')) l(l1, l2) " + + "RIGHT JOIN (VALUES (1,'bar')) r(r1, r2) " + + "ON l2 = r2 " + + "WHERE r1 >= COALESCE(l1, 0)")) + .matches("VALUES 5"); + } }