From 87b8e6f68b1875353f39569050cc3989a838b63c Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 16:14:48 -0800 Subject: [PATCH 01/24] Add Redshift tests --- plugin/trino-redshift/README.md | 20 ++ plugin/trino-redshift/pom.xml | 96 +++++++ .../plugin/redshift/RedshiftClientModule.java | 12 +- .../plugin/redshift/RedshiftQueryRunner.java | 264 ++++++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 193 +++++++++++++ 5 files changed, 584 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-redshift/README.md create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java diff --git a/plugin/trino-redshift/README.md b/plugin/trino-redshift/README.md new file mode 100644 index 000000000000..16229b145da1 --- /dev/null +++ b/plugin/trino-redshift/README.md @@ -0,0 +1,20 @@ +# Redshift Connector + +To run the Redshift tests you will need to provision a Redshift cluster. The +tests are designed to run on the smallest possible Redshift cluster containing +is a single dc2.large instance. Additionally, you will need a S3 bucket +containing TPCH tiny data in Parquet format. The files should be named: + +``` +s3:///tpch/tiny/.parquet +``` + +To run the tests set the following system properties: + +``` +test.redshift.jdbc.endpoint=..redshift.amazonaws.com:5439/ +test.redshift.jdbc.user= +test.redshift.jdbc.password= +test.redshift.s3.tpch.tables.root= +test.redshift.iam.role= +``` diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index a6270049b7c2..30fea843c921 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -40,12 +40,36 @@ + + io.airlift + log + runtime + + + + io.airlift + log-manager + runtime + + com.google.guava guava runtime + + net.jodah + failsafe + runtime + + + + org.jdbi + jdbi3-core + runtime + + io.trino @@ -72,9 +96,59 @@ + + io.trino + trino-base-jdbc + test-jar + test + + + + io.trino + trino-main + test + + io.trino trino-main + test-jar + test + + + + io.trino + trino-testing + test + + + + io.trino + trino-testing-services + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + io.airlift + testing + test + + + + org.assertj + assertj-core test @@ -84,4 +158,26 @@ test + + + + default + + true + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/TestRedshiftConnectorTest.java + + + + + + + diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 53e1aee6ac29..dccc6fabff94 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -28,6 +28,8 @@ import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; +import java.util.Properties; + import static com.google.inject.multibindings.Multibinder.newSetBinder; public class RedshiftClientModule @@ -45,6 +47,14 @@ public void configure(Binder binder) @ForBaseJdbc public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider) { - return new DriverConnectionFactory(new Driver(), config, credentialProvider); + return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), getDriverProperties(), credentialProvider); + } + + private static Properties getDriverProperties() + { + Properties properties = new Properties(); + properties.put("reWriteBatchedInserts", "true"); + properties.put("reWriteBatchedInsertsSize", "512"); + return properties; } } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java new file mode 100644 index 000000000000..61859dd9b52c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -0,0 +1,264 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; +import io.airlift.log.Logger; +import io.airlift.log.Logging; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.copyTable; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.assertions.Assert.assertEquals; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableSet; + +public final class RedshiftQueryRunner +{ + private static final Logger log = Logger.get(RedshiftQueryRunner.class); + private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); + private static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + private static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); + private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); + + private static final String TEST_DATABASE = "testdb"; + private static final String TEST_CATALOG = "redshift"; + static final String TEST_SCHEMA = "test_schema"; + + private static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + + private static final String CONNECTOR_NAME = "redshift"; + private static final String TPCH_CATALOG = "tpch"; + + private static final String GRANTED_USER = "alice"; + private static final String NON_GRANTED_USER = "bob"; + + private RedshiftQueryRunner() {} + + public static DistributedQueryRunner createRedshiftQueryRunner( + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + return createRedshiftQueryRunner( + createSession(), + extraProperties, + connectorProperties, + tables); + } + + public static DistributedQueryRunner createRedshiftQueryRunner( + Session session, + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + DistributedQueryRunner.Builder builder = DistributedQueryRunner.builder(session); + extraProperties.forEach(builder::addExtraProperty); + DistributedQueryRunner runner = builder.build(); + try { + runner.installPlugin(new TpchPlugin()); + runner.createCatalog(TPCH_CATALOG, "tpch", Map.of()); + + Map properties = new HashMap<>(connectorProperties); + properties.putIfAbsent("connection-url", JDBC_URL); + properties.putIfAbsent("connection-user", JDBC_USER); + properties.putIfAbsent("connection-password", JDBC_PASSWORD); + + runner.installPlugin(new RedshiftPlugin()); + runner.createCatalog(TEST_CATALOG, CONNECTOR_NAME, properties); + + executeInRedshift("CREATE SCHEMA IF NOT EXISTS " + TEST_SCHEMA); + createUserIfNotExists(NON_GRANTED_USER, JDBC_PASSWORD); + createUserIfNotExists(GRANTED_USER, JDBC_PASSWORD); + + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", TEST_DATABASE, GRANTED_USER)); + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + + provisionTables(session, runner, tables); + + // This step is necessary for product tests + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + } + catch (Throwable e) { + closeAllSuppress(e, runner); + throw e; + } + return runner; + } + + private static Session createSession() + { + return createSession(GRANTED_USER); + } + + private static Session createSession(String user) + { + return testSessionBuilder() + .setCatalog(TEST_CATALOG) + .setSchema(TEST_SCHEMA) + .setIdentity(Identity.ofUser(user)) + .build(); + } + + private static void createUserIfNotExists(String user, String password) + { + try { + executeInRedshift("CREATE USER " + user + " PASSWORD " + "'" + password + "'"); + } + catch (Exception e) { + // if user already exists, swallow the exception + if (!e.getMessage().matches(".*user \"" + user + "\" already exists.*")) { + throw e; + } + } + } + + private static void executeInRedshiftWithRetry(String sql) + { + Failsafe.with(new RetryPolicy<>() + .handleIf(e -> e.getMessage().matches(".* concurrent transaction .*")) + .withDelay(Duration.ofSeconds(10)) + .withMaxRetries(3)) + .run(() -> executeInRedshift(sql)); + } + + public static void executeInRedshift(String sql, Object... parameters) + { + executeInRedshift(handle -> handle.execute(sql, parameters)); + } + + private static void executeInRedshift(HandleConsumer consumer) + throws E + { + Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(consumer.asCallback()); + } + + private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) + { + Set existingTables = queryRunner.listTables(session, session.getCatalog().orElseThrow(), session.getSchema().orElseThrow()) + .stream() + .map(QualifiedObjectName::getObjectName) + .collect(toUnmodifiableSet()); + + Streams.stream(tables) + .map(table -> table.getTableName().toLowerCase(ENGLISH)) + .filter(name -> !existingTables.contains(name)) + .forEach(name -> copyFromS3(queryRunner, session, name)); + + for (TpchTable tpchTable : tables) { + verifyLoadedDataHasSameSchema(session, queryRunner, tpchTable); + } + } + + private static void copyFromS3(QueryRunner queryRunner, Session session, String name) + { + String s3Path = format("%s/%s/%s.parquet", S3_TPCH_TABLES_ROOT, TPCH_CATALOG, name); + log.info("Creating table %s in Redshift copying from %s", name, s3Path); + + // Create table in ephemeral Redshift cluster with no data + String createTableSql = format("CREATE TABLE %s.%s.%s AS ", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), name) + + format("SELECT * FROM %s.%s.%s WITH NO DATA", TPCH_CATALOG, TINY_SCHEMA_NAME, name); + queryRunner.execute(session, createTableSql); + + // Copy data from S3 bucket to ephemeral Redshift + String copySql = "COPY " + TEST_SCHEMA + "." + name + + " FROM '" + s3Path + "'" + + " IAM_ROLE '" + IAM_ROLE + "'" + + " FORMAT PARQUET"; + executeInRedshiftWithRetry(copySql); + } + + private static void copyFromTpchCatalog(QueryRunner queryRunner, Session session, String name) + { + // This function exists in case we need to copy data from the TPCH catalog rather than S3, + // such as moving to a new AWS account or if the schema changes. We can swap this method out for + // copyFromS3 in provisionTables and then export the data again to S3. + copyTable(queryRunner, TPCH_CATALOG, TINY_SCHEMA_NAME, name, session); + } + + private static void verifyLoadedDataHasSameSchema(Session session, QueryRunner queryRunner, TpchTable tpchTable) + { + // We want to verify that the loaded data has the same schema as if we created a fresh table from the TPC-H catalog + // If this assertion fails, we may need to recreate the Redshift tables from the TPC-H catalog and unload the data to S3 + try { + long expectedCount = (long) queryRunner.execute("SELECT count(*) FROM " + format("%s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())).getOnlyValue(); + long actualCount = (long) queryRunner.execute( + "SELECT count(*) FROM " + format( + "%s.%s.%s", + session.getCatalog().orElseThrow(), + session.getSchema().orElseThrow(), + tpchTable.getTableName())).getOnlyValue(); + + if (expectedCount != actualCount) { + throw new RuntimeException(format("Table %s is not loaded correctly. Expected %s rows got %s", tpchTable.getTableName(), expectedCount, actualCount)); + } + + log.info("Checking column types on table %s", tpchTable.getTableName()); + MaterializedResult expectedColumns = queryRunner.execute(format("DESCRIBE %s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())); + MaterializedResult actualColumns = queryRunner.execute("DESCRIBE " + tpchTable.getTableName()); + assertEquals(actualColumns, expectedColumns); + } + catch (Exception e) { + throw new RuntimeException("Failed to assert columns for TPC-H table " + tpchTable.getTableName(), e); + } + } + + /** + * Get the named system property, throwing an exception if it is not set. + */ + private static String requireSystemProperty(String property) + { + return requireNonNull(System.getProperty(property), property + " is not set"); + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + DistributedQueryRunner queryRunner = createRedshiftQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of(), + ImmutableList.of()); + + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java new file mode 100644 index 000000000000..14bf231c048c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -0,0 +1,193 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.tpch.TpchTable; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftConnectorTest + extends BaseJdbcConnectorTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + // NOTE this can cause tests to time-out if larger tables like + // lineitem and orders need to be re-created. + TpchTable.getTables()); + } + + @Override + @SuppressWarnings("DuplicateBranchesInSwitch") // options here are grouped per-feature + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_DELETE: + case SUPPORTS_AGGREGATION_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_TOPN_PUSHDOWN: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_COMMENT_ON_TABLE: + case SUPPORTS_ADD_COLUMN_WITH_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT: + return false; + + case SUPPORTS_ARRAY: + case SUPPORTS_ROW_TYPE: + return false; + + case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + + @Override + protected TestTable createTableWithDefaultColumns() + { + return new TestTable( + onRemoteDatabase(), + format("%s.test_table_with_default_columns", TEST_SCHEMA), + "(col_required BIGINT NOT NULL," + + "col_nullable BIGINT," + + "col_default BIGINT DEFAULT 43," + + "col_nonnull_default BIGINT NOT NULL DEFAULT 42," + + "col_required2 BIGINT NOT NULL)"); + } + + @Override + protected Optional filterDataMappingSmokeTestData(DataMappingTestSetup dataMappingTestSetup) + { + String typeName = dataMappingTestSetup.getTrinoTypeName(); + if ("date".equals(typeName)) { + if (dataMappingTestSetup.getSampleValueLiteral().equals("DATE '1582-10-05'")) { + return Optional.empty(); + } + } + if ("tinyint".equals(typeName) || typeName.startsWith("time") || "varbinary".equals(typeName)) { + return Optional.empty(); + } + return Optional.of(dataMappingTestSetup); + } + + /** + * Overridden due to Redshift not supporting non-ASCII characters in CHAR. + */ + @Override + public void testCreateTableAsSelectWithUnicode() + { + assertThatThrownBy(super::testCreateTableAsSelectWithUnicode) + .hasStackTraceContaining("Value too long for character type"); + // NOTE we add a copy of the above using VARCHAR which supports non-ASCII characters + assertCreateTableAsSelect( + "SELECT CAST('\u2603' AS VARCHAR) unicode", + "SELECT 1"); + } + + @Override + @Test + public void testReadMetadataWithRelationsConcurrentModifications() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + public void testInsertRowConcurrently() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + protected String errorMessageForInsertIntoNotNullColumn(String columnName) + { + return format("(?s).*Cannot insert a NULL value into column %s.*", columnName); + } + + @Override + public void testCreateSchemaWithLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testRenameSchemaToLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testCreateTableWithLongTableName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testRenameTableToLongTableName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testCreateTableWithLongColumnName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testAlterTableAddLongColumnName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testAlterTableRenameColumnToLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return RedshiftQueryRunner::executeInRedshift; + } + + @Test + @Override + public void testAddNotNullColumnToNonEmptyTable() + { + throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); + } +} From 0312463d78c43621fb31526f0d472cc12560afa3 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 16:57:55 -0800 Subject: [PATCH 02/24] Add Redshift schema, table, and column length checks --- .../trino/plugin/redshift/RedshiftClient.java | 31 ++++++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 32 ++++++++----------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index af0078309e9f..57f43820c501 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -35,6 +35,7 @@ import javax.inject.Inject; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -130,6 +131,36 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) return statement; } + @Override + protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) + throws SQLException + { + // Redshift truncates schema name to 127 chars silently + if (schemaName.length() > databaseMetadata.getMaxSchemaNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Schema name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); + } + } + + @Override + protected void verifyTableName(DatabaseMetaData databaseMetadata, String tableName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (tableName.length() > databaseMetadata.getMaxTableNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Table name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxTableNameLength(), tableName.length())); + } + } + + @Override + protected void verifyColumnName(DatabaseMetaData databaseMetadata, String columnName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (columnName.length() > databaseMetadata.getMaxColumnNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Column name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxColumnNameLength(), columnName.length())); + } + } + @Override public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) { diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 14bf231c048c..0bd3f862e91e 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -24,10 +24,12 @@ import org.testng.annotations.Test; import java.util.Optional; +import java.util.OptionalInt; import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestRedshiftConnectorTest @@ -137,45 +139,39 @@ protected String errorMessageForInsertIntoNotNullColumn(String columnName) } @Override - public void testCreateSchemaWithLongName() + protected OptionalInt maxSchemaNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testRenameSchemaToLongName() + protected void verifySchemaNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Schema name must be shorter than or equal to '127' characters but got '128'"); } @Override - public void testCreateTableWithLongTableName() + protected OptionalInt maxTableNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testRenameTableToLongTableName() + protected void verifyTableNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Table name must be shorter than or equal to '127' characters but got '128'"); } @Override - public void testCreateTableWithLongColumnName() + protected OptionalInt maxColumnNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testAlterTableAddLongColumnName() + protected void verifyColumnNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); - } - - @Override - public void testAlterTableRenameColumnToLongName() - { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Column name must be shorter than or equal to '127' characters but got '128'"); } @Override From 7eab69a7df84ba2195229eb175f76fcd16a0d899 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 17:51:24 -0800 Subject: [PATCH 03/24] Implement proper type mapping for Redshift --- plugin/trino-redshift/pom.xml | 17 +- .../trino/plugin/redshift/RedshiftClient.java | 446 +++++++- .../plugin/redshift/RedshiftClientModule.java | 15 +- .../plugin/redshift/RedshiftQueryRunner.java | 17 +- .../redshift/TestRedshiftConnectorTest.java | 63 +- .../redshift/TestRedshiftTypeMapping.java | 994 ++++++++++++++++++ 6 files changed, 1528 insertions(+), 24 deletions(-) create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index 30fea843c921..ebcd96279cda 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,12 +23,22 @@ trino-base-jdbc + + io.airlift + configuration + + com.amazon.redshift redshift-jdbc42 2.1.0.9 + + com.google.guava + guava + + com.google.inject guice @@ -52,12 +62,6 @@ runtime - - com.google.guava - guava - runtime - - net.jodah failsafe @@ -173,6 +177,7 @@ **/TestRedshiftConnectorTest.java + **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 57f43820c501..05d27ab5a0f0 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -13,6 +13,10 @@ */ package io.trino.plugin.redshift; +import com.amazon.redshift.jdbc.RedshiftPreparedStatement; +import com.amazon.redshift.util.RedshiftObject; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; @@ -20,33 +24,58 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.LongWriteFunction; +import io.trino.plugin.jdbc.ObjectReadFunction; +import io.trino.plugin.jdbc.ObjectWriteFunction; import io.trino.plugin.jdbc.QueryBuilder; +import io.trino.plugin.jdbc.SliceWriteFunction; +import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.CharType; +import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import javax.inject.Inject; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; import java.util.Optional; import java.util.function.BiFunction; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMappingUsingSqlDate; import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunctionUsingSqlDate; @@ -57,6 +86,7 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; @@ -68,29 +98,97 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochSecondsAndFraction; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.TIME_MICROS; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; +import static io.trino.spi.type.Timestamps.roundDiv; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.Math.max; +import static java.lang.Math.min; import static java.lang.String.format; +import static java.math.RoundingMode.UNNECESSARY; +import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; public class RedshiftClient extends BaseJdbcClient { + /** + * Redshift does not handle values larger than 64 bits for + * {@code DECIMAL(19, s)}. It supports the full range of values for all + * other precisions. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_DECIMAL_CUTOFF_PRECISION = 19; + + static final int REDSHIFT_MAX_DECIMAL_PRECISION = 38; + + /** + * Maximum size of a {@link BigInteger} storing a Redshift {@code DECIMAL} + * with precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + */ + // actual value is 63 + private static final int REDSHIFT_DECIMAL_CUTOFF_BITS = BigInteger.valueOf(Long.MAX_VALUE).bitLength(); + + /** + * Maximum size of a Redshift CHAR column. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_MAX_CHAR = 4096; + + /** + * Maximum size of a Redshift VARCHAR column. + * + * @see + * Redshift documentation + */ + static final int REDSHIFT_MAX_VARCHAR = 65535; + + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("yyy-MM-dd[ G]"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("yyy-MM-dd HH:mm:ss") + .optionalStart() + .appendFraction(NANO_OF_SECOND, 0, 6, true) + .optionalEnd() + .appendPattern("[ G]") + .toFormatter(); + private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + @Inject public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { @@ -162,7 +260,96 @@ protected void verifyColumnName(DatabaseMetaData databaseMetadata, String column } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle type) + { + Optional mapping = getForcedMappingToVarchar(type); + if (mapping.isPresent()) { + return mapping; + } + + if ("time".equals(type.getJdbcTypeName().orElse(""))) { + return Optional.of(ColumnMapping.longMapping( + TIME_MICROS, + RedshiftClient::readTime, + RedshiftClient::writeTime)); + } + + switch (type.getJdbcType()) { + case Types.BIT: // Redshift uses this for booleans + return Optional.of(booleanColumnMapping()); + + // case Types.TINYINT: -- Redshift doesn't support tinyint + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + + case Types.REAL: + return Optional.of(realColumnMapping()); + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + + case Types.NUMERIC: { + int precision = type.getRequiredColumnSize(); + int scale = type.getRequiredDecimalDigits(); + DecimalType decimalType = createDecimalType(precision, scale); + if (precision == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + return Optional.of(ColumnMapping.objectMapping( + decimalType, + longDecimalReadFunction(decimalType), + writeDecimalAtRedshiftCutoff(scale))); + } + return Optional.of(decimalColumnMapping(decimalType, UNNECESSARY)); + } + + case Types.CHAR: + CharType charType = createCharType(type.getRequiredColumnSize()); + return Optional.of(ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + RedshiftClient::writeChar)); + + case Types.VARCHAR: { + int length = type.getRequiredColumnSize(); + return Optional.of(varcharColumnMapping( + length < VarcharType.MAX_LENGTH + ? createVarcharType(length) + : createUnboundedVarcharType(), + true)); + } + + case Types.LONGVARBINARY: + return Optional.of(ColumnMapping.sliceMapping( + VARBINARY, + varbinaryReadFunction(), + varbinaryWriteFunction())); + + case Types.DATE: + return Optional.of(ColumnMapping.longMapping( + DATE, + RedshiftClient::readDate, + RedshiftClient::writeDate)); + + case Types.TIMESTAMP: + return Optional.of(ColumnMapping.longMapping( + TIMESTAMP_MICROS, + RedshiftClient::readTimestamp, + RedshiftClient::writeShortTimestamp)); + + case Types.TIMESTAMP_WITH_TIMEZONE: + return Optional.of(ColumnMapping.objectMapping( + TIMESTAMP_TZ_MICROS, + longTimestampWithTimeZoneReadFunction(), + longTimestampWithTimeZoneWriteFunction())); + } + + // Fall back to default behavior + return legacyToColumnMapping(session, type); + } + + private Optional legacyToColumnMapping(ConnectorSession session, JdbcTypeHandle typeHandle) { Optional mapping = getForcedMappingToVarchar(typeHandle); if (mapping.isPresent()) { @@ -181,6 +368,99 @@ public Optional toColumnMapping(ConnectorSession session, Connect @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { + if (BOOLEAN.equals(type)) { + return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); + } + if (TINYINT.equals(type)) { + // Redshift doesn't have tinyint + return WriteMapping.longMapping("smallint", tinyintWriteFunction()); + } + if (SMALLINT.equals(type)) { + return WriteMapping.longMapping("smallint", smallintWriteFunction()); + } + if (INTEGER.equals(type)) { + return WriteMapping.longMapping("integer", integerWriteFunction()); + } + if (BIGINT.equals(type)) { + return WriteMapping.longMapping("bigint", bigintWriteFunction()); + } + if (REAL.equals(type)) { + return WriteMapping.longMapping("real", realWriteFunction()); + } + if (DOUBLE.equals(type)) { + return WriteMapping.doubleMapping("double precision", doubleWriteFunction()); + } + + if (type instanceof DecimalType decimal) { + if (decimal.getPrecision() == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + // See doc for REDSHIFT_DECIMAL_CUTOFF_PRECISION + return WriteMapping.objectMapping( + format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()), + writeDecimalAtRedshiftCutoff(decimal.getScale())); + } + String name = format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()); + return decimal.isShort() + ? WriteMapping.longMapping(name, shortDecimalWriteFunction(decimal)) + : WriteMapping.objectMapping(name, longDecimalWriteFunction(decimal)); + } + + if (type instanceof CharType) { + // Redshift has no unbounded text/binary types, so if a CHAR is too + // large for Redshift, we write as VARCHAR. If too large for that, + // we use the largest VARCHAR Redshift supports. + int size = ((CharType) type).getLength(); + if (size <= REDSHIFT_MAX_CHAR) { + return WriteMapping.sliceMapping( + format("char(%d)", size), + RedshiftClient::writeChar); + } + int redshiftVarcharWidth = min(size, REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping( + format("varchar(%d)", redshiftVarcharWidth), + (statement, index, value) -> writeCharAsVarchar(statement, index, value, redshiftVarcharWidth)); + } + + if (type instanceof VarcharType) { + // Redshift has no unbounded text/binary types, so if a VARCHAR is + // larger than Redshift's limit, we make it that big instead. + int size = ((VarcharType) type).getLength() + .filter(n -> n <= REDSHIFT_MAX_VARCHAR) + .orElse(REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping(format("varchar(%d)", size), varcharWriteFunction()); + } + + if (VARBINARY.equals(type)) { + return WriteMapping.sliceMapping("varbyte", varbinaryWriteFunction()); + } + + if (DATE.equals(type)) { + return WriteMapping.longMapping("date", RedshiftClient::writeDate); + } + + if (type instanceof TimeType) { + return WriteMapping.longMapping("time", RedshiftClient::writeTime); + } + + if (type instanceof TimestampType) { + if (((TimestampType) type).isShort()) { + return WriteMapping.longMapping( + "timestamp", + RedshiftClient::writeShortTimestamp); + } + return WriteMapping.objectMapping( + "timestamp", + LongTimestamp.class, + RedshiftClient::writeLongTimestamp); + } + + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + if (timestampWithTimeZoneType.getPrecision() <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping("timestamptz", shortTimestampWithTimeZoneWriteFunction()); + } + return WriteMapping.objectMapping("timestamptz", longTimestampWithTimeZoneWriteFunction()); + } + + // Fall back to legacy behavior return legacyToWriteMapping(type); } @@ -214,9 +494,168 @@ private static String redshiftVarcharLiteral(String value) return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; } + private static ObjectReadFunction longTimestampWithTimeZoneReadFunction() + { + return ObjectReadFunction.of( + LongTimestampWithTimeZone.class, + (resultSet, columnIndex) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + OffsetDateTime offsetDateTime = resultSet.getObject(columnIndex, OffsetDateTime.class); + return fromEpochSecondsAndFraction( + offsetDateTime.toEpochSecond(), + (long) offsetDateTime.getNano() * PICOSECONDS_PER_NANOSECOND, + UTC_KEY); + }); + } + + private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() + { + return (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long millisUtc = unpackMillisUtc(value); + long epochSeconds = floorDiv(millisUtc, MILLISECONDS_PER_SECOND); + int nanosOfSecond = floorMod(millisUtc, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND; + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }; + } + + private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() + { + return ObjectWriteFunction.of( + LongTimestampWithTimeZone.class, + (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); + long nanosOfSecond = ((long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND) + + (value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND); + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }); + } + + private static void verifySupportedTimestampWithTimeZone(OffsetDateTime value) + { + if (value.isBefore(REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ)) { + DateTimeFormatter format = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS"); + throw new TrinoException( + INVALID_ARGUMENTS, + format("Minimum timestamp with time zone in Redshift is %s: %s", REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ.format(format), value.format(format))); + } + } + + /** + * Decimal write function for precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + * Ensures that values fit in 8 bytes. + */ + private static ObjectWriteFunction writeDecimalAtRedshiftCutoff(int scale) + { + return ObjectWriteFunction.of( + Int128.class, + (statement, index, decimal) -> { + BigInteger unscaled = decimal.toBigInteger(); + if (unscaled.bitLength() > REDSHIFT_DECIMAL_CUTOFF_BITS) { + throw new TrinoException(JDBC_NON_TRANSIENT_ERROR, format( + "Value out of range for Redshift DECIMAL(%d, %d)", + REDSHIFT_DECIMAL_CUTOFF_PRECISION, + scale)); + } + MathContext precision = new MathContext(REDSHIFT_DECIMAL_CUTOFF_PRECISION); + statement.setBigDecimal(index, new BigDecimal(unscaled, scale, precision)); + }); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but restrict to + * ASCII because Redshift only allows ASCII in {@code CHAR} values. + */ + private static void writeChar(PreparedStatement statement, int index, Slice slice) + throws SQLException + { + String value = slice.toStringUtf8(); + if (!CharMatcher.ascii().matchesAllOf(value)) { + throw new TrinoException( + JDBC_NON_TRANSIENT_ERROR, + format("Value for Redshift CHAR must be ASCII, but found '%s'", value)); + } + statement.setString(index, slice.toStringAscii()); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but pads + * the value with spaces to simulate {@code CHAR} semantics. + */ + private static void writeCharAsVarchar(PreparedStatement statement, int index, Slice slice, int columnLength) + throws SQLException + { + // Redshift counts varchar size limits in UTF-8 bytes, so this may make the string longer than + // the limit, but Redshift also truncates extra trailing spaces, so that doesn't cause any problems. + statement.setString(index, Chars.padSpaces(slice, columnLength).toStringUtf8()); + } + + private static void writeDate(PreparedStatement statement, int index, long day) + throws SQLException + { + statement.setObject(index, new RedshiftObject("date", DATE_FORMATTER.format(LocalDate.ofEpochDay(day)))); + } + + private static long readDate(ResultSet results, int index) + throws SQLException + { + // Reading date as string to workaround issues around julian->gregorian calendar switch + return LocalDate.parse(results.getString(index), DATE_FORMATTER).toEpochDay(); + } + + /** + * Write time with microsecond precision + */ + private static void writeTime(PreparedStatement statement, int index, long picos) + throws SQLException + { + statement.setObject(index, LocalTime.ofNanoOfDay((roundDiv(picos, PICOSECONDS_PER_MICROSECOND) % MICROSECONDS_PER_DAY) * NANOSECONDS_PER_MICROSECOND)); + } + + /** + * Read a time value with microsecond precision + */ + private static long readTime(ResultSet results, int index) + throws SQLException + { + return results.getObject(index, LocalTime.class).toNanoOfDay() * PICOSECONDS_PER_NANOSECOND; + } + + private static void writeShortTimestamp(PreparedStatement statement, int index, long epochMicros) + throws SQLException + { + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static void writeLongTimestamp(PreparedStatement statement, int index, Object value) + throws SQLException + { + LongTimestamp timestamp = (LongTimestamp) value; + long epochMicros = timestamp.getEpochMicros(); + if (timestamp.getPicosOfMicro() >= PICOSECONDS_PER_MICROSECOND / 2) { + epochMicros += 1; // Add one micro if picos round up + } + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static long readTimestamp(ResultSet results, int index) + throws SQLException + { + return StandardColumnMappings.toTrinoTimestamp(TIMESTAMP_MICROS, results.getObject(index, LocalDateTime.class)); + } + + private static SliceWriteFunction varbinaryWriteFunction() + { + return (statement, index, value) -> statement.unwrap(RedshiftPreparedStatement.class).setVarbyte(index, value.getBytes()); + } + private static Optional legacyDefaultColumnMapping(JdbcTypeHandle typeHandle) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated StandardColumnMappings.legacyDefaultColumnMapping() switch (typeHandle.getJdbcType()) { case Types.BIT: @@ -282,7 +721,6 @@ private static Optional legacyDefaultColumnMapping(JdbcTypeHandle private static WriteMapping legacyToWriteMapping(Type type) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated BaseJdbcClient.legacyToWriteMapping() if (type == BOOLEAN) { return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index dccc6fabff94..13635c88f69b 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -15,12 +15,12 @@ import com.amazon.redshift.Driver; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Provides; -import com.google.inject.Scopes; import com.google.inject.Singleton; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.DecimalModule; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; @@ -30,16 +30,19 @@ import java.util.Properties; +import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; public class RedshiftClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + public void setup(Binder binder) { - binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(Scopes.SINGLETON); - newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + + install(new DecimalModule()); } @Singleton diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java index 61859dd9b52c..3e96738e7ba1 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -28,6 +28,7 @@ import io.trino.tpch.TpchTable; import net.jodah.failsafe.Failsafe; import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleCallback; import org.jdbi.v3.core.HandleConsumer; import org.jdbi.v3.core.Jdbi; @@ -50,8 +51,8 @@ public final class RedshiftQueryRunner { private static final Logger log = Logger.get(RedshiftQueryRunner.class); private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); - private static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); - private static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); @@ -59,7 +60,7 @@ public final class RedshiftQueryRunner private static final String TEST_CATALOG = "redshift"; static final String TEST_SCHEMA = "test_schema"; - private static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; private static final String CONNECTOR_NAME = "redshift"; private static final String TPCH_CATALOG = "tpch"; @@ -164,10 +165,16 @@ public static void executeInRedshift(String sql, Object... parameters) executeInRedshift(handle -> handle.execute(sql, parameters)); } - private static void executeInRedshift(HandleConsumer consumer) + public static void executeInRedshift(HandleConsumer consumer) throws E { - Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(consumer.asCallback()); + executeWithRedshift(consumer.asCallback()); + } + + public static T executeWithRedshift(HandleCallback callback) + throws E + { + return Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(callback); } private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 0bd3f862e91e..0d46c93a6fc9 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -21,6 +21,7 @@ import io.trino.testing.sql.TestTable; import io.trino.tpch.TpchTable; import org.testng.SkipException; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.Optional; @@ -28,6 +29,7 @@ import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -99,9 +101,6 @@ protected Optional filterDataMappingSmokeTestData(DataMapp return Optional.empty(); } } - if ("tinyint".equals(typeName) || typeName.startsWith("time") || "varbinary".equals(typeName)) { - return Optional.empty(); - } return Optional.of(dataMappingTestSetup); } @@ -119,6 +118,39 @@ public void testCreateTableAsSelectWithUnicode() "SELECT 1"); } + @Test(dataProvider = "redshiftTypeToTrinoTypes") + public void testReadFromLateBindingView(String redshiftType, String trinoType) + { + try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) { + assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName()))) + .projected(1) + .containsAll("VALUES (true)"); + + assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) + .projected(1) + .skippingTypesCheck() + .containsAll("VALUES ('%s')".formatted(trinoType)); + } + } + + @DataProvider + public Object[][] redshiftTypeToTrinoTypes() + { + return new Object[][] { + {"SMALLINT", "smallint"}, + {"INTEGER", "integer"}, + {"BIGINT", "bigint"}, + {"DECIMAL", "decimal(18,0)"}, + {"REAL", "real"}, + {"DOUBLE PRECISION", "double"}, + {"BOOLEAN", "boolean"}, + {"CHAR(1)", "char(1)"}, + {"VARCHAR(1)", "varchar(1)"}, + {"TIME", "time(6)"}, + {"TIMESTAMP", "timestamp(6)"}, + {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -186,4 +218,29 @@ public void testAddNotNullColumnToNonEmptyTable() { throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); } + + private static class TestView + implements AutoCloseable + { + private final String name; + private final SqlExecutor executor; + + public TestView(SqlExecutor executor, String namePrefix, String viewDefinition) + { + this.executor = executor; + this.name = namePrefix + "_" + randomNameSuffix(); + executor.execute("CREATE OR REPLACE VIEW " + name + " AS " + viewDefinition); + } + + @Override + public void close() + { + executor.execute("DROP VIEW " + name); + } + + public String getName() + { + return name; + } + } } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java new file mode 100644 index 000000000000..26938c3b6532 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java @@ -0,0 +1,994 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.base.Utf8; +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingSession; +import io.trino.testing.datatype.CreateAndInsertDataSetup; +import io.trino.testing.datatype.CreateAsSelectDataSetup; +import io.trino.testing.datatype.DataSetup; +import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.JdbcSqlExecutor; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.testing.sql.TrinoSqlExecutor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.io.BaseEncoding.base16; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_VARCHAR; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.createTimeType; +import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftTypeMapping + extends AbstractTestQueryFramework +{ + private static final ZoneId testZone = TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId(); + + private final ZoneId jvmZone = ZoneId.systemDefault(); + private final LocalDateTime timeGapInJvmZone = LocalDate.EPOCH.atStartOfDay(); + private final LocalDateTime timeDoubledInJvmZone = LocalDateTime.of(2018, 10, 28, 1, 33, 17, 456_789_000); + + // using two non-JVM zones so that we don't need to worry what the backend's system zone is + + // no DST in 1970, but has DST in later years (e.g. 2018) + private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); + private final LocalDateTime timeGapInVilnius = LocalDateTime.of(2018, 3, 25, 3, 17, 17); + private final LocalDateTime timeDoubledInVilnius = LocalDateTime.of(2018, 10, 28, 3, 33, 33, 333_333_000); + + // Size of offset changed since 1970-01-01, no DST + private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); + private final LocalDateTime timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7); + + private final LocalDate dayOfMidnightGapInJvmZone = LocalDate.EPOCH; + private final LocalDate dayOfMidnightGapInVilnius = LocalDate.of(1983, 4, 1); + private final LocalDate dayAfterMidnightSetBackInVilnius = LocalDate.of(1983, 10, 1); + + @BeforeClass + public void checkRanges() + { + // Timestamps + checkIsGap(jvmZone, timeGapInJvmZone); + checkIsDoubled(jvmZone, timeDoubledInJvmZone); + checkIsGap(vilnius, timeGapInVilnius); + checkIsDoubled(vilnius, timeDoubledInVilnius); + checkIsGap(kathmandu, timeGapInKathmandu); + + // Times + checkIsGap(jvmZone, LocalTime.of(0, 0, 0).atDate(LocalDate.EPOCH)); + + // Dates + checkIsGap(jvmZone, dayOfMidnightGapInJvmZone.atStartOfDay()); + checkIsGap(vilnius, dayOfMidnightGapInVilnius.atStartOfDay()); + checkIsDoubled(vilnius, dayAfterMidnightSetBackInVilnius.atStartOfDay().minusNanos(1)); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), List.of()); + } + + @Test + public void testBasicTypes() + { + // Assume that if these types work at all, they have standard semantics. + SqlDataTypeTest.create() + .addRoundTrip("boolean", "true", BOOLEAN, "true") + .addRoundTrip("boolean", "false", BOOLEAN, "false") + .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") + .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") + .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") + .addRoundTrip("double", "123.45", DOUBLE, "DOUBLE '123.45'") + .addRoundTrip("real", "123.45", REAL, "REAL '123.45'") + // If we map tinyint to smallint: + .addRoundTrip("tinyint", "5", SMALLINT, "SMALLINT '5'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_basic_types")); + } + + @Test + public void testVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("varchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("varchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("varchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("varchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("varchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("varchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("varchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_varchar")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_varchar")); + } + + @Test + public void testChar() + { + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("char(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("char(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_char")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_char")); + + // Test with types larger than Redshift's char(max) + SqlDataTypeTest.create() + .addRoundTrip("char(65535)", "'varchar max'", createVarcharType(65535), format("CAST('varchar max%s' AS varchar(65535))", " ".repeat(65535 - "varchar max".length()))) + .addRoundTrip("char(4136)", "'攻殻機動隊'", createVarcharType(4136), format("CAST('%s' AS varchar(4136))", padVarchar(4136).apply("攻殻機動隊"))) + .addRoundTrip("char(4104)", "'隊'", createVarcharType(4104), format("CAST('%s' AS varchar(4104))", padVarchar(4104).apply("隊"))) + .addRoundTrip("char(4112)", "'😂'", createVarcharType(4112), format("CAST('%s' AS varchar(4112))", padVarchar(4112).apply("😂"))) + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("char(4106)", "'text_a'", createVarcharType(4106), format("CAST('%s' AS varchar(4106))", padVarchar(4106).apply("text_a"))) + .addRoundTrip("char(4351)", "'text_b'", createVarcharType(4351), format("CAST('%s' AS varchar(4351))", padVarchar(4351).apply("text_b"))) + .addRoundTrip("char(8192)", "'char max'", createVarcharType(8192), format("CAST('%s' AS varchar(8192))", padVarchar(8192).apply("char max"))) + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_large_char")); + } + + /** + * Test handling of data outside Redshift's normal bounds. + * + *

