diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java index e686e7685fd3..b780e8b4ff9c 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteIn.java @@ -14,6 +14,7 @@ package io.trino.plugin.jdbc.expression; import com.google.common.base.Joiner; +import com.google.common.base.Predicate; import com.google.common.collect.ImmutableList; import io.trino.matching.Capture; import io.trino.matching.Captures; @@ -22,6 +23,7 @@ import io.trino.plugin.jdbc.QueryParameter; import io.trino.spi.expression.Call; import io.trino.spi.expression.ConnectorExpression; +import io.trino.spi.type.Type; import java.util.List; import java.util.Optional; @@ -40,6 +42,7 @@ import static io.trino.spi.expression.StandardFunctions.IN_PREDICATE_FUNCTION_NAME; import static io.trino.spi.type.BooleanType.BOOLEAN; import static java.lang.String.format; +import static java.util.Objects.requireNonNull; public class RewriteIn implements ConnectorExpressionRule @@ -54,6 +57,18 @@ public class RewriteIn .with(argument(0).matching(expression().capturedAs(VALUE))) .with(argument(1).matching(call().with(functionName().equalTo(ARRAY_CONSTRUCTOR_FUNCTION_NAME)).with(arguments().capturedAs(EXPRESSIONS)))); + private final Predicate typePredicate; + + public RewriteIn() + { + this(_ -> true); + } + + public RewriteIn(Predicate typePredicate) + { + this.typePredicate = requireNonNull(typePredicate, "typePredicate is null"); + } + @Override public Pattern getPattern() { @@ -73,6 +88,9 @@ public Optional rewrite(Call call, Captures captures, R // We don't want to push down too long IN query text return Optional.empty(); } + if (!expressions.stream().map(ConnectorExpression::getType).allMatch(typePredicate)) { + return Optional.empty(); + } ImmutableList.Builder parameters = ImmutableList.builder(); parameters.addAll(value.get().parameters()); diff --git a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java index e540a1764e4c..36bb1d876375 100644 --- a/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java +++ b/plugin/trino-clickhouse/src/main/java/io/trino/plugin/clickhouse/ClickHouseClient.java @@ -56,6 +56,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementSum; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.spi.TrinoException; import io.trino.spi.connector.AggregateFunction; @@ -63,6 +64,7 @@ import io.trino.spi.connector.ColumnMetadata; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableMetadata; +import io.trino.spi.expression.ConnectorExpression; import io.trino.spi.type.CharType; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; @@ -224,6 +226,11 @@ public ClickHouseClient( JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) + .map("$equal(left: varchar, right: varchar)").to("left = right") + .add(new RewriteIn(type -> type instanceof VarcharType)) + .map("$like(value: varchar, pattern: varchar): boolean").to("value LIKE pattern") + .map("$is_null(value: varchar)").to("value IS NULL") + .map("$nullif(first: varchar, second: varchar)").to("NULLIF(first, second)") .build(); this.aggregateFunctionRewriter = new AggregateFunctionRewriter<>( this.connectorExpressionRewriter, @@ -248,6 +255,12 @@ public Optional implementAggregation(ConnectorSession session, A return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); } + @Override + public Optional convertPredicate(ConnectorSession session, ConnectorExpression expression, Map assignments) + { + return connectorExpressionRewriter.rewrite(session, expression, assignments); + } + @Override public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) { diff --git a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java index dffd29bda90b..5a56963a03cd 100644 --- a/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java +++ b/plugin/trino-clickhouse/src/test/java/io/trino/plugin/clickhouse/TestClickHouseConnectorTest.java @@ -66,6 +66,8 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_AGGREGATION_PUSHDOWN_CORRELATION, SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_EQUALITY, SUPPORTS_TOPN_PUSHDOWN, + SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN, + SUPPORTS_PREDICATE_EXPRESSION_PUSHDOWN_WITH_LIKE, SUPPORTS_TRUNCATE -> true; case SUPPORTS_AGGREGATION_PUSHDOWN_REGRESSION, SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV, @@ -74,6 +76,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) SUPPORTS_DELETE, SUPPORTS_DROP_NOT_NULL_CONSTRAINT, SUPPORTS_NEGATIVE_DATE, + SUPPORTS_PREDICATE_ARITHMETIC_EXPRESSION_PUSHDOWN, SUPPORTS_ROW_TYPE, SUPPORTS_SET_COLUMN_TYPE, SUPPORTS_UPDATE -> false; @@ -903,6 +906,104 @@ public void testTextualPredicatePushdown() .isFullyPushedDown(); } + @Test + public void testOrPredicatePushdown() + { + assertThat(query("SELECT * FROM nation WHERE name = 'ALGERIA' OR comment = 'comment'")).isFullyPushedDown(); + assertThat(query("SELECT * FROM nation WHERE name IS NULL OR comment IS NULL")).isFullyPushedDown(); + } + + @Test + public void testLikePredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name LIKE '%A%'")) + .isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_like_predicate_pushdown", + "(id integer, a_varchar varchar)", + List.of( + "1, 'A'", + "2, 'a'", + "3, 'B'", + "4, 'ą'", + "5, 'Ą'"))) { + assertThat(query("SELECT id FROM " + table.getName() + " WHERE a_varchar LIKE '%A%'")) + .isFullyPushedDown(); + assertThat(query("SELECT id FROM " + table.getName() + " WHERE a_varchar LIKE '%ą%'")) + .isFullyPushedDown(); + } + } + + @Test + public void testIsNullPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL")).isFullyPushedDown(); + assertThat(query("SELECT nationkey FROM nation WHERE name IS NULL OR comment = 'comment'")).isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_is_null_predicate_pushdown", + "(a_int integer, a_varchar varchar(1), a_varchar2 varchar(1))", + List.of( + "1, 'A', ''", + "2, 'B', NULL", + "1, NULL, 'C'", + "2, NULL, 'D'"))) { + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL")).isFullyPushedDown(); + assertThat(query("SELECT a_int FROM " + table.getName() + " WHERE a_varchar IS NULL OR a_varchar2 = 'D'")).isFullyPushedDown(); + } + } + + @Test + public void testNullIfPredicatePushdown() + { + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'ALGERIA') IS NULL")) + .matches("VALUES BIGINT '0'") + .isFullyPushedDown(); + + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Algeria') IS NULL")) + .returnsEmptyResult() + .isFullyPushedDown(); + + // NULLIF returns the first argument because arguments aren't the same + assertThat(query("SELECT nationkey FROM nation WHERE NULLIF(name, 'Name not found') = name")) + .matches("SELECT nationkey FROM nation") + .isFullyPushedDown(); + } + + @Test + public void testInPredicatePushdown() + { + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_in_predicate_pushdown", + "(id varchar(1), id2 varchar(1), id3 double)", + List.of( + "'a', 'b', 1", + "'b', 'c', 2", + "'c', 'c', 3", + "'d', 'd', 4", + "'a', 'f', 5"))) { + // IN pushdowns only varchar + assertThat(query( + "SELECT id3 FROM " + table.getName() + " WHERE id3 IN (1, 2, 3, 4, 5) or id IN ('a', 'B')")) + .isNotFullyPushedDown(FilterNode.class); + + // IN values cannot be represented as a domain + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'b') OR id2 IN ('c', 'd')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B') OR id2 IN ('c', 'D')")) + .isFullyPushedDown(); + + assertThat(query("SELECT id FROM " + table.getName() + " WHERE id IN ('a', 'B', NULL) OR id2 IN ('C', 'd')")) + // NULL constant value is currently not pushed down + .isNotFullyPushedDown(FilterNode.class); + } + } + @Test @Override // Override because ClickHouse doesn't follow SQL standard syntax public void testExecuteProcedure()