diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java index d53d3a135aea..b24d6d56366d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/ConnectorExpressionTranslator.java @@ -39,6 +39,7 @@ import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; import io.trino.sql.tree.AstVisitor; +import io.trino.sql.tree.BetweenPredicate; import io.trino.sql.tree.BinaryLiteral; import io.trino.sql.tree.BooleanLiteral; import io.trino.sql.tree.Cast; @@ -534,6 +535,23 @@ protected Optional visitArithmeticBinary(ArithmeticBinaryEx new Call(typeOf(node), functionNameForArithmeticBinaryOperator(node.getOperator()), ImmutableList.of(left, right)))); } + @Override + protected Optional visitBetweenPredicate(BetweenPredicate node, Void context) + { + if (!isComplexExpressionPushdown(session)) { + return Optional.empty(); + } + return process(node.getValue()).flatMap(value -> + process(node.getMin()).flatMap(min -> + process(node.getMax()).map(max -> + new Call( + BOOLEAN, + AND_FUNCTION_NAME, + ImmutableList.of( + new Call(BOOLEAN, GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, min)), + new Call(BOOLEAN, LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, ImmutableList.of(value, max))))))); + } + @Override protected Optional visitArithmeticUnary(ArithmeticUnaryExpression node, Void context) { diff --git a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java index e839c3aba3c4..8d65ed12d75a 100644 --- a/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java +++ b/core/trino-main/src/test/java/io/trino/sql/planner/TestConnectorExpressionTranslator.java @@ -27,7 +27,9 @@ import io.trino.spi.type.Type; import io.trino.sql.tree.ArithmeticBinaryExpression; import io.trino.sql.tree.ArithmeticUnaryExpression; +import io.trino.sql.tree.BetweenPredicate; import io.trino.sql.tree.ComparisonExpression; +import io.trino.sql.tree.DoubleLiteral; import io.trino.sql.tree.Expression; import io.trino.sql.tree.IsNotNullPredicate; import io.trino.sql.tree.IsNullPredicate; @@ -52,7 +54,10 @@ import static com.google.common.collect.ImmutableMap.toImmutableMap; import static io.airlift.slice.Slices.utf8Slice; +import static io.trino.spi.expression.StandardFunctions.AND_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.IS_NULL_FUNCTION_NAME; +import static io.trino.spi.expression.StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.LIKE_PATTERN_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NEGATE_FUNCTION_NAME; import static io.trino.spi.expression.StandardFunctions.NOT_FUNCTION_NAME; @@ -223,6 +228,33 @@ public void testTranslateArithmeticUnaryPlus() new Variable("double_symbol_1", DOUBLE)); } + @Test + public void testTranslateBetween() + { + assertTranslationToConnectorExpression( + TEST_SESSION, + new BetweenPredicate( + new SymbolReference("double_symbol_1"), + new DoubleLiteral("1.2"), + new SymbolReference("double_symbol_2")), + new Call( + BOOLEAN, + AND_FUNCTION_NAME, + List.of( + new Call( + BOOLEAN, + GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, + List.of( + new Variable("double_symbol_1", DOUBLE), + new Constant(1.2d, DOUBLE))), + new Call( + BOOLEAN, + LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, + List.of( + new Variable("double_symbol_1", DOUBLE), + new Variable("double_symbol_2", DOUBLE)))))); + } + @Test public void testTranslateLike() {