From 0142908ab7f6f80478fe605a3f2e60139f6fc9fd Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Thu, 17 Mar 2022 15:07:11 +0900 Subject: [PATCH] Implement COALESCE pushdown in PostgreSQL connector --- .../jdbc/expression/RewriteCoalesce.java | 74 +++++++++++++++++++ .../plugin/postgresql/PostgreSqlClient.java | 3 +- .../postgresql/TestPostgreSqlClient.java | 35 +++++++++ .../TestPostgreSqlConnectorTest.java | 28 +++++++ 4 files changed, 139 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCoalesce.java diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCoalesce.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCoalesce.java new file mode 100644 index 000000000000..8508a2d9ccc5 --- /dev/null +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/expression/RewriteCoalesce.java @@ -0,0 +1,74 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.jdbc.expression; + +import com.google.common.collect.ImmutableList; +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.expression.ConnectorExpressionRule; +import io.trino.plugin.jdbc.QueryParameter; +import io.trino.spi.expression.Call; +import io.trino.spi.expression.ConnectorExpression; + +import java.util.List; +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.call; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.functionName; +import static io.trino.spi.expression.StandardFunctions.COALESCE_FUNCTION_NAME; +import static java.lang.String.join; + +public class RewriteCoalesce + implements ConnectorExpressionRule +{ + private static final Capture CALL = newCapture(); + private final Pattern pattern; + + public RewriteCoalesce() + { + this.pattern = call() + .with(functionName().matching(name -> name.equals(COALESCE_FUNCTION_NAME))) + .capturedAs(CALL); + } + + @Override + public Pattern getPattern() + { + return pattern; + } + + @Override + public Optional rewrite(Call call, Captures captures, RewriteContext context) + { + verify(call.getArguments().size() >= 2, "Function 'coalesce' expects more than or equals to two arguments"); + + ImmutableList.Builder rewrittenArguments = ImmutableList.builderWithExpectedSize(call.getArguments().size()); + ImmutableList.Builder parameters = ImmutableList.builder(); + for (ConnectorExpression expression : captures.get(CALL).getArguments()) { + Optional rewritten = context.defaultRewrite(expression); + if (rewritten.isEmpty()) { + return Optional.empty(); + } + rewrittenArguments.add(rewritten.get().expression()); + parameters.addAll(rewritten.get().parameters()); + } + + List arguments = rewrittenArguments.build(); + verify(arguments.size() >= 2, "Function 'coalesce' expects more than or equals to two arguments"); + return Optional.of(new ParameterizedExpression("COALESCE(%s)".formatted(join(",", arguments)), parameters.build())); + } +} diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index 15d37ba9762e..791d1628c554 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -74,6 +74,7 @@ import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.expression.ParameterizedExpression; +import io.trino.plugin.jdbc.expression.RewriteCoalesce; import io.trino.plugin.jdbc.expression.RewriteIn; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping; @@ -162,7 +163,6 @@ import static io.trino.geospatial.serde.JtsGeometrySerde.serialize; import static io.trino.plugin.base.util.JsonTypeUtil.jsonParse; import static io.trino.plugin.base.util.JsonTypeUtil.toJsonValue; -import static io.trino.plugin.jdbc.DecimalConfig.DecimalMapping.ALLOW_OVERFLOW; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalDefaultScale; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRounding; import static io.trino.plugin.jdbc.DecimalSessionSessionProperties.getDecimalRoundingMode; @@ -334,6 +334,7 @@ public PostgreSqlClient( this.connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() .addStandardRules(this::quoted) .add(new RewriteIn()) + .add(new RewriteCoalesce()) .withTypeClass("integer_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint")) .withTypeClass("numeric_type", ImmutableSet.of("tinyint", "smallint", "integer", "bigint", "decimal", "real", "double")) .map("$equal(left, right)").to("left = right") diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java index b162e9b90bc2..2231df6dd4e3 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlClient.java @@ -39,6 +39,7 @@ import io.trino.spi.function.OperatorType; import io.trino.spi.session.PropertyMetadata; import io.trino.sql.ir.Call; +import io.trino.sql.ir.Coalesce; import io.trino.sql.ir.Comparison; import io.trino.sql.ir.Constant; import io.trino.sql.ir.Expression; @@ -85,6 +86,13 @@ public class TestPostgreSqlClient .setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())) .build(); + private static final JdbcColumnHandle BIGINT_COLUMN2 = + JdbcColumnHandle.builder() + .setColumnName("c_bigint2") + .setColumnType(BIGINT) + .setJdbcTypeHandle(new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())) + .build(); + private static final JdbcColumnHandle DOUBLE_COLUMN = JdbcColumnHandle.builder() .setColumnName("c_double") @@ -411,6 +419,33 @@ public void testConvertIn() new QueryParameter(createVarcharType(10), Optional.of(utf8Slice("value2"))))); } + @Test + public void testConvertCoalesce() + { + // COALESCE(varchar, varchar) + ParameterizedExpression converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new Coalesce( + new Reference(VARCHAR, "c_varchar_symbol"), + new Reference(VARCHAR, "c_varchar_symbol_2"))), + Map.of("c_varchar_symbol", VARCHAR_COLUMN, "c_varchar_symbol_2", VARCHAR_COLUMN2)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("COALESCE(\"c_varchar\",\"c_varchar2\")"); + assertThat(converted.parameters()).isEqualTo(List.of()); + + // COALESCE(bigint, bigint, bigint) + converted = JDBC_CLIENT.convertPredicate(SESSION, + translateToConnectorExpression( + new Coalesce( + new Reference(BIGINT, "c_bigint_symbol"), + new Reference(BIGINT, "c_bigint_symbol_2"), + new Constant(BIGINT, 123L))), + Map.of("c_bigint_symbol", BIGINT_COLUMN, "c_bigint_symbol_2", BIGINT_COLUMN2)) + .orElseThrow(); + assertThat(converted.expression()).isEqualTo("COALESCE(\"c_bigint\",\"c_bigint2\",?)"); + assertThat(converted.parameters()).isEqualTo(List.of(new QueryParameter(BIGINT, Optional.of(123L)))); + } + private ConnectorExpression translateToConnectorExpression(Expression expression) { return ConnectorExpressionTranslator.translate(TEST_SESSION, expression) diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index eb0c00f806b9..b8121e2230ab 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -807,6 +807,34 @@ public void testOrPredicatePushdown() assertThat(query("SELECT * FROM nation WHERE name = NULL OR regionkey = 4")).isFullyPushedDown(); } + @Test + public void testCoalescePredicatePushdown() + { + assertThat(query("SELECT * FROM nation WHERE COALESCE(nationkey, 1) = nationkey")) + .isFullyPushedDown(); + assertThat(query("SELECT * FROM nation WHERE COALESCE(nationkey, regionkey, 1) = nationkey")) + .isFullyPushedDown(); + + try (TestTable table = new TestTable( + getQueryRunner()::execute, + "test_coalesce_predicate_pushdown", + "(a_varchar varchar, b_varchar varchar, c_varchar varchar)", + List.of( + "NULL, NULL, 'third not null'", + "'1', '2', 'first and second not null'", + "NULL, '2', 'second not null'"))) { + assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar) = '1'")) + .matches("VALUES VARCHAR 'first and second not null'") + .isFullyPushedDown(); + assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar) = '2'")) + .matches("VALUES VARCHAR 'second not null'") + .isFullyPushedDown(); + assertThat(query("SELECT c_varchar FROM " + table.getName() + " WHERE COALESCE(a_varchar, b_varchar, c_varchar) = 'third not null'")) + .matches("VALUES VARCHAR 'third not null'") + .isFullyPushedDown(); + } + } + @Test public void testLikePredicatePushdown() {