Redshift sometimes returns unbounded {@code VARCHAR} data, apparently + * when it returns directly from a Postgres function. + */ + @Test + public void testPostgresText() + { + try (TestView view1 = new TestView("postgres_text_view", "SELECT lpad('x', 1)"); + TestView view2 = new TestView("pg_catalog_view", "SELECT relname FROM pg_class")) { + // Test data and type from a function + assertThat(query(format("SELECT * FROM %s", view1.name))) + .matches("VALUES CAST('x' AS varchar)"); + + // Test the type of an internal table + assertThat(query(format("SELECT * FROM %s LIMIT 1", view2.name))) + .hasOutputTypes(List.of(createUnboundedVarcharType())); + } + } + + // Make sure that Redshift still maps NCHAR and NVARCHAR to CHAR and VARCHAR. + @Test + public void checkNCharAndNVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("nvarchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("nvarchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("nvarchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("nvarchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("nvarchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("nvarchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("nvarchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("nvarchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nvarchar")); + + SqlDataTypeTest.create() + .addRoundTrip("nchar(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("nchar(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("nchar(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nchar")); + } + + @Test + public void testUnicodeChar() // Redshift doesn't allow multibyte chars in CHAR + { + try (TestTable table = testTable("test_multibyte_char", "(c char(32))")) { + assertQueryFails( + format("INSERT INTO %s VALUES ('\u968A')", table.getName()), + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + assertCreateFails( + "test_multibyte_char_ctas", + "AS SELECT CAST('\u968A' AS char(32)) c", + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + // Make sure Redshift really doesn't allow multibyte characters in CHAR + @Test + public void checkUnicodeCharInRedshift() + { + try (TestTable table = testTable("check_multibyte_char", "(c char(32))")) { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("INSERT INTO %s VALUES ('\u968a')", table.getName()))) + .getCause() + .isInstanceOf(SQLException.class) + .hasMessageContaining("CHAR string contains invalid ASCII character"); + } + } + + @Test + public void testOversizedCharacterTypes() + { + // Test that character types too large for Redshift map to the maximum size + SqlDataTypeTest.create() + .addRoundTrip("varchar", "'unbounded'", createVarcharType(65535), "CAST('unbounded' AS varchar(65535))") + .addRoundTrip(format("varchar(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized varchar'", createVarcharType(65535), "CAST('oversized varchar' AS varchar(65535))") + .addRoundTrip(format("char(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized char'", createVarcharType(65535), format("CAST('%s' AS varchar(65535))", padVarchar(65535).apply("oversized char"))) + .execute(getQueryRunner(), trinoCreateAsSelect("oversized_character_types")); + } + + @Test + public void testVarbinary() + { + // Redshift's VARBYTE is mapped to Trino VARBINARY. Redshift does not have VARBINARY type. + SqlDataTypeTest.create() + // varbyte + .addRoundTrip("varbyte", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbyte", "to_varbyte('', 'hex')", VARBINARY, "X''") + .addRoundTrip("varbyte", utf8VarbyteLiteral("hello"), VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Piękna łąka w 東京都"), VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbyte", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbyte", "to_varbyte('000000000000', 'hex')", VARBINARY, "X'000000000000'") + .addRoundTrip("varbyte(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbyte(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // varbinary + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbinary(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // binary varying + .addRoundTrip("binary varying", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("binary varying", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("binary varying", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("binary varying(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("binary varying(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + .execute(getQueryRunner(), redshiftCreateAndInsert("test_varbinary")); + + SqlDataTypeTest.create() + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", "X''", VARBINARY, "X''") + .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); + } + + private static String utf8VarbyteLiteral(String string) + { + return format("to_varbyte('%s', 'hex')", base16().encode(string.getBytes(UTF_8))); + } + + @Test + public void testDecimal() + { + SqlDataTypeTest.create() + .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") + .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") + .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(31, 0)", "CAST('2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(31, 0)", "CAST('-2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('-2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(3, 0)", "NULL", createDecimalType(3, 0), "CAST(NULL AS decimal(3, 0))") + .addRoundTrip("decimal(31, 0)", "NULL", createDecimalType(31, 0), "CAST(NULL AS decimal(31, 0))") + .execute(getQueryRunner(), redshiftCreateAndInsert("test_decimal")) + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + } + + @Test + public void testRedshiftDecimalCutoff() + { + String columns = "(d19 decimal(19, 0), d18 decimal(19, 18), d0 decimal(19, 19))"; + try (TestTable table = testTable("test_decimal_range", columns)) { + assertQueryFails( + format("INSERT INTO %s (d19) VALUES (DECIMAL'9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 0\\)$"); + assertQueryFails( + format("INSERT INTO %s (d18) VALUES (DECIMAL'9.991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 18\\)$"); + assertQueryFails( + format("INSERT INTO %s (d0) VALUES (DECIMAL'.9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 19\\)$"); + } + } + + @Test + public void testRedshiftDecimalScaleLimit() + { + assertCreateFails( + "test_overlarge_decimal_scale", + "(d DECIMAL(38, 38))", + "^ERROR: DECIMAL scale 38 must be between 0 and 37$"); + } + + @Test + public void testUnsupportedTrinoDataTypes() + { + assertCreateFails( + "test_unsupported_type", + "(col json)", + "Unsupported column type: json"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testDate(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0001-01-01'", DATE, "DATE '0001-01-01'") // first day of AD + .addRoundTrip("date", "DATE '1500-01-01'", DATE, "DATE '1500-01-01'") // sometime before julian->gregorian switch + .addRoundTrip("date", "DATE '1600-01-01'", DATE, "DATE '1600-01-01'") // long ago but after julian->gregorian switch + .addRoundTrip("date", "DATE '1952-04-03'", DATE, "DATE '1952-04-03'") // before epoch + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") + .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") // after epoch + .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer in northern hemisphere (possible DST) + .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter in northern hemisphere (possible DST in southern hemisphere) + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") // day of midnight gap in JVM + .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") // day of midnight gap in Vilnius + .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") // day after midnight setback in Vilnius + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '-0100-01-01'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0101-01-01 BC'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTime(ZoneId sessionZone) + { + // Redshift gets bizarre errors if you try to insert after + // specifying precision for a time column. + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + timeTypeTests("time(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "time_from_trino")); + timeTypeTests("time") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("time_from_jdbc")); + } + + private static SqlDataTypeTest timeTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIME '00:00:00.000000'", createTimeType(6), "TIME '00:00:00.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '00:13:42.000000'", createTimeType(6), "TIME '00:13:42.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '01:33:17.000000'", createTimeType(6), "TIME '01:33:17.000000'") + .addRoundTrip(inputType, "TIME '03:17:17.000000'", createTimeType(6), "TIME '03:17:17.000000'") + .addRoundTrip(inputType, "TIME '10:01:17.100000'", createTimeType(6), "TIME '10:01:17.100000'") + .addRoundTrip(inputType, "TIME '13:18:03.000000'", createTimeType(6), "TIME '13:18:03.000000'") + .addRoundTrip(inputType, "TIME '14:18:03.000000'", createTimeType(6), "TIME '14:18:03.000000'") + .addRoundTrip(inputType, "TIME '15:18:03.000000'", createTimeType(6), "TIME '15:18:03.000000'") + .addRoundTrip(inputType, "TIME '16:18:03.123456'", createTimeType(6), "TIME '16:18:03.123456'") + .addRoundTrip(inputType, "TIME '19:01:17.000000'", createTimeType(6), "TIME '19:01:17.000000'") + .addRoundTrip(inputType, "TIME '20:01:17.000000'", createTimeType(6), "TIME '20:01:17.000000'") + .addRoundTrip(inputType, "TIME '21:01:17.000001'", createTimeType(6), "TIME '21:01:17.000001'") + .addRoundTrip(inputType, "TIME '22:59:59.000000'", createTimeType(6), "TIME '22:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.000000'", createTimeType(6), "TIME '23:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.999999'", createTimeType(6), "TIME '23:59:59.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestamp(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + // Redshift doesn't allow timestamp precision to be specified + timestampTypeTests("timestamp(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "timestamp_from_trino")); + timestampTypeTests("timestamp") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("timestamp_from_jdbc")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("timestamp(6)", "TIMESTAMP '-0100-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("timestamp", "TIMESTAMP '0101-01-01 00:00:00 BC'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + private static SqlDataTypeTest timestampTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '0001-01-01 00:00:00.000000'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1500-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1500-01-01 00:00:00.000000'") // sometime before julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1600-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1600-01-01 00:00:00.000000'") // long ago but after julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1958-01-01 13:18:03.123456'", createTimestampType(6), "TIMESTAMP '1958-01-01 13:18:03.123456'") // before epoch + .addRoundTrip(inputType, "TIMESTAMP '2019-03-18 10:09:17.987654'", createTimestampType(6), "TIMESTAMP '2019-03-18 10:09:17.987654'") // after epoch + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789'", createTimestampType(6), "TIMESTAMP '2018-10-28 01:33:17.456789'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333'", createTimestampType(6), "TIMESTAMP '2018-10-28 03:33:33.333333'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.000000'") // time gap in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000'", createTimestampType(6), "TIMESTAMP '2018-03-25 03:17:17.000000'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000'", createTimestampType(6), "TIMESTAMP '1986-01-01 00:13:07.000000'") // time gap in Kathmandu + // Full time precision + .addRoundTrip(inputType, "TIMESTAMP '1969-12-31 23:59:59.999999'", createTimestampType(6), "TIMESTAMP '1969-12-31 23:59:59.999999'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999999'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestampWithTimeZone(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + + SqlDataTypeTest.create() + // test arbitrary time for all supported precisions + .addRoundTrip("timestamp(0) with time zone", "TIMESTAMP '2022-09-27 12:34:56 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.000000 UTC'") + .addRoundTrip("timestamp(1) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(2) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.120000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(4) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1234 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123400 UTC'") + .addRoundTrip("timestamp(5) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12345 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123450 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'") + + // short timestamp with time zone + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '-4712-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999000 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1986-01-01 00:13:07 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456000 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333000 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247000 UTC'") // max value in Trino + .addRoundTrip("timestamp(3) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + + // long timestamp with time zone + // .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu (long timestamp_tz) + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'") // max value in Trino + .addRoundTrip("timestamp(6) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + + redshiftTimestampWithTimeZoneTests("timestamptz") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + redshiftTimestampWithTimeZoneTests("timestamp with time zone") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + } + + private static SqlDataTypeTest redshiftTimestampWithTimeZoneTests(String inputType) + { + return SqlDataTypeTest.create() + // .addRoundTrip(inputType, "TIMESTAMP '4713-01-01 00:00:00 BC' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1582-10-04 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1582-10-05 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-14 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-15 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '73326-09-11 20:14:45.247999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'"); // max value in Trino + } + + @Test + public void testTimestampWithTimeZoneCoercion() + { + SqlDataTypeTest.create() + // short timestamp with time zone + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.12341 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round up, end result rounds down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1235 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.124000 UTC'") // round up + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111000 UTC'") // max precision + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999499999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + + // long timestamp with time zone + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234561 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // round down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // nanoc round up, end result rounds down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234565 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123457 UTC'") // round up + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111222 UTC'") // max precision + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999999499999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .execute(getQueryRunner(), trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + } + + @Test + public void testTimestampWithTimeZoneOverflow() + { + // The min timestamp with time zone value in Trino is smaller than Redshift + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(3) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(6) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752000 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + + // The max timestamp with time zone value in Redshift is larger than Trino + try (TestTable table = new TestTable(getRedshiftExecutor(), TEST_SCHEMA + ".timestamp_tz_max", "(ts timestamptz)", ImmutableList.of("TIMESTAMP '294276-12-31 23:59:59' AT TIME ZONE 'UTC'"))) { + assertThatThrownBy(() -> query("SELECT * FROM " + table.getName())) + .hasMessage("Millis overflow: 9224318015999000"); + } + } + + @DataProvider(name = "datetime_test_parameters") + public Object[][] dataProviderForDatetimeTests() + { + return new Object[][] { + {UTC}, + {jvmZone}, + {vilnius}, + {kathmandu}, + {testZone}, + }; + } + + @Test + public void testUnsupportedDateTimeTypes() + { + assertCreateFails( + "test_time_with_time_zone", + "(value TIME WITH TIME ZONE)", + "Unsupported column type: (?i)time.* with time zone"); + } + + @Test + public void testDateLimits() + { + // We can't test the exact date limits because Redshift doesn't say + // what they are, so we test one date on either side. + try (TestTable table = testTable("test_date_limits", "(d date)")) { + // First day of smallest year that Redshift supports (based on its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '-4712-01-01')", table.getName()), 1); + // Small date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '-4713-06-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"4714-06-01 BC\""); + + // Last day of the largest year that Redshift supports (based on in its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '294275-12-31')", table.getName()), 1); + // Large date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '5875000-01-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"5875000-01-01 AD\""); + } + } + + @Test + public void testLimitedTimePrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIME '\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIME '00:00:00'".length() - (input.contains(".") ? 1 : 0), + List.of( + // No rounding + new TestCase("TIME '00:00:00'", "TIME '00:00:00'"), + new TestCase("TIME '00:00:00.000000'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.123456'", "TIME '00:00:00.123456'"), + new TestCase("TIME '12:34:56'", "TIME '12:34:56'"), + new TestCase("TIME '12:34:56.123456'", "TIME '12:34:56.123456'"), + new TestCase("TIME '23:59:59'", "TIME '23:59:59'"), + new TestCase("TIME '23:59:59.9'", "TIME '23:59:59.9'"), + new TestCase("TIME '23:59:59.999'", "TIME '23:59:59.999'"), + new TestCase("TIME '23:59:59.999999'", "TIME '23:59:59.999999'"), + // round down + new TestCase("TIME '00:00:00.0000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '12:34:56.1234561'", "TIME '12:34:56.123456'"), + // round down, maximal value + new TestCase("TIME '00:00:00.0000004'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000449'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000444449'", "TIME '00:00:00.000000'"), + // round up, minimal value + new TestCase("TIME '00:00:00.0000005'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500000'", "TIME '00:00:00.000001'"), + // round up, maximal value + new TestCase("TIME '00:00:00.0000009'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999999'", "TIME '00:00:00.000001'"), + // round up to next day, minimal value + new TestCase("TIME '23:59:59.9999995'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500000'", "TIME '00:00:00.000000'"), + // round up to next day, maximal value + new TestCase("TIME '23:59:59.9999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999999'", "TIME '00:00:00.000000'"), + // don't round to next day (round down near upper bound) + new TestCase("TIME '23:59:59.9999994'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499999'", "TIME '23:59:59.999999'"))); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_time_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + @Test + public void testLimitedTimestampPrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIMESTAMP '\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIMESTAMP '0000-00-00 00:00:00'".length() - (input.contains(".") ? 1 : 0), + // No rounding + new TestCase("TIMESTAMP '1970-01-01 00:00:00'", "TIMESTAMP '1970-01-01 00:00:00'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56'", "TIMESTAMP '2020-11-03 12:34:56'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + + new TestCase("TIMESTAMP '1970-01-01 00:00:00.123456'", "TIMESTAMP '1970-01-01 00:00:00.123456'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.123456'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59'", "TIMESTAMP '1969-12-31 23:59:59'"), + + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9'", "TIMESTAMP '1970-01-01 23:59:59.9'"), + new TestCase("TIMESTAMP '2020-11-03 23:59:59.999'", "TIMESTAMP '2020-11-03 23:59:59.999'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + // round down + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000001'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000000001'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.1234561'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + // round down, maximal value + new TestCase("TIMESTAMP '2020-11-03 00:00:00.0000004'", "TIMESTAMP '2020-11-03 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000449'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000444449'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // round up, minimal value + new TestCase("TIMESTAMP '1970-01-01 00:00:00.0000005'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000500'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000500000'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + // round up, maximal value + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000009'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000999'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000999999'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + // round up to next year, minimal value + new TestCase("TIMESTAMP '2020-12-31 23:59:59.9999995'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999500'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999500000'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + // round up to next day/year, maximal value + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9999999'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999999'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999999999'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // don't round to next year (round down near upper bound) + new TestCase("TIMESTAMP '1969-12-31 23:59:59.9999994'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999499'", "TIMESTAMP '1970-01-01 23:59:59.999999'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999499999'", "TIMESTAMP '2020-12-31 23:59:59.999999'")); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_timestamp_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, TestCase... testCases) + { + return groupTestCasesByInput(inputRegex, classifier, Arrays.asList(testCases)); + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, List testCases) + { + return testCases.stream() + .peek(test -> { + if (!test.input().matches(inputRegex)) { + throw new RuntimeException("Bad test case input format: " + test.input()); + } + }) + .collect(groupingBy(classifier.compose(TestCase::input))); + } + + private void runTestCases(String tableName, List testCases) + { + // Must use CTAS instead of TestTable because if the table is created before the insert, + // the type mapping will treat it as TIME(6) no matter what it was created as. + getTrinoExecutor().execute(format( + "CREATE TABLE %s AS SELECT * FROM (VALUES %s) AS t (id, value)", + tableName, + testCases.stream() + .map(testCase -> format("(%d, %s)", testCase.id(), testCase.input())) + .collect(joining("), (", "(", ")")))); + try { + assertQuery( + format("SELECT value FROM %s ORDER BY id", tableName), + testCases.stream() + .map(TestCase::expected) + .collect(joining("), (", "VALUES (", ")"))); + } + finally { + getTrinoExecutor().execute("DROP TABLE " + tableName); + } + } + + @Test + public static void checkIllegalRedshiftTimePrecision() + { + assertRedshiftCreateFails( + "check_redshift_time_precision_error", + "(t TIME(6))", + "ERROR: time column does not support precision."); + } + + @Test + public static void checkIllegalRedshiftTimestampPrecision() + { + assertRedshiftCreateFails( + "check_redshift_timestamp_precision_error", + "(t TIMESTAMP(6))", + "ERROR: timestamp column does not support precision."); + } + + /** + * Assert that a {@code CREATE TABLE} statement made from Redshift fails, + * and drop the table if it doesn't fail. + */ + private static void assertRedshiftCreateFails(String tableNamePrefix, String tableBody, String message) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("CREATE TABLE %s %s", tableName, tableBody))) + .getCause() + .as("Redshift create fails for %s %s", tableName, tableBody) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE IF EXISTS " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + /** + * Assert that a {@code CREATE TABLE} statement fails, and drop the table + * if it doesn't fail. + */ + private void assertCreateFails(String tableNamePrefix, String tableBody, String expectedMessageRegExp) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertQueryFails(format("CREATE TABLE %s %s", tableName, tableBody), expectedMessageRegExp); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + private DataSetup trinoCreateAsSelect(String tableNamePrefix) + { + return trinoCreateAsSelect(getQueryRunner().getDefaultSession(), tableNamePrefix); + } + + private DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) + { + return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + private static DataSetup redshiftCreateAndInsert(String tableNamePrefix) + { + return new CreateAndInsertDataSetup(getRedshiftExecutor(), TEST_SCHEMA + "." + tableNamePrefix); + } + + /** + * Create a table in the test schema using the JDBC. + * + *

Creating a test table normally doesn't use the correct schema. + */ + private static TestTable testTable(String namePrefix, String body) + { + return new TestTable(getRedshiftExecutor(), TEST_SCHEMA + "." + namePrefix, body); + } + + private SqlExecutor getTrinoExecutor() + { + return new TrinoSqlExecutor(getQueryRunner()); + } + + private static SqlExecutor getRedshiftExecutor() + { + Properties properties = new Properties(); + properties.setProperty("user", JDBC_USER); + properties.setProperty("password", JDBC_PASSWORD); + return new JdbcSqlExecutor(JDBC_URL, properties); + } + + private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).isEmpty(), + "Expected %s to be a gap in %s", dateTime, zone); + } + + private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).size() == 2, + "Expected %s to be doubled in %s", dateTime, zone); + } + + private static Function padVarchar(int length) + { + // Add the same padding as RedshiftClient.writeCharAsVarchar, but start from String, not Slice + return (input) -> input + " ".repeat(length - Utf8.encodedLength(input)); + } + + /** + * A pair of input and expected output from a test. + * Each instance has a unique ID. + */ + private static class TestCase + { + private static final AtomicInteger LAST_ID = new AtomicInteger(); + + private final int id; + private final String input; + private final String expected; + + private TestCase(String input, String expected) + { + this.id = LAST_ID.incrementAndGet(); + this.input = input; + this.expected = expected; + } + + public int id() + { + return this.id; + } + + public String input() + { + return this.input; + } + + public String expected() + { + return this.expected; + } + } + + private static class TestView + implements AutoCloseable + { + final String name; + + TestView(String namePrefix, String definition) + { + name = requireNonNull(namePrefix) + "_" + randomNameSuffix(); + executeInRedshift(format("CREATE VIEW %s.%s AS %s", TEST_SCHEMA, name, definition)); + } + + @Override + public void close() + { + executeInRedshift(format("DROP VIEW IF EXISTS %s.%s", TEST_SCHEMA, name)); + } + } +} From 1e8887e3082ae969eb8bf81901b2c7b56ad38e3f Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 18:27:03 -0800 Subject: [PATCH 04/24] Implement Redshift DELETE --- .../trino/plugin/redshift/RedshiftClient.java | 26 ++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 35 ++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 05d27ab5a0f0..1983260e43c9 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -27,6 +27,7 @@ import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.SliceWriteFunction; import io.trino.plugin.jdbc.StandardColumnMappings; @@ -68,8 +69,12 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.util.Optional; +import java.util.OptionalLong; import java.util.function.BiFunction; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; @@ -229,6 +234,27 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) return statement; } + @Override + public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) + { + checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle); + checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle); + checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle); + try (Connection connection = connectionFactory.openConnection(session)) { + verify(connection.getAutoCommit()); + PreparedQuery preparedQuery = queryBuilder.prepareDeleteQuery(this, session, connection, handle.getRequiredNamedRelation(), handle.getConstraint(), Optional.empty()); + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery)) { + int affectedRowsCount = preparedStatement.executeUpdate(); + // connection.getAutoCommit() == true is not enough to make DELETE effective and explicit commit is required + connection.commit(); + return OptionalLong.of(affectedRowsCount); + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + @Override protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) throws SQLException diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 0d46c93a6fc9..7e2b12fe85f1 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -54,7 +54,6 @@ protected QueryRunner createQueryRunner() protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { - case SUPPORTS_DELETE: case SUPPORTS_AGGREGATION_PUSHDOWN: case SUPPORTS_JOIN_PUSHDOWN: case SUPPORTS_TOPN_PUSHDOWN: @@ -151,6 +150,33 @@ public Object[][] redshiftTypeToTrinoTypes() {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; } + @Override + public void testDelete() + { + // The base tests is very slow because Redshift CTAS is really slow, so use a smaller test + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_", "AS SELECT * FROM nation")) { + // delete without matching any rows + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey < 0", 0); + + // delete with a predicate that optimizes to false + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey > 5 AND nationkey < 4", 0); + + // delete successive parts of the table + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 5", "SELECT count(*) FROM nation WHERE nationkey <= 5"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 5"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 10", "SELECT count(*) FROM nation WHERE nationkey > 5 AND nationkey <= 10"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 10"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 15", "SELECT count(*) FROM nation WHERE nationkey > 10 AND nationkey <= 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 15"); + + // delete remaining + assertUpdate("DELETE FROM " + table.getName(), "SELECT count(*) FROM nation WHERE nationkey > 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE false"); + } + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -212,6 +238,13 @@ protected SqlExecutor onRemoteDatabase() return RedshiftQueryRunner::executeInRedshift; } + @Override + public void testDeleteWithLike() + { + assertThatThrownBy(super::testDeleteWithLike) + .hasStackTraceContaining("TrinoException: This connector does not support modifying table rows"); + } + @Test @Override public void testAddNotNullColumnToNonEmptyTable() From e65585de567e7383a967bf3a86075c1508a9c8d6 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 18:43:41 -0800 Subject: [PATCH 05/24] Add Redshift statistics --- plugin/trino-redshift/pom.xml | 12 +- .../trino/plugin/redshift/RedshiftClient.java | 36 +- .../plugin/redshift/RedshiftClientModule.java | 3 + .../RedshiftTableStatisticsReader.java | 176 +++++++++ .../redshift/TestRedshiftConnectorTest.java | 73 ++++ .../TestRedshiftTableStatisticsReader.java | 349 ++++++++++++++++++ 6 files changed, 642 insertions(+), 7 deletions(-) create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index ebcd96279cda..2370ecf35c3e 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -49,6 +49,11 @@ javax.inject + + org.jdbi + jdbi3-core + + io.airlift @@ -68,12 +73,6 @@ runtime - - org.jdbi - jdbi3-core - runtime - - io.trino @@ -177,6 +176,7 @@ **/TestRedshiftConnectorTest.java + **/TestRedshiftTableStatisticsReader.java **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 1983260e43c9..2183e01bc4d7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongWriteFunction; @@ -35,7 +36,10 @@ import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; @@ -73,6 +77,7 @@ import java.util.function.BiFunction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; @@ -194,10 +199,21 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + private final boolean statisticsEnabled; + private final RedshiftTableStatisticsReader statisticsReader; + @Inject - public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) + public RedshiftClient( + BaseJdbcConfig config, + ConnectionFactory connectionFactory, + JdbcStatisticsConfig statisticsConfig, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); + this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); } @Override @@ -207,6 +223,24 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return statisticsReader.readTableStatistics(session, handle, () -> this.getColumns(session, handle)); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 13635c88f69b..ef4153ee45ef 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; @@ -32,6 +33,7 @@ import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; public class RedshiftClientModule extends AbstractConfigurationAwareModule @@ -41,6 +43,7 @@ public void setup(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); install(new DecimalModule()); } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..c576abdd109d --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class RedshiftTableStatisticsReader +{ + private final ConnectionFactory connectionFactory; + + public RedshiftTableStatisticsReader(ConnectionFactory connectionFactory) + { + this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory is null"); + } + + public TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, Supplier> columnSupplier) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional optionalRowCount = readRowCountTableStat(statisticsDao, table); + if (optionalRowCount.isEmpty()) { + // Table not found + return TableStatistics.empty(); + } + long rowCount = optionalRowCount.get(); + + TableStatistics.Builder tableStatistics = TableStatistics.builder() + .setRowCount(Estimate.of(rowCount)); + + if (rowCount == 0) { + return tableStatistics.build(); + } + + Map columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); + + for (JdbcColumnHandle column : columnSupplier.get()) { + ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); + if (result == null) { + continue; + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(result.nullsFraction() + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDistinctValuesCount(result.distinctValuesIndicator() + .map(distinctValuesIndicator -> { + // If the distinct value count is an estimate Redshift uses "the negative of the number of distinct values divided by the number of rows + // For example, -1 indicates a unique column in which the number of distinct values is the same as the number of rows." + // https://www.postgresql.org/docs/9.3/view-pg-stats.html + if (distinctValuesIndicator < 0.0) { + return Math.min(-distinctValuesIndicator * rowCount, rowCount); + } + return distinctValuesIndicator; + }) + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDataSize(result.averageColumnLength() + .flatMap(averageColumnLength -> + result.nullsFraction() + .map(nullsFraction -> 1.0 * averageColumnLength * rowCount * (1 - nullsFraction)) + .map(Estimate::of)) + .orElseGet(Estimate::unknown)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Optional readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional rowCount = statisticsDao.getRowCountFromPgClass(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + if (rowCount.isEmpty()) { + // Table not found + return Optional.empty(); + } + + if (rowCount.get() == 0) { + // `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT) + // The `pg_stat_all_tables` view can be way off, so we use it only as a fallback + rowCount = statisticsDao.getRowCountFromPgStat(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + } + + return rowCount; + } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Optional getRowCountFromPgClass(String schema, String tableName) + { + return handle.createQuery("SELECT reltuples FROM pg_class WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountFromPgStat(String schema, String tableName) + { + // Redshift does not have the Postgres `n_live_tup`, so estimate from `inserts - deletes` + return handle.createQuery("SELECT n_tup_ins - n_tup_del FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + List getColumnStatistics(String schema, String tableName) + { + return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .map((rs, ctx) -> + new ColumnStatisticsResult( + requireNonNull(rs.getString("attname"), "attname is null"), + Optional.of(rs.getFloat("null_frac")), + Optional.of(rs.getFloat("n_distinct")), + Optional.of(rs.getInt("avg_width")))) + .list(); + } + } + + // TODO remove when error prone is updated for Java 17 records + @SuppressWarnings("unused") + private record ColumnStatisticsResult(String columnName, Optional nullsFraction, Optional distinctValuesIndicator, Optional averageColumnLength) {} +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 7e2b12fe85f1..863f308b9554 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -29,8 +29,10 @@ import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; +import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -177,6 +179,77 @@ public void testDelete() } } + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + try { + assertUpdate( + "CREATE TABLE " + TEST_SCHEMA + "." + tableName + + " AS SELECT " + + " custkey AS CASE_UNQUOTED_UPPER, " + + " name AS case_unquoted_lower, " + + " address AS cASe_uNQuoTeD_miXED, " + + " nationkey AS \"CASE_QUOTED_UPPER\", " + + " phone AS \"case_quoted_lower\"," + + " acctbal AS \"CasE_QuoTeD_miXED\" " + + "FROM customer", + 1500); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + TEST_SCHEMA + "." + tableName, + "VALUES " + + "('case_unquoted_upper', NULL, 1485, 0, null, null, null)," + + "('case_unquoted_lower', 33000, 1470, 0, null, null, null)," + + "('case_unquoted_mixed', 42000, 1500, 0, null, null, null)," + + "('case_quoted_upper', NULL, 25, 0, null, null, null)," + + "('case_quoted_lower', 28500, 1483, 0, null, null, null)," + + "('case_quoted_mixed', NULL, 1483, 0, null, null, null)," + + "(null, null, null, null, 1500, null, null)"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + private static void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + TEST_SCHEMA + "." + tableName) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery(""" + SELECT reltuples FROM pg_class + WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) + AND relname = :table_name + """) + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + @DataProvider + public Object[][] testCaseColumnNamesDataProvider() + { + return new Object[][] { + {"TEST_STATS_MIXED_UNQUOTED_UPPER_" + randomNameSuffix()}, + {"test_stats_mixed_unquoted_lower_" + randomNameSuffix()}, + {"test_stats_mixed_uNQuoTeD_miXED_" + randomNameSuffix()}, + {"\"TEST_STATS_MIXED_QUOTED_UPPER_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_quoted_lower_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_QuoTeD_miXED_" + randomNameSuffix() + "\""} + }; + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..ff713337ea53 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java @@ -0,0 +1,349 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.amazon.redshift.Driver; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConfig; +import io.trino.plugin.jdbc.DriverConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.plugin.jdbc.credential.StaticCredentialProvider; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.type.VarcharType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.SoftAssertions; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.sql.Types; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static java.util.Collections.emptyMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.from; +import static org.assertj.core.api.Assertions.withinPercentage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestRedshiftTableStatisticsReader + extends AbstractTestQueryFramework +{ + private static final JdbcTypeHandle BIGINT_TYPE_HANDLE = new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + private static final JdbcTypeHandle DOUBLE_TYPE_HANDLE = new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + private static final List CUSTOMER_COLUMNS = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("name", 25), + createVarcharJdbcColumnHandle("address", 48), + new JdbcColumnHandle("nationkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("phone", 15), + new JdbcColumnHandle("acctbal", DOUBLE_TYPE_HANDLE, DOUBLE), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + private RedshiftTableStatisticsReader statsReader; + + @BeforeClass + public void setup() + { + DriverConnectionFactory connectionFactory = new DriverConnectionFactory( + new Driver(), + new BaseJdbcConfig().setConnectionUrl(JDBC_URL), + new StaticCredentialProvider(Optional.of(JDBC_USER), Optional.of(JDBC_PASSWORD))); + statsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), ImmutableList.of(CUSTOMER)); + } + + @Test + public void testCustomerTable() + throws Exception + { + assertThat(collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer", CUSTOMER_COLUMNS)) + .returns(Estimate.of(1500), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(0), statsCloseTo(1500.0, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(1), statsCloseTo(1500.0, 0.0, 33000.0)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(3), statsCloseTo(25.000, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(5), statsCloseTo(1499.0, 0.0, 8.0 * 1500)); + } + + @Test + public void testEmptyTable() + throws Exception + { + TableStatistics tableStatistics = collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer WHERE false", CUSTOMER_COLUMNS); + assertThat(tableStatistics) + .returns(Estimate.of(0.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + + @Test + public void testAllNulls() + throws Exception + { + String tableName = "testallnulls_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " (i BIGINT)"); + executeInRedshift("INSERT INTO " + schemaAndTable + " (i) VALUES (NULL)"); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + + TableStatistics stats = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> ImmutableList.of(new JdbcColumnHandle("i", BIGINT_TYPE_HANDLE, BIGINT))); + assertThat(stats) + .returns(Estimate.of(1.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testNullsFraction() + throws Exception + { + JdbcColumnHandle custkeyColumnHandle = CUSTOMER_COLUMNS.get(0); + TableStatistics stats = collectStats( + "SELECT CASE custkey % 3 WHEN 0 THEN NULL ELSE custkey END FROM " + TEST_SCHEMA + ".customer", + ImmutableList.of(custkeyColumnHandle)); + assertEquals(stats.getRowCount(), Estimate.of(1500)); + + ColumnStatistics columnStatistics = stats.getColumnStatistics().get(custkeyColumnHandle); + assertThat(columnStatistics.getNullsFraction().getValue()).isCloseTo(1.0 / 3, withinPercentage(1)); + } + + @Test + public void testAverageColumnLength() + throws Exception + { + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("v3_in_3", 3), + createVarcharJdbcColumnHandle("v3_in_42", 42), + createVarcharJdbcColumnHandle("single_10v_value", 10), + createVarcharJdbcColumnHandle("half_10v_value", 10), + createVarcharJdbcColumnHandle("half_distinct_20v_value", 20), + createVarcharJdbcColumnHandle("all_nulls", 10)); + + assertThat( + collectStats( + "SELECT " + + " custkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " CASE custkey WHEN 1 THEN '0123456789' ELSE NULL END single_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN '0123456789' ELSE NULL END half_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN CAST((1000000 - custkey) * (1000000 - custkey) AS varchar(20)) ELSE NULL END half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM " + TEST_SCHEMA + ".customer " + + "ORDER BY custkey LIMIT 100", + columns)) + .returns(Estimate.of(100), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(columns.get(0), statsCloseTo(100.0, 0.0, 800)) + .hasEntrySatisfying(columns.get(1), statsCloseTo(1.0, 0.0, 700.0)) + .hasEntrySatisfying(columns.get(2), statsCloseTo(1.0, 0.0, 700)) + .hasEntrySatisfying(columns.get(3), statsCloseTo(1.0, 0.99, 14)) + .hasEntrySatisfying(columns.get(4), statsCloseTo(1.0, 0.5, 700)) + .hasEntrySatisfying(columns.get(5), statsCloseTo(51, 0.5, 800)) + .satisfies(stats -> assertNull(stats.get(columns.get(6)))); + } + + @Test + public void testView() + throws Exception + { + String tableName = "test_stats_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE OR REPLACE VIEW " + schemaAndTable + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP VIEW IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testMaterializedView() + throws Exception + { + String tableName = "test_stats_materialized_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE MATERIALIZED VIEW " + schemaAndTable + + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + executeInRedshift("REFRESH MATERIALIZED VIEW " + schemaAndTable); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP MATERIALIZED VIEW " + schemaAndTable); + } + } + + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) + .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) + .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) + .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) + .put("nans_only double", List.of("nan()", "nan()")) + .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + executeInRedshift("ANALYZE VERBOSE " + TEST_SCHEMA + "." + table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + "('only_negative_infinity', null, 1, 0, null, null, null)," + + "('only_positive_infinity', null, 1, 0, null, null, null)," + + "('mixed_infinities', null, 2, 0, null, null, null)," + + "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + + "('nans_only', null, 1.0, 0.5, null, null, null)," + + "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "(null, null, null, null, 4, null, null)"); + } + } + + /** + * Assert that the given column is within 5% of each statistic in the parameters, and that it has no range + */ + private static Consumer statsCloseTo(double distinctValues, double nullsFraction, double dataSize) + { + return stats -> { + SoftAssertions softly = new SoftAssertions(); + + softly.assertThat(stats.getDistinctValuesCount().getValue()) + .isCloseTo(distinctValues, withinPercentage(5.0)); + + softly.assertThat(stats.getNullsFraction().getValue()) + .isCloseTo(nullsFraction, withinPercentage(5.0)); + + softly.assertThat(stats.getDataSize().getValue()) + .isCloseTo(dataSize, withinPercentage(5.0)); + + softly.assertThat(stats.getRange()).isEmpty(); + softly.assertAll(); + }; + } + + private TableStatistics collectStats(String values, List columnHandles) + throws Exception + { + String tableName = "testredshiftstatisticsreader_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " AS " + values); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + return statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columnHandles); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + private static JdbcColumnHandle createVarcharJdbcColumnHandle(String name, int length) + { + return new JdbcColumnHandle( + name, + new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(length), Optional.empty(), Optional.empty(), Optional.empty()), + VarcharType.createVarcharType(length)); + } +} From c225b8f46851ac4ea5191783c320d8e30757ad67 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Fri, 9 Dec 2022 16:04:47 -0800 Subject: [PATCH 06/24] Add Redshift pushdown --- plugin/trino-redshift/pom.xml | 11 + .../redshift/ImplementRedshiftAvgBigint.java | 26 ++ .../redshift/ImplementRedshiftAvgDecimal.java | 75 +++++ .../trino/plugin/redshift/RedshiftClient.java | 145 +++++++++ .../plugin/redshift/RedshiftClientModule.java | 2 + .../TestRedshiftAutomaticJoinPushdown.java | 72 +++++ .../redshift/TestRedshiftConnectorTest.java | 299 +++++++++++++++++- 7 files changed, 624 insertions(+), 6 deletions(-) create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index 2370ecf35c3e..367ae423d4a4 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,6 +23,16 @@ trino-base-jdbc + + io.trino + trino-matching + + + + io.trino + trino-plugin-toolkit + + io.airlift configuration @@ -175,6 +185,7 @@ maven-surefire-plugin + **/TestRedshiftAutomaticJoinPushdown.java **/TestRedshiftConnectorTest.java **/TestRedshiftTableStatisticsReader.java **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java new file mode 100644 index 000000000000..f9c546105546 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementRedshiftAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java new file mode 100644 index 000000000000..103258db12b7 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DecimalType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_DECIMAL_PRECISION; +import static java.lang.String.format; + +public class ImplementRedshiftAvgDecimal + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleArgument().matching( + variable() + .with(type().matching(DecimalType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + DecimalType type = (DecimalType) columnHandle.getColumnType(); + verify(aggregateFunction.getOutputType().equals(type)); + + // When decimal type has maximum precision we can get result that is not matching Presto avg semantics. + if (type.getPrecision() == REDSHIFT_MAX_DECIMAL_PRECISION) { + return Optional.of(new JdbcExpression( + format("avg(CAST(%s AS decimal(%s, %s)))", context.rewriteExpression(input).orElseThrow(), type.getPrecision(), type.getScale()), + columnHandle.getJdbcTypeHandle())); + } + + // Redshift avg function rounds down resulting decimal. + // To match Presto avg semantics, we extend scale by 1 and round result to target scale. + return Optional.of(new JdbcExpression( + format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.rewriteExpression(input).orElseThrow(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 2183e01bc4d7..21ee2eabcf02 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -16,12 +16,20 @@ import com.amazon.redshift.jdbc.RedshiftPreparedStatement; import com.amazon.redshift.util.RedshiftObject; import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; +import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; @@ -33,11 +41,26 @@ import io.trino.plugin.jdbc.SliceWriteFunction; import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.aggregation.ImplementCount; +import io.trino.plugin.jdbc.aggregation.ImplementCountAll; +import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; +import io.trino.plugin.jdbc.aggregation.ImplementMinMax; +import io.trino.plugin.jdbc.aggregation.ImplementStddevPop; +import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp; +import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; +import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; +import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; @@ -72,6 +95,8 @@ import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalLong; import java.util.function.BiFunction; @@ -81,6 +106,7 @@ import static com.google.common.base.Verify.verify; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; @@ -150,6 +176,7 @@ import static java.math.RoundingMode.UNNECESSARY; import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class RedshiftClient extends BaseJdbcClient @@ -199,6 +226,7 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final boolean statisticsEnabled; private final RedshiftTableStatisticsReader statisticsReader; @@ -212,10 +240,64 @@ public RedshiftClient( RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + .addStandardRules(this::quoted) + .build(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) + .add(new ImplementMinMax(true)) + .add(new ImplementSum(RedshiftClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementRedshiftAvgDecimal()) + .add(new ImplementRedshiftAvgBigint()) + .add(new ImplementStddevSamp()) + .add(new ImplementStddevPop()) + .add(new ImplementVarianceSamp()) + .add(new ImplementVariancePop()) + .build()); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); } + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of( + new JdbcTypeHandle( + Types.NUMERIC, + Optional.of("decimal"), + Optional.of(decimalType.getPrecision()), + Optional.of(decimalType.getScale()), + Optional.empty(), + Optional.empty())); + } + + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) + throws SQLException + { + Connection connection = super.getConnection(session, split, tableHandle); + try { + // super.getConnection sets read-only, since the connection is going to be used only for reads. + // However, for a complex query, Redshift may decide to create some temporary tables behind + // the scenes, and this requires the connection not to be read-only, otherwise Redshift + // may fail with "ERROR: transaction is read-only". + connection.setReadOnly(false); + } + catch (SQLException e) { + connection.close(); + throw e; + } + return connection; + } + @Override public Optional getTableComment(ResultSet resultSet) { @@ -223,6 +305,12 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + } + @Override public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) { @@ -241,6 +329,63 @@ public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHan } } + @Override + public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) + { + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .map(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String nullsHandling = sortItem.getSortOrder().isNullsFirst() ? "NULLS FIRST" : "NULLS LAST"; + return format("%s %s %s", quoted(sortItem.getColumn().getColumnName()), ordering, nullsHandling); + }) + .collect(joining(", ")); + + return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); + }); + } + + @Override + public boolean isTopNGuaranteed(ConnectorSession session) + { + return true; + } + + @Override + protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) + { + return joinCondition.getOperator() != JoinCondition.Operator.IS_DISTINCT_FROM; + } + + @Override + public Optional implementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index ef4153ee45ef..aeffaac16ff7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; @@ -46,6 +47,7 @@ public void setup(Binder binder) configBinder(binder).bindConfig(JdbcStatisticsConfig.class); install(new DecimalModule()); + install(new JdbcJoinPushdownSupportModule()); } @Singleton diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java new file mode 100644 index 000000000000..3509f8dd8b9c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; +import org.testng.SkipException; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class TestRedshiftAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableList.of()); + } + + @Override + public void testJoinPushdownWithEmptyStatsInitially() + { + throw new SkipException("Redshift table statistics are automatically populated"); + } + + @Override + protected void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery(format("SELECT count(*) FROM %s.%s", TEST_SCHEMA, tableName)) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery( + "SELECT reltuples FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name") + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + } + throw new IllegalStateException("Stats not gathered"); + }); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 863f308b9554..1b16e335bd82 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -13,7 +13,9 @@ */ package io.trino.plugin.redshift; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -24,12 +26,18 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.List; import java.util.Optional; import java.util.OptionalInt; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeWithRedshift; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -56,12 +64,6 @@ protected QueryRunner createQueryRunner() protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { - case SUPPORTS_AGGREGATION_PUSHDOWN: - case SUPPORTS_JOIN_PUSHDOWN: - case SUPPORTS_TOPN_PUSHDOWN: - case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: - return false; - case SUPPORTS_COMMENT_ON_TABLE: case SUPPORTS_ADD_COLUMN_WITH_COMMENT: case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: @@ -75,6 +77,18 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: return false; + case SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV: + case SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE: + case SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT: + return true; + + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY: + return true; + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + return false; + default: return super.hasBehavior(connectorBehavior); } @@ -211,6 +225,35 @@ public void testCaseColumnNames(String tableName) } } + /** + * Tries to create situation where Redshift would decide to materialize a temporary table for query sent to it by us. + * Such temporary table requires that our Connection is not read-only. + */ + @Test + public void testComplexPushdownThatMayElicitTemporaryTable() + { + int subqueries = 10; + String subquery = "SELECT custkey, count(*) c FROM orders GROUP BY custkey"; + StringBuilder sql = new StringBuilder(); + sql.append(format( + "SELECT t0.custkey, %s c_sum ", + IntStream.range(0, subqueries) + .mapToObj(i -> format("t%s.c", i)) + .collect(Collectors.joining("+")))); + sql.append(format("FROM (%s) t0 ", subquery)); + for (int i = 1; i < subqueries; i++) { + sql.append(format("JOIN (%s) t%s ON t0.custkey = t%s.custkey ", subquery, i, i)); + } + sql.append("WHERE t0.custkey = 1045 OR rand() = 42"); + + Session forceJoinPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + + assertThat(query(forceJoinPushdown, sql.toString())) + .matches(format("SELECT max(custkey), count(*) * %s FROM tpch.tiny.orders WHERE custkey = 1045", subqueries)); + } + private static void gatherStats(String tableName) { executeInRedshift(handle -> { @@ -250,6 +293,241 @@ public Object[][] testCaseColumnNamesDataProvider() }; } + @Override + public void testCountDistinctWithStringTypes() + { + // cannot test using generic method as Redshift does not allow non-ASCII characters in CHAR values. + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasMessageContaining("Value for Redshift CHAR must be ASCII, but found 'ą'"); + + List rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ") + .map(value -> format("'%1$s', '%1$s'", value)) + .collect(toImmutableList()); + String tableName = "distinct_strings" + randomNameSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) { + // Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + // Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES (BIGINT '6', BIGINT '6')") + .isFullyPushedDown(); + } + } + + @Override + public void testAggregationPushdown() + { + throw new SkipException("tested in testAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testAggregationPushdown(String distStyle) + { + String nation = format("%s.nation_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + String customer = format("%s.customer_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + try { + copyWithDistStyle(TEST_SCHEMA + ".nation", nation, distStyle, Optional.of("regionkey")); + copyWithDistStyle(TEST_SCHEMA + ".customer", customer, distStyle, Optional.of("nationkey")); + + // TODO support aggregation pushdown with GROUPING SETS + // TODO support aggregation over expressions + + // count() + assertThat(query("SELECT count(*) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT regionkey, count(1) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT count(*) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(a_bigint) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT a_bigint, count(1) FROM " + emptyTableName + " GROUP BY a_bigint")).isFullyPushedDown(); + } + + // GROUP BY + assertThat(query("SELECT regionkey, min(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT t_double, min(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, max(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, sum(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, avg(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + } + + // GROUP BY and WHERE on bigint column + // GROUP BY and WHERE on aggregation key + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown(); + + // GROUP BY and WHERE on varchar column + // GROUP BY and WHERE on "other" (not aggregation key, not aggregation input) + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above WHERE and LIMIT + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT * FROM " + nation + " WHERE regionkey < 2 LIMIT 11) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above TopN + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT regionkey, nationkey FROM " + nation + " ORDER BY nationkey ASC LIMIT 10) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY with JOIN + assertThat(query( + joinPushdownEnabled(getSession()), + "SELECT n.regionkey, sum(c.acctbal) acctbals FROM " + nation + " n LEFT JOIN " + customer + " c USING (nationkey) GROUP BY 1")) + .isFullyPushedDown(); + // GROUP BY with WHERE on neither grouping nor aggregation column + assertThat(query("SELECT nationkey, min(regionkey) FROM " + nation + " WHERE name = 'ARGENTINA' GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column + assertThat(query("SELECT count(name) FROM " + nation)).isFullyPushedDown(); + // aggregation on varchar column with GROUPING + assertThat(query("SELECT nationkey, count(name) FROM " + nation + " GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column with WHERE + assertThat(query("SELECT count(name) FROM " + nation + " WHERE name = 'ARGENTINA'")).isFullyPushedDown(); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + nation); + executeInRedshift("DROP TABLE IF EXISTS " + customer); + } + } + + @Override + public void testNumericAggregationPushdown() + { + throw new SkipException("tested in testNumericAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testNumericAggregationPushdown(String distStyle) + { + String schemaName = getSession().getSchema().orElseThrow(); + // empty table + try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + } + + try (TestTable testTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", + ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) { + String testTableName = testTable.getName() + "_" + distStyle; + copyWithDistStyle(testTable.getName(), testTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTableName)).isFullyPushedDown(); + + // smoke testing of more complex cases + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110")).isFullyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + } + } + + private static void copyWithDistStyle(String sourceTableName, String destTableName, String distStyle, Optional distKey) + { + if (distStyle.equals("AUTO")) { + // NOTE: Redshift doesn't support setting diststyle AUTO in CTAS statements + executeInRedshift("CREATE TABLE " + destTableName + " AS SELECT * FROM " + sourceTableName); + // Redshift doesn't allow ALTER DISTSTYLE if original and new style are same, so we need to check current diststyle of table + boolean isDistStyleAuto = executeWithRedshift(handle -> { + Optional currentDistStyle = handle.createQuery("" + + "SELECT releffectivediststyle " + + "FROM pg_class_info AS a LEFT JOIN pg_namespace AS b ON a.relnamespace = b.oid " + + "WHERE lower(nspname) = lower(:schema_name) AND lower(relname) = lower(:table_name)") + .bind("schema_name", TEST_SCHEMA) + // destTableName = TEST_SCHEMA + "." + tableName + .bind("table_name", destTableName.substring(destTableName.indexOf(".") + 1)) + .mapTo(Long.class) + .findOne(); + + // 10 means AUTO(ALL) and 11 means AUTO(EVEN). See https://docs.aws.amazon.com/redshift/latest/dg/r_PG_CLASS_INFO.html. + return currentDistStyle.isPresent() && (currentDistStyle.get() == 10 || currentDistStyle.get() == 11); + }); + if (!isDistStyleAuto) { + executeInRedshift("ALTER TABLE " + destTableName + " ALTER DISTSTYLE " + distStyle); + } + } + else { + String copyWithDistStyleSql = "CREATE TABLE " + destTableName + " DISTSTYLE " + distStyle; + if (distStyle.equals("KEY")) { + copyWithDistStyleSql += format(" DISTKEY(%s)", distKey.orElseThrow()); + } + copyWithDistStyleSql += " AS SELECT * FROM " + sourceTableName; + executeInRedshift(copyWithDistStyleSql); + } + } + + @DataProvider + public Object[][] testAggregationPushdownDistStylesDataProvider() + { + return new Object[][] { + {"EVEN"}, + {"KEY"}, + {"ALL"}, + {"AUTO"}, + }; + } + + @Test + public void testDecimalAvgPushdownForMaximumDecimalScale() + { + List rows = ImmutableList.of( + "12345789.9876543210", + format("%s.%s", "1".repeat(28), "9".repeat(10))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(38, 10))", rows)) { + // Redshift avg rounds down decimal result which doesn't match Presto semantics + assertThatThrownBy(() -> assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown()) + .isInstanceOf(AssertionError.class) + .hasMessageContaining(""" + elements not found: + <(555555555555555555561728450.9938271605)> + and elements not expected: + <(555555555555555555561728450.9938271604)> + """); + } + } + + @Test + public void testDecimalAvgPushdownFoShortDecimalScale() + { + List rows = ImmutableList.of( + "0.987654321234567890", + format("0.%s", "1".repeat(18))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(18, 18))", rows)) { + assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -263,6 +541,15 @@ public void testInsertRowConcurrently() throw new SkipException("Test fails with a timeout sometimes and is flaky"); } + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) { From be62233a6035925ff77f1a4b767837b5f8956020 Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 7 Dec 2022 12:56:35 +0100 Subject: [PATCH 07/24] Test SET PATH support by clients --- .../io/trino/client/StatementClientFactory.java | 10 +++++++++- .../java/io/trino/client/StatementClientV1.java | 8 ++++++-- .../src/main/java/io/trino/Session.java | 1 + .../java/io/trino/testing/TestingSession.java | 6 ++++++ .../testing/AbstractTestingTrinoClient.java | 2 +- .../src/test/java/io/trino/tests/TestServer.java | 16 ++++++++++++++++ 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java index 4aa2eda71495..6e5004994c95 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java @@ -15,12 +15,20 @@ import okhttp3.OkHttpClient; +import java.util.Optional; +import java.util.Set; + public final class StatementClientFactory { private StatementClientFactory() {} public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query) { - return new StatementClientV1(httpClient, session, query); + return new StatementClientV1(httpClient, session, query, Optional.empty()); + } + + public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) + { + return new StatementClientV1(httpClient, session, query, clientCapabilities); } } diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index c2acbaa3c7ff..c231679c32fc 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -48,6 +48,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.net.HttpHeaders.ACCEPT_ENCODING; import static com.google.common.net.HttpHeaders.USER_AGENT; import static io.trino.client.JsonCodec.jsonCodec; @@ -56,6 +57,7 @@ import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -93,7 +95,7 @@ class StatementClientV1 private final AtomicReference state = new AtomicReference<>(State.RUNNING); - public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query) + public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) { requireNonNull(httpClient, "httpClient is null"); requireNonNull(session, "session is null"); @@ -107,7 +109,9 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String .filter(Optional::isPresent) .map(Optional::get) .findFirst(); - this.clientCapabilities = Joiner.on(",").join(ClientCapabilities.values()); + this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values()) + .map(Enum::name) + .collect(toImmutableSet()))); this.compressionDisabled = session.isCompressionDisabled(); Request request = buildQueryRequest(session, query); diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 2b48d5d224ea..b30ec04bdf93 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -607,6 +607,7 @@ private SessionBuilder(Session session) this.remoteUserAddress = session.remoteUserAddress.orElse(null); this.userAgent = session.userAgent.orElse(null); this.clientInfo = session.clientInfo.orElse(null); + this.clientCapabilities = ImmutableSet.copyOf(session.clientCapabilities); this.clientTags = ImmutableSet.copyOf(session.clientTags); this.start = session.start; this.systemProperties.putAll(session.systemProperties); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java index b02d27bd5907..9d4f2637a140 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java @@ -15,11 +15,15 @@ import io.trino.Session; import io.trino.Session.SessionBuilder; +import io.trino.client.ClientCapabilities; import io.trino.execution.QueryIdGenerator; import io.trino.metadata.SessionPropertyManager; import io.trino.spi.security.Identity; import io.trino.spi.type.TimeZoneKey; +import java.util.Arrays; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Locale.ENGLISH; public final class TestingSession @@ -54,6 +58,8 @@ public static SessionBuilder testSessionBuilder(SessionPropertyManager sessionPr .setSchema("schema") .setTimeZoneKey(DEFAULT_TIME_ZONE_KEY) .setLocale(ENGLISH) + .setClientCapabilities(Arrays.stream(ClientCapabilities.values()).map(Enum::name) + .collect(toImmutableSet())) .setRemoteUserAddress("address") .setUserAgent("agent"); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java index 0224103e7530..2679256700cf 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java @@ -92,7 +92,7 @@ public ResultWithQueryId execute(Session session, @Language("SQL") String sql ClientSession clientSession = toClientSession(session, trinoServer.getBaseUrl(), new Duration(2, TimeUnit.MINUTES)); - try (StatementClient client = newStatementClient(httpClient, clientSession, sql)) { + try (StatementClient client = newStatementClient(httpClient, clientSession, sql, Optional.of(session.getClientCapabilities()))) { while (client.isRunning()) { resultsSession.addResults(client.currentStatusInfo(), client.currentData()); client.advance(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index 561c18e474bb..99b162657aae 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -42,6 +42,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Collectors; @@ -66,6 +67,7 @@ import static io.trino.SystemSessionProperties.HASH_PARTITION_COUNT; import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; +import static io.trino.client.ClientCapabilities.PATH; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.spi.StandardErrorCode.INCOMPATIBLE_CLIENT; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -77,6 +79,7 @@ import static javax.ws.rs.core.Response.Status.OK; import static javax.ws.rs.core.Response.Status.SEE_OTHER; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -275,6 +278,19 @@ public void testVersionOnCompilerFailedError() } } + @Test + public void testSetPathSupportByClient() + { + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of()).build())) { + assertThatThrownBy(() -> testingClient.execute("SET PATH foo")) + .hasMessage("SET PATH not supported by client"); + } + + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of(PATH.name())).build())) { + testingClient.execute("SET PATH foo"); + } + } + private void checkVersionOnError(String query, @Language("RegExp") String proofOfOrigin) { QueryResults queryResults = postQuery(request -> request From dfe33c790e47d39bee3f03357d61b2fcc5f7e578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hubert=20=C5=81ojek?= Date: Wed, 7 Dec 2022 18:32:33 +0100 Subject: [PATCH 08/24] Fix HTTP_Status on OAuth2 refresh token redirect When refresh token is retrieved for UI, currently we were sending HTTP Status 303, assuming that all the request will just repeat the call on the Location header. When this works for GET/PUT verbs, it does not for non-idempotent ones like POST, as every js http client should do a GET on LOCATION after 303 on POST. Due to that I change it to 307, that should force every client to repeat exactly the same request, no matter the verb. Co-authored-by: s2lomon --- .../ui/OAuth2WebUiAuthenticationFilter.java | 2 +- .../java/io/trino/server/ui/TestWebUi.java | 203 +++++++++++++++--- 2 files changed, 171 insertions(+), 34 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java index 253b456be457..24b66c563f24 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java @@ -164,7 +164,7 @@ private void redirectForNewToken(ContainerRequestContext request, String refresh { OAuth2Client.Response response = client.refreshTokens(refreshToken); String serializedToken = tokenPairSerializer.serialize(TokenPair.fromOAuth2Response(response)); - request.abortWith(Response.seeOther(request.getUriInfo().getRequestUri()) + request.abortWith(Response.temporaryRedirect(request.getUriInfo().getRequestUri()) .cookie(OAuthWebUiCookie.create(serializedToken, tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(response.getExpiration()))) .build()); } diff --git a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java index cf56a8fbeaa6..d164a6c64245 100644 --- a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java +++ b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java @@ -35,6 +35,8 @@ import io.trino.server.security.ResourceSecurity; import io.trino.server.security.oauth2.ChallengeFailedException; import io.trino.server.security.oauth2.OAuth2Client; +import io.trino.server.security.oauth2.TokenPairSerializer; +import io.trino.server.security.oauth2.TokenPairSerializer.TokenPair; import io.trino.server.testing.TestingTrinoServer; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.BasicPrincipal; @@ -104,12 +106,14 @@ import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGIN; import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGOUT; import static io.trino.testing.assertions.Assert.assertEquals; +import static io.trino.testing.assertions.Assert.assertEventually; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.Objects.requireNonNull; import static java.util.function.Predicate.not; import static javax.servlet.http.HttpServletResponse.SC_NOT_FOUND; import static javax.servlet.http.HttpServletResponse.SC_OK; import static javax.servlet.http.HttpServletResponse.SC_SEE_OTHER; +import static javax.servlet.http.HttpServletResponse.SC_TEMPORARY_REDIRECT; import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static javax.ws.rs.core.Response.Status.UNAUTHORIZED; import static org.assertj.core.api.Assertions.assertThat; @@ -148,6 +152,8 @@ public class TestWebUi private static final String TEST_PASSWORD2 = "test-password2"; private static final String HMAC_KEY = Resources.getResource("hmac_key.txt").getPath(); private static final PrivateKey JWK_PRIVATE_KEY; + private static final String REFRESH_TOKEN = "REFRESH_TOKEN"; + private static final Duration REFRESH_TOKEN_TIMEOUT = Duration.ofMinutes(1); static { try { @@ -652,8 +658,7 @@ public void testOAuth2Authenticator() .setBinding() .toInstance(oauthClient)) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); } finally { jwkServer.stop(); @@ -664,7 +669,7 @@ public void testOAuth2Authenticator() public void testOAuth2AuthenticatorWithoutOpenIdScope() throws Exception { - OAuth2ClientStub oauthClient = new OAuth2ClientStub(false); + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); try (TestingTrinoServer server = TestingTrinoServer.builder() @@ -677,8 +682,116 @@ public void testOAuth2AuthenticatorWithoutOpenIdScope() .setBinding() .toInstance(oauthClient)) .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorWithRefreshToken() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), true); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRedirectAfterAuthTokenRefresh() + throws Exception + { + // the first issued authorization token will be expired + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ZERO); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertCookieWithRefreshToken(server, cookie, oauthClient.getAccessToken()); + + assertResponseCode(client, getValidApiLocation(baseUri), SC_TEMPORARY_REDIRECT); + assertOk(client, getValidApiLocation(baseUri)); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRefreshTokenExpiration() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", "10s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertOk(client, getValidApiLocation(baseUri)); + + // wait for the cookie to expire + assertEventually(() -> assertThat(cookieManager.getCookieStore().getCookies()).isEmpty()); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); + + // create fake cookie with previous cookie value to check token validity + HttpCookie biscuit = new HttpCookie(cookie.getName(), cookie.getValue()); + biscuit.setPath(cookie.getPath()); + cookieManager.getCookieStore().add(baseUri, biscuit); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); } finally { jwkServer.stop(); @@ -694,6 +807,7 @@ public void testCustomPrincipalField() .put(SUBJECT, "unknown") .put("preferred_username", "test-user@email.com") .buildOrThrow(), + Duration.ofSeconds(5), true); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); @@ -711,8 +825,7 @@ public void testCustomPrincipalField() jaxrsBinder(binder).bind(AuthenticatedIdentityCapturingFilter.class); }) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); Identity identity = server.getInstance(Key.get(AuthenticatedIdentityCapturingFilter.class)).getAuthenticatedIdentity(); assertThat(identity.getUser()).isEqualTo("test-user"); assertThat(identity.getPrincipal()).isEqualTo(Optional.of(new BasicPrincipal("test-user@email.com"))); @@ -722,20 +835,15 @@ public void testCustomPrincipalField() } } - private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String accessToken) + private void assertAuth2Authentication(TestingTrinoServer server, String accessToken, boolean refreshTokensEnabled) throws Exception { - String state = newJwtBuilder() - .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) - .setAudience("trino_oauth_ui") - .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) - .compact(); - CookieManager cookieManager = new CookieManager(); OkHttpClient client = this.client.newBuilder() .cookieJar(new JavaNetCookieJar(cookieManager)) .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); // HTTP is not allowed for OAuth testDisabled(httpServerInfo.getHttpUri()); @@ -747,21 +855,17 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getLocation(baseUri, "/ui/unknown"), "http://example.com/authorize", false); assertResponseCode(client, getLocation(baseUri, "/ui/api/unknown"), UNAUTHORIZED.getStatusCode()); - // login with the callback endpoint - assertRedirect( - client, - uriBuilderFrom(baseUri) - .replacePath(CALLBACK_ENDPOINT) - .addParameter("code", "TEST_CODE") - .addParameter("state", state) - .toString(), - getUiLocation(baseUri), - false); + loginWithCallbackEndpoint(client, baseUri); HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); - assertEquals(cookie.getValue(), accessToken); + if (refreshTokensEnabled) { + assertCookieWithRefreshToken(server, cookie, accessToken); + } + else { + assertEquals(cookie.getValue(), accessToken); + assertThat(cookie.getMaxAge()).isGreaterThan(0).isLessThan(30); + } assertEquals(cookie.getPath(), "/ui/"); assertEquals(cookie.getDomain(), baseUri.getHost()); - assertTrue(cookie.getMaxAge() > 0 && cookie.getMaxAge() < MINUTES.toSeconds(5)); assertTrue(cookie.isHttpOnly()); // authentication cookie is now set, so UI should work @@ -778,6 +882,34 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getUiLocation(baseUri), "http://example.com/authorize", false); } + private static void loginWithCallbackEndpoint(OkHttpClient client, URI baseUri) + throws IOException + { + String state = newJwtBuilder() + .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) + .setAudience("trino_oauth_ui") + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) + .compact(); + assertRedirect( + client, + uriBuilderFrom(baseUri) + .replacePath(CALLBACK_ENDPOINT) + .addParameter("code", "TEST_CODE") + .addParameter("state", state) + .toString(), + getUiLocation(baseUri), + false); + } + + private static void assertCookieWithRefreshToken(TestingTrinoServer server, HttpCookie authCookie, String accessToken) + { + TokenPairSerializer tokenPairSerializer = server.getInstance(Key.get(TokenPairSerializer.class)); + TokenPair deserialize = tokenPairSerializer.deserialize(authCookie.getValue()); + assertEquals(deserialize.getAccessToken(), accessToken); + assertEquals(deserialize.getRefreshToken(), Optional.of(REFRESH_TOKEN)); + assertThat(authCookie.getMaxAge()).isGreaterThan(0).isLessThan(REFRESH_TOKEN_TIMEOUT.getSeconds()); + } + private static void testAlwaysAuthorized(URI baseUri, OkHttpClient authorizedClient, String nodeId) throws IOException { @@ -1078,23 +1210,25 @@ private static class OAuth2ClientStub private static final SecureRandom secureRandom = new SecureRandom(); private final Claims claims; private final String accessToken; + private final Duration accessTokenValidity; private final Optional nonce; private final Optional idToken; public OAuth2ClientStub() { - this(true); + this(true, Duration.ofSeconds(5)); } - public OAuth2ClientStub(boolean issueIdToken) + public OAuth2ClientStub(boolean issueIdToken, Duration accessTokenValidity) { - this(ImmutableMap.of(), issueIdToken); + this(ImmutableMap.of(), accessTokenValidity, issueIdToken); } - public OAuth2ClientStub(Map additionalClaims, boolean issueIdToken) + public OAuth2ClientStub(Map additionalClaims, Duration accessTokenValidity, boolean issueIdToken) { claims = new DefaultClaims(createClaims()); - claims.putAll(additionalClaims); + claims.putAll(requireNonNull(additionalClaims, "additionalClaims is null")); + this.accessTokenValidity = requireNonNull(accessTokenValidity, "accessTokenValidity is null"); accessToken = issueToken(claims); if (issueIdToken) { nonce = Optional.of(randomNonce()); @@ -1127,7 +1261,7 @@ public Response getOAuth2Response(String code, URI callbackUri, Optional if (!"TEST_CODE".equals(code)) { throw new IllegalArgumentException("Expected TEST_CODE"); } - return new Response(accessToken, Instant.now().plusSeconds(5), idToken, Optional.empty()); + return new Response(accessToken, Instant.now().plus(accessTokenValidity), idToken, Optional.of(REFRESH_TOKEN)); } @Override @@ -1140,7 +1274,10 @@ public Optional> getClaims(String accessToken) public Response refreshTokens(String refreshToken) throws ChallengeFailedException { - throw new UnsupportedOperationException("Refresh tokens are not supported"); + if (refreshToken.equals(REFRESH_TOKEN)) { + return new Response(issueToken(claims), Instant.now().plusSeconds(30), idToken, Optional.of(REFRESH_TOKEN)); + } + throw new ChallengeFailedException("invalid refresh token"); } public String getAccessToken() From c056bc7cd00d641ce4c4be6632a63508330394c7 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Sat, 10 Dec 2022 17:56:14 +0530 Subject: [PATCH 09/24] Fix recording of projection metrics Actual work is done in `pageProjectWork.process()` call while `projection.project` only performs setup of projection. So both `expressionProfiler` and `metrics.recordProjectionTime` needed to be around that method. --- .../io/trino/operator/project/PageProcessor.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java index a08faec243ae..d21267eec5cc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java @@ -342,13 +342,15 @@ private ProcessBatchResult processBatch(int batchSize) } else { if (pageProjectWork == null) { - Page inputPage = projection.getInputChannels().getInputChannels(page); - expressionProfiler.start(); - pageProjectWork = projection.project(session, yieldSignal, inputPage, positionsBatch); - long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); - metrics.recordProjectionTime(projectionTimeNanos); + pageProjectWork = projection.project(session, yieldSignal, projection.getInputChannels().getInputChannels(page), positionsBatch); } - if (!pageProjectWork.process()) { + + expressionProfiler.start(); + boolean finished = pageProjectWork.process(); + long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); + metrics.recordProjectionTime(projectionTimeNanos); + + if (!finished) { return ProcessBatchResult.processBatchYield(); } previouslyComputedResults[i] = pageProjectWork.getResult(); From a46a5100c8e4cee2853019243e9c2213d87ec3c9 Mon Sep 17 00:00:00 2001 From: leetcode-1533 <1275963@gmail.com> Date: Fri, 2 Dec 2022 00:55:10 -0800 Subject: [PATCH 10/24] Fix dereference operations for union type in Hive Connector --- .../java/io/trino/plugin/hive/HiveType.java | 72 ++++++++++--- .../plugin/hive/util/HiveTypeTranslator.java | 8 +- .../tests/product/hive/TestReadUniontype.java | 101 ++++++++++++++++++ 3 files changed, 164 insertions(+), 17 deletions(-) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java index 9230d9e26013..45afb7bb299a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java @@ -32,10 +32,14 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.lenientFormat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature; @@ -219,13 +223,32 @@ public Optional getHiveTypeForDereferences(List dereferences) { TypeInfo typeInfo = getTypeInfo(); for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - try { - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + try { + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } } - catch (RuntimeException e) { - return Optional.empty(); + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + try { + if (fieldIndex == 0) { + // union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature} + return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE)); + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + } + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); } } return Optional.of(toHiveType(typeInfo)); @@ -235,16 +258,35 @@ public List getHiveDereferenceNames(List dereferences) { ImmutableList.Builder dereferenceNames = ImmutableList.builder(); TypeInfo typeInfo = getTypeInfo(); - for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - + for (int i = 0; i < dereferences.size(); i++) { + int fieldIndex = dereferences.get(i); checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative"); - checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), - "fieldIndex should be less than the number of fields in the struct"); - String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - dereferenceNames.add(fieldName); - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), + "fieldIndex should be less than the number of fields in the struct"); + + String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + dereferenceNames.add(fieldName); + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(), + "fieldIndex should be less than the number of fields in the union plus tag field"); + + if (fieldIndex == 0) { + checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields"); + dereferenceNames.add(UNION_FIELD_TAG_NAME); + break; + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1)); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); + } } return dereferenceNames.build(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 22d1e5f4ce2d..4d165511130a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -91,6 +91,10 @@ public final class HiveTypeTranslator { private HiveTypeTranslator() {} + public static final String UNION_FIELD_TAG_NAME = "tag"; + public static final String UNION_FIELD_FIELD_PREFIX = "field"; + public static final Type UNION_FIELD_TAG_TYPE = TINYINT; + public static TypeInfo toTypeInfo(Type type) { requireNonNull(type, "type is null"); @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; List unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos(); ImmutableList.Builder typeSignatures = ImmutableList.builder(); - typeSignatures.add(namedField("tag", TINYINT.getTypeSignature())); + typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); for (int i = 0; i < unionObjectTypes.size(); i++) { TypeInfo unionObjectType = unionObjectTypes.get(i); - typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision))); + typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision))); } return rowType(typeSignatures.build()); } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java index a7ec94a51eb3..062ae91d9c3b 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; +import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tests.product.TestGroups.SMOKE; import static io.trino.tests.product.utils.QueryExecutors.onHive; import static io.trino.tests.product.utils.QueryExecutors.onTrino; @@ -51,6 +52,87 @@ public static Object[][] storageFormats() return new String[][] {{"ORC"}, {"AVRO"}}; } + @DataProvider(name = "union_dereference_test_cases") + public static Object[][] unionDereferenceTestCases() + { + String tableUnionDereference = "test_union_dereference" + randomNameSuffix(); + // Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC. + return new Object[][] {{ + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "INT, STRING>)" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, 321, 'row1') " + + "UNION ALL " + + "SELECT create_union(1, 55, 'row2') ", + tableUnionDereference), + format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList(321), + format("SELECT unionLevel0.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + // there is an internal issue in Hive 1.2: + // unionLevel1 is declared as unionType, but has to be inserted by create_union(tagId, Int, String) + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference), + Arrays.asList(5), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "STRUCT>>)" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " + + "UNION ALL " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ", + tableUnionDereference), + format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString1"), + format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString"), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}}; + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testReadUniontype(String storageFormat) { @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat) } } + @Test(dataProvider = "union_dereference_test_cases", groups = SMOKE) + public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List expectedResult, String selectTagSql, List expectedTagResult, String dropTableSql) + { + // According to testing results, the Hive INSERT queries here only work in Hive 1.2 + if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) { + throw new SkipException("This test can only be run with Hive 1.2 (default config)"); + } + + onHive().executeQuery(createTableSql); + onHive().executeQuery(insertSql); + + QueryResult result = onTrino().executeQuery(selectSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult); + result = onTrino().executeQuery(selectTagSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult); + + onTrino().executeQuery(dropTableSql); + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testUnionTypeSchemaEvolution(String storageFormat) { From 0459735c25df57335d221056f2f084b580cb2c23 Mon Sep 17 00:00:00 2001 From: James Petty Date: Fri, 9 Dec 2022 14:08:34 -0500 Subject: [PATCH 11/24] Cleanup PartitioningExchanger Removes outdated comments and unnecessary methods in local exchange PartitioningExchanger since the operator is no longer implemented in a way that attempts to be thread-safe. --- .../exchange/PartitioningExchanger.java | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java index 62cee860defe..808d875ea431 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java @@ -19,7 +19,6 @@ import io.trino.spi.Page; import it.unimi.dsi.fastutil.ints.IntArrayList; -import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import java.util.List; @@ -58,18 +57,7 @@ public PartitioningExchanger( @Override public void accept(Page page) { - Consumer wholePagePartition = partitionPageOrFindWholePagePartition(page, partitionedPagePreparer.apply(page)); - if (wholePagePartition != null) { - // whole input page will go to this partition, compact the input page avoid over-retaining memory and to - // match the behavior of sub-partitioned pages that copy positions out - page.compact(); - sendPageToPartition(wholePagePartition, page); - } - } - - @Nullable - private Consumer partitionPageOrFindWholePagePartition(Page page, Page partitionPage) - { + Page partitionPage = partitionedPagePreparer.apply(page); // assign each row to a partition. The assignments lists are all expected to cleared by the previous iterations for (int position = 0; position < partitionPage.getPositionCount(); position++) { int partition = partitionFunction.getPartition(partitionPage, position); @@ -89,22 +77,19 @@ private Consumer partitionPageOrFindWholePagePartition(Page page, Page par int[] positions = positionsList.elements(); positionsList.clear(); + Page pageSplit; if (partitionSize == page.getPositionCount()) { - // entire page will be sent to this partition, compact and send the page after releasing the lock - return buffers.get(partition); + // whole input page will go to this partition, compact the input page avoid over-retaining memory and to + // match the behavior of sub-partitioned pages that copy positions out + page.compact(); + pageSplit = page; } - Page pageSplit = page.copyPositions(positions, 0, partitionSize); - sendPageToPartition(buffers.get(partition), pageSplit); + else { + pageSplit = page.copyPositions(positions, 0, partitionSize); + } + memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); + buffers.get(partition).accept(pageSplit); } - // No single partition receives the entire input page - return null; - } - - // This is safe to call without synchronizing because the partition buffers are themselves threadsafe - private void sendPageToPartition(Consumer buffer, Page pageSplit) - { - memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); - buffer.accept(pageSplit); } @Override From 1788829b30f141856979a27982a13fb13dddf19e Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Mon, 12 Dec 2022 17:38:12 +0530 Subject: [PATCH 12/24] Fix table name in TestDropTableTask --- .../src/test/java/io/trino/execution/TestDropTableTask.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java index 8d22ffd8d491..1a13e6add44c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java @@ -36,7 +36,7 @@ public class TestDropTableTask @Test public void testDropExistingTable() { - QualifiedObjectName tableName = qualifiedObjectName("not_existing_table"); + QualifiedObjectName tableName = qualifiedObjectName("existing_table"); metadata.createTable(testSession, TEST_CATALOG_NAME, someTable(tableName), false); assertThat(metadata.getTableHandle(testSession, tableName)).isPresent(); From 3c1c1fa10234ed1f5e1496c094d2b16c6eacaa8d Mon Sep 17 00:00:00 2001 From: rigogsilva Date: Mon, 12 Dec 2022 10:08:41 -0600 Subject: [PATCH 13/24] Document examples for datetime functions --- docs/src/main/sphinx/functions/datetime.rst | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/src/main/sphinx/functions/datetime.rst b/docs/src/main/sphinx/functions/datetime.rst index 27eb678a83fc..7f16a2465141 100644 --- a/docs/src/main/sphinx/functions/datetime.rst +++ b/docs/src/main/sphinx/functions/datetime.rst @@ -228,7 +228,16 @@ The above examples use the timestamp ``2001-08-22 03:04:05.321`` as the input. .. function:: date_trunc(unit, x) -> [same as input] - Returns ``x`` truncated to ``unit``. + Returns ``x`` truncated to ``unit``:: + + SELECT date_trunc('day' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-20 00:00:00.000 + + SELECT date_trunc('month' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-01 00:00:00.000 + + SELECT date_trunc('year', TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-01-01 00:00:00.000 .. _datetime-interval-functions: @@ -383,11 +392,17 @@ Specifier Description .. function:: date_format(timestamp, format) -> varchar - Formats ``timestamp`` as a string using ``format``. + Formats ``timestamp`` as a string using ``format``:: + + SELECT date_format(TIMESTAMP '2022-10-20 05:10:00', '%m-%d-%Y %H'); + -- 10-20-2022 05 .. function:: date_parse(string, format) -> timestamp(3) - Parses ``string`` into a timestamp using ``format``. + Parses ``string`` into a timestamp using ``format``:: + + SELECT date_parse('2022/10/20/05', '%Y/%m/%d/%H'); + -- 2022-10-20 05:00:00.000 Java date functions ------------------- @@ -437,7 +452,10 @@ field to be extracted. Most fields support all date and time types. .. function:: extract(field FROM x) -> bigint - Returns ``field`` from ``x``. + Returns ``field`` from ``x``:: + + SELECT extract(YEAR FROM TIMESTAMP '2022-10-20 05:10:00'); + -- 2022 .. note:: This SQL-standard function uses special syntax for specifying the arguments. From 58bd1598c3694bdb5dc490188288992215605e9b Mon Sep 17 00:00:00 2001 From: Jonas Irgens Kylling Date: Fri, 25 Nov 2022 20:50:34 +0100 Subject: [PATCH 14/24] Decode path as URI in Delta Lake connector --- .../plugin/deltalake/DeltaLakeMetadata.java | 2 +- .../deltalake/DeltaLakeSplitManager.java | 8 +- .../BaseDeltaLakeConnectorSmokeTest.java | 13 +++ .../TestDeltaLakeAdlsConnectorSmokeTest.java | 9 ++ .../test/resources/databricks/uri/README.md | 12 +++ .../uri/_delta_log/00000000000000000000.json | 3 + .../uri/_delta_log/00000000000000000001.json | 2 + .../uri/_delta_log/00000000000000000002.json | 2 + .../uri/_delta_log/00000000000000000003.json | 2 + ...4de3-916a-ee0d89139446.c000.snappy.parquet | Bin 0 -> 475 bytes ...436d-bb00-245452904a81.c000.snappy.parquet | Bin 0 -> 475 bytes ...4180-91f1-4fd762dbc279.c000.snappy.parquet | Bin 0 -> 475 bytes .../TestDeltaLakeSelectCompatibility.java | 84 ++++++++++++++++++ 13 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet create mode 100644 testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeSelectCompatibility.java diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 081651d19be2..e8aec83e3400 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -1255,7 +1255,7 @@ private static void appendAddFileEntries(TransactionLogWriter transactionLogWrit transactionLogWriter.appendAddFileEntry( new AddFileEntry( - toUriFormat(info.getPath()), // Databricks and OSS Delta Lake both expect path to be url-encoded, even though the procotol specification doesn't mention that + toUriFormat(info.getPath()), // Paths are RFC 2396 URI encoded https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file partitionValues, info.getSize(), info.getCreationTime(), diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java index 7ceac4d7909f..ea296a05a8f2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java @@ -36,7 +36,7 @@ import javax.inject.Inject; -import java.net.URLDecoder; +import java.net.URI; import java.time.Instant; import java.util.List; import java.util.Map; @@ -58,7 +58,6 @@ import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getMaxSplitSize; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -290,8 +289,9 @@ private List splitsForFile( private static String buildSplitPath(String tableLocation, AddFileEntry addAction) { - // paths are relative to the table location and URL encoded - String path = URLDecoder.decode(addAction.getPath(), UTF_8); + // paths are relative to the table location and are RFC 2396 URIs + // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file + String path = URI.create(addAction.getPath()).getPath(); if (tableLocation.endsWith("/")) { return tableLocation + path; } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 94290ae288cd..a7175dce7cf7 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -282,6 +282,19 @@ public void testCreatePartitionedTable() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testPathUriDecoding() + { + String tableName = "test_uri_table_" + randomNameSuffix(); + registerTableFromResources(tableName, "databricks/uri", getQueryRunner()); + + assertQuery("SELECT * FROM " + tableName, "VALUES ('a=equal', 1), ('a:colon', 2), ('a+plus', 3)"); + String firstFilePath = (String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE y = 1"); + assertQuery("SELECT * FROM " + tableName + " WHERE \"$path\" = '" + firstFilePath + "'", "VALUES ('a=equal', 1)"); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testCreateTablePartitionValidation() { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java index 2fe5b4aef91e..82358243fe27 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java @@ -50,6 +50,7 @@ import static java.util.Objects.requireNonNull; import static java.util.regex.Matcher.quoteReplacement; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestDeltaLakeAdlsConnectorSmokeTest extends BaseDeltaLakeConnectorSmokeTest @@ -113,6 +114,14 @@ public void removeTestData() assertThat(azureContainerClient.listBlobsByHierarchy(bucketName + "/").stream()).hasSize(0); } + @Override + public void testPathUriDecoding() + { + // TODO https://github.com/trinodb/trino/issues/15376 AzureBlobFileSystem doesn't expect URI as the path argument + assertThatThrownBy(super::testPathUriDecoding) + .hasStackTraceContaining("The specified path does not exist"); + } + @Override protected void registerTableFromResources(String table, String resourcePath, QueryRunner queryRunner) { diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md new file mode 100644 index 000000000000..52a0fe3e758f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md @@ -0,0 +1,12 @@ +Data generated using OSS DELTA 2.0.0: + +```sql +CREATE TABLE default.uri_test (part string, y long) +USING delta +PARTITIONED BY (part) +LOCATION '/home/username/trino/plugin/trino-delta-lake/src/test/resources/databricks/uri'; + +INSERT INTO default.uri_test VALUES ('a=equal', 1); +INSERT INTO default.uri_test VALUES ('a:colon', 2); +INSERT INTO default.uri_test VALUES ('a+plus', 3); +``` diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..6ce3c0ab0934 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"ae596d2a-868c-4480-9dba-ecb1366eef15","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"part\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"y\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["part"],"configuration":{},"createdTime":1670672197310}} +{"commitInfo":{"timestamp":1670672197479,"operation":"CREATE TABLE","operationParameters":{"isManaged":"false","description":null,"partitionBy":"[\"part\"]","properties":"{}"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"ec6fd978-2b74-4482-a908-02528f96cff5"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..2a4c5a4b763f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet","partitionValues":{"part":"a=equal"},"size":475,"modificationTime":1670672201990,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":1},\"maxValues\":{\"y\":1},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202024,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"bd99e504-0c28-4661-af9c-854016b09a7e"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json new file mode 100644 index 000000000000..f939248ad264 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet","partitionValues":{"part":"a:colon"},"size":475,"modificationTime":1670672202850,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":2},\"maxValues\":{\"y\":2},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202858,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":1,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"fd9e716f-3072-4ecf-ac2b-e52d3a66e3b8"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json new file mode 100644 index 000000000000..e0b9d17bd36e --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet","partitionValues":{"part":"a+plus"},"size":475,"modificationTime":1670672203670,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":3},\"maxValues\":{\"y\":3},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672203680,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":2,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"3adbdcc8-1a1d-4457-b168-0292d9d34dac"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..88a74a3a51a244e6deb548e480d9a01fa543a08a GIT binary patch literal 475 zcmZXR-AV#M6vvN~%3eeeIKwXN#WIj6*x>H?O$6PA7a~F3M8vp`CatdSM<}siNn{*niS1Z%_0VsgKd^YLnK8r0QR1vC3Z(I0k@E{!3HJdshMUNO@&%2WDvhN;zdXCEL@lc zDLmnUS^d`$uL6POt5hc<3SY%sCu*z`W!`RhvR-Q5<8czDe!^>fGSk60=Dqt3U#R)6 zw8redD+m`;REsnS)F{9rPU#x|+sAvG?e<{{W`FN?K98I|p5qD^@tn#iodsj3-PW@1 qcj{d!J1y-uJzx40)$(Oy)YcOzJy&|3-)&Sa+s5aA%K;4WTYms`17yem literal 0 HcmV?d00001 diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..3c3dab40dd86907a18c7f09975092443cf40634f GIT binary patch literal 475 zcmZXR%}T>S5XU!bEdfOkx=R8%gau0racQ#sMg(utLn(q@L}Z(8YcU^fK9o|5FX6NJ z3?6+VaVoVrxQAu_Gqdymu`{~5aS0+8vB~%6*T<1XFvv2|5jtsNgwTPe!M#1^DK-s# zuZoahlS0{|IvGNMKs(pR=|U#YkE#EpLP;pHBzB>(j3@Cc4^<{7GLy#bnq{knVttloN$zLNVbi$^CMoORuURC+ zZD|GBT~`oCGf+_;1acfe5nd^W{mo-a-RUUN q@Vkwk5Z$)&Tb?icscieAIqs;b5S}YM&+j$sre(qRzrl(|u+|?ev}G&+ literal 0 HcmV?d00001 diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..2291dfae5a98673b36f891324764e4ef997d6dbe GIT binary patch literal 475 zcmZXRO-lkn7{{L-P4*B$;0(L4hh-p9ut9fxCxUL_K_sY4M2zcb((bCeYat?C`Urlt zzDvg*!`wr27@q(0_Wv<6yt%7WpbmBE=hycx=TL&OKuv(N1_6Mak;8j`>~msU=22xK zB{svd!%RBB08m$NPm_g;U!N2IMTd$Kp!o}88Eas~c5J6)Iy7%LO(@uyq=XUHXM>V? zm2J%;2IZq|mPJD(MKb`lGiZ@L&}+agq-3%|5qYYXHIJu4Efg|{UmWqK<776EtP?3b z;el28*O4p(f#l1TPDMO_7okqoSf$Fk-|%$3(7eZ!G+y~Buky)E2jhhI9ydHv;l8xS z?4c`|M^jYMDhSjlz#>lRD*xNZdz expectedRows = ImmutableList.of( + row(1, "spark=equal"), + row(2, "spark+plus"), + row(3, "spark space"), + row(10, "trino=equal"), + row(20, "trino+plus"), + row(30, "trino space")); + + assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName)) + .containsOnly(expectedRows); + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName)) + .containsOnly(expectedRows); + + String deltaFilePath = (String) onDelta().executeQuery("SELECT input_file_name() FROM default." + tableName + " WHERE a_number = 1").getOnlyValue(); + String trinoFilePath = (String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE a_number = 1").getOnlyValue(); + // File paths returned by the input_file_name function are URI encoded https://github.com/delta-io/delta/issues/1517 while the $path of Trino is not + assertNotEquals(deltaFilePath, trinoFilePath); + assertEquals(format("s3://%s%s", bucketName, URI.create(deltaFilePath).getPath()), trinoFilePath); + + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName + " WHERE \"$path\" = '" + trinoFilePath + "'")) + .containsOnly(row(1, "spark=equal")); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } +} From 35da905ba489f618863a0213a08b1bfa378d9d05 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Tue, 13 Dec 2022 11:37:50 +0900 Subject: [PATCH 15/24] Remove unused method from AstBuilder --- .../main/java/io/trino/sql/parser/AstBuilder.java | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 0723b8fc4af7..35947ac4a3ea 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -3745,18 +3745,4 @@ private static QueryPeriod.RangeType getRangeType(Token token) } throw new IllegalArgumentException("Unsupported query period range type: " + token.getText()); } - - private static Trim.Specification toTrimSpecification(String functionName) - { - requireNonNull(functionName, "functionName is null"); - switch (functionName) { - case "trim": - return Trim.Specification.BOTH; - case "ltrim": - return Trim.Specification.LEADING; - case "rtrim": - return Trim.Specification.TRAILING; - } - throw new IllegalArgumentException("Unsupported trim specification: " + functionName); - } } From ce7b57f952ac560ac07192819fc71d7fdeaf4f12 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Sat, 19 Nov 2022 06:53:59 +0900 Subject: [PATCH 16/24] Refactor BigQuery connector - Change ColumnHandle to BigQueryColumnHandle in BigQueryTableHandle - Extract buildColumnHandles in BigQueryClient --- .../io/trino/plugin/bigquery/BigQueryClient.java | 15 +++++++++------ .../trino/plugin/bigquery/BigQueryMetadata.java | 9 +++++---- .../io/trino/plugin/bigquery/BigQuerySplit.java | 13 ++++++------- .../plugin/bigquery/BigQuerySplitManager.java | 10 +++++----- .../plugin/bigquery/BigQueryTableHandle.java | 8 ++++---- .../java/io/trino/plugin/bigquery/ptf/Query.java | 4 +--- 6 files changed, 30 insertions(+), 29 deletions(-) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java index ef1dfef9faf3..d009b88f3505 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java @@ -40,6 +40,7 @@ import io.airlift.units.Duration; import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.spi.TrinoException; +import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; import java.util.Collections; @@ -348,19 +349,21 @@ private static String fullTableName(TableId remoteTableId) public List getColumns(BigQueryTableHandle tableHandle) { if (tableHandle.getProjectedColumns().isPresent()) { - return tableHandle.getProjectedColumns().get().stream() - .map(column -> (BigQueryColumnHandle) column) - .collect(toImmutableList()); + return tableHandle.getProjectedColumns().get(); } checkArgument(tableHandle.isNamedRelation(), "Cannot get columns for %s", tableHandle); TableInfo tableInfo = getTable(tableHandle.asPlainTable().getRemoteTableName().toTableId()) .orElseThrow(() -> new TableNotFoundException(tableHandle.asPlainTable().getSchemaTableName())); + return buildColumnHandles(tableInfo); + } + + public static List buildColumnHandles(TableInfo tableInfo) + { Schema schema = tableInfo.getDefinition().getSchema(); if (schema == null) { - throw new TableNotFoundException( - tableHandle.asPlainTable().getSchemaTableName(), - format("Table '%s' has no schema", tableHandle.asPlainTable().getSchemaTableName())); + SchemaTableName schemaTableName = new SchemaTableName(tableInfo.getTableId().getDataset(), tableInfo.getTableId().getTable()); + throw new TableNotFoundException(schemaTableName, format("Table '%s' has no schema", schemaTableName)); } return schema.getFields() .stream() diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 5708ad5e3632..0228112e25e0 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -326,7 +326,7 @@ public Map getColumnHandles(ConnectorSession session, Conn BigQueryTableHandle table = (BigQueryTableHandle) tableHandle; if (table.getProjectedColumns().isPresent()) { return table.getProjectedColumns().get().stream() - .collect(toImmutableMap(columnHandle -> ((BigQueryColumnHandle) columnHandle).getName(), identity())); + .collect(toImmutableMap(BigQueryColumnHandle::getName, identity())); } checkArgument(table.isNamedRelation(), "Cannot get columns for %s", tableHandle); @@ -567,11 +567,12 @@ public Optional> applyProjecti return Optional.empty(); } - ImmutableList.Builder projectedColumns = ImmutableList.builder(); + ImmutableList.Builder projectedColumns = ImmutableList.builder(); ImmutableList.Builder assignmentList = ImmutableList.builder(); assignments.forEach((name, column) -> { - projectedColumns.add(column); - assignmentList.add(new Assignment(name, column, ((BigQueryColumnHandle) column).getTrinoType())); + BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column; + projectedColumns.add(columnHandle); + assignmentList.add(new Assignment(name, column, columnHandle.getTrinoType())); }); bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build()); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java index 6ba9daddae17..7b53e6971ce2 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java @@ -17,7 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.trino.spi.HostAddress; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSplit; import org.openjdk.jol.info.ClassLayout; @@ -43,7 +42,7 @@ public class BigQuerySplit private final Mode mode; private final String streamName; private final String avroSchema; - private final List columns; + private final List columns; private final long emptyRowsToGenerate; private final Optional filter; private final OptionalInt dataSize; @@ -54,7 +53,7 @@ public BigQuerySplit( @JsonProperty("mode") Mode mode, @JsonProperty("streamName") String streamName, @JsonProperty("avroSchema") String avroSchema, - @JsonProperty("columns") List columns, + @JsonProperty("columns") List columns, @JsonProperty("emptyRowsToGenerate") long emptyRowsToGenerate, @JsonProperty("filter") Optional filter, @JsonProperty("dataSize") OptionalInt dataSize) @@ -68,12 +67,12 @@ public BigQuerySplit( this.dataSize = requireNonNull(dataSize, "dataSize is null"); } - static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) + static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) { return new BigQuerySplit(STORAGE, streamName, avroSchema, columns, NO_ROWS_TO_GENERATE, Optional.empty(), dataSize); } - static BigQuerySplit forViewStream(List columns, Optional filter) + static BigQuerySplit forViewStream(List columns, Optional filter) { return new BigQuerySplit(QUERY, "", "", columns, NO_ROWS_TO_GENERATE, filter, OptionalInt.empty()); } @@ -102,7 +101,7 @@ public String getAvroSchema() } @JsonProperty - public List getColumns() + public List getColumns() { return columns; } @@ -149,7 +148,7 @@ public long getRetainedSizeInBytes() return INSTANCE_SIZE + estimatedSizeOf(streamName) + estimatedSizeOf(avroSchema) - + estimatedSizeOf(columns, column -> ((BigQueryColumnHandle) column).getRetainedSizeInBytes()); + + estimatedSizeOf(columns, BigQueryColumnHandle::getRetainedSizeInBytes); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java index dbd10de79003..f15ddc9d9854 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java @@ -103,7 +103,7 @@ public ConnectorSplitSource getSplits( Optional filter = BigQueryFilterQueryBuilder.buildFilter(tableConstraint); if (!bigQueryTableHandle.isNamedRelation()) { - List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); + List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); return new FixedSplitSource(ImmutableList.of(BigQuerySplit.forViewStream(columns, filter))); } @@ -114,17 +114,17 @@ public ConnectorSplitSource getSplits( return new FixedSplitSource(splits); } - private static boolean emptyProjectionIsRequired(Optional> projectedColumns) + private static boolean emptyProjectionIsRequired(Optional> projectedColumns) { return projectedColumns.isPresent() && projectedColumns.get().isEmpty(); } - private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) + private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) { log.debug("readFromBigQuery(tableId=%s, projectedColumns=%s, actualParallelism=%s, filter=[%s])", remoteTableId, projectedColumns, actualParallelism, filter); - List columns = projectedColumns.orElse(ImmutableList.of()); + List columns = projectedColumns.orElse(ImmutableList.of()); List projectedColumnsNames = columns.stream() - .map(column -> ((BigQueryColumnHandle) column).getName()) + .map(BigQueryColumnHandle::getName) .collect(toImmutableList()); if (isWildcardTable(type, remoteTableId.getTable())) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java index cdd3f1988780..6618f6b4d14b 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java @@ -39,13 +39,13 @@ public class BigQueryTableHandle { private final BigQueryRelationHandle relationHandle; private final TupleDomain constraint; - private final Optional> projectedColumns; + private final Optional> projectedColumns; @JsonCreator public BigQueryTableHandle( @JsonProperty("relationHandle") BigQueryRelationHandle relationHandle, @JsonProperty("constraint") TupleDomain constraint, - @JsonProperty("projectedColumns") Optional> projectedColumns) + @JsonProperty("projectedColumns") Optional> projectedColumns) { this.relationHandle = requireNonNull(relationHandle, "relationHandle is null"); this.constraint = requireNonNull(constraint, "constraint is null"); @@ -79,7 +79,7 @@ public TupleDomain getConstraint() } @JsonProperty - public Optional> getProjectedColumns() + public Optional> getProjectedColumns() { return projectedColumns; } @@ -145,7 +145,7 @@ BigQueryTableHandle withConstraint(TupleDomain newConstraint) return new BigQueryTableHandle(relationHandle, newConstraint, projectedColumns); } - public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) + public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) { return new BigQueryTableHandle(relationHandle, constraint, Optional.of(newProjectedColumns)); } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java index 6b9f8c03aaad..bfd8ef34e5aa 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java @@ -24,7 +24,6 @@ import io.trino.plugin.bigquery.BigQueryColumnHandle; import io.trino.plugin.bigquery.BigQueryQueryRelationHandle; import io.trino.plugin.bigquery.BigQueryTableHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -110,13 +109,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact } columnsBuilder.add(toColumnHandle(field)); } - List columns = columnsBuilder.build(); Descriptor returnedType = new Descriptor(columnsBuilder.build().stream() .map(column -> new Field(column.getName(), Optional.of(column.getTrinoType()))) .collect(toList())); - QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columns.stream().map(column -> (ColumnHandle) column).collect(toList()))); + QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columnsBuilder.build())); return TableFunctionAnalysis.builder() .returnedType(returnedType) From a660133428f1668f24b01a9f3f7f13b52a4929f2 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Sat, 19 Nov 2022 09:06:22 +0900 Subject: [PATCH 17/24] Fix projection pushdown when unsupported column exists in BigQuery --- .../plugin/bigquery/BigQueryMetadata.java | 29 ++++++++++--------- .../plugin/bigquery/BigQuerySplitManager.java | 5 +++- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 0228112e25e0..5a3e273abf38 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -33,6 +33,7 @@ import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.plugin.bigquery.BigQueryClient.RemoteDatabaseObject; +import io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType; import io.trino.plugin.bigquery.ptf.Query.QueryHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.Assignment; @@ -87,6 +88,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.plugin.bigquery.BigQueryClient.buildColumnHandles; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_LISTING_DATASET_ERROR; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_UNSUPPORTED_OPERATION; import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_DATE; @@ -230,12 +232,20 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return null; } + ImmutableList.Builder columns = ImmutableList.builder(); + columns.addAll(buildColumnHandles(tableInfo.get())); + Optional partitionType = getPartitionType(tableInfo.get().getDefinition()); + if (partitionType.isPresent() && partitionType.get() == INGESTION) { + columns.add(PARTITION_DATE.getColumnHandle()); + columns.add(PARTITION_TIME.getColumnHandle()); + } return new BigQueryTableHandle(new BigQueryNamedRelationHandle( schemaTableName, new RemoteTableName(tableInfo.get().getTableId()), tableInfo.get().getDefinition().getType().toString(), - getPartitionType(tableInfo.get().getDefinition()), - Optional.ofNullable(tableInfo.get().getDescription()))); + partitionType, + Optional.ofNullable(tableInfo.get().getDescription()))) + .withProjectedColumns(columns.build()); } private ConnectorTableHandle getTableHandleIgnoringConflicts(ConnectorSession session, SchemaTableName schemaTableName) @@ -269,17 +279,10 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect log.debug("getTableMetadata(session=%s, tableHandle=%s)", session, tableHandle); BigQueryTableHandle handle = ((BigQueryTableHandle) tableHandle); - ImmutableList.Builder columnMetadata = ImmutableList.builder(); - for (BigQueryColumnHandle column : client.getColumns(handle)) { - columnMetadata.add(column.getColumnMetadata()); - } - if (handle.isNamedRelation()) { - if (handle.asPlainTable().getPartitionType().isPresent() && handle.asPlainTable().getPartitionType().get() == INGESTION) { - columnMetadata.add(PARTITION_DATE.getColumnMetadata()); - columnMetadata.add(PARTITION_TIME.getColumnMetadata()); - } - } - return new ConnectorTableMetadata(getSchemaTableName(handle), columnMetadata.build(), ImmutableMap.of(), getTableComment(handle)); + List columns = client.getColumns(handle).stream() + .map(BigQueryColumnHandle::getColumnMetadata) + .collect(toImmutableList()); + return new ConnectorTableMetadata(getSchemaTableName(handle), columns, ImmutableMap.of(), getTableComment(handle)); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java index f15ddc9d9854..9b78d0727cf1 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java @@ -47,6 +47,7 @@ import static com.google.cloud.bigquery.TableDefinition.Type.MATERIALIZED_VIEW; import static com.google.cloud.bigquery.TableDefinition.Type.TABLE; import static com.google.cloud.bigquery.TableDefinition.Type.VIEW; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY; import static io.trino.plugin.bigquery.BigQuerySessionProperties.createDisposition; @@ -121,8 +122,10 @@ private static boolean emptyProjectionIsRequired(Optional readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) { + checkArgument(projectedColumns.isPresent() && projectedColumns.get().size() > 0, "Projected column is empty"); + log.debug("readFromBigQuery(tableId=%s, projectedColumns=%s, actualParallelism=%s, filter=[%s])", remoteTableId, projectedColumns, actualParallelism, filter); - List columns = projectedColumns.orElse(ImmutableList.of()); + List columns = projectedColumns.get(); List projectedColumnsNames = columns.stream() .map(BigQueryColumnHandle::getName) .collect(toImmutableList()); From cd8eac64e3e427be45fc5df5b5a61110245479ec Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 17 Nov 2022 13:21:55 +0100 Subject: [PATCH 18/24] Update Iceberg to 1.1.0 --- plugin/trino-iceberg/pom.xml | 2 + .../plugin/iceberg/IcebergAvroPageSource.java | 2 + .../plugin/iceberg/IcebergSplitSource.java | 2 + .../TestIcebergMetadataFileOperations.java | 76 +++++++++---------- .../trino/plugin/iceberg/TestIcebergV2.java | 7 +- pom.xml | 2 +- testing/trino-faulttolerant-tests/pom.xml | 2 + testing/trino-tests/pom.xml | 2 + 8 files changed, 54 insertions(+), 41 deletions(-) diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 825187bc5c81..48db13121363 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -406,6 +406,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java index 3add0c0ceef6..32f721035158 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java @@ -91,6 +91,8 @@ public IcebergAvroPageSource( .collect(toImmutableMap(Types.NestedField::name, Types.NestedField::type)); pageBuilder = new PageBuilder(columnTypes); recordIterator = avroReader.iterator(); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } private boolean isIndexColumn(int column) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java index 028278d43152..5302bf325e67 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java @@ -194,6 +194,8 @@ public CompletableFuture getNextBatch(int maxSize) closer.register(fileScanTaskIterable); this.fileScanTaskIterator = fileScanTaskIterable.iterator(); closer.register(fileScanTaskIterator); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } TupleDomain dynamicFilterPredicate = dynamicFilter.getCurrentPredicate() diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java index baa84cde6423..ed5d84fe5cdd 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java @@ -133,8 +133,8 @@ public void testSelect() assertUpdate("CREATE TABLE test_select AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -160,26 +160,26 @@ public void testSelectFromVersionedTable() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -203,26 +203,26 @@ public void testSelectFromVersionedTableWithSchemaEvolution() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -232,8 +232,8 @@ public void testSelectWithFilter() assertUpdate("CREATE TABLE test_select_with_filter AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select_with_filter WHERE col_name = 1", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -248,8 +248,8 @@ public void testJoin() assertFileSystemAccesses("SELECT name, age FROM test_join_t1 JOIN test_join_t2 ON test_join_t2.id = test_join_t1.id", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -266,8 +266,8 @@ public void testJoinWithPartitionedTable() assertFileSystemAccesses("SELECT count(*) FROM test_join_partitioned_t1 t1 join test_join_partitioned_t2 t2 on t1.a = t2.foo", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -281,8 +281,8 @@ public void testExplainSelect() assertFileSystemAccesses("EXPLAIN SELECT * FROM test_explain", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -296,8 +296,8 @@ public void testShowStatsForTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -313,8 +313,8 @@ public void testShowStatsForPartitionedTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats_partitioned", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -328,8 +328,8 @@ public void testShowStatsForTableWithFilter() assertFileSystemAccesses("SHOW STATS FOR (SELECT * FROM test_show_stats_with_filter WHERE age >= 2)", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -345,8 +345,8 @@ public void testPredicateWithVarcharCastToDate() assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -355,8 +355,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and comparison assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -365,8 +365,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and BETWEEN assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) BETWEEN DATE '2005-01-01' AND DATE '2005-12-31'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -375,8 +375,8 @@ public void testPredicateWithVarcharCastToDate() // conversion to date as a date function assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE date(a) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -404,8 +404,8 @@ public void testRemoveOrphanFiles() .add(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM)) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 4) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 6) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 6) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 5) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 5) .build()); assertUpdate("DROP TABLE " + tableName); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java index 64f74e0053d0..074ccd439107 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java @@ -51,6 +51,7 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDelete; import org.apache.iceberg.deletes.PositionDeleteWriter; import org.apache.iceberg.hadoop.HadoopOutputFile; import org.apache.iceberg.parquet.Parquet; @@ -171,8 +172,10 @@ public void testV2TableWithPositionDelete() .withSpec(PartitionSpec.unpartitioned()) .buildPositionWriter(); + PositionDelete positionDelete = PositionDelete.create(); + PositionDelete record = positionDelete.set(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); try (Closeable ignored = writer) { - writer.delete(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); + writer.write(record); } icebergTable.newRowDelta().addDeletes(writer.toDeleteFile()).commit(); @@ -521,7 +524,7 @@ private void writeEqualityDeleteToNationTable(Table icebergTable, Optional5.5.2 4.14.0 7.1.4 - 1.0.0 + 1.1.0 4.7.2 3.21.6 3.2.2 diff --git a/testing/trino-faulttolerant-tests/pom.xml b/testing/trino-faulttolerant-tests/pom.xml index fcba727b5010..95e8c9167785 100644 --- a/testing/trino-faulttolerant-tests/pom.xml +++ b/testing/trino-faulttolerant-tests/pom.xml @@ -465,6 +465,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 02d6ab48ffaf..5eef67d95db0 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -391,6 +391,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ From b9d26a7da490d580390b4ac5cbeee62280232c3e Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Mon, 5 Jul 2021 12:55:22 +0200 Subject: [PATCH 19/24] Document Top-N pushdown --- docs/src/main/sphinx/optimizer/pushdown.rst | 87 +++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/docs/src/main/sphinx/optimizer/pushdown.rst b/docs/src/main/sphinx/optimizer/pushdown.rst index 447281295a3d..4d40b7dca267 100644 --- a/docs/src/main/sphinx/optimizer/pushdown.rst +++ b/docs/src/main/sphinx/optimizer/pushdown.rst @@ -271,3 +271,90 @@ FETCH FIRST N ROWS``. Implementation and support is connector-specific since different data sources support different SQL syntax and processing. + +For example, you can find two queries to learn how to identify Top-N pushdown behavior in the following section. + +First, a concrete example of a Top-N pushdown query on top of a PostgreSQL database:: + + SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +You can get the explain plan by prepending the above query with ``EXPLAIN``:: + + EXPLAIN SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[id, name] + │ Layout: [id:integer, name:varchar] + │ Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: ?} + └─ RemoteSource[1] + Layout: [id:integer, name:varchar] + + Fragment 1 [SOURCE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TableScan[postgresql:public.company public.company sortOrder=[id:integer:int4 ASC NULLS LAST] limit=5, grouped = false] + Layout: [id:integer, name:varchar] + Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: 0B} + name := name:varchar:text + id := id:integer:int4 + +Second, an example of a Top-N query on the ``tpch`` connector which does not support +Top-N pushdown functionality:: + + SELECT custkey, name + FROM tpch.sf1.customer + ORDER BY custkey + LIMIT 5; + +The related query plan: + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[custkey, name] + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ TopN[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ LocalExchange[SINGLE] () + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ RemoteSource[1] + Layout: [custkey:bigint, name:varchar(25)] + + Fragment 1 [SOURCE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TopNPartial[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ TableScan[tpch:customer:sf1.0, grouped = false] + Layout: [custkey:bigint, name:varchar(25)] + Estimates: {rows: 150000 (4.58MB), cpu: 4.58M, memory: 0B, network: 0B} + custkey := tpch:custkey + name := tpch:name + +In the preceding query plan, the Top-N operation ``TopN[5 by (custkey ASC NULLS LAST)]`` +is being applied in the ``Fragment 0`` by Trino and not by the source database. + +Note that, compared to the query executed on top of the ``tpch`` connector, +the explain plan of the query applied on top of the ``postgresql`` connector +is missing the reference to the operation ``TopN[5 by (id ASC NULLS LAST)]`` +in the ``Fragment 0``. +The absence of the ``TopN`` Trino operator in the ``Fragment 0`` from the query plan +demonstrates that the query benefits of the Top-N pushdown optimization. From f5b5fb8a4a7756566ac5d16aeb8f2b7120dd411a Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Wed, 14 Dec 2022 06:24:30 +0100 Subject: [PATCH 20/24] Remove unnecessary override for `getTableProperties` method --- .../io/trino/connector/system/SystemTablesMetadata.java | 7 ------- .../src/main/java/io/trino/testing/TestingMetadata.java | 7 ------- .../java/io/trino/plugin/accumulo/AccumuloMetadata.java | 7 ------- .../src/main/java/io/trino/plugin/atop/AtopMetadata.java | 7 ------- .../java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java | 7 ------- .../java/io/trino/plugin/bigquery/BigQueryMetadata.java | 8 -------- .../java/io/trino/plugin/blackhole/BlackHoleMetadata.java | 7 ------- .../java/io/trino/plugin/cassandra/CassandraMetadata.java | 7 ------- .../java/io/trino/plugin/example/ExampleMetadata.java | 7 ------- .../io/trino/plugin/google/sheets/SheetsMetadata.java | 7 ------- .../src/main/java/io/trino/plugin/jmx/JmxMetadata.java | 7 ------- .../main/java/io/trino/plugin/kafka/KafkaMetadata.java | 7 ------- .../java/io/trino/plugin/kinesis/KinesisMetadata.java | 7 ------- .../java/io/trino/plugin/localfile/LocalFileMetadata.java | 7 ------- .../main/java/io/trino/plugin/memory/MemoryMetadata.java | 7 ------- .../main/java/io/trino/plugin/pinot/PinotMetadata.java | 7 ------- .../io/trino/plugin/prometheus/PrometheusMetadata.java | 7 ------- .../main/java/io/trino/plugin/redis/RedisMetadata.java | 7 ------- .../main/java/io/trino/plugin/thrift/ThriftMetadata.java | 7 ------- .../main/java/io/trino/plugin/tpcds/TpcdsMetadata.java | 7 ------- 20 files changed, 141 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java index 21b75218f657..98228ac46f32 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -139,12 +138,6 @@ public Map> listTableColumns(ConnectorSess return builder.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java index 4a1d42115cec..587c6203bf2b 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.MaterializedViewNotFoundException; @@ -331,12 +330,6 @@ public void grantTablePrivileges(ConnectorSession session, SchemaTableName table @Override public void revokeTablePrivileges(ConnectorSession session, SchemaTableName tableName, Set privileges, TrinoPrincipal grantee, boolean grantOption) {} - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public void clear() { views.clear(); diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java index 15e91744961d..003f8a0c755c 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java @@ -34,7 +34,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -389,12 +388,6 @@ public Optional> applyFilter(C return Optional.of(new ConstraintApplicationResult<>(handle, constraint.getSummary(), false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle handle) - { - return new ConnectorTableProperties(); - } - private void checkNoRollback() { checkState(rollbackAction.get() == null, "Cannot begin a new write while in an existing one"); diff --git a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java index 9ea2908f7b05..1a91bcc162c1 100644 --- a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java +++ b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -149,12 +148,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable throw new ColumnNotFoundException(tableName, columnName); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 677edecb4cea..7f7cd97ccc3e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -33,7 +33,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -625,12 +624,6 @@ public Optional applyTableScanRedirect(Conne return jdbcClient.getTableScanRedirection(session, tableHandle); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTableHandle table) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 5a3e273abf38..6d9251c49751 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -48,7 +48,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.Constraint; @@ -374,13 +373,6 @@ public Map> listTableColumns(ConnectorSess return columns.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - log.debug("getTableProperties(session=%s, prefix=%s)", session, table); - return new ConnectorTableProperties(); - } - @Override public void createSchema(ConnectorSession session, String schemaName, Map properties, TrinoPrincipal owner) { diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java index 2c63744aab9a..6d611dac602a 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.RetryMode; import io.trino.spi.connector.RowChangeParadigm; @@ -387,12 +386,6 @@ public Optional getView(ConnectorSession session, Schem return Optional.ofNullable(views.get(viewName)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - private void checkSchemaExists(String schemaName) { if (!schemas.contains(schemaName)) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index 9b40fd8fbced..913040043100 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -30,7 +30,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.NotFoundException; @@ -251,12 +250,6 @@ public Optional> applyFilter(C false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, boolean ignoreExisting) { diff --git a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java index f1d705957d94..b964390904d2 100644 --- a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java +++ b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -155,10 +154,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((ExampleColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java index efac03397d77..763070d3a053 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -146,10 +145,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((SheetsColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java index 9029216786c1..3c2c3950ec38 100644 --- a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java +++ b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java @@ -27,7 +27,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -102,12 +101,6 @@ public JmxTableHandle getTableHandle(ConnectorSession session, SchemaTableName t return getTableHandle(tableName); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public JmxTableHandle getTableHandle(SchemaTableName tableName) { requireNonNull(tableName, "tableName is null"); diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java index f0a6a91e3344..dc2a5834c02d 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java @@ -27,7 +27,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.RetryMode; @@ -231,12 +230,6 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema return new ConnectorTableMetadata(schemaTableName, builder.build()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java index f4bdef956de0..6f29ad35d03d 100644 --- a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java +++ b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -85,12 +84,6 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession connectorSession return getTableMetadata(((KinesisTableHandle) tableHandle).toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public List listTables(ConnectorSession session, Optional schemaName) { diff --git a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java index 7d757e3eb2c2..032fbeae7621 100644 --- a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java +++ b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java @@ -21,7 +21,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -133,12 +132,6 @@ private List listTables(ConnectorSession session, SchemaTablePr return ImmutableList.of(prefix.toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java index ab19f25f75f5..0e167cac5494 100644 --- a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java +++ b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.RetryMode; @@ -421,12 +420,6 @@ private void updateRowsOnHosts(long tableId, Collection fragments) tables.put(tableId, new TableInfo(tableId, info.getSchemaName(), info.getTableName(), info.getColumns(), dataFragments, info.getComment())); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public List getDataFragments(long tableId) { return ImmutableList.copyOf(tables.get(tableId).getDataFragments().values()); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java index 0e4a574aeb87..6710fbfca8e1 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java @@ -43,7 +43,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; @@ -234,12 +233,6 @@ public Optional getInfo(ConnectorTableHandle table) return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit) { diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java index 46cafc5c5d8c..4e5e79f4703e 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -155,12 +154,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((PrometheusColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java index cff52758efa4..e2505e399d5e 100644 --- a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java +++ b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java @@ -25,7 +25,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -276,12 +275,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((RedisColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @VisibleForTesting Map getDefinedTables() { diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java index 56f96a5e9cd5..1ab22188272e 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java @@ -36,7 +36,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.ProjectionApplicationResult; @@ -164,12 +163,6 @@ public Optional resolveIndex(ConnectorSession session, C return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java index a917bad9b8ff..51c5fb6ca083 100644 --- a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java +++ b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.statistics.TableStatistics; @@ -98,12 +97,6 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return new TpcdsTableHandle(tableName.getTableName(), scaleFactor); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { From 1a79defa298246bd69fe41a0a4c53e6ddedf7a87 Mon Sep 17 00:00:00 2001 From: Maxim Lukyanenko Date: Mon, 5 Dec 2022 14:18:51 +0200 Subject: [PATCH 21/24] Reproduce kafka protobuf schema parsing stack overflow error Kafka protobuf registry schema cannot be translated to plain Trino SQL structure in common cases when included data types use referencing to same objects recursively. --- plugin/trino-kafka/pom.xml | 39 +++++++++++++++++++ ...ithSchemaRegistryMinimalFunctionality.java | 28 +++++++++++++ .../protobuf-sources/unsupported_nested.proto | 24 ++++++++++++ 3 files changed, 91 insertions(+) create mode 100644 plugin/trino-kafka/src/test/resources/protobuf-sources/unsupported_nested.proto diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index 62d3900dc92f..9506552bd9c9 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -339,6 +339,45 @@ + + com.github.os72 + protoc-jar-maven-plugin + + + generate-test-sources + generate-test-sources + + run + + + ${dep.protobuf.version} + none + + src/test/resources/protobuf-sources + + target/generated-test-sources/ + + + + + + org.codehaus.mojo + build-helper-maven-plugin + + + add-test-sources + generate-test-sources + + add-test-source + + + + ${basedir}/target/generated-test-sources + + + + + diff --git a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java index 9b444098e29e..6278769c6cb7 100644 --- a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java +++ b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java @@ -140,6 +140,34 @@ public void testBasicTopicForInsert() "Insert is not supported for schema registry based tables"); } + @Test + public void testUnsupportedNestedDataTypes() + throws Exception + { + String topic = "topic-unsupported-nested"; + assertNotExists(topic); + + UnsupportedNestedTypes.schema message = UnsupportedNestedTypes.schema.newBuilder() + .setNestedValueOne(UnsupportedNestedTypes.NestedValue.newBuilder().setStringValue("Value1").build()) + .build(); + + ImmutableList.Builder> producerRecordBuilder = ImmutableList.builder(); + producerRecordBuilder.add(new ProducerRecord<>(topic, createKeySchema(0, getKeySchema()), message)); + List> messages = producerRecordBuilder.build(); + testingKafka.sendMessages( + messages.stream(), + ImmutableMap.of( + SCHEMA_REGISTRY_URL_CONFIG, testingKafka.getSchemaRegistryConnectString(), + KEY_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName(), + VALUE_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName())); + + // any call to kafka topic will trigger schema parsing what's get exceptional failure, + // so here is waiting for some time period and invoke query + Thread.sleep(2000); + assertQueryFails("SELECT * FROM " + toDoubleQuoted(topic), + "statement is too large \\(stack overflow during analysis\\)"); + } + private Map producerProperties() { return ImmutableMap.of( diff --git a/plugin/trino-kafka/src/test/resources/protobuf-sources/unsupported_nested.proto b/plugin/trino-kafka/src/test/resources/protobuf-sources/unsupported_nested.proto new file mode 100644 index 000000000000..c6c0dd3b856a --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf-sources/unsupported_nested.proto @@ -0,0 +1,24 @@ +syntax = "proto3"; + +package io.trino.protobuf; + +option java_package = "io.trino.plugin.kafka.protobuf"; +option java_outer_classname = "UnsupportedNestedTypes"; + +message schema { + NestedValue nested_value_one = 1; +} + +message NestedStruct { + map fields = 1; +} + +message NestedValue { + string string_value = 1; + NestedStruct struct_value = 2; + NestedListValue list_value = 3; +} + +message NestedListValue { + repeated NestedValue values = 1; +} From e669359fe90bde7bdaf08015cad80cc7ea6cf2ab Mon Sep 17 00:00:00 2001 From: Maxim Lukyanenko Date: Wed, 7 Dec 2022 18:41:22 +0200 Subject: [PATCH 22/24] Reproduce NPE on Kafka sending when proto file has `import` Message sending to Kafka thows NPE if proto file has `import` derective. --- .../trino/decoder/protobuf/ProtobufUtils.java | 10 ++- ...ithSchemaRegistryMinimalFunctionality.java | 79 +++++++++++++++++++ .../protobuf/structural_datatypes.proto | 31 ++++++++ 3 files changed, 118 insertions(+), 2 deletions(-) create mode 100644 plugin/trino-kafka/src/test/resources/protobuf/structural_datatypes.proto diff --git a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java index 02082c37c612..bfdaaff59aee 100644 --- a/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java +++ b/lib/trino-record-decoder/src/main/java/io/trino/decoder/protobuf/ProtobufUtils.java @@ -71,9 +71,15 @@ private ProtobufUtils() public static FileDescriptor getFileDescriptor(String protoFile) throws DescriptorValidationException + { + return getFileDescriptor(Optional.empty(), protoFile); + } + + public static FileDescriptor getFileDescriptor(Optional fileName, String protoFile) + throws DescriptorValidationException { ProtoFileElement protoFileElement = ProtoParser.Companion.parse(Location.get(""), protoFile); - return getFileDescriptor(Optional.empty(), protoFileElement); + return getFileDescriptor(fileName, protoFileElement); } public static FileDescriptor getFileDescriptor(Optional fileName, ProtoFileElement protoFileElement) @@ -84,7 +90,7 @@ public static FileDescriptor getFileDescriptor(Optional fileName, ProtoF int index = 0; for (String importStatement : protoFileElement.getImports()) { try { - FileDescriptor fileDescriptor = getFileDescriptor(getProtoFile(importStatement)); + FileDescriptor fileDescriptor = getFileDescriptor(Optional.of(importStatement), getProtoFile(importStatement)); fileDescriptor.getMessageTypes().stream() .map(Descriptor::getFullName) .forEach(definedMessages::add); diff --git a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java index 6278769c6cb7..5cb9d5f11e72 100644 --- a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java +++ b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java @@ -18,10 +18,12 @@ import com.google.protobuf.Descriptors.Descriptor; import com.google.protobuf.Descriptors.FieldDescriptor; import com.google.protobuf.DynamicMessage; +import com.google.protobuf.Timestamp; import io.confluent.kafka.serializers.protobuf.KafkaProtobufSerializer; import io.confluent.kafka.serializers.subject.RecordNameStrategy; import io.confluent.kafka.serializers.subject.TopicRecordNameStrategy; import io.trino.plugin.kafka.schema.confluent.KafkaWithConfluentSchemaRegistryQueryRunner; +import io.trino.spi.type.SqlTimestamp; import io.trino.testing.AbstractTestQueryFramework; import io.trino.testing.QueryRunner; import io.trino.testing.kafka.TestingKafka; @@ -31,6 +33,7 @@ import org.testng.annotations.Test; import java.time.Duration; +import java.time.LocalDateTime; import java.util.List; import java.util.Map; @@ -41,12 +44,20 @@ import static io.trino.decoder.protobuf.ProtobufRowDecoderFactory.DEFAULT_MESSAGE; import static io.trino.decoder.protobuf.ProtobufUtils.getFileDescriptor; import static io.trino.decoder.protobuf.ProtobufUtils.getProtoFile; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.testing.DateTimeTestingUtils.sqlTimestampOf; +import static java.lang.Math.PI; +import static java.lang.Math.floorDiv; import static java.lang.Math.multiplyExact; +import static java.lang.StrictMath.floorMod; import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static org.apache.kafka.clients.producer.ProducerConfig.KEY_SERIALIZER_CLASS_CONFIG; import static org.apache.kafka.clients.producer.ProducerConfig.VALUE_SERIALIZER_CLASS_CONFIG; +import static org.assertj.core.api.AssertionsForClassTypes.assertThatThrownBy; import static org.testng.Assert.assertTrue; @Test(singleThreaded = true) @@ -168,6 +179,74 @@ public void testUnsupportedNestedDataTypes() "statement is too large \\(stack overflow during analysis\\)"); } + @Test + public void testStructuralDataTypes() + throws Exception + { + String topic = "topic-structural"; + assertNotExists(topic); + + Descriptor descriptor = getDescriptor("structural_datatypes.proto"); + + Timestamp timestamp = getTimestamp(sqlTimestampOf(3, LocalDateTime.parse("2020-12-12T15:35:45.923"))); + DynamicMessage message = buildDynamicMessage( + descriptor, + ImmutableMap.builder() + .put("list", ImmutableList.of("Search")) + .put("map", ImmutableList.of(buildDynamicMessage( + descriptor.findFieldByName("map").getMessageType(), + ImmutableMap.of("key", "Key1", "value", "Value1")))) + .put("row", ImmutableMap.builder() + .put("string_column", "Trino") + .put("integer_column", 1) + .put("long_column", 493857959588286460L) + .put("double_column", PI) + .put("float_column", 3.14f) + .put("boolean_column", true) + .put("number_column", descriptor.findEnumTypeByName("Number").findValueByName("ONE")) + .put("timestamp_column", timestamp) + .put("bytes_column", "Trino".getBytes(UTF_8)) + .buildOrThrow()) + .buildOrThrow()); + + ImmutableList.Builder> producerRecordBuilder = ImmutableList.builder(); + producerRecordBuilder.add(new ProducerRecord<>(topic, createKeySchema(0, getKeySchema()), message)); + List> messages = producerRecordBuilder.build(); + assertThatThrownBy(() -> { + testingKafka.sendMessages( + messages.stream(), + ImmutableMap.of( + SCHEMA_REGISTRY_URL_CONFIG, testingKafka.getSchemaRegistryConnectString(), + KEY_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName(), + VALUE_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName())); + }).isInstanceOf(NullPointerException.class) + .hasMessage("Cannot invoke \"com.squareup.wire.schema.internal.parser.ProtoFileElement.getImports()\" because \"protoFileElement\" is null"); + } + + private DynamicMessage buildDynamicMessage(Descriptor descriptor, Map data) + { + DynamicMessage.Builder builder = DynamicMessage.newBuilder(descriptor); + for (Map.Entry entry : data.entrySet()) { + FieldDescriptor fieldDescriptor = descriptor.findFieldByName(entry.getKey()); + if (entry.getValue() instanceof Map) { + builder.setField(fieldDescriptor, buildDynamicMessage(fieldDescriptor.getMessageType(), (Map) entry.getValue())); + } + else { + builder.setField(fieldDescriptor, entry.getValue()); + } + } + + return builder.build(); + } + + protected static Timestamp getTimestamp(SqlTimestamp sqlTimestamp) + { + return Timestamp.newBuilder() + .setSeconds(floorDiv(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND)) + .setNanos(floorMod(sqlTimestamp.getEpochMicros(), MICROSECONDS_PER_SECOND) * NANOSECONDS_PER_MICROSECOND) + .build(); + } + private Map producerProperties() { return ImmutableMap.of( diff --git a/plugin/trino-kafka/src/test/resources/protobuf/structural_datatypes.proto b/plugin/trino-kafka/src/test/resources/protobuf/structural_datatypes.proto new file mode 100644 index 000000000000..854bd674dec5 --- /dev/null +++ b/plugin/trino-kafka/src/test/resources/protobuf/structural_datatypes.proto @@ -0,0 +1,31 @@ +syntax = "proto3"; + +import "google/protobuf/timestamp.proto"; + +message schema { + repeated string list = 1; + map map = 2; + enum Number { + ZERO = 0; + ONE = 1; + TWO = 2; + }; + message Row { + string string_column = 1; + uint32 integer_column = 2; + uint64 long_column = 3; + double double_column = 4; + float float_column = 5; + bool boolean_column = 6; + Number number_column = 7; + google.protobuf.Timestamp timestamp_column = 8; + bytes bytes_column = 9; + }; + Row row = 3; + message NestedRow { + repeated Row nested_list = 1; + map nested_map = 2; + Row row = 3; + }; + NestedRow nested_row = 4; +} From e5b3590ebe7f627db4fae32af6ad353a6d3bfbc5 Mon Sep 17 00:00:00 2001 From: Maxim Lukyanenko Date: Wed, 7 Dec 2022 18:44:30 +0200 Subject: [PATCH 23/24] Use declared variable for descriptor --- .../plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java index d9b3e808c5a9..6187962ef51e 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java @@ -98,7 +98,7 @@ private Type getType(FieldDescriptor fieldDescriptor) private Type getTypeForMessage(FieldDescriptor fieldDescriptor) { Descriptor descriptor = fieldDescriptor.getMessageType(); - if (fieldDescriptor.getMessageType().getFullName().equals(TIMESTAMP_TYPE_NAME)) { + if (descriptor.getFullName().equals(TIMESTAMP_TYPE_NAME)) { return createTimestampType(6); } if (fieldDescriptor.isMapField()) { @@ -108,7 +108,7 @@ private Type getTypeForMessage(FieldDescriptor fieldDescriptor) typeManager.getTypeOperators()); } return RowType.from( - fieldDescriptor.getMessageType().getFields().stream() + descriptor.getFields().stream() .map(field -> RowType.field(field.getName(), getType(field))) .collect(toImmutableList())); } From 4cad36489a4a8fd071eadba2490665600aa9eae0 Mon Sep 17 00:00:00 2001 From: Maxim Lukyanenko Date: Mon, 5 Dec 2022 23:09:05 +0200 Subject: [PATCH 24/24] Prevent protobuf schema message parsing dead loop recursion Kafka protobuf schema registry parsing falls to dead loop if has included references to the same object. It's need to abort the recursion loop with appropriated parsing error. --- plugin/trino-kafka/pom.xml | 5 ++++ .../protobuf/ProtobufSchemaParser.java | 29 ++++++++++++++----- ...ithSchemaRegistryMinimalFunctionality.java | 10 ++++--- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/plugin/trino-kafka/pom.xml b/plugin/trino-kafka/pom.xml index 9506552bd9c9..8646112fb62a 100644 --- a/plugin/trino-kafka/pom.xml +++ b/plugin/trino-kafka/pom.xml @@ -138,6 +138,11 @@ kafka-clients + + org.pcollections + pcollections + + io.airlift diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java index 6187962ef51e..8fed11c24eea 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/encoder/protobuf/ProtobufSchemaParser.java @@ -21,18 +21,22 @@ import io.trino.plugin.kafka.KafkaTopicFieldDescription; import io.trino.plugin.kafka.KafkaTopicFieldGroup; import io.trino.plugin.kafka.schema.confluent.SchemaParser; +import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.ArrayType; import io.trino.spi.type.MapType; import io.trino.spi.type.RowType; import io.trino.spi.type.Type; import io.trino.spi.type.TypeManager; +import org.pcollections.Empty; +import org.pcollections.PSet; import javax.inject.Inject; import java.util.Optional; import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; import static io.trino.spi.type.DoubleType.DOUBLE; @@ -41,6 +45,7 @@ import static io.trino.spi.type.TimestampType.createTimestampType; import static io.trino.spi.type.VarbinaryType.VARBINARY; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static java.lang.String.join; import static java.util.Objects.requireNonNull; public class ProtobufSchemaParser @@ -66,7 +71,7 @@ public KafkaTopicFieldGroup parse(ConnectorSession session, String subject, Pars protobufSchema.toDescriptor().getFields().stream() .map(field -> new KafkaTopicFieldDescription( field.getName(), - getType(field), + getType(field, Empty.orderedSet()), field.getName(), null, null, @@ -75,7 +80,7 @@ public KafkaTopicFieldGroup parse(ConnectorSession session, String subject, Pars .collect(toImmutableList())); } - private Type getType(FieldDescriptor fieldDescriptor) + private Type getType(FieldDescriptor fieldDescriptor, PSet processedMessages) { Type baseType = switch (fieldDescriptor.getJavaType()) { case BOOLEAN -> BOOLEAN; @@ -85,31 +90,39 @@ private Type getType(FieldDescriptor fieldDescriptor) case DOUBLE -> DOUBLE; case BYTE_STRING -> VARBINARY; case STRING, ENUM -> createUnboundedVarcharType(); - case MESSAGE -> getTypeForMessage(fieldDescriptor); + case MESSAGE -> getTypeForMessage(fieldDescriptor, processedMessages); }; - // Protobuf does not support adding repeated label for map type but schema registry incorrecty adds it + // Protobuf does not support adding repeated label for map type but schema registry incorrectly adds it if (fieldDescriptor.isRepeated() && !fieldDescriptor.isMapField()) { return new ArrayType(baseType); } return baseType; } - private Type getTypeForMessage(FieldDescriptor fieldDescriptor) + private Type getTypeForMessage(FieldDescriptor fieldDescriptor, PSet processedMessages) { Descriptor descriptor = fieldDescriptor.getMessageType(); if (descriptor.getFullName().equals(TIMESTAMP_TYPE_NAME)) { return createTimestampType(6); } + + if (processedMessages.contains(descriptor.getFullName())) { + throw new TrinoException(NOT_SUPPORTED, "Cannot parse registry schema for nested object with the same object reference: %s > %s" + .formatted(join(" > ", processedMessages), + descriptor.getFullName())); + } + PSet newProcessedMessages = processedMessages.plus(descriptor.getFullName()); + if (fieldDescriptor.isMapField()) { return new MapType( - getType(descriptor.findFieldByNumber(1)), - getType(descriptor.findFieldByNumber(2)), + getType(descriptor.findFieldByNumber(1), newProcessedMessages), + getType(descriptor.findFieldByNumber(2), newProcessedMessages), typeManager.getTypeOperators()); } return RowType.from( descriptor.getFields().stream() - .map(field -> RowType.field(field.getName(), getType(field))) + .map(field -> RowType.field(field.getName(), getType(field, newProcessedMessages))) .collect(toImmutableList())); } } diff --git a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java index 5cb9d5f11e72..7e8ce07c60f2 100644 --- a/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java +++ b/plugin/trino-kafka/src/test/java/io/trino/plugin/kafka/protobuf/TestKafkaProtobufWithSchemaRegistryMinimalFunctionality.java @@ -172,11 +172,13 @@ public void testUnsupportedNestedDataTypes() KEY_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName(), VALUE_SERIALIZER_CLASS_CONFIG, KafkaProtobufSerializer.class.getName())); - // any call to kafka topic will trigger schema parsing what's get exceptional failure, - // so here is waiting for some time period and invoke query - Thread.sleep(2000); + waitUntilTableExists(topic); assertQueryFails("SELECT * FROM " + toDoubleQuoted(topic), - "statement is too large \\(stack overflow during analysis\\)"); + "Cannot parse registry schema for nested object with the same object reference: " + + "io.trino.protobuf.NestedValue > " + + "io.trino.protobuf.NestedStruct > " + + "io.trino.protobuf.NestedStruct.FieldsEntry > " + + "io.trino.protobuf.NestedValue"); } @Test