From 5eb794a1daf43545c87bcdab7da46c5bc1f616b8 Mon Sep 17 00:00:00 2001 From: Naoki Takezoe Date: Sat, 23 Oct 2021 02:15:53 +0900 Subject: [PATCH] Support string pushdown with collation in PostgreSQL connector --- .../plugin/postgresql/PostgreSqlClient.java | 104 +++++++++++++++++- .../plugin/postgresql/PostgreSqlConfig.java | 13 +++ .../PostgreSqlSessionProperties.java | 12 ++ .../postgresql/PostgreSqlQueryRunner.java | 1 + .../postgresql/TestPostgreSqlConfig.java | 7 +- .../TestPostgreSqlConnectorTest.java | 61 ++++++++++ 6 files changed, 190 insertions(+), 8 deletions(-) diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java index e6d575a85dc9..60926b077dfb 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlClient.java @@ -165,6 +165,7 @@ import static io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping.AS_JSON; import static io.trino.plugin.postgresql.PostgreSqlConfig.ArrayMapping.DISABLED; import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.getArrayMapping; +import static io.trino.plugin.postgresql.PostgreSqlSessionProperties.isEnableStringPushdownWithCollate; import static io.trino.plugin.postgresql.TypeUtils.arrayDepth; import static io.trino.plugin.postgresql.TypeUtils.getArrayElementPgTypeName; import static io.trino.plugin.postgresql.TypeUtils.getJdbcObjectArray; @@ -234,7 +235,7 @@ public class PostgreSqlClient private final List tableTypes; private final AggregateFunctionRewriter aggregateFunctionRewriter; - private static final PredicatePushdownController POSTGRESQL_CHARACTER_PUSHDOWN = (session, domain) -> { + private static final PredicatePushdownController POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE = (session, domain) -> { checkArgument( domain.getType() instanceof VarcharType || domain.getType() instanceof CharType, "This PredicatePushdownController can be used only for chars and varchars"); @@ -506,13 +507,22 @@ public Optional toColumnMapping(ConnectorSession session, Connect } case Types.CHAR: + if (isEnableStringPushdownWithCollate(session)) { + return Optional.of(charColumnMappingWithCollate(typeHandle.getRequiredColumnSize())); + } return Optional.of(charColumnMapping(typeHandle.getRequiredColumnSize())); case Types.VARCHAR: if (!jdbcTypeName.equals("varchar")) { // This can be e.g. an ENUM + if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) { + return Optional.of(typedVarcharColumnMappingWithCollate(jdbcTypeName)); + } return Optional.of(typedVarcharColumnMapping(jdbcTypeName)); } + if (isCollatable(jdbcTypeName) && isEnableStringPushdownWithCollate(session)) { + return Optional.of(varcharColumnMappingWithCollate(typeHandle.getRequiredColumnSize())); + } return Optional.of(varcharColumnMapping(typeHandle.getRequiredColumnSize())); case Types.BINARY: @@ -749,14 +759,19 @@ private boolean isCollatable(JdbcColumnHandle column) if (column.getColumnType() instanceof CharType || column.getColumnType() instanceof VarcharType) { String jdbcTypeName = column.getJdbcTypeHandle().getJdbcTypeName() .orElseThrow(() -> new TrinoException(JDBC_ERROR, "Type name is missing: " + column.getJdbcTypeHandle())); - // Only char (internally named bpchar)/varchar/text are the built-in collatable types - return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName); + return isCollatable(jdbcTypeName); } // non-textual types don't have the concept of collation return false; } + private boolean isCollatable(String jdbcTypeName) + { + // Only char (internally named bpchar)/varchar/text are the built-in collatable types + return "bpchar".equals(jdbcTypeName) || "varchar".equals(jdbcTypeName) || "text".equals(jdbcTypeName); + } + @Override public boolean isTopNGuaranteed(ConnectorSession session) { @@ -833,7 +848,20 @@ private static ColumnMapping charColumnMapping(int charLength) charType, charReadFunction(charType), charWriteFunction(), - POSTGRESQL_CHARACTER_PUSHDOWN); + POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); + } + + private static ColumnMapping charColumnMappingWithCollate(int charLength) + { + if (charLength > CharType.MAX_LENGTH) { + return varcharColumnMappingWithCollate(charLength); + } + CharType charType = createCharType(charLength); + return ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + stringWriteFunctionWithCollate(), + FULL_PUSHDOWN); } private static ColumnMapping varcharColumnMapping(int varcharLength) @@ -845,7 +873,38 @@ private static ColumnMapping varcharColumnMapping(int varcharLength) varcharType, varcharReadFunction(varcharType), varcharWriteFunction(), - POSTGRESQL_CHARACTER_PUSHDOWN); + POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); + } + + private static ColumnMapping varcharColumnMappingWithCollate(int varcharLength) + { + VarcharType varcharType = varcharLength <= VarcharType.MAX_LENGTH + ? createVarcharType(varcharLength) + : createUnboundedVarcharType(); + return ColumnMapping.sliceMapping( + varcharType, + varcharReadFunction(varcharType), + stringWriteFunctionWithCollate(), + FULL_PUSHDOWN); + } + + private static SliceWriteFunction stringWriteFunctionWithCollate() + { + return new SliceWriteFunction() + { + @Override + public String getBindExpression() + { + return "? COLLATE \"C\""; + } + + @Override + public void set(PreparedStatement statement, int index, Slice value) + throws SQLException + { + statement.setString(index, value.toStringUtf8()); + } + }; } private static ColumnMapping timeColumnMapping(int precision) @@ -1161,12 +1220,45 @@ private static ColumnMapping typedVarcharColumnMapping(String jdbcTypeName) return ColumnMapping.sliceMapping( VARCHAR, (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), - typedVarcharWriteFunction(jdbcTypeName)); + typedVarcharWriteFunction(jdbcTypeName), + POSTGRESQL_STRING_PUSHDOWN_WITHOUT_COLLATE); + } + + private static ColumnMapping typedVarcharColumnMappingWithCollate(String jdbcTypeName) + { + return ColumnMapping.sliceMapping( + VARCHAR, + (resultSet, columnIndex) -> utf8Slice(resultSet.getString(columnIndex)), + typedVarcharWriteFunctionWithCollate(jdbcTypeName), + FULL_PUSHDOWN); } private static SliceWriteFunction typedVarcharWriteFunction(String jdbcTypeName) { String bindExpression = format("CAST(? AS %s)", requireNonNull(jdbcTypeName, "jdbcTypeName is null")); + + return new SliceWriteFunction() + { + @Override + public String getBindExpression() + { + return bindExpression; + } + + @Override + public void set(PreparedStatement statement, int index, Slice value) + throws SQLException + { + statement.setString(index, value.toStringUtf8()); + } + }; + } + + private static SliceWriteFunction typedVarcharWriteFunctionWithCollate(String jdbcTypeName) + { + String collation = "COLLATE \"C\""; + String bindExpression = format("CAST(? AS %s) %s", requireNonNull(jdbcTypeName, "jdbcTypeName is null"), collation); + return new SliceWriteFunction() { @Override diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlConfig.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlConfig.java index 51fdad95bf72..0c51d8eca592 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlConfig.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlConfig.java @@ -22,6 +22,7 @@ public class PostgreSqlConfig { private ArrayMapping arrayMapping = ArrayMapping.DISABLED; private boolean includeSystemTables; + private boolean enableStringPushdownWithCollate; public enum ArrayMapping { @@ -55,4 +56,16 @@ public PostgreSqlConfig setIncludeSystemTables(boolean includeSystemTables) this.includeSystemTables = includeSystemTables; return this; } + + public boolean isEnableStringPushdownWithCollate() + { + return enableStringPushdownWithCollate; + } + + @Config("postgresql.experimental.enable-string-pushdown-with-collate") + public PostgreSqlConfig setEnableStringPushdownWithCollate(boolean enableStringPushdownWithCollate) + { + this.enableStringPushdownWithCollate = enableStringPushdownWithCollate; + return this; + } } diff --git a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlSessionProperties.java b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlSessionProperties.java index 138ea091a5eb..2f80fd11b265 100644 --- a/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlSessionProperties.java +++ b/plugin/trino-postgresql/src/main/java/io/trino/plugin/postgresql/PostgreSqlSessionProperties.java @@ -23,12 +23,14 @@ import java.util.List; +import static io.trino.spi.session.PropertyMetadata.booleanProperty; import static io.trino.spi.session.PropertyMetadata.enumProperty; public final class PostgreSqlSessionProperties implements SessionPropertiesProvider { public static final String ARRAY_MAPPING = "array_mapping"; + public static final String ENABLE_STRING_PUSHDOWN_WITH_COLLATE = "enable_string_pushdown_with_collate"; private final List> sessionProperties; @@ -41,6 +43,11 @@ public PostgreSqlSessionProperties(PostgreSqlConfig postgreSqlConfig) "Handling of PostgreSql arrays", ArrayMapping.class, postgreSqlConfig.getArrayMapping(), + false), + booleanProperty( + ENABLE_STRING_PUSHDOWN_WITH_COLLATE, + "Enable string pushdown with collate (experimental)", + postgreSqlConfig.isEnableStringPushdownWithCollate(), false)); } @@ -54,4 +61,9 @@ public static ArrayMapping getArrayMapping(ConnectorSession session) { return session.getProperty(ARRAY_MAPPING, ArrayMapping.class); } + + public static boolean isEnableStringPushdownWithCollate(ConnectorSession session) + { + return session.getProperty(ENABLE_STRING_PUSHDOWN_WITH_COLLATE, Boolean.class); + } } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/PostgreSqlQueryRunner.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/PostgreSqlQueryRunner.java index 63b18356b062..05fec7d84236 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/PostgreSqlQueryRunner.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/PostgreSqlQueryRunner.java @@ -58,6 +58,7 @@ public static DistributedQueryRunner createPostgreSqlQueryRunner( connectorProperties.putIfAbsent("connection-password", server.getPassword()); connectorProperties.putIfAbsent("allow-drop-table", "true"); connectorProperties.putIfAbsent("postgresql.include-system-tables", "true"); + //connectorProperties.putIfAbsent("postgresql.experimental.enable-string-pushdown-with-collate", "true"); queryRunner.installPlugin(new PostgreSqlPlugin()); queryRunner.createCatalog("postgresql", "postgresql", connectorProperties); diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConfig.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConfig.java index b6fa250b835a..a5297222f177 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConfig.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConfig.java @@ -29,7 +29,8 @@ public void testDefaults() { assertRecordedDefaults(recordDefaults(PostgreSqlConfig.class) .setArrayMapping(PostgreSqlConfig.ArrayMapping.DISABLED) - .setIncludeSystemTables(false)); + .setIncludeSystemTables(false) + .setEnableStringPushdownWithCollate(false)); } @Test @@ -38,11 +39,13 @@ public void testExplicitPropertyMappings() Map properties = new ImmutableMap.Builder() .put("postgresql.array-mapping", "AS_ARRAY") .put("postgresql.include-system-tables", "true") + .put("postgresql.experimental.enable-string-pushdown-with-collate", "true") .build(); PostgreSqlConfig expected = new PostgreSqlConfig() .setArrayMapping(PostgreSqlConfig.ArrayMapping.AS_ARRAY) - .setIncludeSystemTables(true); + .setIncludeSystemTables(true) + .setEnableStringPushdownWithCollate(true); assertFullMapping(properties, expected); } diff --git a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java index 388e9237f8e0..a02105e4ad30 100644 --- a/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java +++ b/plugin/trino-postgresql/src/test/java/io/trino/plugin/postgresql/TestPostgreSqlConnectorTest.java @@ -13,12 +13,16 @@ */ package io.trino.plugin.postgresql; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import io.airlift.units.Duration; import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.RemoteDatabaseEvent; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.predicate.Range; import io.trino.spi.predicate.TupleDomain; import io.trino.sql.planner.plan.FilterNode; import io.trino.sql.planner.plan.JoinNode; @@ -42,7 +46,10 @@ import java.util.Map; import java.util.UUID; +import static com.google.common.collect.MoreCollectors.onlyElement; +import static io.airlift.slice.Slices.utf8Slice; import static io.trino.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; +import static io.trino.spi.type.VarcharType.createVarcharType; import static io.trino.sql.planner.assertions.PlanMatchPattern.anyTree; import static io.trino.sql.planner.assertions.PlanMatchPattern.node; import static io.trino.sql.planner.assertions.PlanMatchPattern.tableScan; @@ -82,6 +89,7 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_INEQUALITY: return false; case SUPPORTS_TOPN_PUSHDOWN: @@ -368,6 +376,59 @@ public void testPredicatePushdown() anyTree(node(TableScanNode.class)))); } + @Test + public void testStringPushdownWithCollate() + { + Session session = Session.builder(getSession()) + .setCatalogSessionProperty("postgresql", "enable_string_pushdown_with_collate", "true") + .build(); + + // varchar range + assertThat(query(session, "SELECT regionkey, nationkey, name FROM nation WHERE name BETWEEN 'POLAND' AND 'RPA'")) + .matches("VALUES (BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25)))") + .isFullyPushedDown(); + + // varchar IN with small compaction threshold + assertThat(query( + Session.builder(session) + .setCatalogSessionProperty("postgresql", "domain_compaction_threshold", "1") + .build(), + "SELECT regionkey, nationkey, name FROM nation WHERE name IN ('POLAND', 'ROMANIA', 'VIETNAM')")) + .matches("VALUES " + + "(BIGINT '3', BIGINT '19', CAST('ROMANIA' AS varchar(25))), " + + "(BIGINT '2', BIGINT '21', CAST('VIETNAM' AS varchar(25)))") + // Verify that a FilterNode is retained and only a compacted domain is pushed down to connector as a range predicate + .isNotFullyPushedDown(node(FilterNode.class, tableScan( + tableHandle -> { + TupleDomain constraint = ((JdbcTableHandle) tableHandle).getConstraint(); + ColumnHandle nameColumn = constraint.getDomains().orElseThrow() + .keySet().stream() + .map(JdbcColumnHandle.class::cast) + .filter(column -> column.getColumnName().equals("name")) + .collect(onlyElement()); + return constraint.getDomains().get().get(nameColumn).getValues().getRanges().getOrderedRanges() + .equals(ImmutableList.of( + Range.range( + createVarcharType(25), + utf8Slice("POLAND"), true, + utf8Slice("VIETNAM"), true))); + }, + TupleDomain.all(), + ImmutableMap.of()))); + + // varchar predicate over join + Session joinPushdownEnabled = joinPushdownEnabled(session); + assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.custkey = n.nationkey WHERE address < 'TcGe5gaZNgVePxU5kRrvXBfkasDTea'")) + .isFullyPushedDown(); + + // join on varchar columns is not pushed down + assertThat(query(joinPushdownEnabled, "SELECT c.name, n.name FROM customer c JOIN nation n ON c.address = n.name")) + .isNotFullyPushedDown( + node(JoinNode.class, + anyTree(node(TableScanNode.class)), + anyTree(node(TableScanNode.class)))); + } + @Test public void testDecimalPredicatePushdown() throws Exception