From 87b8e6f68b1875353f39569050cc3989a838b63c Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 16:14:48 -0800 Subject: [PATCH 01/24] Add Redshift tests --- plugin/trino-redshift/README.md | 20 ++ plugin/trino-redshift/pom.xml | 96 +++++++ .../plugin/redshift/RedshiftClientModule.java | 12 +- .../plugin/redshift/RedshiftQueryRunner.java | 264 ++++++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 193 +++++++++++++ 5 files changed, 584 insertions(+), 1 deletion(-) create mode 100644 plugin/trino-redshift/README.md create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java diff --git a/plugin/trino-redshift/README.md b/plugin/trino-redshift/README.md new file mode 100644 index 000000000000..16229b145da1 --- /dev/null +++ b/plugin/trino-redshift/README.md @@ -0,0 +1,20 @@ +# Redshift Connector + +To run the Redshift tests you will need to provision a Redshift cluster. The +tests are designed to run on the smallest possible Redshift cluster containing +is a single dc2.large instance. Additionally, you will need a S3 bucket +containing TPCH tiny data in Parquet format. The files should be named: + +``` +s3:///tpch/tiny/.parquet +``` + +To run the tests set the following system properties: + +``` +test.redshift.jdbc.endpoint=..redshift.amazonaws.com:5439/ +test.redshift.jdbc.user= +test.redshift.jdbc.password= +test.redshift.s3.tpch.tables.root= +test.redshift.iam.role= +``` diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index a6270049b7c2..30fea843c921 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -40,12 +40,36 @@ + + io.airlift + log + runtime + + + + io.airlift + log-manager + runtime + + com.google.guava guava runtime + + net.jodah + failsafe + runtime + + + + org.jdbi + jdbi3-core + runtime + + io.trino @@ -72,9 +96,59 @@ + + io.trino + trino-base-jdbc + test-jar + test + + + + io.trino + trino-main + test + + io.trino trino-main + test-jar + test + + + + io.trino + trino-testing + test + + + + io.trino + trino-testing-services + test + + + + io.trino + trino-tpch + test + + + + io.trino.tpch + tpch + test + + + + io.airlift + testing + test + + + + org.assertj + assertj-core test @@ -84,4 +158,26 @@ test + + + + default + + true + + + + + org.apache.maven.plugins + maven-surefire-plugin + + + **/TestRedshiftConnectorTest.java + + + + + + + diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 53e1aee6ac29..dccc6fabff94 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -28,6 +28,8 @@ import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; +import java.util.Properties; + import static com.google.inject.multibindings.Multibinder.newSetBinder; public class RedshiftClientModule @@ -45,6 +47,14 @@ public void configure(Binder binder) @ForBaseJdbc public static ConnectionFactory getConnectionFactory(BaseJdbcConfig config, CredentialProvider credentialProvider) { - return new DriverConnectionFactory(new Driver(), config, credentialProvider); + return new DriverConnectionFactory(new Driver(), config.getConnectionUrl(), getDriverProperties(), credentialProvider); + } + + private static Properties getDriverProperties() + { + Properties properties = new Properties(); + properties.put("reWriteBatchedInserts", "true"); + properties.put("reWriteBatchedInsertsSize", "512"); + return properties; } } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java new file mode 100644 index 000000000000..61859dd9b52c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -0,0 +1,264 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Streams; +import io.airlift.log.Logger; +import io.airlift.log.Logging; +import io.trino.Session; +import io.trino.metadata.QualifiedObjectName; +import io.trino.plugin.tpch.TpchPlugin; +import io.trino.spi.security.Identity; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import io.trino.tpch.TpchTable; +import net.jodah.failsafe.Failsafe; +import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleConsumer; +import org.jdbi.v3.core.Jdbi; + +import java.time.Duration; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +import static io.airlift.testing.Closeables.closeAllSuppress; +import static io.trino.plugin.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static io.trino.testing.QueryAssertions.copyTable; +import static io.trino.testing.TestingSession.testSessionBuilder; +import static io.trino.testing.assertions.Assert.assertEquals; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toUnmodifiableSet; + +public final class RedshiftQueryRunner +{ + private static final Logger log = Logger.get(RedshiftQueryRunner.class); + private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); + private static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + private static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); + private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); + + private static final String TEST_DATABASE = "testdb"; + private static final String TEST_CATALOG = "redshift"; + static final String TEST_SCHEMA = "test_schema"; + + private static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + + private static final String CONNECTOR_NAME = "redshift"; + private static final String TPCH_CATALOG = "tpch"; + + private static final String GRANTED_USER = "alice"; + private static final String NON_GRANTED_USER = "bob"; + + private RedshiftQueryRunner() {} + + public static DistributedQueryRunner createRedshiftQueryRunner( + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + return createRedshiftQueryRunner( + createSession(), + extraProperties, + connectorProperties, + tables); + } + + public static DistributedQueryRunner createRedshiftQueryRunner( + Session session, + Map extraProperties, + Map connectorProperties, + Iterable> tables) + throws Exception + { + DistributedQueryRunner.Builder builder = DistributedQueryRunner.builder(session); + extraProperties.forEach(builder::addExtraProperty); + DistributedQueryRunner runner = builder.build(); + try { + runner.installPlugin(new TpchPlugin()); + runner.createCatalog(TPCH_CATALOG, "tpch", Map.of()); + + Map properties = new HashMap<>(connectorProperties); + properties.putIfAbsent("connection-url", JDBC_URL); + properties.putIfAbsent("connection-user", JDBC_USER); + properties.putIfAbsent("connection-password", JDBC_PASSWORD); + + runner.installPlugin(new RedshiftPlugin()); + runner.createCatalog(TEST_CATALOG, CONNECTOR_NAME, properties); + + executeInRedshift("CREATE SCHEMA IF NOT EXISTS " + TEST_SCHEMA); + createUserIfNotExists(NON_GRANTED_USER, JDBC_PASSWORD); + createUserIfNotExists(GRANTED_USER, JDBC_PASSWORD); + + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON DATABASE %s TO %s", TEST_DATABASE, GRANTED_USER)); + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + + provisionTables(session, runner, tables); + + // This step is necessary for product tests + executeInRedshiftWithRetry(format("GRANT ALL PRIVILEGES ON ALL TABLES IN SCHEMA %s TO %s", TEST_SCHEMA, GRANTED_USER)); + } + catch (Throwable e) { + closeAllSuppress(e, runner); + throw e; + } + return runner; + } + + private static Session createSession() + { + return createSession(GRANTED_USER); + } + + private static Session createSession(String user) + { + return testSessionBuilder() + .setCatalog(TEST_CATALOG) + .setSchema(TEST_SCHEMA) + .setIdentity(Identity.ofUser(user)) + .build(); + } + + private static void createUserIfNotExists(String user, String password) + { + try { + executeInRedshift("CREATE USER " + user + " PASSWORD " + "'" + password + "'"); + } + catch (Exception e) { + // if user already exists, swallow the exception + if (!e.getMessage().matches(".*user \"" + user + "\" already exists.*")) { + throw e; + } + } + } + + private static void executeInRedshiftWithRetry(String sql) + { + Failsafe.with(new RetryPolicy<>() + .handleIf(e -> e.getMessage().matches(".* concurrent transaction .*")) + .withDelay(Duration.ofSeconds(10)) + .withMaxRetries(3)) + .run(() -> executeInRedshift(sql)); + } + + public static void executeInRedshift(String sql, Object... parameters) + { + executeInRedshift(handle -> handle.execute(sql, parameters)); + } + + private static void executeInRedshift(HandleConsumer consumer) + throws E + { + Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(consumer.asCallback()); + } + + private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) + { + Set existingTables = queryRunner.listTables(session, session.getCatalog().orElseThrow(), session.getSchema().orElseThrow()) + .stream() + .map(QualifiedObjectName::getObjectName) + .collect(toUnmodifiableSet()); + + Streams.stream(tables) + .map(table -> table.getTableName().toLowerCase(ENGLISH)) + .filter(name -> !existingTables.contains(name)) + .forEach(name -> copyFromS3(queryRunner, session, name)); + + for (TpchTable tpchTable : tables) { + verifyLoadedDataHasSameSchema(session, queryRunner, tpchTable); + } + } + + private static void copyFromS3(QueryRunner queryRunner, Session session, String name) + { + String s3Path = format("%s/%s/%s.parquet", S3_TPCH_TABLES_ROOT, TPCH_CATALOG, name); + log.info("Creating table %s in Redshift copying from %s", name, s3Path); + + // Create table in ephemeral Redshift cluster with no data + String createTableSql = format("CREATE TABLE %s.%s.%s AS ", session.getCatalog().orElseThrow(), session.getSchema().orElseThrow(), name) + + format("SELECT * FROM %s.%s.%s WITH NO DATA", TPCH_CATALOG, TINY_SCHEMA_NAME, name); + queryRunner.execute(session, createTableSql); + + // Copy data from S3 bucket to ephemeral Redshift + String copySql = "COPY " + TEST_SCHEMA + "." + name + + " FROM '" + s3Path + "'" + + " IAM_ROLE '" + IAM_ROLE + "'" + + " FORMAT PARQUET"; + executeInRedshiftWithRetry(copySql); + } + + private static void copyFromTpchCatalog(QueryRunner queryRunner, Session session, String name) + { + // This function exists in case we need to copy data from the TPCH catalog rather than S3, + // such as moving to a new AWS account or if the schema changes. We can swap this method out for + // copyFromS3 in provisionTables and then export the data again to S3. + copyTable(queryRunner, TPCH_CATALOG, TINY_SCHEMA_NAME, name, session); + } + + private static void verifyLoadedDataHasSameSchema(Session session, QueryRunner queryRunner, TpchTable tpchTable) + { + // We want to verify that the loaded data has the same schema as if we created a fresh table from the TPC-H catalog + // If this assertion fails, we may need to recreate the Redshift tables from the TPC-H catalog and unload the data to S3 + try { + long expectedCount = (long) queryRunner.execute("SELECT count(*) FROM " + format("%s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())).getOnlyValue(); + long actualCount = (long) queryRunner.execute( + "SELECT count(*) FROM " + format( + "%s.%s.%s", + session.getCatalog().orElseThrow(), + session.getSchema().orElseThrow(), + tpchTable.getTableName())).getOnlyValue(); + + if (expectedCount != actualCount) { + throw new RuntimeException(format("Table %s is not loaded correctly. Expected %s rows got %s", tpchTable.getTableName(), expectedCount, actualCount)); + } + + log.info("Checking column types on table %s", tpchTable.getTableName()); + MaterializedResult expectedColumns = queryRunner.execute(format("DESCRIBE %s.%s.%s", TPCH_CATALOG, TINY_SCHEMA_NAME, tpchTable.getTableName())); + MaterializedResult actualColumns = queryRunner.execute("DESCRIBE " + tpchTable.getTableName()); + assertEquals(actualColumns, expectedColumns); + } + catch (Exception e) { + throw new RuntimeException("Failed to assert columns for TPC-H table " + tpchTable.getTableName(), e); + } + } + + /** + * Get the named system property, throwing an exception if it is not set. + */ + private static String requireSystemProperty(String property) + { + return requireNonNull(System.getProperty(property), property + " is not set"); + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + DistributedQueryRunner queryRunner = createRedshiftQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of(), + ImmutableList.of()); + + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java new file mode 100644 index 000000000000..14bf231c048c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -0,0 +1,193 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConnectorTest; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingConnectorBehavior; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.tpch.TpchTable; +import org.testng.SkipException; +import org.testng.annotations.Test; + +import java.util.Optional; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftConnectorTest + extends BaseJdbcConnectorTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + // NOTE this can cause tests to time-out if larger tables like + // lineitem and orders need to be re-created. + TpchTable.getTables()); + } + + @Override + @SuppressWarnings("DuplicateBranchesInSwitch") // options here are grouped per-feature + protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) + { + switch (connectorBehavior) { + case SUPPORTS_DELETE: + case SUPPORTS_AGGREGATION_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_TOPN_PUSHDOWN: + case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: + return false; + + case SUPPORTS_COMMENT_ON_TABLE: + case SUPPORTS_ADD_COLUMN_WITH_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: + case SUPPORTS_CREATE_TABLE_WITH_COLUMN_COMMENT: + return false; + + case SUPPORTS_ARRAY: + case SUPPORTS_ROW_TYPE: + return false; + + case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: + return false; + + default: + return super.hasBehavior(connectorBehavior); + } + } + + @Override + protected TestTable createTableWithDefaultColumns() + { + return new TestTable( + onRemoteDatabase(), + format("%s.test_table_with_default_columns", TEST_SCHEMA), + "(col_required BIGINT NOT NULL," + + "col_nullable BIGINT," + + "col_default BIGINT DEFAULT 43," + + "col_nonnull_default BIGINT NOT NULL DEFAULT 42," + + "col_required2 BIGINT NOT NULL)"); + } + + @Override + protected Optional filterDataMappingSmokeTestData(DataMappingTestSetup dataMappingTestSetup) + { + String typeName = dataMappingTestSetup.getTrinoTypeName(); + if ("date".equals(typeName)) { + if (dataMappingTestSetup.getSampleValueLiteral().equals("DATE '1582-10-05'")) { + return Optional.empty(); + } + } + if ("tinyint".equals(typeName) || typeName.startsWith("time") || "varbinary".equals(typeName)) { + return Optional.empty(); + } + return Optional.of(dataMappingTestSetup); + } + + /** + * Overridden due to Redshift not supporting non-ASCII characters in CHAR. + */ + @Override + public void testCreateTableAsSelectWithUnicode() + { + assertThatThrownBy(super::testCreateTableAsSelectWithUnicode) + .hasStackTraceContaining("Value too long for character type"); + // NOTE we add a copy of the above using VARCHAR which supports non-ASCII characters + assertCreateTableAsSelect( + "SELECT CAST('\u2603' AS VARCHAR) unicode", + "SELECT 1"); + } + + @Override + @Test + public void testReadMetadataWithRelationsConcurrentModifications() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + public void testInsertRowConcurrently() + { + throw new SkipException("Test fails with a timeout sometimes and is flaky"); + } + + @Override + protected String errorMessageForInsertIntoNotNullColumn(String columnName) + { + return format("(?s).*Cannot insert a NULL value into column %s.*", columnName); + } + + @Override + public void testCreateSchemaWithLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testRenameSchemaToLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testCreateTableWithLongTableName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testRenameTableToLongTableName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testCreateTableWithLongColumnName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testAlterTableAddLongColumnName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + public void testAlterTableRenameColumnToLongName() + { + throw new SkipException("Long name checks not implemented"); + } + + @Override + protected SqlExecutor onRemoteDatabase() + { + return RedshiftQueryRunner::executeInRedshift; + } + + @Test + @Override + public void testAddNotNullColumnToNonEmptyTable() + { + throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); + } +} From 0312463d78c43621fb31526f0d472cc12560afa3 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 16:57:55 -0800 Subject: [PATCH 02/24] Add Redshift schema, table, and column length checks --- .../trino/plugin/redshift/RedshiftClient.java | 31 ++++++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 32 ++++++++----------- 2 files changed, 45 insertions(+), 18 deletions(-) diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index af0078309e9f..57f43820c501 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -35,6 +35,7 @@ import javax.inject.Inject; import java.sql.Connection; +import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; @@ -130,6 +131,36 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) return statement; } + @Override + protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) + throws SQLException + { + // Redshift truncates schema name to 127 chars silently + if (schemaName.length() > databaseMetadata.getMaxSchemaNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Schema name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxSchemaNameLength(), schemaName.length())); + } + } + + @Override + protected void verifyTableName(DatabaseMetaData databaseMetadata, String tableName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (tableName.length() > databaseMetadata.getMaxTableNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Table name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxTableNameLength(), tableName.length())); + } + } + + @Override + protected void verifyColumnName(DatabaseMetaData databaseMetadata, String columnName) + throws SQLException + { + // Redshift truncates table name to 127 chars silently + if (columnName.length() > databaseMetadata.getMaxColumnNameLength()) { + throw new TrinoException(NOT_SUPPORTED, "Column name must be shorter than or equal to '%d' characters but got '%d'".formatted(databaseMetadata.getMaxColumnNameLength(), columnName.length())); + } + } + @Override public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) { diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 14bf231c048c..0bd3f862e91e 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -24,10 +24,12 @@ import org.testng.annotations.Test; import java.util.Optional; +import java.util.OptionalInt; import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; import static java.lang.String.format; +import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestRedshiftConnectorTest @@ -137,45 +139,39 @@ protected String errorMessageForInsertIntoNotNullColumn(String columnName) } @Override - public void testCreateSchemaWithLongName() + protected OptionalInt maxSchemaNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testRenameSchemaToLongName() + protected void verifySchemaNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Schema name must be shorter than or equal to '127' characters but got '128'"); } @Override - public void testCreateTableWithLongTableName() + protected OptionalInt maxTableNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testRenameTableToLongTableName() + protected void verifyTableNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Table name must be shorter than or equal to '127' characters but got '128'"); } @Override - public void testCreateTableWithLongColumnName() + protected OptionalInt maxColumnNameLength() { - throw new SkipException("Long name checks not implemented"); + return OptionalInt.of(127); } @Override - public void testAlterTableAddLongColumnName() + protected void verifyColumnNameLengthFailurePermissible(Throwable e) { - throw new SkipException("Long name checks not implemented"); - } - - @Override - public void testAlterTableRenameColumnToLongName() - { - throw new SkipException("Long name checks not implemented"); + assertThat(e).hasMessage("Column name must be shorter than or equal to '127' characters but got '128'"); } @Override From 7eab69a7df84ba2195229eb175f76fcd16a0d899 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 17:51:24 -0800 Subject: [PATCH 03/24] Implement proper type mapping for Redshift --- plugin/trino-redshift/pom.xml | 17 +- .../trino/plugin/redshift/RedshiftClient.java | 446 +++++++- .../plugin/redshift/RedshiftClientModule.java | 15 +- .../plugin/redshift/RedshiftQueryRunner.java | 17 +- .../redshift/TestRedshiftConnectorTest.java | 63 +- .../redshift/TestRedshiftTypeMapping.java | 994 ++++++++++++++++++ 6 files changed, 1528 insertions(+), 24 deletions(-) create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index 30fea843c921..ebcd96279cda 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,12 +23,22 @@ trino-base-jdbc + + io.airlift + configuration + + com.amazon.redshift redshift-jdbc42 2.1.0.9 + + com.google.guava + guava + + com.google.inject guice @@ -52,12 +62,6 @@ runtime - - com.google.guava - guava - runtime - - net.jodah failsafe @@ -173,6 +177,7 @@ **/TestRedshiftConnectorTest.java + **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 57f43820c501..05d27ab5a0f0 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -13,6 +13,10 @@ */ package io.trino.plugin.redshift; +import com.amazon.redshift.jdbc.RedshiftPreparedStatement; +import com.amazon.redshift.util.RedshiftObject; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; @@ -20,33 +24,58 @@ import io.trino.plugin.jdbc.JdbcColumnHandle; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.LongWriteFunction; +import io.trino.plugin.jdbc.ObjectReadFunction; +import io.trino.plugin.jdbc.ObjectWriteFunction; import io.trino.plugin.jdbc.QueryBuilder; +import io.trino.plugin.jdbc.SliceWriteFunction; +import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.type.CharType; +import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; import io.trino.spi.type.Decimals; +import io.trino.spi.type.Int128; +import io.trino.spi.type.LongTimestamp; +import io.trino.spi.type.LongTimestampWithTimeZone; +import io.trino.spi.type.TimeType; +import io.trino.spi.type.TimestampType; +import io.trino.spi.type.TimestampWithTimeZoneType; import io.trino.spi.type.Type; import io.trino.spi.type.VarcharType; import javax.inject.Inject; +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; import java.sql.Connection; import java.sql.DatabaseMetaData; import java.sql.PreparedStatement; import java.sql.ResultSet; import java.sql.SQLException; import java.sql.Types; +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.OffsetDateTime; +import java.time.ZoneOffset; +import java.time.format.DateTimeFormatter; +import java.time.format.DateTimeFormatterBuilder; import java.util.Optional; import java.util.function.BiFunction; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.charReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.charWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.dateColumnMappingUsingSqlDate; import static io.trino.plugin.jdbc.StandardColumnMappings.dateWriteFunctionUsingSqlDate; @@ -57,6 +86,7 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.doubleWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.integerColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.integerWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalReadFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.longDecimalWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.realColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.realWriteFunction; @@ -68,29 +98,97 @@ import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.tinyintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryColumnMapping; -import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryWriteFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varbinaryReadFunction; +import static io.trino.plugin.jdbc.StandardColumnMappings.varcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharWriteFunction; import static io.trino.plugin.jdbc.TypeHandlingJdbcSessionProperties.getUnsupportedTypeHandling; import static io.trino.plugin.jdbc.UnsupportedTypeHandling.CONVERT_TO_VARCHAR; +import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; import static io.trino.spi.StandardErrorCode.NOT_SUPPORTED; import static io.trino.spi.type.BigintType.BIGINT; import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateTimeEncoding.unpackMillisUtc; import static io.trino.spi.type.DateType.DATE; import static io.trino.spi.type.DecimalType.createDecimalType; import static io.trino.spi.type.DoubleType.DOUBLE; import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.LongTimestampWithTimeZone.fromEpochSecondsAndFraction; import static io.trino.spi.type.RealType.REAL; import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.TIME_MICROS; +import static io.trino.spi.type.TimeZoneKey.UTC_KEY; +import static io.trino.spi.type.TimestampType.TIMESTAMP_MICROS; import static io.trino.spi.type.TimestampType.TIMESTAMP_MILLIS; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.Timestamps.MICROSECONDS_PER_DAY; +import static io.trino.spi.type.Timestamps.MILLISECONDS_PER_SECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.NANOSECONDS_PER_MILLISECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_MICROSECOND; +import static io.trino.spi.type.Timestamps.PICOSECONDS_PER_NANOSECOND; +import static io.trino.spi.type.Timestamps.roundDiv; import static io.trino.spi.type.TinyintType.TINYINT; import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static java.lang.Math.floorDiv; +import static java.lang.Math.floorMod; import static java.lang.Math.max; +import static java.lang.Math.min; import static java.lang.String.format; +import static java.math.RoundingMode.UNNECESSARY; +import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; public class RedshiftClient extends BaseJdbcClient { + /** + * Redshift does not handle values larger than 64 bits for + * {@code DECIMAL(19, s)}. It supports the full range of values for all + * other precisions. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_DECIMAL_CUTOFF_PRECISION = 19; + + static final int REDSHIFT_MAX_DECIMAL_PRECISION = 38; + + /** + * Maximum size of a {@link BigInteger} storing a Redshift {@code DECIMAL} + * with precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + */ + // actual value is 63 + private static final int REDSHIFT_DECIMAL_CUTOFF_BITS = BigInteger.valueOf(Long.MAX_VALUE).bitLength(); + + /** + * Maximum size of a Redshift CHAR column. + * + * @see + * Redshift documentation + */ + private static final int REDSHIFT_MAX_CHAR = 4096; + + /** + * Maximum size of a Redshift VARCHAR column. + * + * @see + * Redshift documentation + */ + static final int REDSHIFT_MAX_VARCHAR = 65535; + + private static final DateTimeFormatter DATE_FORMATTER = DateTimeFormatter.ofPattern("yyy-MM-dd[ G]"); + private static final DateTimeFormatter DATE_TIME_FORMATTER = new DateTimeFormatterBuilder() + .appendPattern("yyy-MM-dd HH:mm:ss") + .optionalStart() + .appendFraction(NANO_OF_SECOND, 0, 6, true) + .optionalEnd() + .appendPattern("[ G]") + .toFormatter(); + private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + @Inject public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) { @@ -162,7 +260,96 @@ protected void verifyColumnName(DatabaseMetaData databaseMetadata, String column } @Override - public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) + public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle type) + { + Optional mapping = getForcedMappingToVarchar(type); + if (mapping.isPresent()) { + return mapping; + } + + if ("time".equals(type.getJdbcTypeName().orElse(""))) { + return Optional.of(ColumnMapping.longMapping( + TIME_MICROS, + RedshiftClient::readTime, + RedshiftClient::writeTime)); + } + + switch (type.getJdbcType()) { + case Types.BIT: // Redshift uses this for booleans + return Optional.of(booleanColumnMapping()); + + // case Types.TINYINT: -- Redshift doesn't support tinyint + case Types.SMALLINT: + return Optional.of(smallintColumnMapping()); + case Types.INTEGER: + return Optional.of(integerColumnMapping()); + case Types.BIGINT: + return Optional.of(bigintColumnMapping()); + + case Types.REAL: + return Optional.of(realColumnMapping()); + case Types.DOUBLE: + return Optional.of(doubleColumnMapping()); + + case Types.NUMERIC: { + int precision = type.getRequiredColumnSize(); + int scale = type.getRequiredDecimalDigits(); + DecimalType decimalType = createDecimalType(precision, scale); + if (precision == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + return Optional.of(ColumnMapping.objectMapping( + decimalType, + longDecimalReadFunction(decimalType), + writeDecimalAtRedshiftCutoff(scale))); + } + return Optional.of(decimalColumnMapping(decimalType, UNNECESSARY)); + } + + case Types.CHAR: + CharType charType = createCharType(type.getRequiredColumnSize()); + return Optional.of(ColumnMapping.sliceMapping( + charType, + charReadFunction(charType), + RedshiftClient::writeChar)); + + case Types.VARCHAR: { + int length = type.getRequiredColumnSize(); + return Optional.of(varcharColumnMapping( + length < VarcharType.MAX_LENGTH + ? createVarcharType(length) + : createUnboundedVarcharType(), + true)); + } + + case Types.LONGVARBINARY: + return Optional.of(ColumnMapping.sliceMapping( + VARBINARY, + varbinaryReadFunction(), + varbinaryWriteFunction())); + + case Types.DATE: + return Optional.of(ColumnMapping.longMapping( + DATE, + RedshiftClient::readDate, + RedshiftClient::writeDate)); + + case Types.TIMESTAMP: + return Optional.of(ColumnMapping.longMapping( + TIMESTAMP_MICROS, + RedshiftClient::readTimestamp, + RedshiftClient::writeShortTimestamp)); + + case Types.TIMESTAMP_WITH_TIMEZONE: + return Optional.of(ColumnMapping.objectMapping( + TIMESTAMP_TZ_MICROS, + longTimestampWithTimeZoneReadFunction(), + longTimestampWithTimeZoneWriteFunction())); + } + + // Fall back to default behavior + return legacyToColumnMapping(session, type); + } + + private Optional legacyToColumnMapping(ConnectorSession session, JdbcTypeHandle typeHandle) { Optional mapping = getForcedMappingToVarchar(typeHandle); if (mapping.isPresent()) { @@ -181,6 +368,99 @@ public Optional toColumnMapping(ConnectorSession session, Connect @Override public WriteMapping toWriteMapping(ConnectorSession session, Type type) { + if (BOOLEAN.equals(type)) { + return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); + } + if (TINYINT.equals(type)) { + // Redshift doesn't have tinyint + return WriteMapping.longMapping("smallint", tinyintWriteFunction()); + } + if (SMALLINT.equals(type)) { + return WriteMapping.longMapping("smallint", smallintWriteFunction()); + } + if (INTEGER.equals(type)) { + return WriteMapping.longMapping("integer", integerWriteFunction()); + } + if (BIGINT.equals(type)) { + return WriteMapping.longMapping("bigint", bigintWriteFunction()); + } + if (REAL.equals(type)) { + return WriteMapping.longMapping("real", realWriteFunction()); + } + if (DOUBLE.equals(type)) { + return WriteMapping.doubleMapping("double precision", doubleWriteFunction()); + } + + if (type instanceof DecimalType decimal) { + if (decimal.getPrecision() == REDSHIFT_DECIMAL_CUTOFF_PRECISION) { + // See doc for REDSHIFT_DECIMAL_CUTOFF_PRECISION + return WriteMapping.objectMapping( + format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()), + writeDecimalAtRedshiftCutoff(decimal.getScale())); + } + String name = format("decimal(%s, %s)", decimal.getPrecision(), decimal.getScale()); + return decimal.isShort() + ? WriteMapping.longMapping(name, shortDecimalWriteFunction(decimal)) + : WriteMapping.objectMapping(name, longDecimalWriteFunction(decimal)); + } + + if (type instanceof CharType) { + // Redshift has no unbounded text/binary types, so if a CHAR is too + // large for Redshift, we write as VARCHAR. If too large for that, + // we use the largest VARCHAR Redshift supports. + int size = ((CharType) type).getLength(); + if (size <= REDSHIFT_MAX_CHAR) { + return WriteMapping.sliceMapping( + format("char(%d)", size), + RedshiftClient::writeChar); + } + int redshiftVarcharWidth = min(size, REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping( + format("varchar(%d)", redshiftVarcharWidth), + (statement, index, value) -> writeCharAsVarchar(statement, index, value, redshiftVarcharWidth)); + } + + if (type instanceof VarcharType) { + // Redshift has no unbounded text/binary types, so if a VARCHAR is + // larger than Redshift's limit, we make it that big instead. + int size = ((VarcharType) type).getLength() + .filter(n -> n <= REDSHIFT_MAX_VARCHAR) + .orElse(REDSHIFT_MAX_VARCHAR); + return WriteMapping.sliceMapping(format("varchar(%d)", size), varcharWriteFunction()); + } + + if (VARBINARY.equals(type)) { + return WriteMapping.sliceMapping("varbyte", varbinaryWriteFunction()); + } + + if (DATE.equals(type)) { + return WriteMapping.longMapping("date", RedshiftClient::writeDate); + } + + if (type instanceof TimeType) { + return WriteMapping.longMapping("time", RedshiftClient::writeTime); + } + + if (type instanceof TimestampType) { + if (((TimestampType) type).isShort()) { + return WriteMapping.longMapping( + "timestamp", + RedshiftClient::writeShortTimestamp); + } + return WriteMapping.objectMapping( + "timestamp", + LongTimestamp.class, + RedshiftClient::writeLongTimestamp); + } + + if (type instanceof TimestampWithTimeZoneType timestampWithTimeZoneType) { + if (timestampWithTimeZoneType.getPrecision() <= TimestampWithTimeZoneType.MAX_SHORT_PRECISION) { + return WriteMapping.longMapping("timestamptz", shortTimestampWithTimeZoneWriteFunction()); + } + return WriteMapping.objectMapping("timestamptz", longTimestampWithTimeZoneWriteFunction()); + } + + // Fall back to legacy behavior return legacyToWriteMapping(type); } @@ -214,9 +494,168 @@ private static String redshiftVarcharLiteral(String value) return "'" + value.replace("'", "''").replace("\\", "\\\\") + "'"; } + private static ObjectReadFunction longTimestampWithTimeZoneReadFunction() + { + return ObjectReadFunction.of( + LongTimestampWithTimeZone.class, + (resultSet, columnIndex) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + OffsetDateTime offsetDateTime = resultSet.getObject(columnIndex, OffsetDateTime.class); + return fromEpochSecondsAndFraction( + offsetDateTime.toEpochSecond(), + (long) offsetDateTime.getNano() * PICOSECONDS_PER_NANOSECOND, + UTC_KEY); + }); + } + + private static LongWriteFunction shortTimestampWithTimeZoneWriteFunction() + { + return (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long millisUtc = unpackMillisUtc(value); + long epochSeconds = floorDiv(millisUtc, MILLISECONDS_PER_SECOND); + int nanosOfSecond = floorMod(millisUtc, MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND; + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }; + } + + private static ObjectWriteFunction longTimestampWithTimeZoneWriteFunction() + { + return ObjectWriteFunction.of( + LongTimestampWithTimeZone.class, + (statement, index, value) -> { + // Redshift does not store zone information in "timestamp with time zone" data type + long epochSeconds = floorDiv(value.getEpochMillis(), MILLISECONDS_PER_SECOND); + long nanosOfSecond = ((long) floorMod(value.getEpochMillis(), MILLISECONDS_PER_SECOND) * NANOSECONDS_PER_MILLISECOND) + + (value.getPicosOfMilli() / PICOSECONDS_PER_NANOSECOND); + OffsetDateTime offsetDateTime = OffsetDateTime.ofInstant(Instant.ofEpochSecond(epochSeconds, nanosOfSecond), UTC_KEY.getZoneId()); + verifySupportedTimestampWithTimeZone(offsetDateTime); + statement.setObject(index, offsetDateTime); + }); + } + + private static void verifySupportedTimestampWithTimeZone(OffsetDateTime value) + { + if (value.isBefore(REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ)) { + DateTimeFormatter format = DateTimeFormatter.ofPattern("uuuu-MM-dd HH:mm:ss.SSSSSS"); + throw new TrinoException( + INVALID_ARGUMENTS, + format("Minimum timestamp with time zone in Redshift is %s: %s", REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ.format(format), value.format(format))); + } + } + + /** + * Decimal write function for precision {@link #REDSHIFT_DECIMAL_CUTOFF_PRECISION}. + * Ensures that values fit in 8 bytes. + */ + private static ObjectWriteFunction writeDecimalAtRedshiftCutoff(int scale) + { + return ObjectWriteFunction.of( + Int128.class, + (statement, index, decimal) -> { + BigInteger unscaled = decimal.toBigInteger(); + if (unscaled.bitLength() > REDSHIFT_DECIMAL_CUTOFF_BITS) { + throw new TrinoException(JDBC_NON_TRANSIENT_ERROR, format( + "Value out of range for Redshift DECIMAL(%d, %d)", + REDSHIFT_DECIMAL_CUTOFF_PRECISION, + scale)); + } + MathContext precision = new MathContext(REDSHIFT_DECIMAL_CUTOFF_PRECISION); + statement.setBigDecimal(index, new BigDecimal(unscaled, scale, precision)); + }); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but restrict to + * ASCII because Redshift only allows ASCII in {@code CHAR} values. + */ + private static void writeChar(PreparedStatement statement, int index, Slice slice) + throws SQLException + { + String value = slice.toStringUtf8(); + if (!CharMatcher.ascii().matchesAllOf(value)) { + throw new TrinoException( + JDBC_NON_TRANSIENT_ERROR, + format("Value for Redshift CHAR must be ASCII, but found '%s'", value)); + } + statement.setString(index, slice.toStringAscii()); + } + + /** + * Like {@link StandardColumnMappings#charWriteFunction}, but pads + * the value with spaces to simulate {@code CHAR} semantics. + */ + private static void writeCharAsVarchar(PreparedStatement statement, int index, Slice slice, int columnLength) + throws SQLException + { + // Redshift counts varchar size limits in UTF-8 bytes, so this may make the string longer than + // the limit, but Redshift also truncates extra trailing spaces, so that doesn't cause any problems. + statement.setString(index, Chars.padSpaces(slice, columnLength).toStringUtf8()); + } + + private static void writeDate(PreparedStatement statement, int index, long day) + throws SQLException + { + statement.setObject(index, new RedshiftObject("date", DATE_FORMATTER.format(LocalDate.ofEpochDay(day)))); + } + + private static long readDate(ResultSet results, int index) + throws SQLException + { + // Reading date as string to workaround issues around julian->gregorian calendar switch + return LocalDate.parse(results.getString(index), DATE_FORMATTER).toEpochDay(); + } + + /** + * Write time with microsecond precision + */ + private static void writeTime(PreparedStatement statement, int index, long picos) + throws SQLException + { + statement.setObject(index, LocalTime.ofNanoOfDay((roundDiv(picos, PICOSECONDS_PER_MICROSECOND) % MICROSECONDS_PER_DAY) * NANOSECONDS_PER_MICROSECOND)); + } + + /** + * Read a time value with microsecond precision + */ + private static long readTime(ResultSet results, int index) + throws SQLException + { + return results.getObject(index, LocalTime.class).toNanoOfDay() * PICOSECONDS_PER_NANOSECOND; + } + + private static void writeShortTimestamp(PreparedStatement statement, int index, long epochMicros) + throws SQLException + { + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static void writeLongTimestamp(PreparedStatement statement, int index, Object value) + throws SQLException + { + LongTimestamp timestamp = (LongTimestamp) value; + long epochMicros = timestamp.getEpochMicros(); + if (timestamp.getPicosOfMicro() >= PICOSECONDS_PER_MICROSECOND / 2) { + epochMicros += 1; // Add one micro if picos round up + } + statement.setObject(index, new RedshiftObject("timestamp", DATE_TIME_FORMATTER.format(StandardColumnMappings.fromTrinoTimestamp(epochMicros)))); + } + + private static long readTimestamp(ResultSet results, int index) + throws SQLException + { + return StandardColumnMappings.toTrinoTimestamp(TIMESTAMP_MICROS, results.getObject(index, LocalDateTime.class)); + } + + private static SliceWriteFunction varbinaryWriteFunction() + { + return (statement, index, value) -> statement.unwrap(RedshiftPreparedStatement.class).setVarbyte(index, value.getBytes()); + } + private static Optional legacyDefaultColumnMapping(JdbcTypeHandle typeHandle) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated StandardColumnMappings.legacyDefaultColumnMapping() switch (typeHandle.getJdbcType()) { case Types.BIT: @@ -282,7 +721,6 @@ private static Optional legacyDefaultColumnMapping(JdbcTypeHandle private static WriteMapping legacyToWriteMapping(Type type) { - // TODO (https://github.com/trinodb/trino/issues/497) Implement proper type mapping and add test // This method is copied from deprecated BaseJdbcClient.legacyToWriteMapping() if (type == BOOLEAN) { return WriteMapping.booleanMapping("boolean", booleanWriteFunction()); diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index dccc6fabff94..13635c88f69b 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -15,12 +15,12 @@ import com.amazon.redshift.Driver; import com.google.inject.Binder; -import com.google.inject.Module; import com.google.inject.Provides; -import com.google.inject.Scopes; import com.google.inject.Singleton; +import io.airlift.configuration.AbstractConfigurationAwareModule; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.DecimalModule; import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; @@ -30,16 +30,19 @@ import java.util.Properties; +import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; public class RedshiftClientModule - implements Module + extends AbstractConfigurationAwareModule { @Override - public void configure(Binder binder) + public void setup(Binder binder) { - binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(Scopes.SINGLETON); - newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(Scopes.SINGLETON); + binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); + newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + + install(new DecimalModule()); } @Singleton diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java index 61859dd9b52c..3e96738e7ba1 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/RedshiftQueryRunner.java @@ -28,6 +28,7 @@ import io.trino.tpch.TpchTable; import net.jodah.failsafe.Failsafe; import net.jodah.failsafe.RetryPolicy; +import org.jdbi.v3.core.HandleCallback; import org.jdbi.v3.core.HandleConsumer; import org.jdbi.v3.core.Jdbi; @@ -50,8 +51,8 @@ public final class RedshiftQueryRunner { private static final Logger log = Logger.get(RedshiftQueryRunner.class); private static final String JDBC_ENDPOINT = requireSystemProperty("test.redshift.jdbc.endpoint"); - private static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); - private static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); + static final String JDBC_USER = requireSystemProperty("test.redshift.jdbc.user"); + static final String JDBC_PASSWORD = requireSystemProperty("test.redshift.jdbc.password"); private static final String S3_TPCH_TABLES_ROOT = requireSystemProperty("test.redshift.s3.tpch.tables.root"); private static final String IAM_ROLE = requireSystemProperty("test.redshift.iam.role"); @@ -59,7 +60,7 @@ public final class RedshiftQueryRunner private static final String TEST_CATALOG = "redshift"; static final String TEST_SCHEMA = "test_schema"; - private static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; + static final String JDBC_URL = "jdbc:redshift://" + JDBC_ENDPOINT + TEST_DATABASE; private static final String CONNECTOR_NAME = "redshift"; private static final String TPCH_CATALOG = "tpch"; @@ -164,10 +165,16 @@ public static void executeInRedshift(String sql, Object... parameters) executeInRedshift(handle -> handle.execute(sql, parameters)); } - private static void executeInRedshift(HandleConsumer consumer) + public static void executeInRedshift(HandleConsumer consumer) throws E { - Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(consumer.asCallback()); + executeWithRedshift(consumer.asCallback()); + } + + public static T executeWithRedshift(HandleCallback callback) + throws E + { + return Jdbi.create(JDBC_URL, JDBC_USER, JDBC_PASSWORD).withHandle(callback); } private static synchronized void provisionTables(Session session, QueryRunner queryRunner, Iterable> tables) diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 0bd3f862e91e..0d46c93a6fc9 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -21,6 +21,7 @@ import io.trino.testing.sql.TestTable; import io.trino.tpch.TpchTable; import org.testng.SkipException; +import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.Optional; @@ -28,6 +29,7 @@ import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -99,9 +101,6 @@ protected Optional filterDataMappingSmokeTestData(DataMapp return Optional.empty(); } } - if ("tinyint".equals(typeName) || typeName.startsWith("time") || "varbinary".equals(typeName)) { - return Optional.empty(); - } return Optional.of(dataMappingTestSetup); } @@ -119,6 +118,39 @@ public void testCreateTableAsSelectWithUnicode() "SELECT 1"); } + @Test(dataProvider = "redshiftTypeToTrinoTypes") + public void testReadFromLateBindingView(String redshiftType, String trinoType) + { + try (TestView view = new TestView(onRemoteDatabase(), TEST_SCHEMA + ".late_schema_binding", "SELECT CAST(NULL AS %s) AS value WITH NO SCHEMA BINDING".formatted(redshiftType))) { + assertThat(query("SELECT value, true FROM %s WHERE value IS NULL".formatted(view.getName()))) + .projected(1) + .containsAll("VALUES (true)"); + + assertThat(query("SHOW COLUMNS FROM %s LIKE 'value'".formatted(view.getName()))) + .projected(1) + .skippingTypesCheck() + .containsAll("VALUES ('%s')".formatted(trinoType)); + } + } + + @DataProvider + public Object[][] redshiftTypeToTrinoTypes() + { + return new Object[][] { + {"SMALLINT", "smallint"}, + {"INTEGER", "integer"}, + {"BIGINT", "bigint"}, + {"DECIMAL", "decimal(18,0)"}, + {"REAL", "real"}, + {"DOUBLE PRECISION", "double"}, + {"BOOLEAN", "boolean"}, + {"CHAR(1)", "char(1)"}, + {"VARCHAR(1)", "varchar(1)"}, + {"TIME", "time(6)"}, + {"TIMESTAMP", "timestamp(6)"}, + {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -186,4 +218,29 @@ public void testAddNotNullColumnToNonEmptyTable() { throw new SkipException("Redshift ALTER TABLE ADD COLUMN defined as NOT NULL must have a non-null default expression"); } + + private static class TestView + implements AutoCloseable + { + private final String name; + private final SqlExecutor executor; + + public TestView(SqlExecutor executor, String namePrefix, String viewDefinition) + { + this.executor = executor; + this.name = namePrefix + "_" + randomNameSuffix(); + executor.execute("CREATE OR REPLACE VIEW " + name + " AS " + viewDefinition); + } + + @Override + public void close() + { + executor.execute("DROP VIEW " + name); + } + + public String getName() + { + return name; + } + } } diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java new file mode 100644 index 000000000000..26938c3b6532 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTypeMapping.java @@ -0,0 +1,994 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.base.Utf8; +import com.google.common.collect.ImmutableList; +import io.trino.Session; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.TestingSession; +import io.trino.testing.datatype.CreateAndInsertDataSetup; +import io.trino.testing.datatype.CreateAsSelectDataSetup; +import io.trino.testing.datatype.DataSetup; +import io.trino.testing.datatype.SqlDataTypeTest; +import io.trino.testing.sql.JdbcSqlExecutor; +import io.trino.testing.sql.SqlExecutor; +import io.trino.testing.sql.TestTable; +import io.trino.testing.sql.TrinoSqlExecutor; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.sql.SQLException; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Properties; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.function.Function; + +import static com.google.common.base.Verify.verify; +import static com.google.common.io.BaseEncoding.base16; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_VARCHAR; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.BooleanType.BOOLEAN; +import static io.trino.spi.type.CharType.createCharType; +import static io.trino.spi.type.DateType.DATE; +import static io.trino.spi.type.DecimalType.createDecimalType; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.spi.type.IntegerType.INTEGER; +import static io.trino.spi.type.RealType.REAL; +import static io.trino.spi.type.SmallintType.SMALLINT; +import static io.trino.spi.type.TimeType.createTimeType; +import static io.trino.spi.type.TimeZoneKey.getTimeZoneKey; +import static io.trino.spi.type.TimestampType.createTimestampType; +import static io.trino.spi.type.TimestampWithTimeZoneType.TIMESTAMP_TZ_MICROS; +import static io.trino.spi.type.VarbinaryType.VARBINARY; +import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static io.trino.spi.type.VarcharType.createVarcharType; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static java.lang.String.format; +import static java.nio.charset.StandardCharsets.UTF_8; +import static java.time.ZoneOffset.UTC; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.groupingBy; +import static java.util.stream.Collectors.joining; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +public class TestRedshiftTypeMapping + extends AbstractTestQueryFramework +{ + private static final ZoneId testZone = TestingSession.DEFAULT_TIME_ZONE_KEY.getZoneId(); + + private final ZoneId jvmZone = ZoneId.systemDefault(); + private final LocalDateTime timeGapInJvmZone = LocalDate.EPOCH.atStartOfDay(); + private final LocalDateTime timeDoubledInJvmZone = LocalDateTime.of(2018, 10, 28, 1, 33, 17, 456_789_000); + + // using two non-JVM zones so that we don't need to worry what the backend's system zone is + + // no DST in 1970, but has DST in later years (e.g. 2018) + private final ZoneId vilnius = ZoneId.of("Europe/Vilnius"); + private final LocalDateTime timeGapInVilnius = LocalDateTime.of(2018, 3, 25, 3, 17, 17); + private final LocalDateTime timeDoubledInVilnius = LocalDateTime.of(2018, 10, 28, 3, 33, 33, 333_333_000); + + // Size of offset changed since 1970-01-01, no DST + private final ZoneId kathmandu = ZoneId.of("Asia/Kathmandu"); + private final LocalDateTime timeGapInKathmandu = LocalDateTime.of(1986, 1, 1, 0, 13, 7); + + private final LocalDate dayOfMidnightGapInJvmZone = LocalDate.EPOCH; + private final LocalDate dayOfMidnightGapInVilnius = LocalDate.of(1983, 4, 1); + private final LocalDate dayAfterMidnightSetBackInVilnius = LocalDate.of(1983, 10, 1); + + @BeforeClass + public void checkRanges() + { + // Timestamps + checkIsGap(jvmZone, timeGapInJvmZone); + checkIsDoubled(jvmZone, timeDoubledInJvmZone); + checkIsGap(vilnius, timeGapInVilnius); + checkIsDoubled(vilnius, timeDoubledInVilnius); + checkIsGap(kathmandu, timeGapInKathmandu); + + // Times + checkIsGap(jvmZone, LocalTime.of(0, 0, 0).atDate(LocalDate.EPOCH)); + + // Dates + checkIsGap(jvmZone, dayOfMidnightGapInJvmZone.atStartOfDay()); + checkIsGap(vilnius, dayOfMidnightGapInVilnius.atStartOfDay()); + checkIsDoubled(vilnius, dayAfterMidnightSetBackInVilnius.atStartOfDay().minusNanos(1)); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), List.of()); + } + + @Test + public void testBasicTypes() + { + // Assume that if these types work at all, they have standard semantics. + SqlDataTypeTest.create() + .addRoundTrip("boolean", "true", BOOLEAN, "true") + .addRoundTrip("boolean", "false", BOOLEAN, "false") + .addRoundTrip("bigint", "123456789012", BIGINT, "123456789012") + .addRoundTrip("integer", "1234567890", INTEGER, "1234567890") + .addRoundTrip("smallint", "32456", SMALLINT, "SMALLINT '32456'") + .addRoundTrip("double", "123.45", DOUBLE, "DOUBLE '123.45'") + .addRoundTrip("real", "123.45", REAL, "REAL '123.45'") + // If we map tinyint to smallint: + .addRoundTrip("tinyint", "5", SMALLINT, "SMALLINT '5'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_basic_types")); + } + + @Test + public void testVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("varchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("varchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("varchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("varchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("varchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("varchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("varchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_varchar")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_varchar")); + } + + @Test + public void testChar() + { + SqlDataTypeTest.create() + .addRoundTrip("char(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("char(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("char(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_char")) + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_char")); + + // Test with types larger than Redshift's char(max) + SqlDataTypeTest.create() + .addRoundTrip("char(65535)", "'varchar max'", createVarcharType(65535), format("CAST('varchar max%s' AS varchar(65535))", " ".repeat(65535 - "varchar max".length()))) + .addRoundTrip("char(4136)", "'攻殻機動隊'", createVarcharType(4136), format("CAST('%s' AS varchar(4136))", padVarchar(4136).apply("攻殻機動隊"))) + .addRoundTrip("char(4104)", "'隊'", createVarcharType(4104), format("CAST('%s' AS varchar(4104))", padVarchar(4104).apply("隊"))) + .addRoundTrip("char(4112)", "'😂'", createVarcharType(4112), format("CAST('%s' AS varchar(4112))", padVarchar(4112).apply("😂"))) + .addRoundTrip("varchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("char(4106)", "'text_a'", createVarcharType(4106), format("CAST('%s' AS varchar(4106))", padVarchar(4106).apply("text_a"))) + .addRoundTrip("char(4351)", "'text_b'", createVarcharType(4351), format("CAST('%s' AS varchar(4351))", padVarchar(4351).apply("text_b"))) + .addRoundTrip("char(8192)", "'char max'", createVarcharType(8192), format("CAST('%s' AS varchar(8192))", padVarchar(8192).apply("char max"))) + .execute(getQueryRunner(), trinoCreateAsSelect("trino_test_large_char")); + } + + /** + * Test handling of data outside Redshift's normal bounds. + * + *

Redshift sometimes returns unbounded {@code VARCHAR} data, apparently + * when it returns directly from a Postgres function. + */ + @Test + public void testPostgresText() + { + try (TestView view1 = new TestView("postgres_text_view", "SELECT lpad('x', 1)"); + TestView view2 = new TestView("pg_catalog_view", "SELECT relname FROM pg_class")) { + // Test data and type from a function + assertThat(query(format("SELECT * FROM %s", view1.name))) + .matches("VALUES CAST('x' AS varchar)"); + + // Test the type of an internal table + assertThat(query(format("SELECT * FROM %s LIMIT 1", view2.name))) + .hasOutputTypes(List.of(createUnboundedVarcharType())); + } + } + + // Make sure that Redshift still maps NCHAR and NVARCHAR to CHAR and VARCHAR. + @Test + public void checkNCharAndNVarchar() + { + SqlDataTypeTest.create() + .addRoundTrip("nvarchar(65535)", "'varchar max'", createVarcharType(65535), "CAST('varchar max' AS varchar(65535))") + .addRoundTrip("nvarchar(40)", "'攻殻機動隊'", createVarcharType(40), "CAST('攻殻機動隊' AS varchar(40))") + .addRoundTrip("nvarchar(8)", "'隊'", createVarcharType(8), "CAST('隊' AS varchar(8))") + .addRoundTrip("nvarchar(16)", "'😂'", createVarcharType(16), "CAST('😂' AS varchar(16))") + .addRoundTrip("nvarchar(88)", "'Ну, погоди!'", createVarcharType(88), "CAST('Ну, погоди!' AS varchar(88))") + .addRoundTrip("nvarchar(10)", "'text_a'", createVarcharType(10), "CAST('text_a' AS varchar(10))") + .addRoundTrip("nvarchar(255)", "'text_b'", createVarcharType(255), "CAST('text_b' AS varchar(255))") + .addRoundTrip("nvarchar(4096)", "'char max'", createVarcharType(4096), "CAST('char max' AS varchar(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nvarchar")); + + SqlDataTypeTest.create() + .addRoundTrip("nchar(10)", "'text_a'", createCharType(10), "CAST('text_a' AS char(10))") + .addRoundTrip("nchar(255)", "'text_b'", createCharType(255), "CAST('text_b' AS char(255))") + .addRoundTrip("nchar(4096)", "'char max'", createCharType(4096), "CAST('char max' AS char(4096))") + .execute(getQueryRunner(), redshiftCreateAndInsert("jdbc_test_nchar")); + } + + @Test + public void testUnicodeChar() // Redshift doesn't allow multibyte chars in CHAR + { + try (TestTable table = testTable("test_multibyte_char", "(c char(32))")) { + assertQueryFails( + format("INSERT INTO %s VALUES ('\u968A')", table.getName()), + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + assertCreateFails( + "test_multibyte_char_ctas", + "AS SELECT CAST('\u968A' AS char(32)) c", + "^Value for Redshift CHAR must be ASCII, but found '\u968A'$"); + } + + // Make sure Redshift really doesn't allow multibyte characters in CHAR + @Test + public void checkUnicodeCharInRedshift() + { + try (TestTable table = testTable("check_multibyte_char", "(c char(32))")) { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("INSERT INTO %s VALUES ('\u968a')", table.getName()))) + .getCause() + .isInstanceOf(SQLException.class) + .hasMessageContaining("CHAR string contains invalid ASCII character"); + } + } + + @Test + public void testOversizedCharacterTypes() + { + // Test that character types too large for Redshift map to the maximum size + SqlDataTypeTest.create() + .addRoundTrip("varchar", "'unbounded'", createVarcharType(65535), "CAST('unbounded' AS varchar(65535))") + .addRoundTrip(format("varchar(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized varchar'", createVarcharType(65535), "CAST('oversized varchar' AS varchar(65535))") + .addRoundTrip(format("char(%s)", REDSHIFT_MAX_VARCHAR + 1), "'oversized char'", createVarcharType(65535), format("CAST('%s' AS varchar(65535))", padVarchar(65535).apply("oversized char"))) + .execute(getQueryRunner(), trinoCreateAsSelect("oversized_character_types")); + } + + @Test + public void testVarbinary() + { + // Redshift's VARBYTE is mapped to Trino VARBINARY. Redshift does not have VARBINARY type. + SqlDataTypeTest.create() + // varbyte + .addRoundTrip("varbyte", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbyte", "to_varbyte('', 'hex')", VARBINARY, "X''") + .addRoundTrip("varbyte", utf8VarbyteLiteral("hello"), VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Piękna łąka w 東京都"), VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbyte", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbyte", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbyte", "to_varbyte('000000000000', 'hex')", VARBINARY, "X'000000000000'") + .addRoundTrip("varbyte(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbyte(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // varbinary + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("varbinary(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + // binary varying + .addRoundTrip("binary varying", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("binary varying", utf8VarbyteLiteral("Bag full of 💰"), VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("binary varying", "to_varbyte('0001020304050607080DF9367AA7000000', 'hex')", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("binary varying(1)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // minimum length + .addRoundTrip("binary varying(1024000)", "to_varbyte('00', 'hex')", VARBINARY, "X'00'") // maximum length + .execute(getQueryRunner(), redshiftCreateAndInsert("test_varbinary")); + + SqlDataTypeTest.create() + .addRoundTrip("varbinary", "NULL", VARBINARY, "CAST(NULL AS varbinary)") + .addRoundTrip("varbinary", "X''", VARBINARY, "X''") + .addRoundTrip("varbinary", "X'68656C6C6F'", VARBINARY, "to_utf8('hello')") + .addRoundTrip("varbinary", "X'5069C4996B6E6120C582C4856B61207720E69DB1E4BAACE983BD'", VARBINARY, "to_utf8('Piękna łąka w 東京都')") + .addRoundTrip("varbinary", "X'4261672066756C6C206F6620F09F92B0'", VARBINARY, "to_utf8('Bag full of 💰')") + .addRoundTrip("varbinary", "X'0001020304050607080DF9367AA7000000'", VARBINARY, "X'0001020304050607080DF9367AA7000000'") // non-text + .addRoundTrip("varbinary", "X'000000000000'", VARBINARY, "X'000000000000'") + .execute(getQueryRunner(), trinoCreateAsSelect("test_varbinary")); + } + + private static String utf8VarbyteLiteral(String string) + { + return format("to_varbyte('%s', 'hex')", base16().encode(string.getBytes(UTF_8))); + } + + @Test + public void testDecimal() + { + SqlDataTypeTest.create() + .addRoundTrip("decimal(3, 0)", "CAST('193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('19' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('19' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 0)", "CAST('-193' AS decimal(3, 0))", createDecimalType(3, 0), "CAST('-193' AS decimal(3, 0))") + .addRoundTrip("decimal(3, 1)", "CAST('10.0' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.0' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(3, 1)", "CAST('-10.1' AS decimal(3, 1))", createDecimalType(3, 1), "CAST('-10.1' AS decimal(3, 1))") + .addRoundTrip("decimal(4, 2)", "CAST('2' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2' AS decimal(4, 2))") + .addRoundTrip("decimal(4, 2)", "CAST('2.3' AS decimal(4, 2))", createDecimalType(4, 2), "CAST('2.3' AS decimal(4, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('2.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('2.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 2)", "CAST('123456789.3' AS decimal(24, 2))", createDecimalType(24, 2), "CAST('123456789.3' AS decimal(24, 2))") + .addRoundTrip("decimal(24, 4)", "CAST('12345678901234567890.31' AS decimal(24, 4))", createDecimalType(24, 4), "CAST('12345678901234567890.31' AS decimal(24, 4))") + .addRoundTrip("decimal(30, 5)", "CAST('3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(30, 5)", "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))", createDecimalType(30, 5), "CAST('-3141592653589793238462643.38327' AS decimal(30, 5))") + .addRoundTrip("decimal(31, 0)", "CAST('2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(31, 0)", "CAST('-2718281828459045235360287471352' AS decimal(31, 0))", createDecimalType(31, 0), "CAST('-2718281828459045235360287471352' AS decimal(31, 0))") + .addRoundTrip("decimal(3, 0)", "NULL", createDecimalType(3, 0), "CAST(NULL AS decimal(3, 0))") + .addRoundTrip("decimal(31, 0)", "NULL", createDecimalType(31, 0), "CAST(NULL AS decimal(31, 0))") + .execute(getQueryRunner(), redshiftCreateAndInsert("test_decimal")) + .execute(getQueryRunner(), trinoCreateAsSelect("test_decimal")); + } + + @Test + public void testRedshiftDecimalCutoff() + { + String columns = "(d19 decimal(19, 0), d18 decimal(19, 18), d0 decimal(19, 19))"; + try (TestTable table = testTable("test_decimal_range", columns)) { + assertQueryFails( + format("INSERT INTO %s (d19) VALUES (DECIMAL'9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 0\\)$"); + assertQueryFails( + format("INSERT INTO %s (d18) VALUES (DECIMAL'9.991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 18\\)$"); + assertQueryFails( + format("INSERT INTO %s (d0) VALUES (DECIMAL'.9991999999999999999')", table.getName()), + "^Value out of range for Redshift DECIMAL\\(19, 19\\)$"); + } + } + + @Test + public void testRedshiftDecimalScaleLimit() + { + assertCreateFails( + "test_overlarge_decimal_scale", + "(d DECIMAL(38, 38))", + "^ERROR: DECIMAL scale 38 must be between 0 and 37$"); + } + + @Test + public void testUnsupportedTrinoDataTypes() + { + assertCreateFails( + "test_unsupported_type", + "(col json)", + "Unsupported column type: json"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testDate(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0001-01-01'", DATE, "DATE '0001-01-01'") // first day of AD + .addRoundTrip("date", "DATE '1500-01-01'", DATE, "DATE '1500-01-01'") // sometime before julian->gregorian switch + .addRoundTrip("date", "DATE '1600-01-01'", DATE, "DATE '1600-01-01'") // long ago but after julian->gregorian switch + .addRoundTrip("date", "DATE '1952-04-03'", DATE, "DATE '1952-04-03'") // before epoch + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") + .addRoundTrip("date", "DATE '1970-02-03'", DATE, "DATE '1970-02-03'") // after epoch + .addRoundTrip("date", "DATE '2017-07-01'", DATE, "DATE '2017-07-01'") // summer in northern hemisphere (possible DST) + .addRoundTrip("date", "DATE '2017-01-01'", DATE, "DATE '2017-01-01'") // winter in northern hemisphere (possible DST in southern hemisphere) + .addRoundTrip("date", "DATE '1970-01-01'", DATE, "DATE '1970-01-01'") // day of midnight gap in JVM + .addRoundTrip("date", "DATE '1983-04-01'", DATE, "DATE '1983-04-01'") // day of midnight gap in Vilnius + .addRoundTrip("date", "DATE '1983-10-01'", DATE, "DATE '1983-10-01'") // day after midnight setback in Vilnius + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")) + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '-0100-01-01'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("date", "DATE '0101-01-01 BC'", DATE, "DATE '-0100-01-01'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTime(ZoneId sessionZone) + { + // Redshift gets bizarre errors if you try to insert after + // specifying precision for a time column. + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + timeTypeTests("time(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "time_from_trino")); + timeTypeTests("time") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("time_from_jdbc")); + } + + private static SqlDataTypeTest timeTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIME '00:00:00.000000'", createTimeType(6), "TIME '00:00:00.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '00:13:42.000000'", createTimeType(6), "TIME '00:13:42.000000'") // gap in JVM zone on Epoch day + .addRoundTrip(inputType, "TIME '01:33:17.000000'", createTimeType(6), "TIME '01:33:17.000000'") + .addRoundTrip(inputType, "TIME '03:17:17.000000'", createTimeType(6), "TIME '03:17:17.000000'") + .addRoundTrip(inputType, "TIME '10:01:17.100000'", createTimeType(6), "TIME '10:01:17.100000'") + .addRoundTrip(inputType, "TIME '13:18:03.000000'", createTimeType(6), "TIME '13:18:03.000000'") + .addRoundTrip(inputType, "TIME '14:18:03.000000'", createTimeType(6), "TIME '14:18:03.000000'") + .addRoundTrip(inputType, "TIME '15:18:03.000000'", createTimeType(6), "TIME '15:18:03.000000'") + .addRoundTrip(inputType, "TIME '16:18:03.123456'", createTimeType(6), "TIME '16:18:03.123456'") + .addRoundTrip(inputType, "TIME '19:01:17.000000'", createTimeType(6), "TIME '19:01:17.000000'") + .addRoundTrip(inputType, "TIME '20:01:17.000000'", createTimeType(6), "TIME '20:01:17.000000'") + .addRoundTrip(inputType, "TIME '21:01:17.000001'", createTimeType(6), "TIME '21:01:17.000001'") + .addRoundTrip(inputType, "TIME '22:59:59.000000'", createTimeType(6), "TIME '22:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.000000'", createTimeType(6), "TIME '23:59:59.000000'") + .addRoundTrip(inputType, "TIME '23:59:59.999999'", createTimeType(6), "TIME '23:59:59.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestamp(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + // Redshift doesn't allow timestamp precision to be specified + timestampTypeTests("timestamp(6)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "timestamp_from_trino")); + timestampTypeTests("timestamp") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("timestamp_from_jdbc")); + + // some time BC + SqlDataTypeTest.create() + .addRoundTrip("timestamp(6)", "TIMESTAMP '-0100-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, trinoCreateAsSelect(session, "test_date")); + SqlDataTypeTest.create() + .addRoundTrip("timestamp", "TIMESTAMP '0101-01-01 00:00:00 BC'", createTimestampType(6), "TIMESTAMP '-0100-01-01 00:00:00.000000'") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_date")); + } + + private static SqlDataTypeTest timestampTypeTests(String inputType) + { + return SqlDataTypeTest.create() + .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '0001-01-01 00:00:00.000000'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1500-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1500-01-01 00:00:00.000000'") // sometime before julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1600-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1600-01-01 00:00:00.000000'") // long ago but after julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1958-01-01 13:18:03.123456'", createTimestampType(6), "TIMESTAMP '1958-01-01 13:18:03.123456'") // before epoch + .addRoundTrip(inputType, "TIMESTAMP '2019-03-18 10:09:17.987654'", createTimestampType(6), "TIMESTAMP '2019-03-18 10:09:17.987654'") // after epoch + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789'", createTimestampType(6), "TIMESTAMP '2018-10-28 01:33:17.456789'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333'", createTimestampType(6), "TIMESTAMP '2018-10-28 03:33:33.333333'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.000000'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.000000'") // time gap in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000'", createTimestampType(6), "TIMESTAMP '2018-03-25 03:17:17.000000'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000'", createTimestampType(6), "TIMESTAMP '1986-01-01 00:13:07.000000'") // time gap in Kathmandu + // Full time precision + .addRoundTrip(inputType, "TIMESTAMP '1969-12-31 23:59:59.999999'", createTimestampType(6), "TIMESTAMP '1969-12-31 23:59:59.999999'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999999'", createTimestampType(6), "TIMESTAMP '1970-01-01 00:00:00.999999'"); + } + + @Test(dataProvider = "datetime_test_parameters") + public void testTimestampWithTimeZone(ZoneId sessionZone) + { + Session session = Session.builder(getSession()) + .setTimeZoneKey(getTimeZoneKey(sessionZone.getId())) + .build(); + + SqlDataTypeTest.create() + // test arbitrary time for all supported precisions + .addRoundTrip("timestamp(0) with time zone", "TIMESTAMP '2022-09-27 12:34:56 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.000000 UTC'") + .addRoundTrip("timestamp(1) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(2) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.120000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(4) with time zone", "TIMESTAMP '2022-09-27 12:34:56.1234 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123400 UTC'") + .addRoundTrip("timestamp(5) with time zone", "TIMESTAMP '2022-09-27 12:34:56.12345 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123450 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2022-09-27 12:34:56.123456 UTC'") + + // short timestamp with time zone + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '-4712-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999000 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1986-01-01 00:13:07 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456000 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333000 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247000 UTC'") // max value in Trino + .addRoundTrip("timestamp(3) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + + // long timestamp with time zone + // .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '0001-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu (long timestamp_tz) + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.1 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.9 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'") // max value in Trino + .addRoundTrip("timestamp(6) with time zone", "NULL", TIMESTAMP_TZ_MICROS, "CAST(NULL AS timestamp(6) with time zone)") + .execute(getQueryRunner(), session, trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + + redshiftTimestampWithTimeZoneTests("timestamptz") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + redshiftTimestampWithTimeZoneTests("timestamp with time zone") + .execute(getQueryRunner(), session, redshiftCreateAndInsert("test_timestamp_tz")); + } + + private static SqlDataTypeTest redshiftTimestampWithTimeZoneTests(String inputType) + { + return SqlDataTypeTest.create() + // .addRoundTrip(inputType, "TIMESTAMP '4713-01-01 00:00:00 BC' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '-4712-01-01 00:00:00.000000 UTC'") // min value in Redshift + // .addRoundTrip(inputType, "TIMESTAMP '0001-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '0001-01-01 00:00:00.000000 UTC'") // first day of AD + .addRoundTrip(inputType, "TIMESTAMP '1582-10-04 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-04 23:59:59.999999 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1582-10-05 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-05 00:00:00.000000 UTC'") // begin julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-14 23:59:59.999999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-14 23:59:59.999999 UTC'") // end julian->gregorian switch + .addRoundTrip(inputType, "TIMESTAMP '1582-10-15 00:00:00.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1582-10-15 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1970-01-01 00:00:00.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '1986-01-01 00:13:07.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1986-01-01 00:13:07.000000 UTC'") // time gap in Kathmandu + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 01:33:17.456789' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 01:33:17.456789 UTC'") // time doubled in JVM + .addRoundTrip(inputType, "TIMESTAMP '2018-10-28 03:33:33.333333' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-10-28 03:33:33.333333 UTC'") // time doubled in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2018-03-25 03:17:17.000000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2018-03-25 03:17:17.000000 UTC'") // time gap in Vilnius + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.1' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.100000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.9' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.900000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123000' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.999000 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '2020-09-27 12:34:56.123456' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '2020-09-27 12:34:56.123456 UTC'") + .addRoundTrip(inputType, "TIMESTAMP '73326-09-11 20:14:45.247999' AT TIME ZONE 'UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '73326-09-11 20:14:45.247999 UTC'"); // max value in Trino + } + + @Test + public void testTimestampWithTimeZoneCoercion() + { + SqlDataTypeTest.create() + // short timestamp with time zone + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.12341 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123000 UTC'") // round up, end result rounds down + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1235 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.124000 UTC'") // round up + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111000 UTC'") // max precision + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999499999999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + .addRoundTrip("timestamp(3) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999000 UTC'") // negative epoch + + // long timestamp with time zone + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234561 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // round down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.123456499 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123456 UTC'") // nanoc round up, end result rounds down + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.1234565 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.123457 UTC'") // round up + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.111222333444 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.111222 UTC'") // max precision + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 00:00:00.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:01.000000 UTC'") // round up to next second + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1970-01-01 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-02 00:00:00.000000 UTC'") // round up to next day + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999995 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1970-01-01 00:00:00.000000 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.999999499999 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .addRoundTrip("timestamp(6) with time zone", "TIMESTAMP '1969-12-31 23:59:59.9999994 UTC'", TIMESTAMP_TZ_MICROS, "TIMESTAMP '1969-12-31 23:59:59.999999 UTC'") // negative epoch + .execute(getQueryRunner(), trinoCreateAsSelect(getSession(), "test_timestamp_tz")); + } + + @Test + public void testTimestampWithTimeZoneOverflow() + { + // The min timestamp with time zone value in Trino is smaller than Redshift + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(3) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + try (TestTable table = new TestTable(getTrinoExecutor(), "timestamp_tz_min", "(ts timestamp(6) with time zone)")) { + assertQueryFails( + format("INSERT INTO %s VALUES (TIMESTAMP '-69387-04-22 03:45:14.752000 UTC')", table.getName()), + "\\QMinimum timestamp with time zone in Redshift is -4712-01-01 00:00:00.000000: -69387-04-22 03:45:14.752000"); + } + + // The max timestamp with time zone value in Redshift is larger than Trino + try (TestTable table = new TestTable(getRedshiftExecutor(), TEST_SCHEMA + ".timestamp_tz_max", "(ts timestamptz)", ImmutableList.of("TIMESTAMP '294276-12-31 23:59:59' AT TIME ZONE 'UTC'"))) { + assertThatThrownBy(() -> query("SELECT * FROM " + table.getName())) + .hasMessage("Millis overflow: 9224318015999000"); + } + } + + @DataProvider(name = "datetime_test_parameters") + public Object[][] dataProviderForDatetimeTests() + { + return new Object[][] { + {UTC}, + {jvmZone}, + {vilnius}, + {kathmandu}, + {testZone}, + }; + } + + @Test + public void testUnsupportedDateTimeTypes() + { + assertCreateFails( + "test_time_with_time_zone", + "(value TIME WITH TIME ZONE)", + "Unsupported column type: (?i)time.* with time zone"); + } + + @Test + public void testDateLimits() + { + // We can't test the exact date limits because Redshift doesn't say + // what they are, so we test one date on either side. + try (TestTable table = testTable("test_date_limits", "(d date)")) { + // First day of smallest year that Redshift supports (based on its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '-4712-01-01')", table.getName()), 1); + // Small date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '-4713-06-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"4714-06-01 BC\""); + + // Last day of the largest year that Redshift supports (based on in its documentation) + assertUpdate(format("INSERT INTO %s VALUES (DATE '294275-12-31')", table.getName()), 1); + // Large date observed to not work + assertThatThrownBy(() -> computeActual(format("INSERT INTO %s VALUES (DATE '5875000-01-01')", table.getName()))) + .hasStackTraceContaining("ERROR: date out of range: \"5875000-01-01 AD\""); + } + } + + @Test + public void testLimitedTimePrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIME '\\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIME '00:00:00'".length() - (input.contains(".") ? 1 : 0), + List.of( + // No rounding + new TestCase("TIME '00:00:00'", "TIME '00:00:00'"), + new TestCase("TIME '00:00:00.000000'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.123456'", "TIME '00:00:00.123456'"), + new TestCase("TIME '12:34:56'", "TIME '12:34:56'"), + new TestCase("TIME '12:34:56.123456'", "TIME '12:34:56.123456'"), + new TestCase("TIME '23:59:59'", "TIME '23:59:59'"), + new TestCase("TIME '23:59:59.9'", "TIME '23:59:59.9'"), + new TestCase("TIME '23:59:59.999'", "TIME '23:59:59.999'"), + new TestCase("TIME '23:59:59.999999'", "TIME '23:59:59.999999'"), + // round down + new TestCase("TIME '00:00:00.0000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000000001'", "TIME '00:00:00.000000'"), + new TestCase("TIME '12:34:56.1234561'", "TIME '12:34:56.123456'"), + // round down, maximal value + new TestCase("TIME '00:00:00.0000004'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000449'", "TIME '00:00:00.000000'"), + new TestCase("TIME '00:00:00.000000444449'", "TIME '00:00:00.000000'"), + // round up, minimal value + new TestCase("TIME '00:00:00.0000005'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000500000'", "TIME '00:00:00.000001'"), + // round up, maximal value + new TestCase("TIME '00:00:00.0000009'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999'", "TIME '00:00:00.000001'"), + new TestCase("TIME '00:00:00.000000999999'", "TIME '00:00:00.000001'"), + // round up to next day, minimal value + new TestCase("TIME '23:59:59.9999995'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999500000'", "TIME '00:00:00.000000'"), + // round up to next day, maximal value + new TestCase("TIME '23:59:59.9999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999'", "TIME '00:00:00.000000'"), + new TestCase("TIME '23:59:59.999999999999'", "TIME '00:00:00.000000'"), + // don't round to next day (round down near upper bound) + new TestCase("TIME '23:59:59.9999994'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499'", "TIME '23:59:59.999999'"), + new TestCase("TIME '23:59:59.999999499999'", "TIME '23:59:59.999999'"))); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_time_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + @Test + public void testLimitedTimestampPrecision() + { + Map> testCasesByPrecision = groupTestCasesByInput( + "TIMESTAMP '\\d{4}-\\d{2}-\\d{2} \\d{2}:\\d{2}:\\d{2}(\\.\\d{1,12})?'", + input -> input.length() - "TIMESTAMP '0000-00-00 00:00:00'".length() - (input.contains(".") ? 1 : 0), + // No rounding + new TestCase("TIMESTAMP '1970-01-01 00:00:00'", "TIMESTAMP '1970-01-01 00:00:00'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56'", "TIMESTAMP '2020-11-03 12:34:56'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + + new TestCase("TIMESTAMP '1970-01-01 00:00:00.123456'", "TIMESTAMP '1970-01-01 00:00:00.123456'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.123456'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59'", "TIMESTAMP '1969-12-31 23:59:59'"), + + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9'", "TIMESTAMP '1970-01-01 23:59:59.9'"), + new TestCase("TIMESTAMP '2020-11-03 23:59:59.999'", "TIMESTAMP '2020-11-03 23:59:59.999'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + // round down + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000001'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000000001'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-11-03 12:34:56.1234561'", "TIMESTAMP '2020-11-03 12:34:56.123456'"), + // round down, maximal value + new TestCase("TIMESTAMP '2020-11-03 00:00:00.0000004'", "TIMESTAMP '2020-11-03 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000449'", "TIMESTAMP '1969-12-31 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000444449'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // round up, minimal value + new TestCase("TIMESTAMP '1970-01-01 00:00:00.0000005'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000500'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + new TestCase("TIMESTAMP '1969-12-31 00:00:00.000000500000'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + // round up, maximal value + new TestCase("TIMESTAMP '1969-12-31 00:00:00.0000009'", "TIMESTAMP '1969-12-31 00:00:00.000001'"), + new TestCase("TIMESTAMP '1970-01-01 00:00:00.000000999'", "TIMESTAMP '1970-01-01 00:00:00.000001'"), + new TestCase("TIMESTAMP '2020-11-03 00:00:00.000000999999'", "TIMESTAMP '2020-11-03 00:00:00.000001'"), + // round up to next year, minimal value + new TestCase("TIMESTAMP '2020-12-31 23:59:59.9999995'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999500'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999500000'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + // round up to next day/year, maximal value + new TestCase("TIMESTAMP '1970-01-01 23:59:59.9999999'", "TIMESTAMP '1970-01-02 00:00:00.000000'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999999'", "TIMESTAMP '2021-01-01 00:00:00.000000'"), + new TestCase("TIMESTAMP '1969-12-31 23:59:59.999999999999'", "TIMESTAMP '1970-01-01 00:00:00.000000'"), + // don't round to next year (round down near upper bound) + new TestCase("TIMESTAMP '1969-12-31 23:59:59.9999994'", "TIMESTAMP '1969-12-31 23:59:59.999999'"), + new TestCase("TIMESTAMP '1970-01-01 23:59:59.999999499'", "TIMESTAMP '1970-01-01 23:59:59.999999'"), + new TestCase("TIMESTAMP '2020-12-31 23:59:59.999999499999'", "TIMESTAMP '2020-12-31 23:59:59.999999'")); + + for (Entry> entry : testCasesByPrecision.entrySet()) { + String tableName = format("test_timestamp_precision_%d_%s", entry.getKey(), randomNameSuffix()); + runTestCases(tableName, entry.getValue()); + } + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, TestCase... testCases) + { + return groupTestCasesByInput(inputRegex, classifier, Arrays.asList(testCases)); + } + + private static Map> groupTestCasesByInput(String inputRegex, Function classifier, List testCases) + { + return testCases.stream() + .peek(test -> { + if (!test.input().matches(inputRegex)) { + throw new RuntimeException("Bad test case input format: " + test.input()); + } + }) + .collect(groupingBy(classifier.compose(TestCase::input))); + } + + private void runTestCases(String tableName, List testCases) + { + // Must use CTAS instead of TestTable because if the table is created before the insert, + // the type mapping will treat it as TIME(6) no matter what it was created as. + getTrinoExecutor().execute(format( + "CREATE TABLE %s AS SELECT * FROM (VALUES %s) AS t (id, value)", + tableName, + testCases.stream() + .map(testCase -> format("(%d, %s)", testCase.id(), testCase.input())) + .collect(joining("), (", "(", ")")))); + try { + assertQuery( + format("SELECT value FROM %s ORDER BY id", tableName), + testCases.stream() + .map(TestCase::expected) + .collect(joining("), (", "VALUES (", ")"))); + } + finally { + getTrinoExecutor().execute("DROP TABLE " + tableName); + } + } + + @Test + public static void checkIllegalRedshiftTimePrecision() + { + assertRedshiftCreateFails( + "check_redshift_time_precision_error", + "(t TIME(6))", + "ERROR: time column does not support precision."); + } + + @Test + public static void checkIllegalRedshiftTimestampPrecision() + { + assertRedshiftCreateFails( + "check_redshift_timestamp_precision_error", + "(t TIMESTAMP(6))", + "ERROR: timestamp column does not support precision."); + } + + /** + * Assert that a {@code CREATE TABLE} statement made from Redshift fails, + * and drop the table if it doesn't fail. + */ + private static void assertRedshiftCreateFails(String tableNamePrefix, String tableBody, String message) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertThatThrownBy(() -> getRedshiftExecutor() + .execute(format("CREATE TABLE %s %s", tableName, tableBody))) + .getCause() + .as("Redshift create fails for %s %s", tableName, tableBody) + .isInstanceOf(SQLException.class) + .hasMessage(message); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE IF EXISTS " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + /** + * Assert that a {@code CREATE TABLE} statement fails, and drop the table + * if it doesn't fail. + */ + private void assertCreateFails(String tableNamePrefix, String tableBody, String expectedMessageRegExp) + { + String tableName = tableNamePrefix + "_" + randomNameSuffix(); + try { + assertQueryFails(format("CREATE TABLE %s %s", tableName, tableBody), expectedMessageRegExp); + } + catch (AssertionError failure) { + // If the table was created, clean it up because the tests run on a shared Redshift instance + try { + getRedshiftExecutor().execute("DROP TABLE " + tableName); + } + catch (Throwable e) { + failure.addSuppressed(e); + } + throw failure; + } + } + + private DataSetup trinoCreateAsSelect(String tableNamePrefix) + { + return trinoCreateAsSelect(getQueryRunner().getDefaultSession(), tableNamePrefix); + } + + private DataSetup trinoCreateAsSelect(Session session, String tableNamePrefix) + { + return new CreateAsSelectDataSetup(new TrinoSqlExecutor(getQueryRunner(), session), tableNamePrefix); + } + + private static DataSetup redshiftCreateAndInsert(String tableNamePrefix) + { + return new CreateAndInsertDataSetup(getRedshiftExecutor(), TEST_SCHEMA + "." + tableNamePrefix); + } + + /** + * Create a table in the test schema using the JDBC. + * + *

Creating a test table normally doesn't use the correct schema. + */ + private static TestTable testTable(String namePrefix, String body) + { + return new TestTable(getRedshiftExecutor(), TEST_SCHEMA + "." + namePrefix, body); + } + + private SqlExecutor getTrinoExecutor() + { + return new TrinoSqlExecutor(getQueryRunner()); + } + + private static SqlExecutor getRedshiftExecutor() + { + Properties properties = new Properties(); + properties.setProperty("user", JDBC_USER); + properties.setProperty("password", JDBC_PASSWORD); + return new JdbcSqlExecutor(JDBC_URL, properties); + } + + private static void checkIsGap(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).isEmpty(), + "Expected %s to be a gap in %s", dateTime, zone); + } + + private static void checkIsDoubled(ZoneId zone, LocalDateTime dateTime) + { + verify( + zone.getRules().getValidOffsets(dateTime).size() == 2, + "Expected %s to be doubled in %s", dateTime, zone); + } + + private static Function padVarchar(int length) + { + // Add the same padding as RedshiftClient.writeCharAsVarchar, but start from String, not Slice + return (input) -> input + " ".repeat(length - Utf8.encodedLength(input)); + } + + /** + * A pair of input and expected output from a test. + * Each instance has a unique ID. + */ + private static class TestCase + { + private static final AtomicInteger LAST_ID = new AtomicInteger(); + + private final int id; + private final String input; + private final String expected; + + private TestCase(String input, String expected) + { + this.id = LAST_ID.incrementAndGet(); + this.input = input; + this.expected = expected; + } + + public int id() + { + return this.id; + } + + public String input() + { + return this.input; + } + + public String expected() + { + return this.expected; + } + } + + private static class TestView + implements AutoCloseable + { + final String name; + + TestView(String namePrefix, String definition) + { + name = requireNonNull(namePrefix) + "_" + randomNameSuffix(); + executeInRedshift(format("CREATE VIEW %s.%s AS %s", TEST_SCHEMA, name, definition)); + } + + @Override + public void close() + { + executeInRedshift(format("DROP VIEW IF EXISTS %s.%s", TEST_SCHEMA, name)); + } + } +} From 1e8887e3082ae969eb8bf81901b2c7b56ad38e3f Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 18:27:03 -0800 Subject: [PATCH 04/24] Implement Redshift DELETE --- .../trino/plugin/redshift/RedshiftClient.java | 26 ++++++++++++++ .../redshift/TestRedshiftConnectorTest.java | 35 ++++++++++++++++++- 2 files changed, 60 insertions(+), 1 deletion(-) diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 05d27ab5a0f0..1983260e43c9 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -27,6 +27,7 @@ import io.trino.plugin.jdbc.LongWriteFunction; import io.trino.plugin.jdbc.ObjectReadFunction; import io.trino.plugin.jdbc.ObjectWriteFunction; +import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.QueryBuilder; import io.trino.plugin.jdbc.SliceWriteFunction; import io.trino.plugin.jdbc.StandardColumnMappings; @@ -68,8 +69,12 @@ import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; import java.util.Optional; +import java.util.OptionalLong; import java.util.function.BiFunction; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; +import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; @@ -229,6 +234,27 @@ public PreparedStatement getPreparedStatement(Connection connection, String sql) return statement; } + @Override + public OptionalLong delete(ConnectorSession session, JdbcTableHandle handle) + { + checkArgument(handle.isNamedRelation(), "Unable to delete from synthetic table: %s", handle); + checkArgument(handle.getLimit().isEmpty(), "Unable to delete when limit is set: %s", handle); + checkArgument(handle.getSortOrder().isEmpty(), "Unable to delete when sort order is set: %s", handle); + try (Connection connection = connectionFactory.openConnection(session)) { + verify(connection.getAutoCommit()); + PreparedQuery preparedQuery = queryBuilder.prepareDeleteQuery(this, session, connection, handle.getRequiredNamedRelation(), handle.getConstraint(), Optional.empty()); + try (PreparedStatement preparedStatement = queryBuilder.prepareStatement(this, session, connection, preparedQuery)) { + int affectedRowsCount = preparedStatement.executeUpdate(); + // connection.getAutoCommit() == true is not enough to make DELETE effective and explicit commit is required + connection.commit(); + return OptionalLong.of(affectedRowsCount); + } + } + catch (SQLException e) { + throw new TrinoException(JDBC_ERROR, e); + } + } + @Override protected void verifySchemaName(DatabaseMetaData databaseMetadata, String schemaName) throws SQLException diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 0d46c93a6fc9..7e2b12fe85f1 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -54,7 +54,6 @@ protected QueryRunner createQueryRunner() protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { - case SUPPORTS_DELETE: case SUPPORTS_AGGREGATION_PUSHDOWN: case SUPPORTS_JOIN_PUSHDOWN: case SUPPORTS_TOPN_PUSHDOWN: @@ -151,6 +150,33 @@ public Object[][] redshiftTypeToTrinoTypes() {"TIMESTAMPTZ", "timestamp(6) with time zone"}}; } + @Override + public void testDelete() + { + // The base tests is very slow because Redshift CTAS is really slow, so use a smaller test + try (TestTable table = new TestTable(getQueryRunner()::execute, "test_delete_", "AS SELECT * FROM nation")) { + // delete without matching any rows + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey < 0", 0); + + // delete with a predicate that optimizes to false + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey > 5 AND nationkey < 4", 0); + + // delete successive parts of the table + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 5", "SELECT count(*) FROM nation WHERE nationkey <= 5"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 5"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 10", "SELECT count(*) FROM nation WHERE nationkey > 5 AND nationkey <= 10"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 10"); + + assertUpdate("DELETE FROM " + table.getName() + " WHERE nationkey <= 15", "SELECT count(*) FROM nation WHERE nationkey > 10 AND nationkey <= 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE nationkey > 15"); + + // delete remaining + assertUpdate("DELETE FROM " + table.getName(), "SELECT count(*) FROM nation WHERE nationkey > 15"); + assertQuery("SELECT * FROM " + table.getName(), "SELECT * FROM nation WHERE false"); + } + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -212,6 +238,13 @@ protected SqlExecutor onRemoteDatabase() return RedshiftQueryRunner::executeInRedshift; } + @Override + public void testDeleteWithLike() + { + assertThatThrownBy(super::testDeleteWithLike) + .hasStackTraceContaining("TrinoException: This connector does not support modifying table rows"); + } + @Test @Override public void testAddNotNullColumnToNonEmptyTable() From e65585de567e7383a967bf3a86075c1508a9c8d6 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Sat, 10 Dec 2022 18:43:41 -0800 Subject: [PATCH 05/24] Add Redshift statistics --- plugin/trino-redshift/pom.xml | 12 +- .../trino/plugin/redshift/RedshiftClient.java | 36 +- .../plugin/redshift/RedshiftClientModule.java | 3 + .../RedshiftTableStatisticsReader.java | 176 +++++++++ .../redshift/TestRedshiftConnectorTest.java | 73 ++++ .../TestRedshiftTableStatisticsReader.java | 349 ++++++++++++++++++ 6 files changed, 642 insertions(+), 7 deletions(-) create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index ebcd96279cda..2370ecf35c3e 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -49,6 +49,11 @@ javax.inject + + org.jdbi + jdbi3-core + + io.airlift @@ -68,12 +73,6 @@ runtime - - org.jdbi - jdbi3-core - runtime - - io.trino @@ -177,6 +176,7 @@ **/TestRedshiftConnectorTest.java + **/TestRedshiftTableStatisticsReader.java **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 1983260e43c9..2183e01bc4d7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -22,6 +22,7 @@ import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.LongWriteFunction; @@ -35,7 +36,10 @@ import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.predicate.TupleDomain; +import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; import io.trino.spi.type.Chars; import io.trino.spi.type.DecimalType; @@ -73,6 +77,7 @@ import java.util.function.BiFunction; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Throwables.throwIfInstanceOf; import static com.google.common.base.Verify.verify; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; @@ -194,10 +199,21 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + private final boolean statisticsEnabled; + private final RedshiftTableStatisticsReader statisticsReader; + @Inject - public RedshiftClient(BaseJdbcConfig config, ConnectionFactory connectionFactory, QueryBuilder queryBuilder, IdentifierMapping identifierMapping, RemoteQueryModifier queryModifier) + public RedshiftClient( + BaseJdbcConfig config, + ConnectionFactory connectionFactory, + JdbcStatisticsConfig statisticsConfig, + QueryBuilder queryBuilder, + IdentifierMapping identifierMapping, + RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); + this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); } @Override @@ -207,6 +223,24 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) + { + if (!statisticsEnabled) { + return TableStatistics.empty(); + } + if (!handle.isNamedRelation()) { + return TableStatistics.empty(); + } + try { + return statisticsReader.readTableStatistics(session, handle, () -> this.getColumns(session, handle)); + } + catch (SQLException | RuntimeException e) { + throwIfInstanceOf(e, TrinoException.class); + throw new TrinoException(JDBC_ERROR, "Failed fetching statistics for table: " + handle, e); + } + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index 13635c88f69b..ef4153ee45ef 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; import io.trino.spi.ptf.ConnectorTableFunction; @@ -32,6 +33,7 @@ import static com.google.inject.Scopes.SINGLETON; import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static io.airlift.configuration.ConfigBinder.configBinder; public class RedshiftClientModule extends AbstractConfigurationAwareModule @@ -41,6 +43,7 @@ public void setup(Binder binder) { binder.bind(JdbcClient.class).annotatedWith(ForBaseJdbc.class).to(RedshiftClient.class).in(SINGLETON); newSetBinder(binder, ConnectorTableFunction.class).addBinding().toProvider(Query.class).in(SINGLETON); + configBinder(binder).bindConfig(JdbcStatisticsConfig.class); install(new DecimalModule()); } diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..c576abdd109d --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftTableStatisticsReader.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.ConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Supplier; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.util.Objects.requireNonNull; +import static java.util.function.Function.identity; + +public class RedshiftTableStatisticsReader +{ + private final ConnectionFactory connectionFactory; + + public RedshiftTableStatisticsReader(ConnectionFactory connectionFactory) + { + this.connectionFactory = requireNonNull(connectionFactory, "connectionFactory is null"); + } + + public TableStatistics readTableStatistics(ConnectorSession session, JdbcTableHandle table, Supplier> columnSupplier) + throws SQLException + { + checkArgument(table.isNamedRelation(), "Relation is not a table: %s", table); + + try (Connection connection = connectionFactory.openConnection(session); + Handle handle = Jdbi.open(connection)) { + StatisticsDao statisticsDao = new StatisticsDao(handle); + + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional optionalRowCount = readRowCountTableStat(statisticsDao, table); + if (optionalRowCount.isEmpty()) { + // Table not found + return TableStatistics.empty(); + } + long rowCount = optionalRowCount.get(); + + TableStatistics.Builder tableStatistics = TableStatistics.builder() + .setRowCount(Estimate.of(rowCount)); + + if (rowCount == 0) { + return tableStatistics.build(); + } + + Map columnStatistics = statisticsDao.getColumnStatistics(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()).stream() + .collect(toImmutableMap(ColumnStatisticsResult::columnName, identity())); + + for (JdbcColumnHandle column : columnSupplier.get()) { + ColumnStatisticsResult result = columnStatistics.get(column.getColumnName()); + if (result == null) { + continue; + } + + ColumnStatistics statistics = ColumnStatistics.builder() + .setNullsFraction(result.nullsFraction() + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDistinctValuesCount(result.distinctValuesIndicator() + .map(distinctValuesIndicator -> { + // If the distinct value count is an estimate Redshift uses "the negative of the number of distinct values divided by the number of rows + // For example, -1 indicates a unique column in which the number of distinct values is the same as the number of rows." + // https://www.postgresql.org/docs/9.3/view-pg-stats.html + if (distinctValuesIndicator < 0.0) { + return Math.min(-distinctValuesIndicator * rowCount, rowCount); + } + return distinctValuesIndicator; + }) + .map(Estimate::of) + .orElseGet(Estimate::unknown)) + .setDataSize(result.averageColumnLength() + .flatMap(averageColumnLength -> + result.nullsFraction() + .map(nullsFraction -> 1.0 * averageColumnLength * rowCount * (1 - nullsFraction)) + .map(Estimate::of)) + .orElseGet(Estimate::unknown)) + .build(); + + tableStatistics.setColumnStatistics(column, statistics); + } + + return tableStatistics.build(); + } + } + + private static Optional readRowCountTableStat(StatisticsDao statisticsDao, JdbcTableHandle table) + { + RemoteTableName remoteTableName = table.getRequiredNamedRelation().getRemoteTableName(); + Optional rowCount = statisticsDao.getRowCountFromPgClass(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + if (rowCount.isEmpty()) { + // Table not found + return Optional.empty(); + } + + if (rowCount.get() == 0) { + // `pg_class.reltuples = 0` may mean an empty table or a recently populated table (CTAS, LOAD or INSERT) + // The `pg_stat_all_tables` view can be way off, so we use it only as a fallback + rowCount = statisticsDao.getRowCountFromPgStat(remoteTableName.getSchemaName().orElse(null), remoteTableName.getTableName()); + } + + return rowCount; + } + + private static class StatisticsDao + { + private final Handle handle; + + public StatisticsDao(Handle handle) + { + this.handle = requireNonNull(handle, "handle is null"); + } + + Optional getRowCountFromPgClass(String schema, String tableName) + { + return handle.createQuery("SELECT reltuples FROM pg_class WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + Optional getRowCountFromPgStat(String schema, String tableName) + { + // Redshift does not have the Postgres `n_live_tup`, so estimate from `inserts - deletes` + return handle.createQuery("SELECT n_tup_ins - n_tup_del FROM pg_stat_all_tables WHERE schemaname = :schema AND relname = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .mapTo(Long.class) + .findOne(); + } + + List getColumnStatistics(String schema, String tableName) + { + return handle.createQuery("SELECT attname, null_frac, n_distinct, avg_width FROM pg_stats WHERE schemaname = :schema AND tablename = :table_name") + .bind("schema", schema) + .bind("table_name", tableName) + .map((rs, ctx) -> + new ColumnStatisticsResult( + requireNonNull(rs.getString("attname"), "attname is null"), + Optional.of(rs.getFloat("null_frac")), + Optional.of(rs.getFloat("n_distinct")), + Optional.of(rs.getInt("avg_width")))) + .list(); + } + } + + // TODO remove when error prone is updated for Java 17 records + @SuppressWarnings("unused") + private record ColumnStatisticsResult(String columnName, Optional nullsFraction, Optional distinctValuesIndicator, Optional averageColumnLength) {} +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 7e2b12fe85f1..863f308b9554 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -29,8 +29,10 @@ import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; +import static java.util.Locale.ENGLISH; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatThrownBy; @@ -177,6 +179,77 @@ public void testDelete() } } + @Test(dataProvider = "testCaseColumnNamesDataProvider") + public void testCaseColumnNames(String tableName) + { + try { + assertUpdate( + "CREATE TABLE " + TEST_SCHEMA + "." + tableName + + " AS SELECT " + + " custkey AS CASE_UNQUOTED_UPPER, " + + " name AS case_unquoted_lower, " + + " address AS cASe_uNQuoTeD_miXED, " + + " nationkey AS \"CASE_QUOTED_UPPER\", " + + " phone AS \"case_quoted_lower\"," + + " acctbal AS \"CasE_QuoTeD_miXED\" " + + "FROM customer", + 1500); + gatherStats(tableName); + assertQuery( + "SHOW STATS FOR " + TEST_SCHEMA + "." + tableName, + "VALUES " + + "('case_unquoted_upper', NULL, 1485, 0, null, null, null)," + + "('case_unquoted_lower', 33000, 1470, 0, null, null, null)," + + "('case_unquoted_mixed', 42000, 1500, 0, null, null, null)," + + "('case_quoted_upper', NULL, 25, 0, null, null, null)," + + "('case_quoted_lower', 28500, 1483, 0, null, null, null)," + + "('case_quoted_mixed', NULL, 1483, 0, null, null, null)," + + "(null, null, null, null, 1500, null, null)"); + } + finally { + assertUpdate("DROP TABLE IF EXISTS " + tableName); + } + } + + private static void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery("SELECT count(*) FROM " + TEST_SCHEMA + "." + tableName) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery(""" + SELECT reltuples FROM pg_class + WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) + AND relname = :table_name + """) + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute("ANALYZE VERBOSE " + TEST_SCHEMA + "." + tableName); + } + throw new IllegalStateException("Stats not gathered"); // for small test tables reltuples should be exact + }); + } + + @DataProvider + public Object[][] testCaseColumnNamesDataProvider() + { + return new Object[][] { + {"TEST_STATS_MIXED_UNQUOTED_UPPER_" + randomNameSuffix()}, + {"test_stats_mixed_unquoted_lower_" + randomNameSuffix()}, + {"test_stats_mixed_uNQuoTeD_miXED_" + randomNameSuffix()}, + {"\"TEST_STATS_MIXED_QUOTED_UPPER_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_quoted_lower_" + randomNameSuffix() + "\""}, + {"\"test_stats_mixed_QuoTeD_miXED_" + randomNameSuffix() + "\""} + }; + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java new file mode 100644 index 000000000000..ff713337ea53 --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftTableStatisticsReader.java @@ -0,0 +1,349 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.amazon.redshift.Driver; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseJdbcConfig; +import io.trino.plugin.jdbc.DriverConnectionFactory; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcTableHandle; +import io.trino.plugin.jdbc.JdbcTypeHandle; +import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.plugin.jdbc.credential.StaticCredentialProvider; +import io.trino.spi.connector.ColumnHandle; +import io.trino.spi.connector.SchemaTableName; +import io.trino.spi.statistics.ColumnStatistics; +import io.trino.spi.statistics.Estimate; +import io.trino.spi.statistics.TableStatistics; +import io.trino.spi.type.VarcharType; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.QueryRunner; +import io.trino.testing.sql.TestTable; +import org.assertj.core.api.InstanceOfAssertFactories; +import org.assertj.core.api.SoftAssertions; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.sql.Types; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.function.Consumer; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_PASSWORD; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_URL; +import static io.trino.plugin.redshift.RedshiftQueryRunner.JDBC_USER; +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.spi.type.BigintType.BIGINT; +import static io.trino.spi.type.DoubleType.DOUBLE; +import static io.trino.testing.TestingConnectorSession.SESSION; +import static io.trino.testing.TestingNames.randomNameSuffix; +import static io.trino.testing.sql.TestTable.fromColumns; +import static io.trino.tpch.TpchTable.CUSTOMER; +import static java.util.Collections.emptyMap; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.from; +import static org.assertj.core.api.Assertions.withinPercentage; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNull; + +public class TestRedshiftTableStatisticsReader + extends AbstractTestQueryFramework +{ + private static final JdbcTypeHandle BIGINT_TYPE_HANDLE = new JdbcTypeHandle(Types.BIGINT, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + private static final JdbcTypeHandle DOUBLE_TYPE_HANDLE = new JdbcTypeHandle(Types.DOUBLE, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + private static final List CUSTOMER_COLUMNS = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("name", 25), + createVarcharJdbcColumnHandle("address", 48), + new JdbcColumnHandle("nationkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("phone", 15), + new JdbcColumnHandle("acctbal", DOUBLE_TYPE_HANDLE, DOUBLE), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + private RedshiftTableStatisticsReader statsReader; + + @BeforeClass + public void setup() + { + DriverConnectionFactory connectionFactory = new DriverConnectionFactory( + new Driver(), + new BaseJdbcConfig().setConnectionUrl(JDBC_URL), + new StaticCredentialProvider(Optional.of(JDBC_USER), Optional.of(JDBC_PASSWORD))); + statsReader = new RedshiftTableStatisticsReader(connectionFactory); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner(Map.of(), Map.of(), ImmutableList.of(CUSTOMER)); + } + + @Test + public void testCustomerTable() + throws Exception + { + assertThat(collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer", CUSTOMER_COLUMNS)) + .returns(Estimate.of(1500), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(0), statsCloseTo(1500.0, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(1), statsCloseTo(1500.0, 0.0, 33000.0)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(3), statsCloseTo(25.000, 0.0, 8.0 * 1500)) + .hasEntrySatisfying(CUSTOMER_COLUMNS.get(5), statsCloseTo(1499.0, 0.0, 8.0 * 1500)); + } + + @Test + public void testEmptyTable() + throws Exception + { + TableStatistics tableStatistics = collectStats("SELECT * FROM " + TEST_SCHEMA + ".customer WHERE false", CUSTOMER_COLUMNS); + assertThat(tableStatistics) + .returns(Estimate.of(0.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + + @Test + public void testAllNulls() + throws Exception + { + String tableName = "testallnulls_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " (i BIGINT)"); + executeInRedshift("INSERT INTO " + schemaAndTable + " (i) VALUES (NULL)"); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + + TableStatistics stats = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> ImmutableList.of(new JdbcColumnHandle("i", BIGINT_TYPE_HANDLE, BIGINT))); + assertThat(stats) + .returns(Estimate.of(1.0), from(TableStatistics::getRowCount)) + .returns(emptyMap(), from(TableStatistics::getColumnStatistics)); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testNullsFraction() + throws Exception + { + JdbcColumnHandle custkeyColumnHandle = CUSTOMER_COLUMNS.get(0); + TableStatistics stats = collectStats( + "SELECT CASE custkey % 3 WHEN 0 THEN NULL ELSE custkey END FROM " + TEST_SCHEMA + ".customer", + ImmutableList.of(custkeyColumnHandle)); + assertEquals(stats.getRowCount(), Estimate.of(1500)); + + ColumnStatistics columnStatistics = stats.getColumnStatistics().get(custkeyColumnHandle); + assertThat(columnStatistics.getNullsFraction().getValue()).isCloseTo(1.0 / 3, withinPercentage(1)); + } + + @Test + public void testAverageColumnLength() + throws Exception + { + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("v3_in_3", 3), + createVarcharJdbcColumnHandle("v3_in_42", 42), + createVarcharJdbcColumnHandle("single_10v_value", 10), + createVarcharJdbcColumnHandle("half_10v_value", 10), + createVarcharJdbcColumnHandle("half_distinct_20v_value", 20), + createVarcharJdbcColumnHandle("all_nulls", 10)); + + assertThat( + collectStats( + "SELECT " + + " custkey, " + + " 'abc' v3_in_3, " + + " CAST('abc' AS varchar(42)) v3_in_42, " + + " CASE custkey WHEN 1 THEN '0123456789' ELSE NULL END single_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN '0123456789' ELSE NULL END half_10v_value, " + + " CASE custkey % 2 WHEN 0 THEN CAST((1000000 - custkey) * (1000000 - custkey) AS varchar(20)) ELSE NULL END half_distinct_20v_value, " + // 12 chars each + " CAST(NULL AS varchar(10)) all_nulls " + + "FROM " + TEST_SCHEMA + ".customer " + + "ORDER BY custkey LIMIT 100", + columns)) + .returns(Estimate.of(100), from(TableStatistics::getRowCount)) + .extracting(TableStatistics::getColumnStatistics, InstanceOfAssertFactories.map(ColumnHandle.class, ColumnStatistics.class)) + .hasEntrySatisfying(columns.get(0), statsCloseTo(100.0, 0.0, 800)) + .hasEntrySatisfying(columns.get(1), statsCloseTo(1.0, 0.0, 700.0)) + .hasEntrySatisfying(columns.get(2), statsCloseTo(1.0, 0.0, 700)) + .hasEntrySatisfying(columns.get(3), statsCloseTo(1.0, 0.99, 14)) + .hasEntrySatisfying(columns.get(4), statsCloseTo(1.0, 0.5, 700)) + .hasEntrySatisfying(columns.get(5), statsCloseTo(51, 0.5, 800)) + .satisfies(stats -> assertNull(stats.get(columns.get(6)))); + } + + @Test + public void testView() + throws Exception + { + String tableName = "test_stats_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE OR REPLACE VIEW " + schemaAndTable + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP VIEW IF EXISTS " + schemaAndTable); + } + } + + @Test + public void testMaterializedView() + throws Exception + { + String tableName = "test_stats_materialized_view_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + List columns = ImmutableList.of( + new JdbcColumnHandle("custkey", BIGINT_TYPE_HANDLE, BIGINT), + createVarcharJdbcColumnHandle("mktsegment", 10), + createVarcharJdbcColumnHandle("comment", 117)); + + try { + executeInRedshift("CREATE MATERIALIZED VIEW " + schemaAndTable + + " AS SELECT custkey, mktsegment, comment FROM " + TEST_SCHEMA + ".customer"); + executeInRedshift("REFRESH MATERIALIZED VIEW " + schemaAndTable); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + TableStatistics tableStatistics = statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columns); + assertThat(tableStatistics).isEqualTo(TableStatistics.empty()); + } + finally { + executeInRedshift("DROP MATERIALIZED VIEW " + schemaAndTable); + } + } + + @Test + public void testNumericCornerCases() + { + try (TestTable table = fromColumns( + getQueryRunner()::execute, + "test_numeric_corner_cases_", + ImmutableMap.>builder() + .put("only_negative_infinity double", List.of("-infinity()", "-infinity()", "-infinity()", "-infinity()")) + .put("only_positive_infinity double", List.of("infinity()", "infinity()", "infinity()", "infinity()")) + .put("mixed_infinities double", List.of("-infinity()", "infinity()", "-infinity()", "infinity()")) + .put("mixed_infinities_and_numbers double", List.of("-infinity()", "infinity()", "-5.0", "7.0")) + .put("nans_only double", List.of("nan()", "nan()")) + .put("nans_and_numbers double", List.of("nan()", "nan()", "-5.0", "7.0")) + .put("large_doubles double", List.of("CAST(-50371909150609548946090.0 AS DOUBLE)", "CAST(50371909150609548946090.0 AS DOUBLE)")) // 2^77 DIV 3 + .put("short_decimals_big_fraction decimal(16,15)", List.of("-1.234567890123456", "1.234567890123456")) + .put("short_decimals_big_integral decimal(16,1)", List.of("-123456789012345.6", "123456789012345.6")) + .put("long_decimals_big_fraction decimal(38,37)", List.of("-1.2345678901234567890123456789012345678", "1.2345678901234567890123456789012345678")) + .put("long_decimals_middle decimal(38,16)", List.of("-1234567890123456.7890123456789012345678", "1234567890123456.7890123456789012345678")) + .put("long_decimals_big_integral decimal(38,1)", List.of("-1234567890123456789012345678901234567.8", "1234567890123456789012345678901234567.8")) + .buildOrThrow(), + "null")) { + executeInRedshift("ANALYZE VERBOSE " + TEST_SCHEMA + "." + table.getName()); + assertQuery( + "SHOW STATS FOR " + table.getName(), + "VALUES " + + "('only_negative_infinity', null, 1, 0, null, null, null)," + + "('only_positive_infinity', null, 1, 0, null, null, null)," + + "('mixed_infinities', null, 2, 0, null, null, null)," + + "('mixed_infinities_and_numbers', null, 4.0, 0.0, null, null, null)," + + "('nans_only', null, 1.0, 0.5, null, null, null)," + + "('nans_and_numbers', null, 3.0, 0.0, null, null, null)," + + "('large_doubles', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('short_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_fraction', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_middle', null, 2.0, 0.5, null, null, null)," + + "('long_decimals_big_integral', null, 2.0, 0.5, null, null, null)," + + "(null, null, null, null, 4, null, null)"); + } + } + + /** + * Assert that the given column is within 5% of each statistic in the parameters, and that it has no range + */ + private static Consumer statsCloseTo(double distinctValues, double nullsFraction, double dataSize) + { + return stats -> { + SoftAssertions softly = new SoftAssertions(); + + softly.assertThat(stats.getDistinctValuesCount().getValue()) + .isCloseTo(distinctValues, withinPercentage(5.0)); + + softly.assertThat(stats.getNullsFraction().getValue()) + .isCloseTo(nullsFraction, withinPercentage(5.0)); + + softly.assertThat(stats.getDataSize().getValue()) + .isCloseTo(dataSize, withinPercentage(5.0)); + + softly.assertThat(stats.getRange()).isEmpty(); + softly.assertAll(); + }; + } + + private TableStatistics collectStats(String values, List columnHandles) + throws Exception + { + String tableName = "testredshiftstatisticsreader_" + randomNameSuffix(); + String schemaAndTable = TEST_SCHEMA + "." + tableName; + try { + executeInRedshift("CREATE TABLE " + schemaAndTable + " AS " + values); + executeInRedshift("ANALYZE VERBOSE " + schemaAndTable); + return statsReader.readTableStatistics( + SESSION, + new JdbcTableHandle( + new SchemaTableName(TEST_SCHEMA, tableName), + new RemoteTableName(Optional.empty(), Optional.of(TEST_SCHEMA), tableName), + Optional.empty()), + () -> columnHandles); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + schemaAndTable); + } + } + + private static JdbcColumnHandle createVarcharJdbcColumnHandle(String name, int length) + { + return new JdbcColumnHandle( + name, + new JdbcTypeHandle(Types.VARCHAR, Optional.of("varchar"), Optional.of(length), Optional.empty(), Optional.empty(), Optional.empty()), + VarcharType.createVarcharType(length)); + } +} From c225b8f46851ac4ea5191783c320d8e30757ad67 Mon Sep 17 00:00:00 2001 From: Dain Sundstrom Date: Fri, 9 Dec 2022 16:04:47 -0800 Subject: [PATCH 06/24] Add Redshift pushdown --- plugin/trino-redshift/pom.xml | 11 + .../redshift/ImplementRedshiftAvgBigint.java | 26 ++ .../redshift/ImplementRedshiftAvgDecimal.java | 75 +++++ .../trino/plugin/redshift/RedshiftClient.java | 145 +++++++++ .../plugin/redshift/RedshiftClientModule.java | 2 + .../TestRedshiftAutomaticJoinPushdown.java | 72 +++++ .../redshift/TestRedshiftConnectorTest.java | 299 +++++++++++++++++- 7 files changed, 624 insertions(+), 6 deletions(-) create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java create mode 100644 plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java create mode 100644 plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java diff --git a/plugin/trino-redshift/pom.xml b/plugin/trino-redshift/pom.xml index 2370ecf35c3e..367ae423d4a4 100644 --- a/plugin/trino-redshift/pom.xml +++ b/plugin/trino-redshift/pom.xml @@ -23,6 +23,16 @@ trino-base-jdbc + + io.trino + trino-matching + + + + io.trino + trino-plugin-toolkit + + io.airlift configuration @@ -175,6 +185,7 @@ maven-surefire-plugin + **/TestRedshiftAutomaticJoinPushdown.java **/TestRedshiftConnectorTest.java **/TestRedshiftTableStatisticsReader.java **/TestRedshiftTypeMapping.java diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java new file mode 100644 index 000000000000..f9c546105546 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgBigint.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.plugin.jdbc.aggregation.BaseImplementAvgBigint; + +public class ImplementRedshiftAvgBigint + extends BaseImplementAvgBigint +{ + @Override + protected String getRewriteFormatExpression() + { + return "avg(CAST(%s AS double precision))"; + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java new file mode 100644 index 000000000000..103258db12b7 --- /dev/null +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/ImplementRedshiftAvgDecimal.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import io.trino.matching.Capture; +import io.trino.matching.Captures; +import io.trino.matching.Pattern; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.spi.connector.AggregateFunction; +import io.trino.spi.expression.Variable; +import io.trino.spi.type.DecimalType; + +import java.util.Optional; + +import static com.google.common.base.Verify.verify; +import static io.trino.matching.Capture.newCapture; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.basicAggregation; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.functionName; +import static io.trino.plugin.base.aggregation.AggregateFunctionPatterns.singleArgument; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.type; +import static io.trino.plugin.base.expression.ConnectorExpressionPatterns.variable; +import static io.trino.plugin.redshift.RedshiftClient.REDSHIFT_MAX_DECIMAL_PRECISION; +import static java.lang.String.format; + +public class ImplementRedshiftAvgDecimal + implements AggregateFunctionRule +{ + private static final Capture INPUT = newCapture(); + + @Override + public Pattern getPattern() + { + return basicAggregation() + .with(functionName().equalTo("avg")) + .with(singleArgument().matching( + variable() + .with(type().matching(DecimalType.class::isInstance)) + .capturedAs(INPUT))); + } + + @Override + public Optional rewrite(AggregateFunction aggregateFunction, Captures captures, RewriteContext context) + { + Variable input = captures.get(INPUT); + JdbcColumnHandle columnHandle = (JdbcColumnHandle) context.getAssignment(input.getName()); + DecimalType type = (DecimalType) columnHandle.getColumnType(); + verify(aggregateFunction.getOutputType().equals(type)); + + // When decimal type has maximum precision we can get result that is not matching Presto avg semantics. + if (type.getPrecision() == REDSHIFT_MAX_DECIMAL_PRECISION) { + return Optional.of(new JdbcExpression( + format("avg(CAST(%s AS decimal(%s, %s)))", context.rewriteExpression(input).orElseThrow(), type.getPrecision(), type.getScale()), + columnHandle.getJdbcTypeHandle())); + } + + // Redshift avg function rounds down resulting decimal. + // To match Presto avg semantics, we extend scale by 1 and round result to target scale. + return Optional.of(new JdbcExpression( + format("round(avg(CAST(%s AS decimal(%s, %s))), %s)", context.rewriteExpression(input).orElseThrow(), type.getPrecision() + 1, type.getScale() + 1, type.getScale()), + columnHandle.getJdbcTypeHandle())); + } +} diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java index 2183e01bc4d7..21ee2eabcf02 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClient.java @@ -16,12 +16,20 @@ import com.amazon.redshift.jdbc.RedshiftPreparedStatement; import com.amazon.redshift.util.RedshiftObject; import com.google.common.base.CharMatcher; +import com.google.common.collect.ImmutableSet; import io.airlift.slice.Slice; +import io.trino.plugin.base.aggregation.AggregateFunctionRewriter; +import io.trino.plugin.base.aggregation.AggregateFunctionRule; +import io.trino.plugin.base.expression.ConnectorExpressionRewriter; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcExpression; +import io.trino.plugin.jdbc.JdbcJoinCondition; +import io.trino.plugin.jdbc.JdbcSortItem; +import io.trino.plugin.jdbc.JdbcSplit; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.JdbcTableHandle; import io.trino.plugin.jdbc.JdbcTypeHandle; @@ -33,11 +41,26 @@ import io.trino.plugin.jdbc.SliceWriteFunction; import io.trino.plugin.jdbc.StandardColumnMappings; import io.trino.plugin.jdbc.WriteMapping; +import io.trino.plugin.jdbc.aggregation.ImplementAvgFloatingPoint; +import io.trino.plugin.jdbc.aggregation.ImplementCount; +import io.trino.plugin.jdbc.aggregation.ImplementCountAll; +import io.trino.plugin.jdbc.aggregation.ImplementCountDistinct; +import io.trino.plugin.jdbc.aggregation.ImplementMinMax; +import io.trino.plugin.jdbc.aggregation.ImplementStddevPop; +import io.trino.plugin.jdbc.aggregation.ImplementStddevSamp; +import io.trino.plugin.jdbc.aggregation.ImplementSum; +import io.trino.plugin.jdbc.aggregation.ImplementVariancePop; +import io.trino.plugin.jdbc.aggregation.ImplementVarianceSamp; +import io.trino.plugin.jdbc.expression.JdbcConnectorExpressionRewriterBuilder; import io.trino.plugin.jdbc.logging.RemoteQueryModifier; import io.trino.plugin.jdbc.mapping.IdentifierMapping; import io.trino.spi.TrinoException; +import io.trino.spi.connector.AggregateFunction; import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; +import io.trino.spi.connector.JoinCondition; +import io.trino.spi.connector.JoinStatistics; +import io.trino.spi.connector.JoinType; import io.trino.spi.predicate.TupleDomain; import io.trino.spi.statistics.TableStatistics; import io.trino.spi.type.CharType; @@ -72,6 +95,8 @@ import java.time.ZoneOffset; import java.time.format.DateTimeFormatter; import java.time.format.DateTimeFormatterBuilder; +import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.OptionalLong; import java.util.function.BiFunction; @@ -81,6 +106,7 @@ import static com.google.common.base.Verify.verify; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_NON_TRANSIENT_ERROR; +import static io.trino.plugin.jdbc.JdbcJoinPushdownUtil.implementJoinCostAware; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.bigintWriteFunction; import static io.trino.plugin.jdbc.StandardColumnMappings.booleanColumnMapping; @@ -150,6 +176,7 @@ import static java.math.RoundingMode.UNNECESSARY; import static java.time.temporal.ChronoField.NANO_OF_SECOND; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; public class RedshiftClient extends BaseJdbcClient @@ -199,6 +226,7 @@ public class RedshiftClient .toFormatter(); private static final OffsetDateTime REDSHIFT_MIN_SUPPORTED_TIMESTAMP_TZ = OffsetDateTime.of(-4712, 1, 1, 0, 0, 0, 0, ZoneOffset.UTC); + private final AggregateFunctionRewriter aggregateFunctionRewriter; private final boolean statisticsEnabled; private final RedshiftTableStatisticsReader statisticsReader; @@ -212,10 +240,64 @@ public RedshiftClient( RemoteQueryModifier queryModifier) { super(config, "\"", connectionFactory, queryBuilder, identifierMapping, queryModifier); + ConnectorExpressionRewriter connectorExpressionRewriter = JdbcConnectorExpressionRewriterBuilder.newBuilder() + .addStandardRules(this::quoted) + .build(); + + JdbcTypeHandle bigintTypeHandle = new JdbcTypeHandle(Types.BIGINT, Optional.of("bigint"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty()); + + aggregateFunctionRewriter = new AggregateFunctionRewriter<>( + connectorExpressionRewriter, + ImmutableSet.>builder() + .add(new ImplementCountAll(bigintTypeHandle)) + .add(new ImplementCount(bigintTypeHandle)) + .add(new ImplementCountDistinct(bigintTypeHandle, true)) + .add(new ImplementMinMax(true)) + .add(new ImplementSum(RedshiftClient::toTypeHandle)) + .add(new ImplementAvgFloatingPoint()) + .add(new ImplementRedshiftAvgDecimal()) + .add(new ImplementRedshiftAvgBigint()) + .add(new ImplementStddevSamp()) + .add(new ImplementStddevPop()) + .add(new ImplementVarianceSamp()) + .add(new ImplementVariancePop()) + .build()); + this.statisticsEnabled = requireNonNull(statisticsConfig, "statisticsConfig is null").isEnabled(); this.statisticsReader = new RedshiftTableStatisticsReader(connectionFactory); } + private static Optional toTypeHandle(DecimalType decimalType) + { + return Optional.of( + new JdbcTypeHandle( + Types.NUMERIC, + Optional.of("decimal"), + Optional.of(decimalType.getPrecision()), + Optional.of(decimalType.getScale()), + Optional.empty(), + Optional.empty())); + } + + @Override + public Connection getConnection(ConnectorSession session, JdbcSplit split, JdbcTableHandle tableHandle) + throws SQLException + { + Connection connection = super.getConnection(session, split, tableHandle); + try { + // super.getConnection sets read-only, since the connection is going to be used only for reads. + // However, for a complex query, Redshift may decide to create some temporary tables behind + // the scenes, and this requires the connection not to be read-only, otherwise Redshift + // may fail with "ERROR: transaction is read-only". + connection.setReadOnly(false); + } + catch (SQLException e) { + connection.close(); + throw e; + } + return connection; + } + @Override public Optional getTableComment(ResultSet resultSet) { @@ -223,6 +305,12 @@ public Optional getTableComment(ResultSet resultSet) return Optional.empty(); } + @Override + public Optional implementAggregation(ConnectorSession session, AggregateFunction aggregate, Map assignments) + { + return aggregateFunctionRewriter.rewrite(session, aggregate, assignments); + } + @Override public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHandle handle, TupleDomain tupleDomain) { @@ -241,6 +329,63 @@ public TableStatistics getTableStatistics(ConnectorSession session, JdbcTableHan } } + @Override + public boolean supportsTopN(ConnectorSession session, JdbcTableHandle handle, List sortOrder) + { + return true; + } + + @Override + protected Optional topNFunction() + { + return Optional.of((query, sortItems, limit) -> { + String orderBy = sortItems.stream() + .map(sortItem -> { + String ordering = sortItem.getSortOrder().isAscending() ? "ASC" : "DESC"; + String nullsHandling = sortItem.getSortOrder().isNullsFirst() ? "NULLS FIRST" : "NULLS LAST"; + return format("%s %s %s", quoted(sortItem.getColumn().getColumnName()), ordering, nullsHandling); + }) + .collect(joining(", ")); + + return format("%s ORDER BY %s LIMIT %d", query, orderBy, limit); + }); + } + + @Override + public boolean isTopNGuaranteed(ConnectorSession session) + { + return true; + } + + @Override + protected boolean isSupportedJoinCondition(ConnectorSession session, JdbcJoinCondition joinCondition) + { + return joinCondition.getOperator() != JoinCondition.Operator.IS_DISTINCT_FROM; + } + + @Override + public Optional implementJoin(ConnectorSession session, + JoinType joinType, + PreparedQuery leftSource, + PreparedQuery rightSource, + List joinConditions, + Map rightAssignments, + Map leftAssignments, + JoinStatistics statistics) + { + if (joinType == JoinType.FULL_OUTER) { + // FULL JOIN is only supported with merge-joinable or hash-joinable join conditions + return Optional.empty(); + } + return implementJoinCostAware( + session, + joinType, + leftSource, + rightSource, + statistics, + () -> super.implementJoin(session, joinType, leftSource, rightSource, joinConditions, rightAssignments, leftAssignments, statistics)); + } + @Override protected void renameTable(ConnectorSession session, Connection connection, String catalogName, String remoteSchemaName, String remoteTableName, String newRemoteSchemaName, String newRemoteTableName) throws SQLException diff --git a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java index ef4153ee45ef..aeffaac16ff7 100644 --- a/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java +++ b/plugin/trino-redshift/src/main/java/io/trino/plugin/redshift/RedshiftClientModule.java @@ -24,6 +24,7 @@ import io.trino.plugin.jdbc.DriverConnectionFactory; import io.trino.plugin.jdbc.ForBaseJdbc; import io.trino.plugin.jdbc.JdbcClient; +import io.trino.plugin.jdbc.JdbcJoinPushdownSupportModule; import io.trino.plugin.jdbc.JdbcStatisticsConfig; import io.trino.plugin.jdbc.credential.CredentialProvider; import io.trino.plugin.jdbc.ptf.Query; @@ -46,6 +47,7 @@ public void setup(Binder binder) configBinder(binder).bindConfig(JdbcStatisticsConfig.class); install(new DecimalModule()); + install(new JdbcJoinPushdownSupportModule()); } @Singleton diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java new file mode 100644 index 000000000000..3509f8dd8b9c --- /dev/null +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftAutomaticJoinPushdown.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.redshift; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import io.trino.plugin.jdbc.BaseAutomaticJoinPushdownTest; +import io.trino.testing.QueryRunner; +import org.testng.SkipException; + +import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; +import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static java.lang.String.format; +import static java.util.Locale.ENGLISH; + +public class TestRedshiftAutomaticJoinPushdown + extends BaseAutomaticJoinPushdownTest +{ + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return createRedshiftQueryRunner( + ImmutableMap.of(), + ImmutableMap.of(), + ImmutableList.of()); + } + + @Override + public void testJoinPushdownWithEmptyStatsInitially() + { + throw new SkipException("Redshift table statistics are automatically populated"); + } + + @Override + protected void gatherStats(String tableName) + { + executeInRedshift(handle -> { + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + for (int i = 0; i < 5; i++) { + long actualCount = handle.createQuery(format("SELECT count(*) FROM %s.%s", TEST_SCHEMA, tableName)) + .mapTo(Long.class) + .one(); + long estimatedCount = handle.createQuery( + "SELECT reltuples FROM pg_class " + + "WHERE relnamespace = (SELECT oid FROM pg_namespace WHERE nspname = :schema) " + + "AND relname = :table_name") + .bind("schema", TEST_SCHEMA) + .bind("table_name", tableName.toLowerCase(ENGLISH).replace("\"", "")) + .mapTo(Long.class) + .one(); + if (actualCount == estimatedCount) { + return; + } + handle.execute(format("ANALYZE VERBOSE %s.%s", TEST_SCHEMA, tableName)); + } + throw new IllegalStateException("Stats not gathered"); + }); + } +} diff --git a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java index 863f308b9554..1b16e335bd82 100644 --- a/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java +++ b/plugin/trino-redshift/src/test/java/io/trino/plugin/redshift/TestRedshiftConnectorTest.java @@ -13,7 +13,9 @@ */ package io.trino.plugin.redshift; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import io.trino.Session; import io.trino.plugin.jdbc.BaseJdbcConnectorTest; import io.trino.testing.QueryRunner; import io.trino.testing.TestingConnectorBehavior; @@ -24,12 +26,18 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.util.List; import java.util.Optional; import java.util.OptionalInt; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.redshift.RedshiftQueryRunner.TEST_SCHEMA; import static io.trino.plugin.redshift.RedshiftQueryRunner.createRedshiftQueryRunner; import static io.trino.plugin.redshift.RedshiftQueryRunner.executeInRedshift; +import static io.trino.plugin.redshift.RedshiftQueryRunner.executeWithRedshift; import static io.trino.testing.TestingNames.randomNameSuffix; import static java.lang.String.format; import static java.util.Locale.ENGLISH; @@ -56,12 +64,6 @@ protected QueryRunner createQueryRunner() protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) { switch (connectorBehavior) { - case SUPPORTS_AGGREGATION_PUSHDOWN: - case SUPPORTS_JOIN_PUSHDOWN: - case SUPPORTS_TOPN_PUSHDOWN: - case SUPPORTS_PREDICATE_PUSHDOWN_WITH_VARCHAR_INEQUALITY: - return false; - case SUPPORTS_COMMENT_ON_TABLE: case SUPPORTS_ADD_COLUMN_WITH_COMMENT: case SUPPORTS_CREATE_TABLE_WITH_TABLE_COMMENT: @@ -75,6 +77,18 @@ protected boolean hasBehavior(TestingConnectorBehavior connectorBehavior) case SUPPORTS_RENAME_TABLE_ACROSS_SCHEMAS: return false; + case SUPPORTS_AGGREGATION_PUSHDOWN_STDDEV: + case SUPPORTS_AGGREGATION_PUSHDOWN_VARIANCE: + case SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT: + return true; + + case SUPPORTS_JOIN_PUSHDOWN: + case SUPPORTS_JOIN_PUSHDOWN_WITH_VARCHAR_EQUALITY: + return true; + case SUPPORTS_JOIN_PUSHDOWN_WITH_DISTINCT_FROM: + case SUPPORTS_JOIN_PUSHDOWN_WITH_FULL_JOIN: + return false; + default: return super.hasBehavior(connectorBehavior); } @@ -211,6 +225,35 @@ public void testCaseColumnNames(String tableName) } } + /** + * Tries to create situation where Redshift would decide to materialize a temporary table for query sent to it by us. + * Such temporary table requires that our Connection is not read-only. + */ + @Test + public void testComplexPushdownThatMayElicitTemporaryTable() + { + int subqueries = 10; + String subquery = "SELECT custkey, count(*) c FROM orders GROUP BY custkey"; + StringBuilder sql = new StringBuilder(); + sql.append(format( + "SELECT t0.custkey, %s c_sum ", + IntStream.range(0, subqueries) + .mapToObj(i -> format("t%s.c", i)) + .collect(Collectors.joining("+")))); + sql.append(format("FROM (%s) t0 ", subquery)); + for (int i = 1; i < subqueries; i++) { + sql.append(format("JOIN (%s) t%s ON t0.custkey = t%s.custkey ", subquery, i, i)); + } + sql.append("WHERE t0.custkey = 1045 OR rand() = 42"); + + Session forceJoinPushdown = Session.builder(getSession()) + .setCatalogSessionProperty(getSession().getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + + assertThat(query(forceJoinPushdown, sql.toString())) + .matches(format("SELECT max(custkey), count(*) * %s FROM tpch.tiny.orders WHERE custkey = 1045", subqueries)); + } + private static void gatherStats(String tableName) { executeInRedshift(handle -> { @@ -250,6 +293,241 @@ public Object[][] testCaseColumnNamesDataProvider() }; } + @Override + public void testCountDistinctWithStringTypes() + { + // cannot test using generic method as Redshift does not allow non-ASCII characters in CHAR values. + assertThatThrownBy(super::testCountDistinctWithStringTypes).hasMessageContaining("Value for Redshift CHAR must be ASCII, but found 'ą'"); + + List rows = Stream.of("a", "b", "A", "B", " a ", "a", "b", " b ") + .map(value -> format("'%1$s', '%1$s'", value)) + .collect(toImmutableList()); + String tableName = "distinct_strings" + randomNameSuffix(); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, tableName, "(t_char CHAR(5), t_varchar VARCHAR(5))", rows)) { + // Single count(DISTINCT ...) can be pushed even down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + // Single count(DISTINCT ...) can be pushed down even if SUPPORTS_AGGREGATION_PUSHDOWN_COUNT_DISTINCT == false as GROUP BY + assertThat(query("SELECT count(DISTINCT t_char) FROM " + testTable.getName())) + .matches("VALUES BIGINT '6'") + .isFullyPushedDown(); + + assertThat(query("SELECT count(DISTINCT t_char), count(DISTINCT t_varchar) FROM " + testTable.getName())) + .matches("VALUES (BIGINT '6', BIGINT '6')") + .isFullyPushedDown(); + } + } + + @Override + public void testAggregationPushdown() + { + throw new SkipException("tested in testAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testAggregationPushdown(String distStyle) + { + String nation = format("%s.nation_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + String customer = format("%s.customer_%s_%s", TEST_SCHEMA, distStyle, randomNameSuffix()); + try { + copyWithDistStyle(TEST_SCHEMA + ".nation", nation, distStyle, Optional.of("regionkey")); + copyWithDistStyle(TEST_SCHEMA + ".customer", customer, distStyle, Optional.of("nationkey")); + + // TODO support aggregation pushdown with GROUPING SETS + // TODO support aggregation over expressions + + // count() + assertThat(query("SELECT count(*) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(nationkey) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + nation)).isFullyPushedDown(); + assertThat(query("SELECT regionkey, count(1) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT count(*) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(a_bigint) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count(1) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT count() FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT a_bigint, count(1) FROM " + emptyTableName + " GROUP BY a_bigint")).isFullyPushedDown(); + } + + // GROUP BY + assertThat(query("SELECT regionkey, min(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, max(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + assertThat(query("SELECT regionkey, avg(nationkey) FROM " + nation + " GROUP BY regionkey")).isFullyPushedDown(); + try (TestTable emptyTable = createAggregationTestTable(getSession().getSchema().orElseThrow() + ".empty_table", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT t_double, min(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, max(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, sum(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + assertThat(query("SELECT t_double, avg(a_bigint) FROM " + emptyTableName + " GROUP BY t_double")).isFullyPushedDown(); + } + + // GROUP BY and WHERE on bigint column + // GROUP BY and WHERE on aggregation key + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 GROUP BY regionkey")).isFullyPushedDown(); + + // GROUP BY and WHERE on varchar column + // GROUP BY and WHERE on "other" (not aggregation key, not aggregation input) + assertThat(query("SELECT regionkey, sum(nationkey) FROM " + nation + " WHERE regionkey < 4 AND name > 'AAA' GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above WHERE and LIMIT + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT * FROM " + nation + " WHERE regionkey < 2 LIMIT 11) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY above TopN + assertThat(query("SELECT regionkey, sum(nationkey) FROM (SELECT regionkey, nationkey FROM " + nation + " ORDER BY nationkey ASC LIMIT 10) GROUP BY regionkey")).isFullyPushedDown(); + // GROUP BY with JOIN + assertThat(query( + joinPushdownEnabled(getSession()), + "SELECT n.regionkey, sum(c.acctbal) acctbals FROM " + nation + " n LEFT JOIN " + customer + " c USING (nationkey) GROUP BY 1")) + .isFullyPushedDown(); + // GROUP BY with WHERE on neither grouping nor aggregation column + assertThat(query("SELECT nationkey, min(regionkey) FROM " + nation + " WHERE name = 'ARGENTINA' GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column + assertThat(query("SELECT count(name) FROM " + nation)).isFullyPushedDown(); + // aggregation on varchar column with GROUPING + assertThat(query("SELECT nationkey, count(name) FROM " + nation + " GROUP BY nationkey")).isFullyPushedDown(); + // aggregation on varchar column with WHERE + assertThat(query("SELECT count(name) FROM " + nation + " WHERE name = 'ARGENTINA'")).isFullyPushedDown(); + } + finally { + executeInRedshift("DROP TABLE IF EXISTS " + nation); + executeInRedshift("DROP TABLE IF EXISTS " + customer); + } + } + + @Override + public void testNumericAggregationPushdown() + { + throw new SkipException("tested in testNumericAggregationPushdown(String)"); + } + + @Test(dataProvider = "testAggregationPushdownDistStylesDataProvider") + public void testNumericAggregationPushdown(String distStyle) + { + String schemaName = getSession().getSchema().orElseThrow(); + // empty table + try (TestTable emptyTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", ImmutableList.of())) { + String emptyTableName = emptyTable.getName() + "_" + distStyle; + copyWithDistStyle(emptyTable.getName(), emptyTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + emptyTableName)).isFullyPushedDown(); + } + + try (TestTable testTable = createAggregationTestTable(schemaName + ".test_aggregation_pushdown", + ImmutableList.of("100.000, 100000000.000000000, 100.000, 100000000", "123.321, 123456789.987654321, 123.321, 123456789"))) { + String testTableName = testTable.getName() + "_" + distStyle; + copyWithDistStyle(testTable.getName(), testTableName, distStyle, Optional.of("a_bigint")); + + assertThat(query("SELECT min(short_decimal), min(long_decimal), min(a_bigint), min(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT max(short_decimal), max(long_decimal), max(a_bigint), max(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT sum(short_decimal), sum(long_decimal), sum(a_bigint), sum(t_double) FROM " + testTableName)).isFullyPushedDown(); + assertThat(query("SELECT avg(short_decimal), avg(long_decimal), avg(a_bigint), avg(t_double) FROM " + testTableName)).isFullyPushedDown(); + + // smoke testing of more complex cases + // WHERE on aggregation column + assertThat(query("SELECT min(short_decimal), min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124")).isFullyPushedDown(); + // WHERE on non-aggregation column + assertThat(query("SELECT min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110")).isFullyPushedDown(); + // GROUP BY + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on both grouping and aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 AND long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on grouping column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE short_decimal < 110 GROUP BY short_decimal")).isFullyPushedDown(); + // GROUP BY with WHERE on aggregation column + assertThat(query("SELECT short_decimal, min(long_decimal) FROM " + testTableName + " WHERE long_decimal < 124 GROUP BY short_decimal")).isFullyPushedDown(); + } + } + + private static void copyWithDistStyle(String sourceTableName, String destTableName, String distStyle, Optional distKey) + { + if (distStyle.equals("AUTO")) { + // NOTE: Redshift doesn't support setting diststyle AUTO in CTAS statements + executeInRedshift("CREATE TABLE " + destTableName + " AS SELECT * FROM " + sourceTableName); + // Redshift doesn't allow ALTER DISTSTYLE if original and new style are same, so we need to check current diststyle of table + boolean isDistStyleAuto = executeWithRedshift(handle -> { + Optional currentDistStyle = handle.createQuery("" + + "SELECT releffectivediststyle " + + "FROM pg_class_info AS a LEFT JOIN pg_namespace AS b ON a.relnamespace = b.oid " + + "WHERE lower(nspname) = lower(:schema_name) AND lower(relname) = lower(:table_name)") + .bind("schema_name", TEST_SCHEMA) + // destTableName = TEST_SCHEMA + "." + tableName + .bind("table_name", destTableName.substring(destTableName.indexOf(".") + 1)) + .mapTo(Long.class) + .findOne(); + + // 10 means AUTO(ALL) and 11 means AUTO(EVEN). See https://docs.aws.amazon.com/redshift/latest/dg/r_PG_CLASS_INFO.html. + return currentDistStyle.isPresent() && (currentDistStyle.get() == 10 || currentDistStyle.get() == 11); + }); + if (!isDistStyleAuto) { + executeInRedshift("ALTER TABLE " + destTableName + " ALTER DISTSTYLE " + distStyle); + } + } + else { + String copyWithDistStyleSql = "CREATE TABLE " + destTableName + " DISTSTYLE " + distStyle; + if (distStyle.equals("KEY")) { + copyWithDistStyleSql += format(" DISTKEY(%s)", distKey.orElseThrow()); + } + copyWithDistStyleSql += " AS SELECT * FROM " + sourceTableName; + executeInRedshift(copyWithDistStyleSql); + } + } + + @DataProvider + public Object[][] testAggregationPushdownDistStylesDataProvider() + { + return new Object[][] { + {"EVEN"}, + {"KEY"}, + {"ALL"}, + {"AUTO"}, + }; + } + + @Test + public void testDecimalAvgPushdownForMaximumDecimalScale() + { + List rows = ImmutableList.of( + "12345789.9876543210", + format("%s.%s", "1".repeat(28), "9".repeat(10))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(38, 10))", rows)) { + // Redshift avg rounds down decimal result which doesn't match Presto semantics + assertThatThrownBy(() -> assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown()) + .isInstanceOf(AssertionError.class) + .hasMessageContaining(""" + elements not found: + <(555555555555555555561728450.9938271605)> + and elements not expected: + <(555555555555555555561728450.9938271604)> + """); + } + } + + @Test + public void testDecimalAvgPushdownFoShortDecimalScale() + { + List rows = ImmutableList.of( + "0.987654321234567890", + format("0.%s", "1".repeat(18))); + + try (TestTable testTable = new TestTable(getQueryRunner()::execute, TEST_SCHEMA + ".test_agg_pushdown_avg_max_decimal", + "(t_decimal DECIMAL(18, 18))", rows)) { + assertThat(query("SELECT avg(t_decimal) FROM " + testTable.getName())).isFullyPushedDown(); + } + } + @Override @Test public void testReadMetadataWithRelationsConcurrentModifications() @@ -263,6 +541,15 @@ public void testInsertRowConcurrently() throw new SkipException("Test fails with a timeout sometimes and is flaky"); } + @Override + protected Session joinPushdownEnabled(Session session) + { + return Session.builder(super.joinPushdownEnabled(session)) + // strategy is AUTOMATIC by default and would not work for certain test cases (even if statistics are collected) + .setCatalogSessionProperty(session.getCatalog().orElseThrow(), "join_pushdown_strategy", "EAGER") + .build(); + } + @Override protected String errorMessageForInsertIntoNotNullColumn(String columnName) { From be62233a6035925ff77f1a4b767837b5f8956020 Mon Sep 17 00:00:00 2001 From: Mateusz Gajewski Date: Wed, 7 Dec 2022 12:56:35 +0100 Subject: [PATCH 07/24] Test SET PATH support by clients --- .../io/trino/client/StatementClientFactory.java | 10 +++++++++- .../java/io/trino/client/StatementClientV1.java | 8 ++++++-- .../src/main/java/io/trino/Session.java | 1 + .../java/io/trino/testing/TestingSession.java | 6 ++++++ .../testing/AbstractTestingTrinoClient.java | 2 +- .../src/test/java/io/trino/tests/TestServer.java | 16 ++++++++++++++++ 6 files changed, 39 insertions(+), 4 deletions(-) diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java index 4aa2eda71495..6e5004994c95 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientFactory.java @@ -15,12 +15,20 @@ import okhttp3.OkHttpClient; +import java.util.Optional; +import java.util.Set; + public final class StatementClientFactory { private StatementClientFactory() {} public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query) { - return new StatementClientV1(httpClient, session, query); + return new StatementClientV1(httpClient, session, query, Optional.empty()); + } + + public static StatementClient newStatementClient(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) + { + return new StatementClientV1(httpClient, session, query, clientCapabilities); } } diff --git a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java index c2acbaa3c7ff..c231679c32fc 100644 --- a/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java +++ b/client/trino-client/src/main/java/io/trino/client/StatementClientV1.java @@ -48,6 +48,7 @@ import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static com.google.common.net.HttpHeaders.ACCEPT_ENCODING; import static com.google.common.net.HttpHeaders.USER_AGENT; import static io.trino.client.JsonCodec.jsonCodec; @@ -56,6 +57,7 @@ import static java.net.HttpURLConnection.HTTP_OK; import static java.net.HttpURLConnection.HTTP_UNAUTHORIZED; import static java.net.HttpURLConnection.HTTP_UNAVAILABLE; +import static java.util.Arrays.stream; import static java.util.Objects.requireNonNull; import static java.util.concurrent.TimeUnit.MILLISECONDS; @@ -93,7 +95,7 @@ class StatementClientV1 private final AtomicReference state = new AtomicReference<>(State.RUNNING); - public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query) + public StatementClientV1(OkHttpClient httpClient, ClientSession session, String query, Optional> clientCapabilities) { requireNonNull(httpClient, "httpClient is null"); requireNonNull(session, "session is null"); @@ -107,7 +109,9 @@ public StatementClientV1(OkHttpClient httpClient, ClientSession session, String .filter(Optional::isPresent) .map(Optional::get) .findFirst(); - this.clientCapabilities = Joiner.on(",").join(ClientCapabilities.values()); + this.clientCapabilities = Joiner.on(",").join(clientCapabilities.orElseGet(() -> stream(ClientCapabilities.values()) + .map(Enum::name) + .collect(toImmutableSet()))); this.compressionDisabled = session.isCompressionDisabled(); Request request = buildQueryRequest(session, query); diff --git a/core/trino-main/src/main/java/io/trino/Session.java b/core/trino-main/src/main/java/io/trino/Session.java index 2b48d5d224ea..b30ec04bdf93 100644 --- a/core/trino-main/src/main/java/io/trino/Session.java +++ b/core/trino-main/src/main/java/io/trino/Session.java @@ -607,6 +607,7 @@ private SessionBuilder(Session session) this.remoteUserAddress = session.remoteUserAddress.orElse(null); this.userAgent = session.userAgent.orElse(null); this.clientInfo = session.clientInfo.orElse(null); + this.clientCapabilities = ImmutableSet.copyOf(session.clientCapabilities); this.clientTags = ImmutableSet.copyOf(session.clientTags); this.start = session.start; this.systemProperties.putAll(session.systemProperties); diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java index b02d27bd5907..9d4f2637a140 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingSession.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingSession.java @@ -15,11 +15,15 @@ import io.trino.Session; import io.trino.Session.SessionBuilder; +import io.trino.client.ClientCapabilities; import io.trino.execution.QueryIdGenerator; import io.trino.metadata.SessionPropertyManager; import io.trino.spi.security.Identity; import io.trino.spi.type.TimeZoneKey; +import java.util.Arrays; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; import static java.util.Locale.ENGLISH; public final class TestingSession @@ -54,6 +58,8 @@ public static SessionBuilder testSessionBuilder(SessionPropertyManager sessionPr .setSchema("schema") .setTimeZoneKey(DEFAULT_TIME_ZONE_KEY) .setLocale(ENGLISH) + .setClientCapabilities(Arrays.stream(ClientCapabilities.values()).map(Enum::name) + .collect(toImmutableSet())) .setRemoteUserAddress("address") .setUserAgent("agent"); } diff --git a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java index 0224103e7530..2679256700cf 100644 --- a/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java +++ b/testing/trino-testing/src/main/java/io/trino/testing/AbstractTestingTrinoClient.java @@ -92,7 +92,7 @@ public ResultWithQueryId execute(Session session, @Language("SQL") String sql ClientSession clientSession = toClientSession(session, trinoServer.getBaseUrl(), new Duration(2, TimeUnit.MINUTES)); - try (StatementClient client = newStatementClient(httpClient, clientSession, sql)) { + try (StatementClient client = newStatementClient(httpClient, clientSession, sql, Optional.of(session.getClientCapabilities()))) { while (client.isRunning()) { resultsSession.addResults(client.currentStatusInfo(), client.currentData()); client.advance(); diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java index 561c18e474bb..99b162657aae 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestServer.java @@ -42,6 +42,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; +import java.util.Set; import java.util.function.Function; import java.util.stream.Collector; import java.util.stream.Collectors; @@ -66,6 +67,7 @@ import static io.trino.SystemSessionProperties.HASH_PARTITION_COUNT; import static io.trino.SystemSessionProperties.JOIN_DISTRIBUTION_TYPE; import static io.trino.SystemSessionProperties.QUERY_MAX_MEMORY; +import static io.trino.client.ClientCapabilities.PATH; import static io.trino.client.ProtocolHeaders.TRINO_HEADERS; import static io.trino.spi.StandardErrorCode.INCOMPATIBLE_CLIENT; import static io.trino.testing.TestingSession.testSessionBuilder; @@ -77,6 +79,7 @@ import static javax.ws.rs.core.Response.Status.OK; import static javax.ws.rs.core.Response.Status.SEE_OTHER; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; import static org.testng.Assert.assertEquals; import static org.testng.Assert.assertNotNull; import static org.testng.Assert.assertNull; @@ -275,6 +278,19 @@ public void testVersionOnCompilerFailedError() } } + @Test + public void testSetPathSupportByClient() + { + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of()).build())) { + assertThatThrownBy(() -> testingClient.execute("SET PATH foo")) + .hasMessage("SET PATH not supported by client"); + } + + try (TestingTrinoClient testingClient = new TestingTrinoClient(server, testSessionBuilder().setClientCapabilities(Set.of(PATH.name())).build())) { + testingClient.execute("SET PATH foo"); + } + } + private void checkVersionOnError(String query, @Language("RegExp") String proofOfOrigin) { QueryResults queryResults = postQuery(request -> request From dfe33c790e47d39bee3f03357d61b2fcc5f7e578 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Hubert=20=C5=81ojek?= Date: Wed, 7 Dec 2022 18:32:33 +0100 Subject: [PATCH 08/24] Fix HTTP_Status on OAuth2 refresh token redirect When refresh token is retrieved for UI, currently we were sending HTTP Status 303, assuming that all the request will just repeat the call on the Location header. When this works for GET/PUT verbs, it does not for non-idempotent ones like POST, as every js http client should do a GET on LOCATION after 303 on POST. Due to that I change it to 307, that should force every client to repeat exactly the same request, no matter the verb. Co-authored-by: s2lomon --- .../ui/OAuth2WebUiAuthenticationFilter.java | 2 +- .../java/io/trino/server/ui/TestWebUi.java | 203 +++++++++++++++--- 2 files changed, 171 insertions(+), 34 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java index 253b456be457..24b66c563f24 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/OAuth2WebUiAuthenticationFilter.java @@ -164,7 +164,7 @@ private void redirectForNewToken(ContainerRequestContext request, String refresh { OAuth2Client.Response response = client.refreshTokens(refreshToken); String serializedToken = tokenPairSerializer.serialize(TokenPair.fromOAuth2Response(response)); - request.abortWith(Response.seeOther(request.getUriInfo().getRequestUri()) + request.abortWith(Response.temporaryRedirect(request.getUriInfo().getRequestUri()) .cookie(OAuthWebUiCookie.create(serializedToken, tokenExpiration.map(expiration -> Instant.now().plus(expiration)).orElse(response.getExpiration()))) .build()); } diff --git a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java index cf56a8fbeaa6..d164a6c64245 100644 --- a/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java +++ b/core/trino-main/src/test/java/io/trino/server/ui/TestWebUi.java @@ -35,6 +35,8 @@ import io.trino.server.security.ResourceSecurity; import io.trino.server.security.oauth2.ChallengeFailedException; import io.trino.server.security.oauth2.OAuth2Client; +import io.trino.server.security.oauth2.TokenPairSerializer; +import io.trino.server.security.oauth2.TokenPairSerializer.TokenPair; import io.trino.server.testing.TestingTrinoServer; import io.trino.spi.security.AccessDeniedException; import io.trino.spi.security.BasicPrincipal; @@ -104,12 +106,14 @@ import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGIN; import static io.trino.server.ui.FormWebUiAuthenticationFilter.UI_LOGOUT; import static io.trino.testing.assertions.Assert.assertEquals; +import static io.trino.testing.assertions.Assert.assertEventually; import static java.nio.charset.StandardCharsets.UTF_8; -import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.Objects.requireNonNull; import static java.util.function.Predicate.not; import static javax.servlet.http.HttpServletResponse.SC_NOT_FOUND; import static javax.servlet.http.HttpServletResponse.SC_OK; import static javax.servlet.http.HttpServletResponse.SC_SEE_OTHER; +import static javax.servlet.http.HttpServletResponse.SC_TEMPORARY_REDIRECT; import static javax.servlet.http.HttpServletResponse.SC_UNAUTHORIZED; import static javax.ws.rs.core.Response.Status.UNAUTHORIZED; import static org.assertj.core.api.Assertions.assertThat; @@ -148,6 +152,8 @@ public class TestWebUi private static final String TEST_PASSWORD2 = "test-password2"; private static final String HMAC_KEY = Resources.getResource("hmac_key.txt").getPath(); private static final PrivateKey JWK_PRIVATE_KEY; + private static final String REFRESH_TOKEN = "REFRESH_TOKEN"; + private static final Duration REFRESH_TOKEN_TIMEOUT = Duration.ofMinutes(1); static { try { @@ -652,8 +658,7 @@ public void testOAuth2Authenticator() .setBinding() .toInstance(oauthClient)) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); } finally { jwkServer.stop(); @@ -664,7 +669,7 @@ public void testOAuth2Authenticator() public void testOAuth2AuthenticatorWithoutOpenIdScope() throws Exception { - OAuth2ClientStub oauthClient = new OAuth2ClientStub(false); + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); try (TestingTrinoServer server = TestingTrinoServer.builder() @@ -677,8 +682,116 @@ public void testOAuth2AuthenticatorWithoutOpenIdScope() .setBinding() .toInstance(oauthClient)) .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorWithRefreshToken() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + assertAuth2Authentication(server, oauthClient.getAccessToken(), true); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRedirectAfterAuthTokenRefresh() + throws Exception + { + // the first issued authorization token will be expired + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ZERO); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", REFRESH_TOKEN_TIMEOUT.getSeconds() + "s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertCookieWithRefreshToken(server, cookie, oauthClient.getAccessToken()); + + assertResponseCode(client, getValidApiLocation(baseUri), SC_TEMPORARY_REDIRECT); + assertOk(client, getValidApiLocation(baseUri)); + } + finally { + jwkServer.stop(); + } + } + + @Test + public void testOAuth2AuthenticatorRefreshTokenExpiration() + throws Exception + { + OAuth2ClientStub oauthClient = new OAuth2ClientStub(false, Duration.ofSeconds(5)); + TestingHttpServer jwkServer = createTestingJwkServer(); + jwkServer.start(); + try (TestingTrinoServer server = TestingTrinoServer.builder() + .setProperties(ImmutableMap.builder() + .putAll(OAUTH2_PROPERTIES) + .put("http-server.authentication.oauth2.jwks-url", jwkServer.getBaseUrl().toString()) + .put("http-server.authentication.oauth2.refresh-tokens", "true") + .put("http-server.authentication.oauth2.refresh-tokens.issued-token.timeout", "10s") + .buildOrThrow()) + .setAdditionalModule(binder -> newOptionalBinder(binder, OAuth2Client.class) + .setBinding() + .toInstance(oauthClient)) + .build()) { + CookieManager cookieManager = new CookieManager(); + OkHttpClient client = this.client.newBuilder() + .cookieJar(new JavaNetCookieJar(cookieManager)) + .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + URI baseUri = httpServerInfo.getHttpsUri(); + + loginWithCallbackEndpoint(client, baseUri); + HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); + assertOk(client, getValidApiLocation(baseUri)); + + // wait for the cookie to expire + assertEventually(() -> assertThat(cookieManager.getCookieStore().getCookies()).isEmpty()); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); + + // create fake cookie with previous cookie value to check token validity + HttpCookie biscuit = new HttpCookie(cookie.getName(), cookie.getValue()); + biscuit.setPath(cookie.getPath()); + cookieManager.getCookieStore().add(baseUri, biscuit); + assertResponseCode(client, getValidApiLocation(baseUri), UNAUTHORIZED.getStatusCode()); } finally { jwkServer.stop(); @@ -694,6 +807,7 @@ public void testCustomPrincipalField() .put(SUBJECT, "unknown") .put("preferred_username", "test-user@email.com") .buildOrThrow(), + Duration.ofSeconds(5), true); TestingHttpServer jwkServer = createTestingJwkServer(); jwkServer.start(); @@ -711,8 +825,7 @@ public void testCustomPrincipalField() jaxrsBinder(binder).bind(AuthenticatedIdentityCapturingFilter.class); }) .build()) { - HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); - assertAuth2Authentication(httpServerInfo, oauthClient.getAccessToken()); + assertAuth2Authentication(server, oauthClient.getAccessToken(), false); Identity identity = server.getInstance(Key.get(AuthenticatedIdentityCapturingFilter.class)).getAuthenticatedIdentity(); assertThat(identity.getUser()).isEqualTo("test-user"); assertThat(identity.getPrincipal()).isEqualTo(Optional.of(new BasicPrincipal("test-user@email.com"))); @@ -722,20 +835,15 @@ public void testCustomPrincipalField() } } - private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String accessToken) + private void assertAuth2Authentication(TestingTrinoServer server, String accessToken, boolean refreshTokensEnabled) throws Exception { - String state = newJwtBuilder() - .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) - .setAudience("trino_oauth_ui") - .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) - .compact(); - CookieManager cookieManager = new CookieManager(); OkHttpClient client = this.client.newBuilder() .cookieJar(new JavaNetCookieJar(cookieManager)) .build(); + HttpServerInfo httpServerInfo = server.getInstance(Key.get(HttpServerInfo.class)); // HTTP is not allowed for OAuth testDisabled(httpServerInfo.getHttpUri()); @@ -747,21 +855,17 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getLocation(baseUri, "/ui/unknown"), "http://example.com/authorize", false); assertResponseCode(client, getLocation(baseUri, "/ui/api/unknown"), UNAUTHORIZED.getStatusCode()); - // login with the callback endpoint - assertRedirect( - client, - uriBuilderFrom(baseUri) - .replacePath(CALLBACK_ENDPOINT) - .addParameter("code", "TEST_CODE") - .addParameter("state", state) - .toString(), - getUiLocation(baseUri), - false); + loginWithCallbackEndpoint(client, baseUri); HttpCookie cookie = getOnlyElement(cookieManager.getCookieStore().getCookies()); - assertEquals(cookie.getValue(), accessToken); + if (refreshTokensEnabled) { + assertCookieWithRefreshToken(server, cookie, accessToken); + } + else { + assertEquals(cookie.getValue(), accessToken); + assertThat(cookie.getMaxAge()).isGreaterThan(0).isLessThan(30); + } assertEquals(cookie.getPath(), "/ui/"); assertEquals(cookie.getDomain(), baseUri.getHost()); - assertTrue(cookie.getMaxAge() > 0 && cookie.getMaxAge() < MINUTES.toSeconds(5)); assertTrue(cookie.isHttpOnly()); // authentication cookie is now set, so UI should work @@ -778,6 +882,34 @@ private void assertAuth2Authentication(HttpServerInfo httpServerInfo, String acc assertRedirect(client, getUiLocation(baseUri), "http://example.com/authorize", false); } + private static void loginWithCallbackEndpoint(OkHttpClient client, URI baseUri) + throws IOException + { + String state = newJwtBuilder() + .signWith(hmacShaKeyFor(Hashing.sha256().hashString(STATE_KEY, UTF_8).asBytes())) + .setAudience("trino_oauth_ui") + .setExpiration(Date.from(ZonedDateTime.now().plusMinutes(10).toInstant())) + .compact(); + assertRedirect( + client, + uriBuilderFrom(baseUri) + .replacePath(CALLBACK_ENDPOINT) + .addParameter("code", "TEST_CODE") + .addParameter("state", state) + .toString(), + getUiLocation(baseUri), + false); + } + + private static void assertCookieWithRefreshToken(TestingTrinoServer server, HttpCookie authCookie, String accessToken) + { + TokenPairSerializer tokenPairSerializer = server.getInstance(Key.get(TokenPairSerializer.class)); + TokenPair deserialize = tokenPairSerializer.deserialize(authCookie.getValue()); + assertEquals(deserialize.getAccessToken(), accessToken); + assertEquals(deserialize.getRefreshToken(), Optional.of(REFRESH_TOKEN)); + assertThat(authCookie.getMaxAge()).isGreaterThan(0).isLessThan(REFRESH_TOKEN_TIMEOUT.getSeconds()); + } + private static void testAlwaysAuthorized(URI baseUri, OkHttpClient authorizedClient, String nodeId) throws IOException { @@ -1078,23 +1210,25 @@ private static class OAuth2ClientStub private static final SecureRandom secureRandom = new SecureRandom(); private final Claims claims; private final String accessToken; + private final Duration accessTokenValidity; private final Optional nonce; private final Optional idToken; public OAuth2ClientStub() { - this(true); + this(true, Duration.ofSeconds(5)); } - public OAuth2ClientStub(boolean issueIdToken) + public OAuth2ClientStub(boolean issueIdToken, Duration accessTokenValidity) { - this(ImmutableMap.of(), issueIdToken); + this(ImmutableMap.of(), accessTokenValidity, issueIdToken); } - public OAuth2ClientStub(Map additionalClaims, boolean issueIdToken) + public OAuth2ClientStub(Map additionalClaims, Duration accessTokenValidity, boolean issueIdToken) { claims = new DefaultClaims(createClaims()); - claims.putAll(additionalClaims); + claims.putAll(requireNonNull(additionalClaims, "additionalClaims is null")); + this.accessTokenValidity = requireNonNull(accessTokenValidity, "accessTokenValidity is null"); accessToken = issueToken(claims); if (issueIdToken) { nonce = Optional.of(randomNonce()); @@ -1127,7 +1261,7 @@ public Response getOAuth2Response(String code, URI callbackUri, Optional if (!"TEST_CODE".equals(code)) { throw new IllegalArgumentException("Expected TEST_CODE"); } - return new Response(accessToken, Instant.now().plusSeconds(5), idToken, Optional.empty()); + return new Response(accessToken, Instant.now().plus(accessTokenValidity), idToken, Optional.of(REFRESH_TOKEN)); } @Override @@ -1140,7 +1274,10 @@ public Optional> getClaims(String accessToken) public Response refreshTokens(String refreshToken) throws ChallengeFailedException { - throw new UnsupportedOperationException("Refresh tokens are not supported"); + if (refreshToken.equals(REFRESH_TOKEN)) { + return new Response(issueToken(claims), Instant.now().plusSeconds(30), idToken, Optional.of(REFRESH_TOKEN)); + } + throw new ChallengeFailedException("invalid refresh token"); } public String getAccessToken() From c056bc7cd00d641ce4c4be6632a63508330394c7 Mon Sep 17 00:00:00 2001 From: Raunaq Morarka Date: Sat, 10 Dec 2022 17:56:14 +0530 Subject: [PATCH 09/24] Fix recording of projection metrics Actual work is done in `pageProjectWork.process()` call while `projection.project` only performs setup of projection. So both `expressionProfiler` and `metrics.recordProjectionTime` needed to be around that method. --- .../io/trino/operator/project/PageProcessor.java | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java index a08faec243ae..d21267eec5cc 100644 --- a/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java +++ b/core/trino-main/src/main/java/io/trino/operator/project/PageProcessor.java @@ -342,13 +342,15 @@ private ProcessBatchResult processBatch(int batchSize) } else { if (pageProjectWork == null) { - Page inputPage = projection.getInputChannels().getInputChannels(page); - expressionProfiler.start(); - pageProjectWork = projection.project(session, yieldSignal, inputPage, positionsBatch); - long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); - metrics.recordProjectionTime(projectionTimeNanos); + pageProjectWork = projection.project(session, yieldSignal, projection.getInputChannels().getInputChannels(page), positionsBatch); } - if (!pageProjectWork.process()) { + + expressionProfiler.start(); + boolean finished = pageProjectWork.process(); + long projectionTimeNanos = expressionProfiler.stop(positionsBatch.size()); + metrics.recordProjectionTime(projectionTimeNanos); + + if (!finished) { return ProcessBatchResult.processBatchYield(); } previouslyComputedResults[i] = pageProjectWork.getResult(); From a46a5100c8e4cee2853019243e9c2213d87ec3c9 Mon Sep 17 00:00:00 2001 From: leetcode-1533 <1275963@gmail.com> Date: Fri, 2 Dec 2022 00:55:10 -0800 Subject: [PATCH 10/24] Fix dereference operations for union type in Hive Connector --- .../java/io/trino/plugin/hive/HiveType.java | 72 ++++++++++--- .../plugin/hive/util/HiveTypeTranslator.java | 8 +- .../tests/product/hive/TestReadUniontype.java | 101 ++++++++++++++++++ 3 files changed, 164 insertions(+), 17 deletions(-) diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java index 9230d9e26013..45afb7bb299a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/HiveType.java @@ -32,10 +32,14 @@ import java.util.Optional; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.lenientFormat; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.hive.HiveStorageFormat.AVRO; import static io.trino.plugin.hive.HiveStorageFormat.ORC; import static io.trino.plugin.hive.HiveTimestampPrecision.DEFAULT_PRECISION; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_FIELD_PREFIX; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_NAME; +import static io.trino.plugin.hive.util.HiveTypeTranslator.UNION_FIELD_TAG_TYPE; import static io.trino.plugin.hive.util.HiveTypeTranslator.fromPrimitiveType; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeInfo; import static io.trino.plugin.hive.util.HiveTypeTranslator.toTypeSignature; @@ -219,13 +223,32 @@ public Optional getHiveTypeForDereferences(List dereferences) { TypeInfo typeInfo = getTypeInfo(); for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - try { - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + try { + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } } - catch (RuntimeException e) { - return Optional.empty(); + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + try { + if (fieldIndex == 0) { + // union's tag field, defined in {@link io.trino.plugin.hive.util.HiveTypeTranslator#toTypeSignature} + return Optional.of(HiveType.toHiveType(UNION_FIELD_TAG_TYPE)); + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + } + } + catch (RuntimeException e) { + // return empty when failed to dereference, this could happen when partition and table schema mismatch + return Optional.empty(); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); } } return Optional.of(toHiveType(typeInfo)); @@ -235,16 +258,35 @@ public List getHiveDereferenceNames(List dereferences) { ImmutableList.Builder dereferenceNames = ImmutableList.builder(); TypeInfo typeInfo = getTypeInfo(); - for (int fieldIndex : dereferences) { - checkArgument(typeInfo instanceof StructTypeInfo, "typeInfo should be struct type", typeInfo); - StructTypeInfo structTypeInfo = (StructTypeInfo) typeInfo; - + for (int i = 0; i < dereferences.size(); i++) { + int fieldIndex = dereferences.get(i); checkArgument(fieldIndex >= 0, "fieldIndex cannot be negative"); - checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), - "fieldIndex should be less than the number of fields in the struct"); - String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); - dereferenceNames.add(fieldName); - typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + + if (typeInfo instanceof StructTypeInfo structTypeInfo) { + checkArgument(fieldIndex < structTypeInfo.getAllStructFieldNames().size(), + "fieldIndex should be less than the number of fields in the struct"); + + String fieldName = structTypeInfo.getAllStructFieldNames().get(fieldIndex); + dereferenceNames.add(fieldName); + typeInfo = structTypeInfo.getAllStructFieldTypeInfos().get(fieldIndex); + } + else if (typeInfo instanceof UnionTypeInfo unionTypeInfo) { + checkArgument((fieldIndex - 1) < unionTypeInfo.getAllUnionObjectTypeInfos().size(), + "fieldIndex should be less than the number of fields in the union plus tag field"); + + if (fieldIndex == 0) { + checkArgument(i == (dereferences.size() - 1), "Union's tag field should not have more subfields"); + dereferenceNames.add(UNION_FIELD_TAG_NAME); + break; + } + else { + typeInfo = unionTypeInfo.getAllUnionObjectTypeInfos().get(fieldIndex - 1); + dereferenceNames.add(UNION_FIELD_FIELD_PREFIX + (fieldIndex - 1)); + } + } + else { + throw new IllegalArgumentException(lenientFormat("typeInfo: %s should be struct or union type", typeInfo)); + } } return dereferenceNames.build(); diff --git a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java index 22d1e5f4ce2d..4d165511130a 100644 --- a/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java +++ b/plugin/trino-hive/src/main/java/io/trino/plugin/hive/util/HiveTypeTranslator.java @@ -91,6 +91,10 @@ public final class HiveTypeTranslator { private HiveTypeTranslator() {} + public static final String UNION_FIELD_TAG_NAME = "tag"; + public static final String UNION_FIELD_FIELD_PREFIX = "field"; + public static final Type UNION_FIELD_TAG_TYPE = TINYINT; + public static TypeInfo toTypeInfo(Type type) { requireNonNull(type, "type is null"); @@ -213,10 +217,10 @@ public static TypeSignature toTypeSignature(TypeInfo typeInfo, HiveTimestampPrec UnionTypeInfo unionTypeInfo = (UnionTypeInfo) typeInfo; List unionObjectTypes = unionTypeInfo.getAllUnionObjectTypeInfos(); ImmutableList.Builder typeSignatures = ImmutableList.builder(); - typeSignatures.add(namedField("tag", TINYINT.getTypeSignature())); + typeSignatures.add(namedField(UNION_FIELD_TAG_NAME, UNION_FIELD_TAG_TYPE.getTypeSignature())); for (int i = 0; i < unionObjectTypes.size(); i++) { TypeInfo unionObjectType = unionObjectTypes.get(i); - typeSignatures.add(namedField("field" + i, toTypeSignature(unionObjectType, timestampPrecision))); + typeSignatures.add(namedField(UNION_FIELD_FIELD_PREFIX + i, toTypeSignature(unionObjectType, timestampPrecision))); } return rowType(typeSignatures.build()); } diff --git a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java index a7ec94a51eb3..062ae91d9c3b 100644 --- a/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java +++ b/testing/trino-product-tests/src/main/java/io/trino/tests/product/hive/TestReadUniontype.java @@ -24,6 +24,7 @@ import java.util.Arrays; import java.util.List; +import static io.trino.testing.TestingNames.randomNameSuffix; import static io.trino.tests.product.TestGroups.SMOKE; import static io.trino.tests.product.utils.QueryExecutors.onHive; import static io.trino.tests.product.utils.QueryExecutors.onTrino; @@ -51,6 +52,87 @@ public static Object[][] storageFormats() return new String[][] {{"ORC"}, {"AVRO"}}; } + @DataProvider(name = "union_dereference_test_cases") + public static Object[][] unionDereferenceTestCases() + { + String tableUnionDereference = "test_union_dereference" + randomNameSuffix(); + // Hive insertion for union type in AVRO format has bugs, so we test on different table schemas for AVRO than ORC. + return new Object[][] {{ + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "INT, STRING>)" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, 321, 'row1') " + + "UNION ALL " + + "SELECT create_union(1, 55, 'row2') ", + tableUnionDereference), + format("SELECT unionLevel0.field0 FROM %s WHERE unionLevel0.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList(321), + format("SELECT unionLevel0.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + // there is an internal issue in Hive 1.2: + // unionLevel1 is declared as unionType, but has to be inserted by create_union(tagId, Int, String) + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "AVRO"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 5, 'testString'))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 5, 'testString'))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field1 FROM %s WHERE unionLevel0.field2.unionLevel1.field1 IS NOT NULL", tableUnionDereference), + Arrays.asList(5), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE<" + + "STRUCT>>)" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(0, 'testString1', 23))) " + + "UNION ALL " + + "SELECT create_union(0, named_struct('unionLevel1', create_union(1, 'testString2', 45))) ", + tableUnionDereference), + format("SELECT unionLevel0.field0.unionLevel1.field0 FROM %s WHERE unionLevel0.field0.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString1"), + format("SELECT unionLevel0.field0.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}, + { + format( + "CREATE TABLE %s (unionLevel0 UNIONTYPE>>, intLevel0 INT )" + + "STORED AS %s", + tableUnionDereference, + "ORC"), + format( + "INSERT INTO TABLE %s " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(0, 'testString', 5))), 8 " + + "UNION ALL " + + "SELECT create_union(2, 321, 'row1', named_struct('intLevel1', 1, 'stringLevel1', 'structval', 'unionLevel1', create_union(1, 'testString', 5))), 8 ", + tableUnionDereference), + format("SELECT unionLevel0.field2.unionLevel1.field0 FROM %s WHERE unionLevel0.field2.unionLevel1.field0 IS NOT NULL", tableUnionDereference), + Arrays.asList("testString"), + format("SELECT unionLevel0.field2.unionLevel1.tag FROM %s", tableUnionDereference), + Arrays.asList((byte) 0, (byte) 1), + "DROP TABLE IF EXISTS " + tableUnionDereference}}; + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testReadUniontype(String storageFormat) { @@ -137,6 +219,25 @@ public void testReadUniontype(String storageFormat) } } + @Test(dataProvider = "union_dereference_test_cases", groups = SMOKE) + public void testReadUniontypeWithDereference(String createTableSql, String insertSql, String selectSql, List expectedResult, String selectTagSql, List expectedTagResult, String dropTableSql) + { + // According to testing results, the Hive INSERT queries here only work in Hive 1.2 + if (getHiveVersionMajor() != 1 || getHiveVersionMinor() != 2) { + throw new SkipException("This test can only be run with Hive 1.2 (default config)"); + } + + onHive().executeQuery(createTableSql); + onHive().executeQuery(insertSql); + + QueryResult result = onTrino().executeQuery(selectSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedResult); + result = onTrino().executeQuery(selectTagSql); + assertThat(result.column(1)).containsExactlyInAnyOrderElementsOf(expectedTagResult); + + onTrino().executeQuery(dropTableSql); + } + @Test(dataProvider = "storage_formats", groups = SMOKE) public void testUnionTypeSchemaEvolution(String storageFormat) { From 0459735c25df57335d221056f2f084b580cb2c23 Mon Sep 17 00:00:00 2001 From: James Petty Date: Fri, 9 Dec 2022 14:08:34 -0500 Subject: [PATCH 11/24] Cleanup PartitioningExchanger Removes outdated comments and unnecessary methods in local exchange PartitioningExchanger since the operator is no longer implemented in a way that attempts to be thread-safe. --- .../exchange/PartitioningExchanger.java | 37 ++++++------------- 1 file changed, 11 insertions(+), 26 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java index 62cee860defe..808d875ea431 100644 --- a/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java +++ b/core/trino-main/src/main/java/io/trino/operator/exchange/PartitioningExchanger.java @@ -19,7 +19,6 @@ import io.trino.spi.Page; import it.unimi.dsi.fastutil.ints.IntArrayList; -import javax.annotation.Nullable; import javax.annotation.concurrent.NotThreadSafe; import java.util.List; @@ -58,18 +57,7 @@ public PartitioningExchanger( @Override public void accept(Page page) { - Consumer wholePagePartition = partitionPageOrFindWholePagePartition(page, partitionedPagePreparer.apply(page)); - if (wholePagePartition != null) { - // whole input page will go to this partition, compact the input page avoid over-retaining memory and to - // match the behavior of sub-partitioned pages that copy positions out - page.compact(); - sendPageToPartition(wholePagePartition, page); - } - } - - @Nullable - private Consumer partitionPageOrFindWholePagePartition(Page page, Page partitionPage) - { + Page partitionPage = partitionedPagePreparer.apply(page); // assign each row to a partition. The assignments lists are all expected to cleared by the previous iterations for (int position = 0; position < partitionPage.getPositionCount(); position++) { int partition = partitionFunction.getPartition(partitionPage, position); @@ -89,22 +77,19 @@ private Consumer partitionPageOrFindWholePagePartition(Page page, Page par int[] positions = positionsList.elements(); positionsList.clear(); + Page pageSplit; if (partitionSize == page.getPositionCount()) { - // entire page will be sent to this partition, compact and send the page after releasing the lock - return buffers.get(partition); + // whole input page will go to this partition, compact the input page avoid over-retaining memory and to + // match the behavior of sub-partitioned pages that copy positions out + page.compact(); + pageSplit = page; } - Page pageSplit = page.copyPositions(positions, 0, partitionSize); - sendPageToPartition(buffers.get(partition), pageSplit); + else { + pageSplit = page.copyPositions(positions, 0, partitionSize); + } + memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); + buffers.get(partition).accept(pageSplit); } - // No single partition receives the entire input page - return null; - } - - // This is safe to call without synchronizing because the partition buffers are themselves threadsafe - private void sendPageToPartition(Consumer buffer, Page pageSplit) - { - memoryManager.updateMemoryUsage(pageSplit.getRetainedSizeInBytes()); - buffer.accept(pageSplit); } @Override From 1788829b30f141856979a27982a13fb13dddf19e Mon Sep 17 00:00:00 2001 From: Vikash Kumar Date: Mon, 12 Dec 2022 17:38:12 +0530 Subject: [PATCH 12/24] Fix table name in TestDropTableTask --- .../src/test/java/io/trino/execution/TestDropTableTask.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java index 8d22ffd8d491..1a13e6add44c 100644 --- a/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java +++ b/core/trino-main/src/test/java/io/trino/execution/TestDropTableTask.java @@ -36,7 +36,7 @@ public class TestDropTableTask @Test public void testDropExistingTable() { - QualifiedObjectName tableName = qualifiedObjectName("not_existing_table"); + QualifiedObjectName tableName = qualifiedObjectName("existing_table"); metadata.createTable(testSession, TEST_CATALOG_NAME, someTable(tableName), false); assertThat(metadata.getTableHandle(testSession, tableName)).isPresent(); From 3c1c1fa10234ed1f5e1496c094d2b16c6eacaa8d Mon Sep 17 00:00:00 2001 From: rigogsilva Date: Mon, 12 Dec 2022 10:08:41 -0600 Subject: [PATCH 13/24] Document examples for datetime functions --- docs/src/main/sphinx/functions/datetime.rst | 26 +++++++++++++++++---- 1 file changed, 22 insertions(+), 4 deletions(-) diff --git a/docs/src/main/sphinx/functions/datetime.rst b/docs/src/main/sphinx/functions/datetime.rst index 27eb678a83fc..7f16a2465141 100644 --- a/docs/src/main/sphinx/functions/datetime.rst +++ b/docs/src/main/sphinx/functions/datetime.rst @@ -228,7 +228,16 @@ The above examples use the timestamp ``2001-08-22 03:04:05.321`` as the input. .. function:: date_trunc(unit, x) -> [same as input] - Returns ``x`` truncated to ``unit``. + Returns ``x`` truncated to ``unit``:: + + SELECT date_trunc('day' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-20 00:00:00.000 + + SELECT date_trunc('month' , TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-10-01 00:00:00.000 + + SELECT date_trunc('year', TIMESTAMP '2022-10-20 05:10:00'); + -- 2022-01-01 00:00:00.000 .. _datetime-interval-functions: @@ -383,11 +392,17 @@ Specifier Description .. function:: date_format(timestamp, format) -> varchar - Formats ``timestamp`` as a string using ``format``. + Formats ``timestamp`` as a string using ``format``:: + + SELECT date_format(TIMESTAMP '2022-10-20 05:10:00', '%m-%d-%Y %H'); + -- 10-20-2022 05 .. function:: date_parse(string, format) -> timestamp(3) - Parses ``string`` into a timestamp using ``format``. + Parses ``string`` into a timestamp using ``format``:: + + SELECT date_parse('2022/10/20/05', '%Y/%m/%d/%H'); + -- 2022-10-20 05:00:00.000 Java date functions ------------------- @@ -437,7 +452,10 @@ field to be extracted. Most fields support all date and time types. .. function:: extract(field FROM x) -> bigint - Returns ``field`` from ``x``. + Returns ``field`` from ``x``:: + + SELECT extract(YEAR FROM TIMESTAMP '2022-10-20 05:10:00'); + -- 2022 .. note:: This SQL-standard function uses special syntax for specifying the arguments. From 58bd1598c3694bdb5dc490188288992215605e9b Mon Sep 17 00:00:00 2001 From: Jonas Irgens Kylling Date: Fri, 25 Nov 2022 20:50:34 +0100 Subject: [PATCH 14/24] Decode path as URI in Delta Lake connector --- .../plugin/deltalake/DeltaLakeMetadata.java | 2 +- .../deltalake/DeltaLakeSplitManager.java | 8 +- .../BaseDeltaLakeConnectorSmokeTest.java | 13 +++ .../TestDeltaLakeAdlsConnectorSmokeTest.java | 9 ++ .../test/resources/databricks/uri/README.md | 12 +++ .../uri/_delta_log/00000000000000000000.json | 3 + .../uri/_delta_log/00000000000000000001.json | 2 + .../uri/_delta_log/00000000000000000002.json | 2 + .../uri/_delta_log/00000000000000000003.json | 2 + ...4de3-916a-ee0d89139446.c000.snappy.parquet | Bin 0 -> 475 bytes ...436d-bb00-245452904a81.c000.snappy.parquet | Bin 0 -> 475 bytes ...4180-91f1-4fd762dbc279.c000.snappy.parquet | Bin 0 -> 475 bytes .../TestDeltaLakeSelectCompatibility.java | 84 ++++++++++++++++++ 13 files changed, 132 insertions(+), 5 deletions(-) create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet create mode 100644 plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet create mode 100644 testing/trino-product-tests/src/main/java/io/trino/tests/product/deltalake/TestDeltaLakeSelectCompatibility.java diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java index 081651d19be2..e8aec83e3400 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeMetadata.java @@ -1255,7 +1255,7 @@ private static void appendAddFileEntries(TransactionLogWriter transactionLogWrit transactionLogWriter.appendAddFileEntry( new AddFileEntry( - toUriFormat(info.getPath()), // Databricks and OSS Delta Lake both expect path to be url-encoded, even though the procotol specification doesn't mention that + toUriFormat(info.getPath()), // Paths are RFC 2396 URI encoded https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file partitionValues, info.getSize(), info.getCreationTime(), diff --git a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java index 7ceac4d7909f..ea296a05a8f2 100644 --- a/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java +++ b/plugin/trino-delta-lake/src/main/java/io/trino/plugin/deltalake/DeltaLakeSplitManager.java @@ -36,7 +36,7 @@ import javax.inject.Inject; -import java.net.URLDecoder; +import java.net.URI; import java.time.Instant; import java.util.List; import java.util.Map; @@ -58,7 +58,6 @@ import static io.trino.plugin.deltalake.DeltaLakeSessionProperties.getMaxSplitSize; import static io.trino.plugin.deltalake.transactionlog.DeltaLakeSchemaSupport.extractSchema; import static io.trino.plugin.deltalake.transactionlog.TransactionLogParser.deserializePartitionValue; -import static java.nio.charset.StandardCharsets.UTF_8; import static java.util.Locale.ENGLISH; import static java.util.Objects.requireNonNull; import static java.util.function.Function.identity; @@ -290,8 +289,9 @@ private List splitsForFile( private static String buildSplitPath(String tableLocation, AddFileEntry addAction) { - // paths are relative to the table location and URL encoded - String path = URLDecoder.decode(addAction.getPath(), UTF_8); + // paths are relative to the table location and are RFC 2396 URIs + // https://github.com/delta-io/delta/blob/master/PROTOCOL.md#add-file-and-remove-file + String path = URI.create(addAction.getPath()).getPath(); if (tableLocation.endsWith("/")) { return tableLocation + path; } diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java index 94290ae288cd..a7175dce7cf7 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/BaseDeltaLakeConnectorSmokeTest.java @@ -282,6 +282,19 @@ public void testCreatePartitionedTable() assertUpdate("DROP TABLE " + tableName); } + @Test + public void testPathUriDecoding() + { + String tableName = "test_uri_table_" + randomNameSuffix(); + registerTableFromResources(tableName, "databricks/uri", getQueryRunner()); + + assertQuery("SELECT * FROM " + tableName, "VALUES ('a=equal', 1), ('a:colon', 2), ('a+plus', 3)"); + String firstFilePath = (String) computeScalar("SELECT \"$path\" FROM " + tableName + " WHERE y = 1"); + assertQuery("SELECT * FROM " + tableName + " WHERE \"$path\" = '" + firstFilePath + "'", "VALUES ('a=equal', 1)"); + + assertUpdate("DROP TABLE " + tableName); + } + @Test public void testCreateTablePartitionValidation() { diff --git a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java index 2fe5b4aef91e..82358243fe27 100644 --- a/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java +++ b/plugin/trino-delta-lake/src/test/java/io/trino/plugin/deltalake/TestDeltaLakeAdlsConnectorSmokeTest.java @@ -50,6 +50,7 @@ import static java.util.Objects.requireNonNull; import static java.util.regex.Matcher.quoteReplacement; import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; public class TestDeltaLakeAdlsConnectorSmokeTest extends BaseDeltaLakeConnectorSmokeTest @@ -113,6 +114,14 @@ public void removeTestData() assertThat(azureContainerClient.listBlobsByHierarchy(bucketName + "/").stream()).hasSize(0); } + @Override + public void testPathUriDecoding() + { + // TODO https://github.com/trinodb/trino/issues/15376 AzureBlobFileSystem doesn't expect URI as the path argument + assertThatThrownBy(super::testPathUriDecoding) + .hasStackTraceContaining("The specified path does not exist"); + } + @Override protected void registerTableFromResources(String table, String resourcePath, QueryRunner queryRunner) { diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md new file mode 100644 index 000000000000..52a0fe3e758f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/README.md @@ -0,0 +1,12 @@ +Data generated using OSS DELTA 2.0.0: + +```sql +CREATE TABLE default.uri_test (part string, y long) +USING delta +PARTITIONED BY (part) +LOCATION '/home/username/trino/plugin/trino-delta-lake/src/test/resources/databricks/uri'; + +INSERT INTO default.uri_test VALUES ('a=equal', 1); +INSERT INTO default.uri_test VALUES ('a:colon', 2); +INSERT INTO default.uri_test VALUES ('a+plus', 3); +``` diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json new file mode 100644 index 000000000000..6ce3c0ab0934 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000000.json @@ -0,0 +1,3 @@ +{"protocol":{"minReaderVersion":1,"minWriterVersion":2}} +{"metaData":{"id":"ae596d2a-868c-4480-9dba-ecb1366eef15","format":{"provider":"parquet","options":{}},"schemaString":"{\"type\":\"struct\",\"fields\":[{\"name\":\"part\",\"type\":\"string\",\"nullable\":true,\"metadata\":{}},{\"name\":\"y\",\"type\":\"long\",\"nullable\":true,\"metadata\":{}}]}","partitionColumns":["part"],"configuration":{},"createdTime":1670672197310}} +{"commitInfo":{"timestamp":1670672197479,"operation":"CREATE TABLE","operationParameters":{"isManaged":"false","description":null,"partitionBy":"[\"part\"]","properties":"{}"},"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"ec6fd978-2b74-4482-a908-02528f96cff5"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json new file mode 100644 index 000000000000..2a4c5a4b763f --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000001.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet","partitionValues":{"part":"a=equal"},"size":475,"modificationTime":1670672201990,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":1},\"maxValues\":{\"y\":1},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202024,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":0,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"bd99e504-0c28-4661-af9c-854016b09a7e"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json new file mode 100644 index 000000000000..f939248ad264 --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000002.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a%253Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet","partitionValues":{"part":"a:colon"},"size":475,"modificationTime":1670672202850,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":2},\"maxValues\":{\"y\":2},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672202858,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":1,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"fd9e716f-3072-4ecf-ac2b-e52d3a66e3b8"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json new file mode 100644 index 000000000000..e0b9d17bd36e --- /dev/null +++ b/plugin/trino-delta-lake/src/test/resources/databricks/uri/_delta_log/00000000000000000003.json @@ -0,0 +1,2 @@ +{"add":{"path":"part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet","partitionValues":{"part":"a+plus"},"size":475,"modificationTime":1670672203670,"dataChange":true,"stats":"{\"numRecords\":1,\"minValues\":{\"y\":3},\"maxValues\":{\"y\":3},\"nullCount\":{\"y\":0}}"}} +{"commitInfo":{"timestamp":1670672203680,"operation":"WRITE","operationParameters":{"mode":"Append","partitionBy":"[]"},"readVersion":2,"isolationLevel":"Serializable","isBlindAppend":true,"operationMetrics":{"numFiles":"1","numOutputRows":"1","numOutputBytes":"475"},"engineInfo":"Apache-Spark/3.2.2 Delta-Lake/2.0.0","txnId":"3adbdcc8-1a1d-4457-b168-0292d9d34dac"}} diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Acolon/part-00000-ab613d48-052b-4de3-916a-ee0d89139446.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..88a74a3a51a244e6deb548e480d9a01fa543a08a GIT binary patch literal 475 zcmZXR-AV#M6vvN~%3eeeIKwXN#WIj6*x>H?O$6PA7a~F3M8vp`CatdSM<}siNn{*niS1Z%_0VsgKd^YLnK8r0QR1vC3Z(I0k@E{!3HJdshMUNO@&%2WDvhN;zdXCEL@lc zDLmnUS^d`$uL6POt5hc<3SY%sCu*z`W!`RhvR-Q5<8czDe!^>fGSk60=Dqt3U#R)6 zw8redD+m`;REsnS)F{9rPU#x|+sAvG?e<{{W`FN?K98I|p5qD^@tn#iodsj3-PW@1 qcj{d!J1y-uJzx40)$(Oy)YcOzJy&|3-)&Sa+s5aA%K;4WTYms`17yem literal 0 HcmV?d00001 diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a%3Dequal/part-00000-d6a0dd7d-8416-436d-bb00-245452904a81.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..3c3dab40dd86907a18c7f09975092443cf40634f GIT binary patch literal 475 zcmZXR%}T>S5XU!bEdfOkx=R8%gau0racQ#sMg(utLn(q@L}Z(8YcU^fK9o|5FX6NJ z3?6+VaVoVrxQAu_Gqdymu`{~5aS0+8vB~%6*T<1XFvv2|5jtsNgwTPe!M#1^DK-s# zuZoahlS0{|IvGNMKs(pR=|U#YkE#EpLP;pHBzB>(j3@Cc4^<{7GLy#bnq{knVttloN$zLNVbi$^CMoORuURC+ zZD|GBT~`oCGf+_;1acfe5nd^W{mo-a-RUUN q@Vkwk5Z$)&Tb?icscieAIqs;b5S}YM&+j$sre(qRzrl(|u+|?ev}G&+ literal 0 HcmV?d00001 diff --git a/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet b/plugin/trino-delta-lake/src/test/resources/databricks/uri/part=a+plus/part-00000-4aecc0d9-dfbf-4180-91f1-4fd762dbc279.c000.snappy.parquet new file mode 100644 index 0000000000000000000000000000000000000000..2291dfae5a98673b36f891324764e4ef997d6dbe GIT binary patch literal 475 zcmZXRO-lkn7{{L-P4*B$;0(L4hh-p9ut9fxCxUL_K_sY4M2zcb((bCeYat?C`Urlt zzDvg*!`wr27@q(0_Wv<6yt%7WpbmBE=hycx=TL&OKuv(N1_6Mak;8j`>~msU=22xK zB{svd!%RBB08m$NPm_g;U!N2IMTd$Kp!o}88Eas~c5J6)Iy7%LO(@uyq=XUHXM>V? zm2J%;2IZq|mPJD(MKb`lGiZ@L&}+agq-3%|5qYYXHIJu4Efg|{UmWqK<776EtP?3b z;el28*O4p(f#l1TPDMO_7okqoSf$Fk-|%$3(7eZ!G+y~Buky)E2jhhI9ydHv;l8xS z?4c`|M^jYMDhSjlz#>lRD*xNZdz expectedRows = ImmutableList.of( + row(1, "spark=equal"), + row(2, "spark+plus"), + row(3, "spark space"), + row(10, "trino=equal"), + row(20, "trino+plus"), + row(30, "trino space")); + + assertThat(onDelta().executeQuery("SELECT * FROM default." + tableName)) + .containsOnly(expectedRows); + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName)) + .containsOnly(expectedRows); + + String deltaFilePath = (String) onDelta().executeQuery("SELECT input_file_name() FROM default." + tableName + " WHERE a_number = 1").getOnlyValue(); + String trinoFilePath = (String) onTrino().executeQuery("SELECT \"$path\" FROM delta.default." + tableName + " WHERE a_number = 1").getOnlyValue(); + // File paths returned by the input_file_name function are URI encoded https://github.com/delta-io/delta/issues/1517 while the $path of Trino is not + assertNotEquals(deltaFilePath, trinoFilePath); + assertEquals(format("s3://%s%s", bucketName, URI.create(deltaFilePath).getPath()), trinoFilePath); + + assertThat(onTrino().executeQuery("SELECT * FROM delta.default." + tableName + " WHERE \"$path\" = '" + trinoFilePath + "'")) + .containsOnly(row(1, "spark=equal")); + } + finally { + onDelta().executeQuery("DROP TABLE default." + tableName); + } + } +} From 35da905ba489f618863a0213a08b1bfa378d9d05 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Tue, 13 Dec 2022 11:37:50 +0900 Subject: [PATCH 15/24] Remove unused method from AstBuilder --- .../main/java/io/trino/sql/parser/AstBuilder.java | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java index 0723b8fc4af7..35947ac4a3ea 100644 --- a/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java +++ b/core/trino-parser/src/main/java/io/trino/sql/parser/AstBuilder.java @@ -3745,18 +3745,4 @@ private static QueryPeriod.RangeType getRangeType(Token token) } throw new IllegalArgumentException("Unsupported query period range type: " + token.getText()); } - - private static Trim.Specification toTrimSpecification(String functionName) - { - requireNonNull(functionName, "functionName is null"); - switch (functionName) { - case "trim": - return Trim.Specification.BOTH; - case "ltrim": - return Trim.Specification.LEADING; - case "rtrim": - return Trim.Specification.TRAILING; - } - throw new IllegalArgumentException("Unsupported trim specification: " + functionName); - } } From ce7b57f952ac560ac07192819fc71d7fdeaf4f12 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Sat, 19 Nov 2022 06:53:59 +0900 Subject: [PATCH 16/24] Refactor BigQuery connector - Change ColumnHandle to BigQueryColumnHandle in BigQueryTableHandle - Extract buildColumnHandles in BigQueryClient --- .../io/trino/plugin/bigquery/BigQueryClient.java | 15 +++++++++------ .../trino/plugin/bigquery/BigQueryMetadata.java | 9 +++++---- .../io/trino/plugin/bigquery/BigQuerySplit.java | 13 ++++++------- .../plugin/bigquery/BigQuerySplitManager.java | 10 +++++----- .../plugin/bigquery/BigQueryTableHandle.java | 8 ++++---- .../java/io/trino/plugin/bigquery/ptf/Query.java | 4 +--- 6 files changed, 30 insertions(+), 29 deletions(-) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java index ef1dfef9faf3..d009b88f3505 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryClient.java @@ -40,6 +40,7 @@ import io.airlift.units.Duration; import io.trino.collect.cache.EvictableCacheBuilder; import io.trino.spi.TrinoException; +import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.TableNotFoundException; import java.util.Collections; @@ -348,19 +349,21 @@ private static String fullTableName(TableId remoteTableId) public List getColumns(BigQueryTableHandle tableHandle) { if (tableHandle.getProjectedColumns().isPresent()) { - return tableHandle.getProjectedColumns().get().stream() - .map(column -> (BigQueryColumnHandle) column) - .collect(toImmutableList()); + return tableHandle.getProjectedColumns().get(); } checkArgument(tableHandle.isNamedRelation(), "Cannot get columns for %s", tableHandle); TableInfo tableInfo = getTable(tableHandle.asPlainTable().getRemoteTableName().toTableId()) .orElseThrow(() -> new TableNotFoundException(tableHandle.asPlainTable().getSchemaTableName())); + return buildColumnHandles(tableInfo); + } + + public static List buildColumnHandles(TableInfo tableInfo) + { Schema schema = tableInfo.getDefinition().getSchema(); if (schema == null) { - throw new TableNotFoundException( - tableHandle.asPlainTable().getSchemaTableName(), - format("Table '%s' has no schema", tableHandle.asPlainTable().getSchemaTableName())); + SchemaTableName schemaTableName = new SchemaTableName(tableInfo.getTableId().getDataset(), tableInfo.getTableId().getTable()); + throw new TableNotFoundException(schemaTableName, format("Table '%s' has no schema", schemaTableName)); } return schema.getFields() .stream() diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 5708ad5e3632..0228112e25e0 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -326,7 +326,7 @@ public Map getColumnHandles(ConnectorSession session, Conn BigQueryTableHandle table = (BigQueryTableHandle) tableHandle; if (table.getProjectedColumns().isPresent()) { return table.getProjectedColumns().get().stream() - .collect(toImmutableMap(columnHandle -> ((BigQueryColumnHandle) columnHandle).getName(), identity())); + .collect(toImmutableMap(BigQueryColumnHandle::getName, identity())); } checkArgument(table.isNamedRelation(), "Cannot get columns for %s", tableHandle); @@ -567,11 +567,12 @@ public Optional> applyProjecti return Optional.empty(); } - ImmutableList.Builder projectedColumns = ImmutableList.builder(); + ImmutableList.Builder projectedColumns = ImmutableList.builder(); ImmutableList.Builder assignmentList = ImmutableList.builder(); assignments.forEach((name, column) -> { - projectedColumns.add(column); - assignmentList.add(new Assignment(name, column, ((BigQueryColumnHandle) column).getTrinoType())); + BigQueryColumnHandle columnHandle = (BigQueryColumnHandle) column; + projectedColumns.add(columnHandle); + assignmentList.add(new Assignment(name, column, columnHandle.getTrinoType())); }); bigQueryTableHandle = bigQueryTableHandle.withProjectedColumns(projectedColumns.build()); diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java index 6ba9daddae17..7b53e6971ce2 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplit.java @@ -17,7 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import io.trino.spi.HostAddress; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSplit; import org.openjdk.jol.info.ClassLayout; @@ -43,7 +42,7 @@ public class BigQuerySplit private final Mode mode; private final String streamName; private final String avroSchema; - private final List columns; + private final List columns; private final long emptyRowsToGenerate; private final Optional filter; private final OptionalInt dataSize; @@ -54,7 +53,7 @@ public BigQuerySplit( @JsonProperty("mode") Mode mode, @JsonProperty("streamName") String streamName, @JsonProperty("avroSchema") String avroSchema, - @JsonProperty("columns") List columns, + @JsonProperty("columns") List columns, @JsonProperty("emptyRowsToGenerate") long emptyRowsToGenerate, @JsonProperty("filter") Optional filter, @JsonProperty("dataSize") OptionalInt dataSize) @@ -68,12 +67,12 @@ public BigQuerySplit( this.dataSize = requireNonNull(dataSize, "dataSize is null"); } - static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) + static BigQuerySplit forStream(String streamName, String avroSchema, List columns, OptionalInt dataSize) { return new BigQuerySplit(STORAGE, streamName, avroSchema, columns, NO_ROWS_TO_GENERATE, Optional.empty(), dataSize); } - static BigQuerySplit forViewStream(List columns, Optional filter) + static BigQuerySplit forViewStream(List columns, Optional filter) { return new BigQuerySplit(QUERY, "", "", columns, NO_ROWS_TO_GENERATE, filter, OptionalInt.empty()); } @@ -102,7 +101,7 @@ public String getAvroSchema() } @JsonProperty - public List getColumns() + public List getColumns() { return columns; } @@ -149,7 +148,7 @@ public long getRetainedSizeInBytes() return INSTANCE_SIZE + estimatedSizeOf(streamName) + estimatedSizeOf(avroSchema) - + estimatedSizeOf(columns, column -> ((BigQueryColumnHandle) column).getRetainedSizeInBytes()); + + estimatedSizeOf(columns, BigQueryColumnHandle::getRetainedSizeInBytes); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java index dbd10de79003..f15ddc9d9854 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java @@ -103,7 +103,7 @@ public ConnectorSplitSource getSplits( Optional filter = BigQueryFilterQueryBuilder.buildFilter(tableConstraint); if (!bigQueryTableHandle.isNamedRelation()) { - List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); + List columns = bigQueryTableHandle.getProjectedColumns().orElse(ImmutableList.of()); return new FixedSplitSource(ImmutableList.of(BigQuerySplit.forViewStream(columns, filter))); } @@ -114,17 +114,17 @@ public ConnectorSplitSource getSplits( return new FixedSplitSource(splits); } - private static boolean emptyProjectionIsRequired(Optional> projectedColumns) + private static boolean emptyProjectionIsRequired(Optional> projectedColumns) { return projectedColumns.isPresent() && projectedColumns.get().isEmpty(); } - private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) + private List readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) { log.debug("readFromBigQuery(tableId=%s, projectedColumns=%s, actualParallelism=%s, filter=[%s])", remoteTableId, projectedColumns, actualParallelism, filter); - List columns = projectedColumns.orElse(ImmutableList.of()); + List columns = projectedColumns.orElse(ImmutableList.of()); List projectedColumnsNames = columns.stream() - .map(column -> ((BigQueryColumnHandle) column).getName()) + .map(BigQueryColumnHandle::getName) .collect(toImmutableList()); if (isWildcardTable(type, remoteTableId.getTable())) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java index cdd3f1988780..6618f6b4d14b 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryTableHandle.java @@ -39,13 +39,13 @@ public class BigQueryTableHandle { private final BigQueryRelationHandle relationHandle; private final TupleDomain constraint; - private final Optional> projectedColumns; + private final Optional> projectedColumns; @JsonCreator public BigQueryTableHandle( @JsonProperty("relationHandle") BigQueryRelationHandle relationHandle, @JsonProperty("constraint") TupleDomain constraint, - @JsonProperty("projectedColumns") Optional> projectedColumns) + @JsonProperty("projectedColumns") Optional> projectedColumns) { this.relationHandle = requireNonNull(relationHandle, "relationHandle is null"); this.constraint = requireNonNull(constraint, "constraint is null"); @@ -79,7 +79,7 @@ public TupleDomain getConstraint() } @JsonProperty - public Optional> getProjectedColumns() + public Optional> getProjectedColumns() { return projectedColumns; } @@ -145,7 +145,7 @@ BigQueryTableHandle withConstraint(TupleDomain newConstraint) return new BigQueryTableHandle(relationHandle, newConstraint, projectedColumns); } - public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) + public BigQueryTableHandle withProjectedColumns(List newProjectedColumns) { return new BigQueryTableHandle(relationHandle, constraint, Optional.of(newProjectedColumns)); } diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java index 6b9f8c03aaad..bfd8ef34e5aa 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/ptf/Query.java @@ -24,7 +24,6 @@ import io.trino.plugin.bigquery.BigQueryColumnHandle; import io.trino.plugin.bigquery.BigQueryQueryRelationHandle; import io.trino.plugin.bigquery.BigQueryTableHandle; -import io.trino.spi.connector.ColumnHandle; import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTransactionHandle; @@ -110,13 +109,12 @@ public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransact } columnsBuilder.add(toColumnHandle(field)); } - List columns = columnsBuilder.build(); Descriptor returnedType = new Descriptor(columnsBuilder.build().stream() .map(column -> new Field(column.getName(), Optional.of(column.getTrinoType()))) .collect(toList())); - QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columns.stream().map(column -> (ColumnHandle) column).collect(toList()))); + QueryHandle handle = new QueryHandle(tableHandle.withProjectedColumns(columnsBuilder.build())); return TableFunctionAnalysis.builder() .returnedType(returnedType) From a660133428f1668f24b01a9f3f7f13b52a4929f2 Mon Sep 17 00:00:00 2001 From: Yuya Ebihara Date: Sat, 19 Nov 2022 09:06:22 +0900 Subject: [PATCH 17/24] Fix projection pushdown when unsupported column exists in BigQuery --- .../plugin/bigquery/BigQueryMetadata.java | 29 ++++++++++--------- .../plugin/bigquery/BigQuerySplitManager.java | 5 +++- 2 files changed, 20 insertions(+), 14 deletions(-) diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 0228112e25e0..5a3e273abf38 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -33,6 +33,7 @@ import io.airlift.log.Logger; import io.airlift.slice.Slice; import io.trino.plugin.bigquery.BigQueryClient.RemoteDatabaseObject; +import io.trino.plugin.bigquery.BigQueryTableHandle.BigQueryPartitionType; import io.trino.plugin.bigquery.ptf.Query.QueryHandle; import io.trino.spi.TrinoException; import io.trino.spi.connector.Assignment; @@ -87,6 +88,7 @@ import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static io.trino.plugin.bigquery.BigQueryClient.buildColumnHandles; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_LISTING_DATASET_ERROR; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_UNSUPPORTED_OPERATION; import static io.trino.plugin.bigquery.BigQueryPseudoColumn.PARTITION_DATE; @@ -230,12 +232,20 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return null; } + ImmutableList.Builder columns = ImmutableList.builder(); + columns.addAll(buildColumnHandles(tableInfo.get())); + Optional partitionType = getPartitionType(tableInfo.get().getDefinition()); + if (partitionType.isPresent() && partitionType.get() == INGESTION) { + columns.add(PARTITION_DATE.getColumnHandle()); + columns.add(PARTITION_TIME.getColumnHandle()); + } return new BigQueryTableHandle(new BigQueryNamedRelationHandle( schemaTableName, new RemoteTableName(tableInfo.get().getTableId()), tableInfo.get().getDefinition().getType().toString(), - getPartitionType(tableInfo.get().getDefinition()), - Optional.ofNullable(tableInfo.get().getDescription()))); + partitionType, + Optional.ofNullable(tableInfo.get().getDescription()))) + .withProjectedColumns(columns.build()); } private ConnectorTableHandle getTableHandleIgnoringConflicts(ConnectorSession session, SchemaTableName schemaTableName) @@ -269,17 +279,10 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, Connect log.debug("getTableMetadata(session=%s, tableHandle=%s)", session, tableHandle); BigQueryTableHandle handle = ((BigQueryTableHandle) tableHandle); - ImmutableList.Builder columnMetadata = ImmutableList.builder(); - for (BigQueryColumnHandle column : client.getColumns(handle)) { - columnMetadata.add(column.getColumnMetadata()); - } - if (handle.isNamedRelation()) { - if (handle.asPlainTable().getPartitionType().isPresent() && handle.asPlainTable().getPartitionType().get() == INGESTION) { - columnMetadata.add(PARTITION_DATE.getColumnMetadata()); - columnMetadata.add(PARTITION_TIME.getColumnMetadata()); - } - } - return new ConnectorTableMetadata(getSchemaTableName(handle), columnMetadata.build(), ImmutableMap.of(), getTableComment(handle)); + List columns = client.getColumns(handle).stream() + .map(BigQueryColumnHandle::getColumnMetadata) + .collect(toImmutableList()); + return new ConnectorTableMetadata(getSchemaTableName(handle), columns, ImmutableMap.of(), getTableComment(handle)); } @Override diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java index f15ddc9d9854..9b78d0727cf1 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQuerySplitManager.java @@ -47,6 +47,7 @@ import static com.google.cloud.bigquery.TableDefinition.Type.MATERIALIZED_VIEW; import static com.google.cloud.bigquery.TableDefinition.Type.TABLE; import static com.google.cloud.bigquery.TableDefinition.Type.VIEW; +import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.plugin.bigquery.BigQueryErrorCode.BIGQUERY_FAILED_TO_EXECUTE_QUERY; import static io.trino.plugin.bigquery.BigQuerySessionProperties.createDisposition; @@ -121,8 +122,10 @@ private static boolean emptyProjectionIsRequired(Optional readFromBigQuery(ConnectorSession session, TableDefinition.Type type, TableId remoteTableId, Optional> projectedColumns, int actualParallelism, Optional filter) { + checkArgument(projectedColumns.isPresent() && projectedColumns.get().size() > 0, "Projected column is empty"); + log.debug("readFromBigQuery(tableId=%s, projectedColumns=%s, actualParallelism=%s, filter=[%s])", remoteTableId, projectedColumns, actualParallelism, filter); - List columns = projectedColumns.orElse(ImmutableList.of()); + List columns = projectedColumns.get(); List projectedColumnsNames = columns.stream() .map(BigQueryColumnHandle::getName) .collect(toImmutableList()); From cd8eac64e3e427be45fc5df5b5a61110245479ec Mon Sep 17 00:00:00 2001 From: Fokko Driesprong Date: Thu, 17 Nov 2022 13:21:55 +0100 Subject: [PATCH 18/24] Update Iceberg to 1.1.0 --- plugin/trino-iceberg/pom.xml | 2 + .../plugin/iceberg/IcebergAvroPageSource.java | 2 + .../plugin/iceberg/IcebergSplitSource.java | 2 + .../TestIcebergMetadataFileOperations.java | 76 +++++++++---------- .../trino/plugin/iceberg/TestIcebergV2.java | 7 +- pom.xml | 2 +- testing/trino-faulttolerant-tests/pom.xml | 2 + testing/trino-tests/pom.xml | 2 + 8 files changed, 54 insertions(+), 41 deletions(-) diff --git a/plugin/trino-iceberg/pom.xml b/plugin/trino-iceberg/pom.xml index 825187bc5c81..48db13121363 100644 --- a/plugin/trino-iceberg/pom.xml +++ b/plugin/trino-iceberg/pom.xml @@ -406,6 +406,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java index 3add0c0ceef6..32f721035158 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergAvroPageSource.java @@ -91,6 +91,8 @@ public IcebergAvroPageSource( .collect(toImmutableMap(Types.NestedField::name, Types.NestedField::type)); pageBuilder = new PageBuilder(columnTypes); recordIterator = avroReader.iterator(); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } private boolean isIndexColumn(int column) diff --git a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java index 028278d43152..5302bf325e67 100644 --- a/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java +++ b/plugin/trino-iceberg/src/main/java/io/trino/plugin/iceberg/IcebergSplitSource.java @@ -194,6 +194,8 @@ public CompletableFuture getNextBatch(int maxSize) closer.register(fileScanTaskIterable); this.fileScanTaskIterator = fileScanTaskIterable.iterator(); closer.register(fileScanTaskIterator); + // TODO: Remove when NPE check has been released: https://github.com/trinodb/trino/issues/15372 + isFinished(); } TupleDomain dynamicFilterPredicate = dynamicFilter.getCurrentPredicate() diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java index baa84cde6423..ed5d84fe5cdd 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergMetadataFileOperations.java @@ -133,8 +133,8 @@ public void testSelect() assertUpdate("CREATE TABLE test_select AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -160,26 +160,26 @@ public void testSelectFromVersionedTable() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -203,26 +203,26 @@ public void testSelectFromVersionedTableWithSchemaEvolution() assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v2SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName + " FOR VERSION AS OF " + v3SnapshotId, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); assertFileSystemAccesses("SELECT * FROM " + tableName, ImmutableMultiset.builder() .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .build()); } @@ -232,8 +232,8 @@ public void testSelectWithFilter() assertUpdate("CREATE TABLE test_select_with_filter AS SELECT 1 col_name", 1); assertFileSystemAccesses("SELECT * FROM test_select_with_filter WHERE col_name = 1", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -248,8 +248,8 @@ public void testJoin() assertFileSystemAccesses("SELECT name, age FROM test_join_t1 JOIN test_join_t2 ON test_join_t2.id = test_join_t1.id", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -266,8 +266,8 @@ public void testJoinWithPartitionedTable() assertFileSystemAccesses("SELECT count(*) FROM test_join_partitioned_t1 t1 join test_join_partitioned_t2 t2 on t1.a = t2.foo", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 8) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 8) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 2) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 2) @@ -281,8 +281,8 @@ public void testExplainSelect() assertFileSystemAccesses("EXPLAIN SELECT * FROM test_explain", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -296,8 +296,8 @@ public void testShowStatsForTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -313,8 +313,8 @@ public void testShowStatsForPartitionedTable() assertFileSystemAccesses("SHOW STATS FOR test_show_stats_partitioned", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -328,8 +328,8 @@ public void testShowStatsForTableWithFilter() assertFileSystemAccesses("SHOW STATS FOR (SELECT * FROM test_show_stats_with_filter WHERE age >= 2)", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -345,8 +345,8 @@ public void testPredicateWithVarcharCastToDate() assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 4) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -355,8 +355,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and comparison assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -365,8 +365,8 @@ public void testPredicateWithVarcharCastToDate() // CAST to date and BETWEEN assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE CAST(a AS date) BETWEEN DATE '2005-01-01' AND DATE '2005-12-31'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -375,8 +375,8 @@ public void testPredicateWithVarcharCastToDate() // conversion to date as a date function assertFileSystemAccesses("SELECT * FROM test_varchar_as_date_predicate WHERE date(a) >= DATE '2005-01-01'", ImmutableMultiset.builder() - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 2) // fewer than without filter - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 2) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 1) // fewer than without filter + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 1) // fewer than without filter .addCopies(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 1) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 1) @@ -404,8 +404,8 @@ public void testRemoveOrphanFiles() .add(new FileOperation(METADATA_JSON, INPUT_FILE_NEW_STREAM)) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_GET_LENGTH), 4) .addCopies(new FileOperation(SNAPSHOT, INPUT_FILE_NEW_STREAM), 4) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 6) - .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 6) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_GET_LENGTH), 5) + .addCopies(new FileOperation(MANIFEST, INPUT_FILE_NEW_STREAM), 5) .build()); assertUpdate("DROP TABLE " + tableName); diff --git a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java index 64f74e0053d0..074ccd439107 100644 --- a/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java +++ b/plugin/trino-iceberg/src/test/java/io/trino/plugin/iceberg/TestIcebergV2.java @@ -51,6 +51,7 @@ import org.apache.iceberg.data.Record; import org.apache.iceberg.data.parquet.GenericParquetWriter; import org.apache.iceberg.deletes.EqualityDeleteWriter; +import org.apache.iceberg.deletes.PositionDelete; import org.apache.iceberg.deletes.PositionDeleteWriter; import org.apache.iceberg.hadoop.HadoopOutputFile; import org.apache.iceberg.parquet.Parquet; @@ -171,8 +172,10 @@ public void testV2TableWithPositionDelete() .withSpec(PartitionSpec.unpartitioned()) .buildPositionWriter(); + PositionDelete positionDelete = PositionDelete.create(); + PositionDelete record = positionDelete.set(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); try (Closeable ignored = writer) { - writer.delete(dataFilePath, 0, GenericRecord.create(icebergTable.schema())); + writer.write(record); } icebergTable.newRowDelta().addDeletes(writer.toDeleteFile()).commit(); @@ -521,7 +524,7 @@ private void writeEqualityDeleteToNationTable(Table icebergTable, Optional5.5.2 4.14.0 7.1.4 - 1.0.0 + 1.1.0 4.7.2 3.21.6 3.2.2 diff --git a/testing/trino-faulttolerant-tests/pom.xml b/testing/trino-faulttolerant-tests/pom.xml index fcba727b5010..95e8c9167785 100644 --- a/testing/trino-faulttolerant-tests/pom.xml +++ b/testing/trino-faulttolerant-tests/pom.xml @@ -465,6 +465,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ diff --git a/testing/trino-tests/pom.xml b/testing/trino-tests/pom.xml index 02d6ab48ffaf..5eef67d95db0 100644 --- a/testing/trino-tests/pom.xml +++ b/testing/trino-tests/pom.xml @@ -391,6 +391,8 @@ about.html iceberg-build.properties mozilla/public-suffix-list.txt + + google/protobuf/.*\.proto$ From b9d26a7da490d580390b4ac5cbeee62280232c3e Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Mon, 5 Jul 2021 12:55:22 +0200 Subject: [PATCH 19/24] Document Top-N pushdown --- docs/src/main/sphinx/optimizer/pushdown.rst | 87 +++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/docs/src/main/sphinx/optimizer/pushdown.rst b/docs/src/main/sphinx/optimizer/pushdown.rst index 447281295a3d..4d40b7dca267 100644 --- a/docs/src/main/sphinx/optimizer/pushdown.rst +++ b/docs/src/main/sphinx/optimizer/pushdown.rst @@ -271,3 +271,90 @@ FETCH FIRST N ROWS``. Implementation and support is connector-specific since different data sources support different SQL syntax and processing. + +For example, you can find two queries to learn how to identify Top-N pushdown behavior in the following section. + +First, a concrete example of a Top-N pushdown query on top of a PostgreSQL database:: + + SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +You can get the explain plan by prepending the above query with ``EXPLAIN``:: + + EXPLAIN SELECT id, name + FROM postgresql.public.company + ORDER BY id + LIMIT 5; + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[id, name] + │ Layout: [id:integer, name:varchar] + │ Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: ?} + └─ RemoteSource[1] + Layout: [id:integer, name:varchar] + + Fragment 1 [SOURCE] + Output layout: [id, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TableScan[postgresql:public.company public.company sortOrder=[id:integer:int4 ASC NULLS LAST] limit=5, grouped = false] + Layout: [id:integer, name:varchar] + Estimates: {rows: ? (?), cpu: ?, memory: 0B, network: 0B} + name := name:varchar:text + id := id:integer:int4 + +Second, an example of a Top-N query on the ``tpch`` connector which does not support +Top-N pushdown functionality:: + + SELECT custkey, name + FROM tpch.sf1.customer + ORDER BY custkey + LIMIT 5; + +The related query plan: + +.. code-block:: text + + Fragment 0 [SINGLE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + Output[custkey, name] + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ TopN[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ LocalExchange[SINGLE] () + │ Layout: [custkey:bigint, name:varchar(25)] + │ Estimates: {rows: ? (?), cpu: ?, memory: ?, network: ?} + └─ RemoteSource[1] + Layout: [custkey:bigint, name:varchar(25)] + + Fragment 1 [SOURCE] + Output layout: [custkey, name] + Output partitioning: SINGLE [] + Stage Execution Strategy: UNGROUPED_EXECUTION + TopNPartial[5 by (custkey ASC NULLS LAST)] + │ Layout: [custkey:bigint, name:varchar(25)] + └─ TableScan[tpch:customer:sf1.0, grouped = false] + Layout: [custkey:bigint, name:varchar(25)] + Estimates: {rows: 150000 (4.58MB), cpu: 4.58M, memory: 0B, network: 0B} + custkey := tpch:custkey + name := tpch:name + +In the preceding query plan, the Top-N operation ``TopN[5 by (custkey ASC NULLS LAST)]`` +is being applied in the ``Fragment 0`` by Trino and not by the source database. + +Note that, compared to the query executed on top of the ``tpch`` connector, +the explain plan of the query applied on top of the ``postgresql`` connector +is missing the reference to the operation ``TopN[5 by (id ASC NULLS LAST)]`` +in the ``Fragment 0``. +The absence of the ``TopN`` Trino operator in the ``Fragment 0`` from the query plan +demonstrates that the query benefits of the Top-N pushdown optimization. From f5b5fb8a4a7756566ac5d16aeb8f2b7120dd411a Mon Sep 17 00:00:00 2001 From: Marius Grama Date: Wed, 14 Dec 2022 06:24:30 +0100 Subject: [PATCH 20/24] Remove unnecessary override for `getTableProperties` method --- .../io/trino/connector/system/SystemTablesMetadata.java | 7 ------- .../src/main/java/io/trino/testing/TestingMetadata.java | 7 ------- .../java/io/trino/plugin/accumulo/AccumuloMetadata.java | 7 ------- .../src/main/java/io/trino/plugin/atop/AtopMetadata.java | 7 ------- .../java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java | 7 ------- .../java/io/trino/plugin/bigquery/BigQueryMetadata.java | 8 -------- .../java/io/trino/plugin/blackhole/BlackHoleMetadata.java | 7 ------- .../java/io/trino/plugin/cassandra/CassandraMetadata.java | 7 ------- .../java/io/trino/plugin/example/ExampleMetadata.java | 7 ------- .../io/trino/plugin/google/sheets/SheetsMetadata.java | 7 ------- .../src/main/java/io/trino/plugin/jmx/JmxMetadata.java | 7 ------- .../main/java/io/trino/plugin/kafka/KafkaMetadata.java | 7 ------- .../java/io/trino/plugin/kinesis/KinesisMetadata.java | 7 ------- .../java/io/trino/plugin/localfile/LocalFileMetadata.java | 7 ------- .../main/java/io/trino/plugin/memory/MemoryMetadata.java | 7 ------- .../main/java/io/trino/plugin/pinot/PinotMetadata.java | 7 ------- .../io/trino/plugin/prometheus/PrometheusMetadata.java | 7 ------- .../main/java/io/trino/plugin/redis/RedisMetadata.java | 7 ------- .../main/java/io/trino/plugin/thrift/ThriftMetadata.java | 7 ------- .../main/java/io/trino/plugin/tpcds/TpcdsMetadata.java | 7 ------- 20 files changed, 141 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java index 21b75218f657..98228ac46f32 100644 --- a/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java +++ b/core/trino-main/src/main/java/io/trino/connector/system/SystemTablesMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -139,12 +138,6 @@ public Map> listTableColumns(ConnectorSess return builder.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java index 4a1d42115cec..587c6203bf2b 100644 --- a/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java +++ b/core/trino-main/src/main/java/io/trino/testing/TestingMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.MaterializedViewFreshness; import io.trino.spi.connector.MaterializedViewNotFoundException; @@ -331,12 +330,6 @@ public void grantTablePrivileges(ConnectorSession session, SchemaTableName table @Override public void revokeTablePrivileges(ConnectorSession session, SchemaTableName tableName, Set privileges, TrinoPrincipal grantee, boolean grantOption) {} - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public void clear() { views.clear(); diff --git a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java index 15e91744961d..003f8a0c755c 100644 --- a/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java +++ b/plugin/trino-accumulo/src/main/java/io/trino/plugin/accumulo/AccumuloMetadata.java @@ -34,7 +34,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -389,12 +388,6 @@ public Optional> applyFilter(C return Optional.of(new ConstraintApplicationResult<>(handle, constraint.getSummary(), false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle handle) - { - return new ConnectorTableProperties(); - } - private void checkNoRollback() { checkState(rollbackAction.get() == null, "Cannot begin a new write while in an existing one"); diff --git a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java index 9ea2908f7b05..1a91bcc162c1 100644 --- a/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java +++ b/plugin/trino-atop/src/main/java/io/trino/plugin/atop/AtopMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -149,12 +148,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable throw new ColumnNotFoundException(tableName, columnName); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java index 677edecb4cea..7f7cd97ccc3e 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/DefaultJdbcMetadata.java @@ -33,7 +33,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; @@ -625,12 +624,6 @@ public Optional applyTableScanRedirect(Conne return jdbcClient.getTableScanRedirection(session, tableHandle); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableSchema getTableSchema(ConnectorSession session, ConnectorTableHandle table) { diff --git a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java index 5a3e273abf38..6d9251c49751 100644 --- a/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java +++ b/plugin/trino-bigquery/src/main/java/io/trino/plugin/bigquery/BigQueryMetadata.java @@ -48,7 +48,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorTableSchema; import io.trino.spi.connector.ConnectorTransactionHandle; import io.trino.spi.connector.Constraint; @@ -374,13 +373,6 @@ public Map> listTableColumns(ConnectorSess return columns.buildOrThrow(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - log.debug("getTableProperties(session=%s, prefix=%s)", session, table); - return new ConnectorTableProperties(); - } - @Override public void createSchema(ConnectorSession session, String schemaName, Map properties, TrinoPrincipal owner) { diff --git a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java index 2c63744aab9a..6d611dac602a 100644 --- a/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java +++ b/plugin/trino-blackhole/src/main/java/io/trino/plugin/blackhole/BlackHoleMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.RetryMode; import io.trino.spi.connector.RowChangeParadigm; @@ -387,12 +386,6 @@ public Optional getView(ConnectorSession session, Schem return Optional.ofNullable(views.get(viewName)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - private void checkSchemaExists(String schemaName) { if (!schemas.contains(schemaName)) { diff --git a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java index 9b40fd8fbced..913040043100 100644 --- a/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java +++ b/plugin/trino-cassandra/src/main/java/io/trino/plugin/cassandra/CassandraMetadata.java @@ -30,7 +30,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.NotFoundException; @@ -251,12 +250,6 @@ public Optional> applyFilter(C false)); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public void createTable(ConnectorSession session, ConnectorTableMetadata tableMetadata, boolean ignoreExisting) { diff --git a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java index f1d705957d94..b964390904d2 100644 --- a/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java +++ b/plugin/trino-example-http/src/main/java/io/trino/plugin/example/ExampleMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -155,10 +154,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((ExampleColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java index efac03397d77..763070d3a053 100644 --- a/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java +++ b/plugin/trino-google-sheets/src/main/java/io/trino/plugin/google/sheets/SheetsMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -146,10 +145,4 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable { return ((SheetsColumnHandle) columnHandle).getColumnMetadata(); } - - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } } diff --git a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java index 9029216786c1..3c2c3950ec38 100644 --- a/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java +++ b/plugin/trino-jmx/src/main/java/io/trino/plugin/jmx/JmxMetadata.java @@ -27,7 +27,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -102,12 +101,6 @@ public JmxTableHandle getTableHandle(ConnectorSession session, SchemaTableName t return getTableHandle(tableName); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public JmxTableHandle getTableHandle(SchemaTableName tableName) { requireNonNull(tableName, "tableName is null"); diff --git a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java index f0a6a91e3344..dc2a5834c02d 100644 --- a/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java +++ b/plugin/trino-kafka/src/main/java/io/trino/plugin/kafka/KafkaMetadata.java @@ -27,7 +27,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.RetryMode; @@ -231,12 +230,6 @@ private ConnectorTableMetadata getTableMetadata(ConnectorSession session, Schema return new ConnectorTableMetadata(schemaTableName, builder.build()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java index f4bdef956de0..6f29ad35d03d 100644 --- a/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java +++ b/plugin/trino-kinesis/src/main/java/io/trino/plugin/kinesis/KinesisMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.connector.TableNotFoundException; @@ -85,12 +84,6 @@ public ConnectorTableMetadata getTableMetadata(ConnectorSession connectorSession return getTableMetadata(((KinesisTableHandle) tableHandle).toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public List listTables(ConnectorSession session, Optional schemaName) { diff --git a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java index 7d757e3eb2c2..032fbeae7621 100644 --- a/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java +++ b/plugin/trino-local-file/src/main/java/io/trino/plugin/localfile/LocalFileMetadata.java @@ -21,7 +21,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -133,12 +132,6 @@ private List listTables(ConnectorSession session, SchemaTablePr return ImmutableList.of(prefix.toSchemaTableName()); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java index ab19f25f75f5..0e167cac5494 100644 --- a/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java +++ b/plugin/trino-memory/src/main/java/io/trino/plugin/memory/MemoryMetadata.java @@ -32,7 +32,6 @@ import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableLayout; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.ConnectorViewDefinition; import io.trino.spi.connector.LimitApplicationResult; import io.trino.spi.connector.RetryMode; @@ -421,12 +420,6 @@ private void updateRowsOnHosts(long tableId, Collection fragments) tables.put(tableId, new TableInfo(tableId, info.getSchemaName(), info.getTableName(), info.getColumns(), dataFragments, info.getComment())); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - public List getDataFragments(long tableId) { return ImmutableList.copyOf(tables.get(tableId).getDataFragments().values()); diff --git a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java index 0e4a574aeb87..6710fbfca8e1 100755 --- a/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java +++ b/plugin/trino-pinot/src/main/java/io/trino/plugin/pinot/PinotMetadata.java @@ -43,7 +43,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.LimitApplicationResult; @@ -234,12 +233,6 @@ public Optional getInfo(ConnectorTableHandle table) return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyLimit(ConnectorSession session, ConnectorTableHandle table, long limit) { diff --git a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java index 46cafc5c5d8c..4e5e79f4703e 100644 --- a/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java +++ b/plugin/trino-prometheus/src/main/java/io/trino/plugin/prometheus/PrometheusMetadata.java @@ -22,7 +22,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -155,12 +154,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((PrometheusColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle handle, Constraint constraint) { diff --git a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java index cff52758efa4..e2505e399d5e 100644 --- a/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java +++ b/plugin/trino-redis/src/main/java/io/trino/plugin/redis/RedisMetadata.java @@ -25,7 +25,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.SchemaTableName; @@ -276,12 +275,6 @@ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTable return ((RedisColumnHandle) columnHandle).getColumnMetadata(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle tableHandle) - { - return new ConnectorTableProperties(); - } - @VisibleForTesting Map getDefinedTables() { diff --git a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java index 56f96a5e9cd5..1ab22188272e 100644 --- a/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java +++ b/plugin/trino-thrift/src/main/java/io/trino/plugin/thrift/ThriftMetadata.java @@ -36,7 +36,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.Constraint; import io.trino.spi.connector.ConstraintApplicationResult; import io.trino.spi.connector.ProjectionApplicationResult; @@ -164,12 +163,6 @@ public Optional resolveIndex(ConnectorSession session, C return Optional.empty(); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public Optional> applyFilter(ConnectorSession session, ConnectorTableHandle table, Constraint constraint) { diff --git a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java index a917bad9b8ff..51c5fb6ca083 100644 --- a/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java +++ b/plugin/trino-tpcds/src/main/java/io/trino/plugin/tpcds/TpcdsMetadata.java @@ -23,7 +23,6 @@ import io.trino.spi.connector.ConnectorSession; import io.trino.spi.connector.ConnectorTableHandle; import io.trino.spi.connector.ConnectorTableMetadata; -import io.trino.spi.connector.ConnectorTableProperties; import io.trino.spi.connector.SchemaTableName; import io.trino.spi.connector.SchemaTablePrefix; import io.trino.spi.statistics.TableStatistics; @@ -98,12 +97,6 @@ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTable return new TpcdsTableHandle(tableName.getTableName(), scaleFactor); } - @Override - public ConnectorTableProperties getTableProperties(ConnectorSession session, ConnectorTableHandle table) - { - return new ConnectorTableProperties(); - } - @Override public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle tableHandle) { From ec8a8fdecab99fc0a9f83756cd88f095689f9b89 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Thu, 8 Dec 2022 18:53:34 +0100 Subject: [PATCH 21/24] Fix formatting --- .../src/main/java/io/trino/sql/planner/RelationPlanner.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index f66a8b3f96cf..e2c31555153a 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -347,7 +347,7 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node outputSymbols.addAll(properOutputs); - // process sources in order of argument declarations + // process sources in order of argument declarations for (TableArgumentAnalysis tableArgument : functionAnalysis.getTableArgumentAnalyses()) { RelationPlan sourcePlan = process(tableArgument.getRelation(), context); PlanBuilder sourcePlanBuilder = newPlanBuilder(sourcePlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); From 70264322eed19807a6c12f8b1452e2711a22ce77 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Wed, 30 Nov 2022 10:59:54 +0100 Subject: [PATCH 22/24] Add requiredColumns field to TableFunctionAnalysis The new field allows the table function to declare during analysis which columns from the input tables are necessary to execute the function. The required columns can be then validated by the analyzer. This declaration can be also used by the optimizer to prune any input columns that are not used by the table function. --- .../trino/spi/ptf/ConnectorTableFunction.java | 2 +- .../trino/spi/ptf/TableFunctionAnalysis.java | 28 +++++++++++++++++-- .../main/sphinx/develop/table-functions.rst | 5 ++-- 3 files changed, 30 insertions(+), 5 deletions(-) diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java index dddd662f86ef..a5a0d5af1946 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/ConnectorTableFunction.java @@ -34,7 +34,7 @@ public interface ConnectorTableFunction /** * This method is called by the Analyzer. Its main purposes are to: * 1. Determine the resulting relation type of the Table Function in case when the declared return type is GENERIC_TABLE. - * 2. Declare the dependencies between the input descriptors and the input tables. + * 2. Declare the required columns from the input tables. * 3. Perform function-specific validation and pre-processing of the input arguments. * As part of function-specific validation, the Table Function's author might want to: * - check if the descriptors which reference input tables contain a correct number of column references diff --git a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java index a415c54d61f6..7c6709d70b08 100644 --- a/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java +++ b/core/trino-spi/src/main/java/io/trino/spi/ptf/TableFunctionAnalysis.java @@ -15,10 +15,14 @@ import io.trino.spi.Experimental; +import java.util.HashMap; +import java.util.List; +import java.util.Map; import java.util.Optional; import static io.trino.spi.ptf.Preconditions.checkArgument; import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.toMap; /** * An object of this class is produced by the `analyze()` method of a `ConnectorTableFunction` @@ -28,6 +32,9 @@ * Function, that is, the columns produced by the function, as opposed to the columns passed from the * input tables. The `returnedType` should only be set if the declared returned type is GENERIC_TABLE. *

+ * The `requiredColumns` field is used to inform the Analyzer of the columns from the table arguments + * that are necessary to execute the table function. + *

* The `handle` field can be used to carry all information necessary to execute the table function, * gathered at analysis time. Typically, these are the values of the constant arguments, and results * of pre-processing arguments. @@ -36,12 +43,17 @@ public final class TableFunctionAnalysis { private final Optional returnedType; + + // a map from table argument name to list of column indexes for all columns required from the table argument + private final Map> requiredColumns; private final ConnectorTableFunctionHandle handle; - private TableFunctionAnalysis(Optional returnedType, ConnectorTableFunctionHandle handle) + private TableFunctionAnalysis(Optional returnedType, Map> requiredColumns, ConnectorTableFunctionHandle handle) { this.returnedType = requireNonNull(returnedType, "returnedType is null"); returnedType.ifPresent(descriptor -> checkArgument(descriptor.isTyped(), "field types not specified")); + this.requiredColumns = Map.copyOf(requiredColumns.entrySet().stream() + .collect(toMap(Map.Entry::getKey, entry -> List.copyOf(entry.getValue())))); this.handle = requireNonNull(handle, "handle is null"); } @@ -50,6 +62,11 @@ public Optional getReturnedType() return returnedType; } + public Map> getRequiredColumns() + { + return requiredColumns; + } + public ConnectorTableFunctionHandle getHandle() { return handle; @@ -63,6 +80,7 @@ public static Builder builder() public static final class Builder { private Descriptor returnedType; + private final Map> requiredColumns = new HashMap<>(); private ConnectorTableFunctionHandle handle = new ConnectorTableFunctionHandle() {}; private Builder() {} @@ -73,6 +91,12 @@ public Builder returnedType(Descriptor returnedType) return this; } + public Builder requiredColumns(String tableArgument, List columns) + { + this.requiredColumns.put(tableArgument, columns); + return this; + } + public Builder handle(ConnectorTableFunctionHandle handle) { this.handle = handle; @@ -81,7 +105,7 @@ public Builder handle(ConnectorTableFunctionHandle handle) public TableFunctionAnalysis build() { - return new TableFunctionAnalysis(Optional.ofNullable(returnedType), handle); + return new TableFunctionAnalysis(Optional.ofNullable(returnedType), requiredColumns, handle); } } } diff --git a/docs/src/main/sphinx/develop/table-functions.rst b/docs/src/main/sphinx/develop/table-functions.rst index 1fe1eaf9b13a..817e0aa3aaeb 100644 --- a/docs/src/main/sphinx/develop/table-functions.rst +++ b/docs/src/main/sphinx/develop/table-functions.rst @@ -1,3 +1,4 @@ + =============== Table functions =============== @@ -136,8 +137,8 @@ execute the table function invocation: - The returned row type, specified as an optional ``Descriptor``. It should be passed if and only if the table function is declared with the ``GENERIC_TABLE`` returned type. -- Dependencies between descriptor arguments and table arguments. It defaults to - ``EMPTY_MAPPING``. +- Required columns from the table arguments, specified as a map of table + argument names to lists of column indexes. - Any information gathered during analysis that is useful during planning or execution, in the form of a ``ConnectorTableFunctionHandle``. ``ConnectorTableFunctionHandle`` is a marker interface intended to carry From c3ee15fffaa2e1d53cd20d02a6e934faa7b728c7 Mon Sep 17 00:00:00 2001 From: kasiafi <30203062+kasiafi@users.noreply.github.com> Date: Wed, 30 Nov 2022 13:03:21 +0100 Subject: [PATCH 23/24] Analyze table function's required input columns --- .../java/io/trino/sql/analyzer/Analysis.java | 9 +++ .../trino/sql/analyzer/StatementAnalyzer.java | 33 ++++++++++- .../io/trino/sql/planner/RelationPlanner.java | 18 ++---- .../UnaliasSymbolReferences.java | 6 +- .../sql/planner/plan/TableFunctionNode.java | 20 +++---- .../sql/planner/planprinter/PlanPrinter.java | 3 + .../sanity/ValidateDependenciesChecker.java | 6 +- .../connector/TestingTableFunctions.java | 58 +++++++++++++++++-- .../io/trino/sql/analyzer/TestAnalyzer.java | 33 ++++++++++- 9 files changed, 145 insertions(+), 41 deletions(-) diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java index d5f2b90e1a43..2b83247f7ad8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/Analysis.java @@ -2222,6 +2222,7 @@ public static class TableFunctionInvocationAnalysis private final String functionName; private final Map arguments; private final List tableArgumentAnalyses; + private final Map> requiredColumns; private final List> copartitioningLists; private final int properColumnsCount; private final ConnectorTableFunctionHandle connectorTableFunctionHandle; @@ -2232,6 +2233,7 @@ public TableFunctionInvocationAnalysis( String functionName, Map arguments, List tableArgumentAnalyses, + Map> requiredColumns, List> copartitioningLists, int properColumnsCount, ConnectorTableFunctionHandle connectorTableFunctionHandle, @@ -2241,6 +2243,8 @@ public TableFunctionInvocationAnalysis( this.functionName = requireNonNull(functionName, "functionName is null"); this.arguments = ImmutableMap.copyOf(arguments); this.tableArgumentAnalyses = ImmutableList.copyOf(tableArgumentAnalyses); + this.requiredColumns = requiredColumns.entrySet().stream() + .collect(toImmutableMap(Map.Entry::getKey, entry -> ImmutableList.copyOf(entry.getValue()))); this.copartitioningLists = ImmutableList.copyOf(copartitioningLists); this.properColumnsCount = properColumnsCount; this.connectorTableFunctionHandle = requireNonNull(connectorTableFunctionHandle, "connectorTableFunctionHandle is null"); @@ -2267,6 +2271,11 @@ public List getTableArgumentAnalyses() return tableArgumentAnalyses; } + public Map> getRequiredColumns() + { + return requiredColumns; + } + public List> getCopartitioningLists() { return copartitioningLists; diff --git a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java index d76da4c290bb..44cd0919a94d 100644 --- a/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java +++ b/core/trino-main/src/main/java/io/trino/sql/analyzer/StatementAnalyzer.java @@ -292,6 +292,7 @@ import static io.trino.spi.StandardErrorCode.DUPLICATE_WINDOW_NAME; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_IN_DISTINCT; +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_WINDOW; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; @@ -1570,6 +1571,35 @@ else if (returnTypeSpecification == GENERIC_TABLE) { properColumnsDescriptor = ((DescribedTable) returnTypeSpecification).getDescriptor(); } + // validate the required input columns + Map> requiredColumns = functionAnalysis.getRequiredColumns(); + Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() + .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); + Set allInputs = ImmutableSet.copyOf(tableArgumentsByName.keySet()); + requiredColumns.forEach((name, columns) -> { + if (!allInputs.contains(name)) { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s specifies required columns from table argument %s which cannot be found", node.getName(), name)); + } + if (columns.isEmpty()) { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s specifies empty list of required columns from table argument %s", node.getName(), name)); + } + // the scope is recorded, because table arguments are already analyzed + Scope inputScope = analysis.getScope(tableArgumentsByName.get(name).getRelation()); + columns.stream() + .filter(column -> column < 0 || column >= inputScope.getRelationType().getAllFieldCount()) // hidden columns can be required as well as visible columns + .findFirst() + .ifPresent(column -> { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Invalid index: %s of required column from table argument %s", column, name)); + }); + }); + Set requiredInputs = ImmutableSet.copyOf(requiredColumns.keySet()); + allInputs.stream() + .filter(input -> !requiredInputs.contains(input)) + .findFirst() + .ifPresent(input -> { + throw new TrinoException(FUNCTION_IMPLEMENTATION_ERROR, format("Table function %s does not specify required input columns from table argument %s", node.getName(), input)); + }); + // The result relation type of a table function consists of: // 1. columns created by the table function, called the proper columns. // 2. passed columns from input tables: @@ -1590,8 +1620,6 @@ else if (returnTypeSpecification == GENERIC_TABLE) { .filter(argumentSpecification -> argumentSpecification instanceof TableArgumentSpecification) .map(ArgumentSpecification::getName) .collect(toImmutableList()); - Map tableArgumentsByName = argumentsAnalysis.getTableArgumentAnalyses().stream() - .collect(toImmutableMap(TableArgumentAnalysis::getArgumentName, Function.identity())); // table arguments in order of argument declarations ImmutableList.Builder orderedTableArguments = ImmutableList.builder(); @@ -1616,6 +1644,7 @@ else if (argument.getPartitionBy().isPresent()) { function.getName(), argumentsAnalysis.getPassedArguments(), orderedTableArguments.build(), + functionAnalysis.getRequiredColumns(), copartitioningLists, properColumnsDescriptor == null ? 0 : properColumnsDescriptor.getFields().size(), functionAnalysis.getHandle(), diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java index e2c31555153a..2b0e420e27d8 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/RelationPlanner.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import io.trino.Session; @@ -352,18 +351,9 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node RelationPlan sourcePlan = process(tableArgument.getRelation(), context); PlanBuilder sourcePlanBuilder = newPlanBuilder(sourcePlan, analysis, lambdaDeclarationToSymbolMap, session, plannerContext); - // map column names to symbols - // note: hidden columns are included in the mapping. They are present both in sourceDescriptor.allFields, and in sourcePlan.fieldMappings - // note: for an aliased relation or a CTE, the field names in the relation type are in the same case as specified in the alias. - // quotes and canonicalization rules are not applied. - ImmutableMultimap.Builder columnMapping = ImmutableMultimap.builder(); - RelationType sourceDescriptor = sourcePlan.getDescriptor(); - for (int i = 0; i < sourceDescriptor.getAllFieldCount(); i++) { - Optional name = sourceDescriptor.getFieldByIndex(i).getName(); - if (name.isPresent()) { - columnMapping.put(name.get(), sourcePlan.getSymbol(i)); - } - } + List requiredColumns = functionAnalysis.getRequiredColumns().get(tableArgument.getArgumentName()).stream() + .map(sourcePlan::getSymbol) + .collect(toImmutableList()); Optional specification = Optional.empty(); @@ -394,10 +384,10 @@ protected RelationPlan visitTableFunctionInvocation(TableFunctionInvocation node sources.add(sourcePlanBuilder.getRoot()); sourceProperties.add(new TableArgumentProperties( tableArgument.getArgumentName(), - columnMapping.build(), tableArgument.isRowSemantics(), tableArgument.isPruneWhenEmpty(), tableArgument.isPassThroughColumns(), + requiredColumns, specification)); // add output symbols passed from the table argument diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java index ac302ca75719..d08a671b2f46 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/optimizations/UnaliasSymbolReferences.java @@ -16,7 +16,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableListMultimap; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; import com.google.common.collect.ImmutableSet; import com.google.common.collect.ListMultimap; import com.google.common.collect.Sets; @@ -338,16 +337,13 @@ public PlanAndMappings visitTableFunction(TableFunctionNode node, UnaliasContext SymbolMapper inputMapper = symbolMapper(new HashMap<>(newSource.getMappings())); TableArgumentProperties properties = node.getTableArgumentProperties().get(i); - ImmutableMultimap.Builder newColumnMapping = ImmutableMultimap.builder(); - properties.getColumnMapping().entries().stream() - .forEach(entry -> newColumnMapping.put(entry.getKey(), inputMapper.map(entry.getValue()))); Optional newSpecification = properties.getSpecification().map(inputMapper::mapAndDistinct); newTableArgumentProperties.add(new TableArgumentProperties( properties.getArgumentName(), - newColumnMapping.build(), properties.isRowSemantics(), properties.isPruneWhenEmpty(), properties.isPassThroughColumns(), + inputMapper.map(properties.getRequiredColumns()), newSpecification)); } diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java index 924d88960693..3386a2280334 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/plan/TableFunctionNode.java @@ -17,8 +17,6 @@ import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableMultimap; -import com.google.common.collect.Multimap; import io.trino.metadata.TableFunctionHandle; import io.trino.spi.ptf.Argument; import io.trino.sql.planner.Symbol; @@ -149,26 +147,26 @@ public PlanNode replaceChildren(List newSources) public static class TableArgumentProperties { private final String argumentName; - private final Multimap columnMapping; private final boolean rowSemantics; private final boolean pruneWhenEmpty; private final boolean passThroughColumns; + private final List requiredColumns; private final Optional specification; @JsonCreator public TableArgumentProperties( @JsonProperty("argumentName") String argumentName, - @JsonProperty("columnMapping") Multimap columnMapping, @JsonProperty("rowSemantics") boolean rowSemantics, @JsonProperty("pruneWhenEmpty") boolean pruneWhenEmpty, @JsonProperty("passThroughColumns") boolean passThroughColumns, + @JsonProperty("requiredColumns") List requiredColumns, @JsonProperty("specification") Optional specification) { this.argumentName = requireNonNull(argumentName, "argumentName is null"); - this.columnMapping = ImmutableMultimap.copyOf(columnMapping); this.rowSemantics = rowSemantics; this.pruneWhenEmpty = pruneWhenEmpty; this.passThroughColumns = passThroughColumns; + this.requiredColumns = ImmutableList.copyOf(requiredColumns); this.specification = requireNonNull(specification, "specification is null"); } @@ -178,12 +176,6 @@ public String getArgumentName() return argumentName; } - @JsonProperty - public Multimap getColumnMapping() - { - return columnMapping; - } - @JsonProperty public boolean isRowSemantics() { @@ -202,6 +194,12 @@ public boolean isPassThroughColumns() return passThroughColumns; } + @JsonProperty + public List getRequiredColumns() + { + return requiredColumns; + } + @JsonProperty public Optional getSpecification() { diff --git a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java index e03e392f882a..040067cc8588 100644 --- a/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java +++ b/core/trino-main/src/main/java/io/trino/sql/planner/planprinter/PlanPrinter.java @@ -1821,6 +1821,9 @@ private String formatArgument(String argumentName, Argument argument, Map boundSymbols) checkDependencies( inputs, - argumentProperties.getColumnMapping().values(), - "Invalid node. Input symbols from source %s (%s) not in source plan output (%s)", + argumentProperties.getRequiredColumns(), + "Invalid node. Required input symbols from source %s (%s) not in source plan output (%s)", argumentProperties.getArgumentName(), - argumentProperties.getColumnMapping().values(), + argumentProperties.getRequiredColumns(), source.getOutputSymbols()); argumentProperties.getSpecification().ifPresent(specification -> { checkDependencies( diff --git a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java index b5ba197c029c..4febbe36409a 100644 --- a/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java +++ b/core/trino-main/src/test/java/io/trino/connector/TestingTableFunctions.java @@ -54,6 +54,7 @@ public class TestingTableFunctions .build(); private static final TableFunctionAnalysis NO_DESCRIPTOR_ANALYSIS = TableFunctionAnalysis.builder() .handle(HANDLE) + .requiredColumns("INPUT", ImmutableList.of(0)) .build(); /** @@ -164,7 +165,11 @@ public TableArgumentFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); } } @@ -187,7 +192,11 @@ public TableArgumentRowSemanticsFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0)) + .build(); } } @@ -235,7 +244,12 @@ public TwoTableArgumentsFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT1", ImmutableList.of(0)) + .requiredColumns("INPUT2", ImmutableList.of(0)) + .build(); } } @@ -278,7 +292,9 @@ public MonomorphicStaticReturnTypeFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return NO_DESCRIPTOR_ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .build(); } } @@ -364,7 +380,39 @@ public DifferentArgumentTypesFunction() @Override public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) { - return ANALYSIS; + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field(COLUMN_NAME, Optional.of(BOOLEAN))))) + .requiredColumns("INPUT_1", ImmutableList.of(0)) + .requiredColumns("INPUT_2", ImmutableList.of(0)) + .requiredColumns("INPUT_3", ImmutableList.of(0)) + .build(); + } + } + + public static class RequiredColumnsFunction + extends AbstractConnectorTableFunction + { + public RequiredColumnsFunction() + { + super( + SCHEMA_NAME, + "required_columns_function", + ImmutableList.of( + TableArgumentSpecification.builder() + .name("INPUT") + .build()), + GENERIC_TABLE); + } + + @Override + public TableFunctionAnalysis analyze(ConnectorSession session, ConnectorTransactionHandle transaction, Map arguments) + { + return TableFunctionAnalysis.builder() + .handle(HANDLE) + .returnedType(new Descriptor(ImmutableList.of(new Descriptor.Field("column", Optional.of(BOOLEAN))))) + .requiredColumns("INPUT", ImmutableList.of(0, 1)) + .build(); } } diff --git a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java index 40e0b1725512..86f95fc1a775 100644 --- a/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java +++ b/core/trino-main/src/test/java/io/trino/sql/analyzer/TestAnalyzer.java @@ -30,6 +30,7 @@ import io.trino.connector.TestingTableFunctions.OnlyPassThroughFunction; import io.trino.connector.TestingTableFunctions.PassThroughFunction; import io.trino.connector.TestingTableFunctions.PolymorphicStaticReturnTypeFunction; +import io.trino.connector.TestingTableFunctions.RequiredColumnsFunction; import io.trino.connector.TestingTableFunctions.TableArgumentFunction; import io.trino.connector.TestingTableFunctions.TableArgumentRowSemanticsFunction; import io.trino.connector.TestingTableFunctions.TwoScalarArgumentsFunction; @@ -120,6 +121,7 @@ import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_CONSTANT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_IN_DISTINCT; import static io.trino.spi.StandardErrorCode.EXPRESSION_NOT_SCALAR; +import static io.trino.spi.StandardErrorCode.FUNCTION_IMPLEMENTATION_ERROR; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_AGGREGATE; import static io.trino.spi.StandardErrorCode.FUNCTION_NOT_FOUND; import static io.trino.spi.StandardErrorCode.INVALID_ARGUMENTS; @@ -6622,6 +6624,34 @@ public void testTableFunctionAliasing() .hasMessage("line 1:23: Column 'table_alias.a' cannot be resolved"); } + @Test + public void testTableFunctionRequiredColumns() + { + // the function required_column_function specifies columns 0 and 1 from table argument "INPUT" as required. + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(t1))) + """); + + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(SELECT 1, 2, 3))) + """); + + assertFails(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(SELECT 1))) + """) + .hasErrorCode(FUNCTION_IMPLEMENTATION_ERROR) + .hasMessage("Invalid index: 1 of required column from table argument INPUT"); + + // table s1.t5 has two columns. The second column is hidden. Table function can require a hidden column. + analyze(""" + SELECT * FROM TABLE(system.required_columns_function( + input => TABLE(s1.t5))) + """); + } + @BeforeClass public void setup() { @@ -7011,7 +7041,8 @@ public ConnectorTransactionHandle getConnectorTransaction(TransactionId transact new OnlyPassThroughFunction(), new MonomorphicStaticReturnTypeFunction(), new PolymorphicStaticReturnTypeFunction(), - new PassThroughFunction()))), + new PassThroughFunction(), + new RequiredColumnsFunction()))), new SessionPropertyManager(), tablePropertyManager, analyzePropertyManager, From bcef8854f0e4eebc80efa4c60037f2e120626d3c Mon Sep 17 00:00:00 2001 From: Daniel Zhi Date: Thu, 14 Oct 2021 14:28:22 -0700 Subject: [PATCH 24/24] Coordinator-driven graceful decommission/recommission of worker nodes Support logic in coordinator to graceful decommission and recommission of workers. 1. Add endpoint to list node and their rich states (to be used by autoscaler). 2. Add endpoint to refresh node with list of nodes to exluced/decommission. 3. Add logic in coordinator to track and exclude nodes to decommission. 4. Add DECOMMISSIONING and DECOMMISSIONED in NodeState 5. Add handling of decommission and recommission request in workers. reference: https://github.com/trinodb/trino/issues/9976 --- .../main/java/io/trino/metadata/AllNodes.java | 43 +- .../trino/metadata/DiscoveryNodeManager.java | 132 +++--- .../trino/metadata/InMemoryNodeManager.java | 8 +- .../java/io/trino/metadata/NodeState.java | 2 + .../io/trino/server/CoordinatorModule.java | 9 + .../main/java/io/trino/server/ForNodes.java | 31 ++ .../main/java/io/trino/server/NodeStatus.java | 11 +- .../java/io/trino/server/NodesResource.java | 389 ++++++++++++++++++ .../java/io/trino/server/QueryResource.java | 28 ++ .../io/trino/server/RemoteNodeStatus.java | 134 ++++++ .../src/main/java/io/trino/server/Server.java | 2 +- .../io/trino/server/ServerInfoResource.java | 44 +- .../java/io/trino/server/StatusResource.java | 4 +- ...ndler.java => UpdateNodeStateHandler.java} | 167 ++++++-- ...Module.java => UpdateNodeStateModule.java} | 4 +- .../server/testing/TestingTrinoServer.java | 12 +- .../trino/server/ui/ClusterStatsResource.java | 22 + .../io/trino/tests/TestGracefulShutdown.java | 4 +- 18 files changed, 916 insertions(+), 130 deletions(-) create mode 100644 core/trino-main/src/main/java/io/trino/server/ForNodes.java create mode 100644 core/trino-main/src/main/java/io/trino/server/NodesResource.java create mode 100644 core/trino-main/src/main/java/io/trino/server/RemoteNodeStatus.java rename core/trino-main/src/main/java/io/trino/server/{GracefulShutdownHandler.java => UpdateNodeStateHandler.java} (54%) rename core/trino-main/src/main/java/io/trino/server/{GracefulShutdownModule.java => UpdateNodeStateModule.java} (89%) diff --git a/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java b/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java index cecf16b9c9e7..532af85693dc 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java +++ b/core/trino-main/src/main/java/io/trino/metadata/AllNodes.java @@ -23,16 +23,32 @@ public class AllNodes { private final Set activeNodes; + private final Set decommissionedNodes; + private final Set decommissioningNodes; private final Set inactiveNodes; private final Set shuttingDownNodes; private final Set activeCoordinators; + private final Set aliveNodes; - public AllNodes(Set activeNodes, Set inactiveNodes, Set shuttingDownNodes, Set activeCoordinators) + public AllNodes(Set activeNodes, + Set decommissionedNodes, + Set decommissioningNodes, + Set inactiveNodes, + Set shuttingDownNodes, + Set activeCoordinators) { this.activeNodes = ImmutableSet.copyOf(requireNonNull(activeNodes, "activeNodes is null")); + this.decommissionedNodes = ImmutableSet.copyOf(requireNonNull(decommissionedNodes, "decommissionedNodes is null")); + this.decommissioningNodes = ImmutableSet.copyOf(requireNonNull(decommissioningNodes, "decommissioningNodes is null")); this.inactiveNodes = ImmutableSet.copyOf(requireNonNull(inactiveNodes, "inactiveNodes is null")); this.shuttingDownNodes = ImmutableSet.copyOf(requireNonNull(shuttingDownNodes, "shuttingDownNodes is null")); this.activeCoordinators = ImmutableSet.copyOf(requireNonNull(activeCoordinators, "activeCoordinators is null")); + this.aliveNodes = ImmutableSet.builder() + .addAll(activeNodes) + .addAll(decommissionedNodes) + .addAll(decommissioningNodes) + .addAll(shuttingDownNodes) + .build(); } public Set getActiveNodes() @@ -40,6 +56,16 @@ public Set getActiveNodes() return activeNodes; } + public Set getDecommissionedNodes() + { + return decommissionedNodes; + } + + public Set getDecommissioningNodes() + { + return decommissioningNodes; + } + public Set getInactiveNodes() { return inactiveNodes; @@ -50,6 +76,11 @@ public Set getShuttingDownNodes() return shuttingDownNodes; } + public Set getAliveNodes() + { + return aliveNodes; + } + public Set getActiveCoordinators() { return activeCoordinators; @@ -66,6 +97,8 @@ public boolean equals(Object o) } AllNodes allNodes = (AllNodes) o; return Objects.equals(activeNodes, allNodes.activeNodes) && + Objects.equals(decommissionedNodes, allNodes.decommissionedNodes) && + Objects.equals(decommissioningNodes, allNodes.decommissioningNodes) && Objects.equals(inactiveNodes, allNodes.inactiveNodes) && Objects.equals(shuttingDownNodes, allNodes.shuttingDownNodes) && Objects.equals(activeCoordinators, allNodes.activeCoordinators); @@ -74,6 +107,12 @@ public boolean equals(Object o) @Override public int hashCode() { - return Objects.hash(activeNodes, inactiveNodes, shuttingDownNodes, activeCoordinators); + return Objects.hash( + activeNodes, + inactiveNodes, + shuttingDownNodes, + decommissioningNodes, + decommissionedNodes, + activeCoordinators); } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java b/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java index 394bdb8a6bc2..2a106e28bd9e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/DiscoveryNodeManager.java @@ -18,7 +18,6 @@ import com.google.common.collect.ImmutableSet; import com.google.common.collect.ImmutableSetMultimap; import com.google.common.collect.SetMultimap; -import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import io.airlift.discovery.client.ServiceDescriptor; import io.airlift.discovery.client.ServiceSelector; @@ -43,6 +42,7 @@ import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; +import java.util.HashSet; import java.util.List; import java.util.Optional; import java.util.Set; @@ -59,6 +59,8 @@ import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; import static io.trino.connector.system.GlobalSystemConnector.CATALOG_HANDLE; import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DECOMMISSIONED; +import static io.trino.metadata.NodeState.DECOMMISSIONING; import static io.trino.metadata.NodeState.INACTIVE; import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static java.util.Locale.ENGLISH; @@ -93,6 +95,9 @@ public final class DiscoveryNodeManager @GuardedBy("this") private Set coordinators; + @GuardedBy("this") + private Set nodesToBeDecommissioned = new HashSet<>(); + @GuardedBy("this") private final List> listeners = new ArrayList<>(); @@ -169,10 +174,7 @@ public void destroy() private void pollWorkers() { AllNodes allNodes = getAllNodes(); - Set aliveNodes = ImmutableSet.builder() - .addAll(allNodes.getActiveNodes()) - .addAll(allNodes.getShuttingDownNodes()) - .build(); + Set aliveNodes = allNodes.getAliveNodes(); ImmutableSet aliveNodeIds = aliveNodes.stream() .map(InternalNode::getNodeIdentifier) @@ -215,9 +217,7 @@ private synchronized void refreshNodesInternal() .filter(service -> !failureDetector.getFailed().contains(service)) .collect(toImmutableSet()); - ImmutableSet.Builder activeNodesBuilder = ImmutableSet.builder(); - ImmutableSet.Builder inactiveNodesBuilder = ImmutableSet.builder(); - ImmutableSet.Builder shuttingDownNodesBuilder = ImmutableSet.builder(); + ImmutableSetMultimap.Builder nodeStateMapBuilder = ImmutableSetMultimap.builder(); ImmutableSet.Builder coordinatorsBuilder = ImmutableSet.builder(); ImmutableSetMultimap.Builder byCatalogHandleBuilder = ImmutableSetMultimap.builder(); @@ -229,55 +229,65 @@ private synchronized void refreshNodesInternal() InternalNode node = new InternalNode(service.getNodeId(), uri, nodeVersion, coordinator); NodeState nodeState = getNodeState(node); - switch (nodeState) { - case ACTIVE: - activeNodesBuilder.add(node); - if (coordinator) { - coordinatorsBuilder.add(node); - } + // nodesToBeDecommissioned is the authoritative list of node to be decommissioned + // from coordinator perspective. Once a worker appears in the list, + // its state become DECOMMISSIONING even if worker has yet confirmed such + // so that no new tasks will be scheduled on it. + if (!coordinator && nodesToBeDecommissioned.contains(node.getNodeIdentifier()) + && nodeState == ACTIVE) { + log.debug("Treat " + node.getNodeIdentifier() + " as DECOMMISSIONING"); + nodeState = DECOMMISSIONING; + } - // record available active nodes organized by catalog handle - String catalogHandleIds = service.getProperties().get("catalogHandleIds"); - if (catalogHandleIds != null) { - catalogHandleIds = catalogHandleIds.toLowerCase(ENGLISH); - for (String catalogHandleId : CATALOG_HANDLE_ID_SPLITTER.split(catalogHandleIds)) { - byCatalogHandleBuilder.put(CatalogHandle.fromId(catalogHandleId), node); - } + // Add node to node state map. + nodeStateMapBuilder.put(nodeState, node); + + if (nodeState == ACTIVE) { + if (coordinator) { + coordinatorsBuilder.add(node); + } + + // record available active nodes organized by catalog handle + String catalogHandleIds = service.getProperties().get("catalogHandleIds"); + if (catalogHandleIds != null) { + catalogHandleIds = catalogHandleIds.toLowerCase(ENGLISH); + for (String catalogHandleId : CATALOG_HANDLE_ID_SPLITTER.split(catalogHandleIds)) { + byCatalogHandleBuilder.put(CatalogHandle.fromId(catalogHandleId), node); } + } - // always add system connector - byCatalogHandleBuilder.put(CATALOG_HANDLE, node); - break; - case INACTIVE: - inactiveNodesBuilder.add(node); - break; - case SHUTTING_DOWN: - shuttingDownNodesBuilder.add(node); - break; - default: - log.error("Unknown state %s for node %s", nodeState, node); + // always add system connector + byCatalogHandleBuilder.put(CATALOG_HANDLE, node); } } } - if (allNodes != null) { + // nodes by catalog handle changes anytime a node adds or removes a catalog (note: this is not part of the listener system) + if (!allCatalogsOnAllNodes) { + activeNodesByCatalogHandle = Optional.of(byCatalogHandleBuilder.build()); + } + + SetMultimap nodeStateMap = nodeStateMapBuilder.build(); + AllNodes currAllNodes = new AllNodes( + nodeStateMap.get(ACTIVE), + nodeStateMap.get(DECOMMISSIONED), + nodeStateMap.get(DECOMMISSIONING), + nodeStateMap.get(INACTIVE), + nodeStateMap.get(SHUTTING_DOWN), + coordinatorsBuilder.build()); + + if (this.allNodes != null) { // log node that are no longer active (but not shutting down) - SetView missingNodes = difference(allNodes.getActiveNodes(), Sets.union(activeNodesBuilder.build(), shuttingDownNodesBuilder.build())); + SetView missingNodes = difference(allNodes.getActiveNodes(), currAllNodes.getAliveNodes()); for (InternalNode missingNode : missingNodes) { log.info("Previously active node is missing: %s (last seen at %s)", missingNode.getNodeIdentifier(), missingNode.getHost()); } } - // nodes by catalog handle changes anytime a node adds or removes a catalog (note: this is not part of the listener system) - if (!allCatalogsOnAllNodes) { - activeNodesByCatalogHandle = Optional.of(byCatalogHandleBuilder.build()); - } - - AllNodes allNodes = new AllNodes(activeNodesBuilder.build(), inactiveNodesBuilder.build(), shuttingDownNodesBuilder.build(), coordinatorsBuilder.build()); // only update if all nodes actually changed (note: this does not include the connectors registered with the nodes) - if (!allNodes.equals(this.allNodes)) { + if (!currAllNodes.equals(this.allNodes)) { // assign allNodes to a local variable for use in the callback below - this.allNodes = allNodes; + this.allNodes = currAllNodes; coordinators = coordinatorsBuilder.build(); // notify listeners @@ -289,22 +299,19 @@ private synchronized void refreshNodesInternal() private NodeState getNodeState(InternalNode node) { if (expectedNodeVersion.equals(node.getNodeVersion())) { - if (isNodeShuttingDown(node.getNodeIdentifier())) { - return SHUTTING_DOWN; + String nodeId = node.getNodeIdentifier(); + Optional remoteNodeState = nodeStates.containsKey(nodeId) + ? nodeStates.get(nodeId).getNodeState() + : Optional.empty(); + if (remoteNodeState.isPresent()) { + return remoteNodeState.get(); } + // no remote node state return ACTIVE; } return INACTIVE; } - private boolean isNodeShuttingDown(String nodeId) - { - Optional remoteNodeState = nodeStates.containsKey(nodeId) - ? nodeStates.get(nodeId).getNodeState() - : Optional.empty(); - return remoteNodeState.isPresent() && remoteNodeState.get() == SHUTTING_DOWN; - } - @Override public synchronized AllNodes getAllNodes() { @@ -317,6 +324,18 @@ public int getActiveNodeCount() return getAllNodes().getActiveNodes().size(); } + @Managed + public int getDecommissionedNodeCount() + { + return getAllNodes().getDecommissionedNodes().size(); + } + + @Managed + public int getDecommissioningNodeCount() + { + return getAllNodes().getDecommissioningNodes().size(); + } + @Managed public int getInactiveNodeCount() { @@ -335,6 +354,10 @@ public Set getNodes(NodeState state) switch (state) { case ACTIVE: return getAllNodes().getActiveNodes(); + case DECOMMISSIONED: + return getAllNodes().getDecommissionedNodes(); + case DECOMMISSIONING: + return getAllNodes().getDecommissioningNodes(); case INACTIVE: return getAllNodes().getInactiveNodes(); case SHUTTING_DOWN: @@ -407,4 +430,9 @@ private static boolean isCoordinator(ServiceDescriptor service) { return Boolean.parseBoolean(service.getProperties().get("coordinator")); } + + public synchronized void setNodesToExclude(Set nodesToExclude) + { + this.nodesToBeDecommissioned = nodesToExclude; + } } diff --git a/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java b/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java index a2e17c960d30..751ea9e91134 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java +++ b/core/trino-main/src/main/java/io/trino/metadata/InMemoryNodeManager.java @@ -59,7 +59,11 @@ public Set getNodes(NodeState state) { switch (state) { case ACTIVE: - return allNodes; + return getAllNodes().getActiveNodes(); + case DECOMMISSIONED: + return getAllNodes().getDecommissionedNodes(); + case DECOMMISSIONING: + return getAllNodes().getDecommissioningNodes(); case INACTIVE: case SHUTTING_DOWN: return ImmutableSet.of(); @@ -86,6 +90,8 @@ public AllNodes getAllNodes() allNodes, ImmutableSet.of(), ImmutableSet.of(), + ImmutableSet.of(), + ImmutableSet.of(), ImmutableSet.of(CURRENT_NODE)); } diff --git a/core/trino-main/src/main/java/io/trino/metadata/NodeState.java b/core/trino-main/src/main/java/io/trino/metadata/NodeState.java index 4bf511968a31..286ad269268e 100644 --- a/core/trino-main/src/main/java/io/trino/metadata/NodeState.java +++ b/core/trino-main/src/main/java/io/trino/metadata/NodeState.java @@ -16,6 +16,8 @@ public enum NodeState { ACTIVE, + DECOMMISSIONED, + DECOMMISSIONING, INACTIVE, SHUTTING_DOWN } diff --git a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java index 610d3b09a11c..da50f559de08 100644 --- a/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java +++ b/core/trino-main/src/main/java/io/trino/server/CoordinatorModule.java @@ -341,6 +341,15 @@ protected void setup(Binder binder) install(new QueryExecutionFactoryModule()); + // nodes and queries for monitoring and auto-scaling + jaxrsBinder(binder).bind(NodesResource.class); + httpClientBinder(binder).bindHttpClient("nodes", ForNodes.class) + .withTracing() + .withConfigDefaults(config -> { + config.setIdleTimeout(new Duration(30, SECONDS)); + config.setRequestTimeout(new Duration(10, SECONDS)); + }); + // cleanup binder.bind(ExecutorCleanup.class).asEagerSingleton(); } diff --git a/core/trino-main/src/main/java/io/trino/server/ForNodes.java b/core/trino-main/src/main/java/io/trino/server/ForNodes.java new file mode 100644 index 000000000000..c6eb6c3a05c2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/ForNodes.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import javax.inject.Qualifier; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({FIELD, PARAMETER, METHOD}) +@Qualifier +public @interface ForNodes +{ +} diff --git a/core/trino-main/src/main/java/io/trino/server/NodeStatus.java b/core/trino-main/src/main/java/io/trino/server/NodeStatus.java index cce607262301..c331ed918e03 100644 --- a/core/trino-main/src/main/java/io/trino/server/NodeStatus.java +++ b/core/trino-main/src/main/java/io/trino/server/NodeStatus.java @@ -37,6 +37,7 @@ public class NodeStatus private final long heapUsed; private final long heapAvailable; private final long nonHeapUsed; + private final long startTimeEpoch; @JsonCreator public NodeStatus( @@ -53,7 +54,8 @@ public NodeStatus( @JsonProperty("systemCpuLoad") double systemCpuLoad, @JsonProperty("heapUsed") long heapUsed, @JsonProperty("heapAvailable") long heapAvailable, - @JsonProperty("nonHeapUsed") long nonHeapUsed) + @JsonProperty("nonHeapUsed") long nonHeapUsed, + @JsonProperty("startTimeEpoch") long startTimeEpoch) { this.nodeId = requireNonNull(nodeId, "nodeId is null"); this.nodeVersion = requireNonNull(nodeVersion, "nodeVersion is null"); @@ -69,6 +71,7 @@ public NodeStatus( this.heapUsed = heapUsed; this.heapAvailable = heapAvailable; this.nonHeapUsed = nonHeapUsed; + this.startTimeEpoch = startTimeEpoch; } @JsonProperty @@ -154,4 +157,10 @@ public long getNonHeapUsed() { return nonHeapUsed; } + + @JsonProperty + public long getStartTimeEpoch() + { + return startTimeEpoch; + } } diff --git a/core/trino-main/src/main/java/io/trino/server/NodesResource.java b/core/trino-main/src/main/java/io/trino/server/NodesResource.java new file mode 100644 index 000000000000..bcfe63bcdc49 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/NodesResource.java @@ -0,0 +1,389 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableSet; +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import io.airlift.http.client.BodyGenerator; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.HttpClient.HttpResponseFuture; +import io.airlift.http.client.Request; +import io.airlift.http.client.StaticBodyGenerator; +import io.airlift.http.client.StatusResponseHandler; +import io.airlift.http.client.StatusResponseHandler.StatusResponse; +import io.airlift.log.Logger; +import io.trino.metadata.AllNodes; +import io.trino.metadata.DiscoveryNodeManager; +import io.trino.metadata.InternalNode; +import io.trino.metadata.NodeState; +import io.trino.server.security.ResourceSecurity; + +import javax.annotation.Nullable; +import javax.annotation.PostConstruct; +import javax.annotation.PreDestroy; +import javax.inject.Inject; +import javax.ws.rs.Consumes; +import javax.ws.rs.GET; +import javax.ws.rs.PUT; +import javax.ws.rs.Path; +import javax.ws.rs.Produces; +import javax.ws.rs.core.Response; + +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.google.common.collect.ImmutableSet.toImmutableSet; +import static com.google.common.collect.Sets.difference; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.concurrent.Threads.threadsNamed; +import static io.airlift.http.client.HttpUriBuilder.uriBuilderFrom; +import static io.airlift.http.client.Request.Builder.preparePut; +import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DECOMMISSIONED; +import static io.trino.metadata.NodeState.DECOMMISSIONING; +import static io.trino.metadata.NodeState.INACTIVE; +import static io.trino.metadata.NodeState.SHUTTING_DOWN; +import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; +import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; +import static javax.ws.rs.core.MediaType.TEXT_PLAIN; + +// NodesResource expose coordinator endpoints to facilitate the auto-scaling of cluster. +// These endpoints include: +// 1. /v1/nodes --- list of all alive nodes with NodeState and NodeStatus; +// 2. /v1/nodes/refreshnodes --- refresh with list of nodes to exclude (decommission); +@Path("/v1/nodes") +public class NodesResource +{ + private static Logger log = Logger.get(NodesResource.class); + + private final DiscoveryNodeManager nodeManager; + private final HttpClient httpClient; + + // Set of worker nodes to exclude (decommission). + Set nodesToExclude = new HashSet<>(); + + // Executor to periodically and asynchronously poll NodeStatus of all workers. + private final ScheduledExecutorService nodeStatusExecutor; + + // Poll worker status once every 15 seconds, a balance between freshness and cost. + private static final int POLL_NODESTATUS_SEC = 15; + + // Map from NodeId to RemoteNodeStatus. + private final ConcurrentHashMap nodeStatuses = new ConcurrentHashMap<>(); + + private AtomicInteger numRefreshNodes = new AtomicInteger(); + private AtomicInteger numUpdateStateOk = new AtomicInteger(); + private AtomicInteger numUpdateStateFailed = new AtomicInteger(); + + @Inject + public NodesResource(DiscoveryNodeManager nodeManager, @ForNodes HttpClient httpClient) + { + log.info("Construct NodesResource"); + this.nodeManager = requireNonNull(nodeManager, "nodeManager is null"); + this.httpClient = httpClient; + this.nodeStatusExecutor = newSingleThreadScheduledExecutor(threadsNamed("autoscale-executor-%s")); + } + + @PostConstruct + public void startPollingNodeStatus() + { + nodeStatusExecutor.scheduleWithFixedDelay(() -> { + try { + pollWorkers(); + } + catch (Exception e) { + log.error(e, "Error polling state of nodes"); + } + }, 5, POLL_NODESTATUS_SEC, TimeUnit.SECONDS); + pollWorkers(); + } + + // Poll /v1/status of all alive workers. + private void pollWorkers() + { + AllNodes allNodes = nodeManager.getAllNodes(); + + Set aliveNodes = allNodes.getAliveNodes(); + + ImmutableSet aliveNodeIds = aliveNodes.stream() + .map(InternalNode::getNodeIdentifier) + .collect(toImmutableSet()); + + Set deadNodes = difference(nodeStatuses.keySet(), aliveNodeIds).immutableCopy(); + nodeStatuses.keySet().removeAll(deadNodes); + + // Add new nodes + for (InternalNode node : aliveNodes) { + URI statusUri = uriBuilderFrom(node.getInternalUri()).appendPath("/v1/status").build(); + nodeStatuses.putIfAbsent(node.getNodeIdentifier(), new RemoteNodeStatus(httpClient, statusUri)); + } + + // Schedule refresh + nodeStatuses.values().forEach(RemoteNodeStatus::asyncRefresh); + } + + @PreDestroy + public void stop() + { + nodeStatusExecutor.shutdownNow(); + } + + // Gets list of all nodes where each is modeled as NodeInfo. + @ResourceSecurity(PUBLIC) + @GET + @Produces(APPLICATION_JSON) + public List getNodes() + { + final AllNodes nodes = nodeManager.getAllNodes(); + TreeMap asmp = new TreeMap<>(); + addToNodeMap(nodes.getActiveNodes(), ACTIVE, asmp); + addToNodeMap(nodes.getShuttingDownNodes(), SHUTTING_DOWN, asmp); + addToNodeMap(nodes.getInactiveNodes(), INACTIVE, asmp); + addToNodeMap(nodes.getDecommissioningNodes(), DECOMMISSIONING, asmp); + addToNodeMap(nodes.getDecommissionedNodes(), DECOMMISSIONED, asmp); + return new ArrayList<>(asmp.values()); + } + + // NodeInfo is a bundle of (InternalNode, NodeState, NodeStatus) + // where NodeStatus is polled from v1/status of the node. + public static class NodeInfo + { + private final String nodeId; + private final String uri; + private final boolean coordinator; + private final NodeState state; + private final NodeStatus status; + private final long statusTime; + + @JsonCreator + public NodeInfo( + @JsonProperty("nodeId") String nodeId, + @JsonProperty("uri") String uri, + @JsonProperty("coordinator") boolean coordinator, + @JsonProperty("state") NodeState state, + @JsonProperty("status") NodeStatus status, + @JsonProperty("statusTime") long statusTime) + { + this.nodeId = requireNonNull(nodeId, "nodeId is null"); + this.uri = uri; + this.coordinator = coordinator; + this.state = state; + this.status = status; + this.statusTime = statusTime; + } + + @JsonProperty + public String getNodeId() + { + return nodeId; + } + + @JsonProperty + public String getUri() + { + return uri; + } + + @JsonProperty + public boolean getCoordinator() + { + return coordinator; + } + + @JsonProperty + public NodeState getState() + { + return state; + } + + @JsonProperty + public NodeStatus getStatus() + { + return status; + } + + @JsonProperty + public long getStatusTime() + { + return statusTime; + } + } + + // Given an absolute list of nodes to exclude (a.k.a. decommission), which means: + // 1. The desired state for worker appear in the exclude list is DECOMMISSIONED; + // 2. The desired state for worker that does not appear in the exclude list is ACTIVE; + // Initiate decommission/recommission actions as appropriate to have all worker nodes + // move toward their desired state. Specifically: + // 1. A worker to exclude will be honored within seconds with no new task dispatch. + // asyncUpdateState will be called for the worker to wait for pending tasks and + // later report as DECOMMISSIONED upon completion. + // 2. A worker that was previously excluded but no longer will qualify within seconds + // for new task dispatch. asyncUpdateState will be called for the worker to be + // back to ACTIVE. + @ResourceSecurity(PUBLIC) + @PUT + @Path("refreshnodes") + @Consumes(APPLICATION_JSON) + @Produces(TEXT_PLAIN) + public Response refreshNodes(List exclude) + { + numRefreshNodes.incrementAndGet(); + log.info(numRefreshNodes.get() + " refreshNodes " + Joiner.on(',').join(exclude)); + + TreeMap asnm = getId2NodeInfoMap(); + // Assume nodesToExclude are comma separated list of nodeIds + Set nodesToExclude = parseNodesToExclude(exclude, asnm.keySet()); + if (!nodesToExclude.equals(this.nodesToExclude)) { + this.nodesToExclude = nodesToExclude; + nodeManager.setNodesToExclude(nodesToExclude); + } + + for (NodeInfo node : asnm.values()) { + if (node.coordinator) { + continue; + } + // Decommission ACTIVE nodes that appear in nodesToExclude. + // Note that for now we update state during each refresh: + // 1. worker handle decommission efficiently if it is already DN or DD state. + // 2. we didn't track whether the previous update was successful + // 3. ensure DECOMMISSIONING state on worker just in case. + if (nodesToExclude.contains(node.nodeId)) { + asyncUpdateState(node, DECOMMISSIONING); + } + + // Recommission DN/DD nodes that do not appear in nodesToExclude + if ((node.state == DECOMMISSIONING || node.state == DECOMMISSIONED) + && !nodesToExclude.contains(node.nodeId)) { + asyncUpdateState(node, ACTIVE); + } + } + + return Response.ok().type(TEXT_PLAIN) + .entity(String.format("refreshNodes [%s] OK", Joiner.on(',').join(exclude))) + .build(); + } + + // Parse given list of node to exclude into a set and log unknown ones. + private static Set parseNodesToExclude(List exclude, Set nodes) + { + ImmutableSet.Builder nodesToExclude = ImmutableSet.builder(); + for (String node : exclude) { + if (!nodes.contains(node)) { + log.info("parseNodesToExclude unknown node " + node); + } + nodesToExclude.add(node); + } + return nodesToExclude.build(); + } + + // Get map from nodeId to AutoScaleNode for all nodes. + private TreeMap getId2NodeInfoMap() + { + final AllNodes nodes = nodeManager.getAllNodes(); + TreeMap nodeMap = new TreeMap<>(); + addToNodeMap(nodes.getActiveNodes(), ACTIVE, nodeMap); + addToNodeMap(nodes.getShuttingDownNodes(), SHUTTING_DOWN, nodeMap); + addToNodeMap(nodes.getInactiveNodes(), INACTIVE, nodeMap); + addToNodeMap(nodes.getDecommissioningNodes(), DECOMMISSIONING, nodeMap); + addToNodeMap(nodes.getDecommissionedNodes(), DECOMMISSIONED, nodeMap); + return nodeMap; + } + + // Add all nodes with a specific NodeState into nmap. + private void addToNodeMap( + Set nodes, NodeState state, TreeMap nmap) + { + for (InternalNode node : nodes) { + String nodeId = node.getNodeIdentifier(); + String uri = node.getInternalUri().toString(); + RemoteNodeStatus rns = nodeStatuses.get(nodeId); + NodeStatus status = rns != null && rns.getNodeStatus().isPresent() + ? rns.getNodeStatus().get() : null; + nmap.put(nodeId, new NodeInfo( + nodeId, uri, node.isCoordinator(), state, status, + rns == null ? 0 : rns.getLastUpdateTime())); + } + } + + // Asynchronously update state of a specific worker, basically execute HTTP put + // request against /v1/info/state endpoint on the remote worker. + private synchronized void asyncUpdateState(NodeInfo node, NodeState state) + { + log.info(String.format("asyncUpdateState %s %s", node.nodeId, state)); + Request request = getUpdateStateRequest(node, state); + HttpResponseFuture responseFuture = httpClient.executeAsync( + request, StatusResponseHandler.createStatusResponseHandler()); + + Futures.addCallback(responseFuture, new FutureCallback() + { + @Override + public void onSuccess(@Nullable StatusResponse result) + { + numUpdateStateOk.incrementAndGet(); + log.info(String.format("OK async updated %s %s", request.getUri(), state)); + } + + @Override + public void onFailure(Throwable t) + { + numUpdateStateFailed.incrementAndGet(); + log.info(String.format("Error async updated %s %s %s", + request.getUri(), state, t.getMessage())); + } + }, directExecutor()); + } + + private synchronized Request getUpdateStateRequest(NodeInfo node, NodeState state) + { + // http://10.43.31.106:8081 -> http://10.43.31.106:8081/v1/info/state + URI infoStateUri = uriBuilderFrom(getUri(node.getUri())).appendPath("/v1/info/state").build(); + + // Note that the quote in "" is needed as otherwise + // Unrecognized token 'DECOMMISSION': was expecting ('true', 'false' or 'null') + BodyGenerator bodyGenerator = StaticBodyGenerator.createStaticBodyGenerator( + "\"" + state + "\"", Charset.defaultCharset()); + return preparePut() + .setUri(infoStateUri) + .setHeader(CONTENT_TYPE, "application/json") + .setBodyGenerator(bodyGenerator) + .build(); + } + + private static URI getUri(String uri) + { + try { + return new URI(uri); + } + catch (URISyntaxException e) { + throw new RuntimeException(e.getMessage()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/QueryResource.java b/core/trino-main/src/main/java/io/trino/server/QueryResource.java index 77073dda812f..fdaa52526ef1 100644 --- a/core/trino-main/src/main/java/io/trino/server/QueryResource.java +++ b/core/trino-main/src/main/java/io/trino/server/QueryResource.java @@ -31,6 +31,7 @@ import javax.ws.rs.PUT; import javax.ws.rs.Path; import javax.ws.rs.PathParam; +import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; import javax.ws.rs.core.Context; import javax.ws.rs.core.HttpHeaders; @@ -42,13 +43,16 @@ import java.util.NoSuchElementException; import java.util.Optional; +import static com.google.common.collect.ImmutableList.toImmutableList; import static io.trino.connector.system.KillQueryProcedure.createKillQueryException; import static io.trino.connector.system.KillQueryProcedure.createPreemptQueryException; import static io.trino.security.AccessControlUtil.checkCanKillQueryOwnedBy; import static io.trino.security.AccessControlUtil.checkCanViewQueryOwnedBy; import static io.trino.security.AccessControlUtil.filterQueries; import static io.trino.server.security.ResourceSecurity.AccessType.AUTHENTICATED_USER; +import static io.trino.server.security.ResourceSecurity.AccessType.MANAGEMENT_READ; import static java.util.Objects.requireNonNull; +import static javax.ws.rs.core.MediaType.APPLICATION_JSON; /** * Manage queries scheduled on this node @@ -168,4 +172,28 @@ private Response failQuery(QueryId queryId, TrinoException queryException, HttpS return Response.status(Status.GONE).build(); } } + + // Get BasicQueryInfo of all pending and recently ended queries. + // Here recently is defined as ended on or after maxEndAgeSec (default value 0) seconds ago. + @ResourceSecurity(MANAGEMENT_READ) + @GET + @Path("all") + @Produces(APPLICATION_JSON) + public List getAllQueryInfos(@QueryParam("maxEndAgeSec") int maxEndAgeSec) + { + // If maxEndAgeSec is negative, return all queries cached. + // Otherwise, return queries not ended or ended within maxEndAgeSec seconds. + // Specifically if maxEndAgeSec is 0, return all queries not ended. + if (maxEndAgeSec < 0) { + return dispatchManager.getQueries(); + } + else { + long endCutoff = System.currentTimeMillis() - 1000L * maxEndAgeSec; + return dispatchManager.getQueries().stream() + .filter(v -> v.getQueryStats() == null + || v.getQueryStats().getEndTime() == null + || v.getQueryStats().getEndTime().getMillis() >= endCutoff) + .collect(toImmutableList()); + } + } } diff --git a/core/trino-main/src/main/java/io/trino/server/RemoteNodeStatus.java b/core/trino-main/src/main/java/io/trino/server/RemoteNodeStatus.java new file mode 100644 index 000000000000..c2cab4d529d2 --- /dev/null +++ b/core/trino-main/src/main/java/io/trino/server/RemoteNodeStatus.java @@ -0,0 +1,134 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.server; + +import com.google.common.util.concurrent.FutureCallback; +import com.google.common.util.concurrent.Futures; +import io.airlift.http.client.FullJsonResponseHandler.JsonResponse; +import io.airlift.http.client.HttpClient; +import io.airlift.http.client.HttpClient.HttpResponseFuture; +import io.airlift.http.client.Request; +import io.airlift.json.JsonCodec; +import io.airlift.log.Logger; +import io.airlift.units.Duration; + +import javax.annotation.Nullable; +import javax.annotation.concurrent.ThreadSafe; + +import java.net.URI; +import java.util.Optional; +import java.util.concurrent.Future; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; + +import static com.google.common.net.MediaType.JSON_UTF_8; +import static com.google.common.util.concurrent.MoreExecutors.directExecutor; +import static io.airlift.http.client.FullJsonResponseHandler.createFullJsonResponseHandler; +import static io.airlift.http.client.HttpStatus.OK; +import static io.airlift.http.client.Request.Builder.prepareGet; +import static io.airlift.json.JsonCodec.jsonCodec; +import static io.airlift.units.Duration.nanosSince; +import static java.util.Objects.requireNonNull; +import static java.util.concurrent.TimeUnit.SECONDS; +import static javax.ws.rs.core.HttpHeaders.CONTENT_TYPE; + +@ThreadSafe +public class RemoteNodeStatus +{ + private static final Logger log = Logger.get(RemoteNodeStatus.class); + private static final JsonCodec NODE_STATUS_CODEC = jsonCodec(NodeStatus.class); + + private final HttpClient httpClient; + private final URI statusUri; + private final AtomicReference> nodeStatus = new AtomicReference<>(Optional.empty()); + private final AtomicReference> future = new AtomicReference<>(); + private final AtomicLong lastUpdateNanos = new AtomicLong(); + private final AtomicLong lastWarningLogged = new AtomicLong(); + private boolean lastUpdateSuccess = true; + // Last time in epoch the remote status was successful obtained + private long lastUpdateTime; + + public RemoteNodeStatus(HttpClient httpClient, URI statusUri) + { + this.httpClient = requireNonNull(httpClient, "httpClient is null"); + this.statusUri = requireNonNull(statusUri, "statusUri is null"); + } + + public Optional getNodeStatus() + { + return nodeStatus.get(); + } + + // Whether the latest refresh was successful. + public boolean isLastUpdateSuccess() + { + return lastUpdateSuccess; + } + + // Gets the last time in epoch the remote status was successfully obtained. + public long getLastUpdateTime() + { + return lastUpdateTime; + } + + public synchronized void asyncRefresh() + { + Duration sinceUpdate = nanosSince(lastUpdateNanos.get()); + if (nanosSince(lastWarningLogged.get()).toMillis() > 1_000 && + sinceUpdate.toMillis() > 10_000 && + future.get() != null) { + log.warn("NodeStatus request to %s has not returned in %s", + statusUri, sinceUpdate.toString(SECONDS)); + lastWarningLogged.set(System.nanoTime()); + } + if (sinceUpdate.toMillis() > 5_000 && future.get() == null) { + Request request = prepareGet() + .setUri(statusUri) + .setHeader(CONTENT_TYPE, JSON_UTF_8.toString()) + .build(); + HttpResponseFuture> responseFuture = httpClient.executeAsync( + request, createFullJsonResponseHandler(NODE_STATUS_CODEC)); + future.compareAndSet(null, responseFuture); + + Futures.addCallback(responseFuture, new FutureCallback<>() + { + @Override + public void onSuccess(@Nullable JsonResponse result) + { + lastUpdateTime = System.currentTimeMillis(); + lastUpdateNanos.set(System.nanoTime()); + future.compareAndSet(responseFuture, null); + if (result != null) { + if (result.hasValue()) { + nodeStatus.set(Optional.ofNullable(result.getValue())); + } + if (result.getStatusCode() != OK.code()) { + log.warn("Error fetching node status from %s returned status %d", + statusUri, result.getStatusCode()); + return; + } + } + } + + @Override + public void onFailure(Throwable t) + { + log.warn("Error fetching node status from %s: %s", statusUri, t.getMessage()); + lastUpdateNanos.set(System.nanoTime()); + future.compareAndSet(responseFuture, null); + } + }, directExecutor()); + } + } +} diff --git a/core/trino-main/src/main/java/io/trino/server/Server.java b/core/trino-main/src/main/java/io/trino/server/Server.java index faa9fae6527f..5fee8a12845e 100644 --- a/core/trino-main/src/main/java/io/trino/server/Server.java +++ b/core/trino-main/src/main/java/io/trino/server/Server.java @@ -119,7 +119,7 @@ private void doStart(String trinoVersion) new CatalogManagerModule(), new TransactionManagerModule(), new ServerMainModule(trinoVersion), - new GracefulShutdownModule(), + new UpdateNodeStateModule(), new WarningCollectorModule()); modules.addAll(getAdditionalModules()); diff --git a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java index 525cf192bf4f..82a939570912 100644 --- a/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ServerInfoResource.java @@ -26,21 +26,18 @@ import javax.ws.rs.Path; import javax.ws.rs.Produces; import javax.ws.rs.WebApplicationException; -import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Optional; +import java.util.concurrent.atomic.AtomicBoolean; +import static com.google.common.base.Preconditions.checkState; import static io.airlift.units.Duration.nanosSince; -import static io.trino.metadata.NodeState.ACTIVE; -import static io.trino.metadata.NodeState.SHUTTING_DOWN; import static io.trino.server.security.ResourceSecurity.AccessType.MANAGEMENT_WRITE; import static io.trino.server.security.ResourceSecurity.AccessType.PUBLIC; -import static java.lang.String.format; import static java.util.Objects.requireNonNull; import static javax.ws.rs.core.MediaType.APPLICATION_JSON; import static javax.ws.rs.core.MediaType.TEXT_PLAIN; -import static javax.ws.rs.core.Response.Status.BAD_REQUEST; @Path("/v1/info") public class ServerInfoResource @@ -48,17 +45,20 @@ public class ServerInfoResource private final NodeVersion version; private final String environment; private final boolean coordinator; - private final GracefulShutdownHandler shutdownHandler; + private final UpdateNodeStateHandler updateNodeStateHandler; private final StartupStatus startupStatus; private final long startTime = System.nanoTime(); + private final AtomicBoolean startupComplete = new AtomicBoolean(); @Inject - public ServerInfoResource(NodeVersion nodeVersion, NodeInfo nodeInfo, ServerConfig serverConfig, GracefulShutdownHandler shutdownHandler, StartupStatus startupStatus) + public ServerInfoResource(NodeVersion nodeVersion, NodeInfo nodeInfo, + ServerConfig serverConfig, UpdateNodeStateHandler updateNodeStateHandler, + StartupStatus startupStatus) { this.version = requireNonNull(nodeVersion, "nodeVersion is null"); this.environment = nodeInfo.getEnvironment(); this.coordinator = serverConfig.isCoordinator(); - this.shutdownHandler = requireNonNull(shutdownHandler, "shutdownHandler is null"); + this.updateNodeStateHandler = requireNonNull(updateNodeStateHandler, "updateNodeStateHandler is null"); this.startupStatus = requireNonNull(startupStatus, "startupStatus is null"); } @@ -80,23 +80,7 @@ public Response updateState(NodeState state) throws WebApplicationException { requireNonNull(state, "state is null"); - switch (state) { - case SHUTTING_DOWN: - shutdownHandler.requestShutdown(); - return Response.ok().build(); - case ACTIVE: - case INACTIVE: - throw new WebApplicationException(Response - .status(BAD_REQUEST) - .type(MediaType.TEXT_PLAIN) - .entity(format("Invalid state transition to %s", state)) - .build()); - default: - return Response.status(BAD_REQUEST) - .type(TEXT_PLAIN) - .entity(format("Invalid state %s", state)) - .build(); - } + return updateNodeStateHandler.updateState(state); } @ResourceSecurity(PUBLIC) @@ -105,10 +89,7 @@ public Response updateState(NodeState state) @Produces(APPLICATION_JSON) public NodeState getServerState() { - if (shutdownHandler.isShutdownRequested()) { - return SHUTTING_DOWN; - } - return ACTIVE; + return updateNodeStateHandler.getServerState(); } @ResourceSecurity(PUBLIC) @@ -123,4 +104,9 @@ public Response getServerCoordinator() // return 404 to allow load balancers to only send traffic to the coordinator return Response.status(Response.Status.NOT_FOUND).build(); } + + public void startupComplete() + { + checkState(startupComplete.compareAndSet(false, true), "Server startup already marked as complete"); + } } diff --git a/core/trino-main/src/main/java/io/trino/server/StatusResource.java b/core/trino-main/src/main/java/io/trino/server/StatusResource.java index ea2d39d0b1aa..6cf70877f212 100644 --- a/core/trino-main/src/main/java/io/trino/server/StatusResource.java +++ b/core/trino-main/src/main/java/io/trino/server/StatusResource.java @@ -45,6 +45,7 @@ public class StatusResource private final int logicalCores; private final LocalMemoryManager memoryManager; private final MemoryMXBean memoryMXBean; + private final long startTimeEpoch = System.currentTimeMillis(); private OperatingSystemMXBean operatingSystemMXBean; @@ -92,6 +93,7 @@ public NodeStatus getStatus() operatingSystemMXBean == null ? 0 : operatingSystemMXBean.getSystemCpuLoad(), memoryMXBean.getHeapMemoryUsage().getUsed(), memoryMXBean.getHeapMemoryUsage().getMax(), - memoryMXBean.getNonHeapMemoryUsage().getUsed()); + memoryMXBean.getNonHeapMemoryUsage().getUsed(), + startTimeEpoch); } } diff --git a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java b/core/trino-main/src/main/java/io/trino/server/UpdateNodeStateHandler.java similarity index 54% rename from core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java rename to core/trino-main/src/main/java/io/trino/server/UpdateNodeStateHandler.java index 1d23f9cb2b21..f64754d755e1 100644 --- a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownHandler.java +++ b/core/trino-main/src/main/java/io/trino/server/UpdateNodeStateHandler.java @@ -18,9 +18,13 @@ import io.airlift.units.Duration; import io.trino.execution.SqlTaskManager; import io.trino.execution.TaskInfo; +import io.trino.metadata.NodeState; import javax.annotation.concurrent.GuardedBy; import javax.inject.Inject; +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.MediaType; +import javax.ws.rs.core.Response; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -33,16 +37,23 @@ import static com.google.common.collect.ImmutableList.toImmutableList; import static com.google.common.util.concurrent.Uninterruptibles.sleepUninterruptibly; import static io.airlift.concurrent.Threads.threadsNamed; +import static io.trino.metadata.NodeState.ACTIVE; +import static io.trino.metadata.NodeState.DECOMMISSIONED; +import static io.trino.metadata.NodeState.DECOMMISSIONING; +import static io.trino.metadata.NodeState.SHUTTING_DOWN; +import static java.lang.String.format; import static java.lang.Thread.currentThread; import static java.util.Objects.requireNonNull; import static java.util.concurrent.Executors.newSingleThreadExecutor; import static java.util.concurrent.Executors.newSingleThreadScheduledExecutor; import static java.util.concurrent.TimeUnit.MILLISECONDS; import static java.util.concurrent.TimeUnit.SECONDS; +import static javax.ws.rs.core.MediaType.TEXT_PLAIN; +import static javax.ws.rs.core.Response.Status.BAD_REQUEST; -public class GracefulShutdownHandler +public class UpdateNodeStateHandler { - private static final Logger log = Logger.get(GracefulShutdownHandler.class); + private static final Logger log = Logger.get(UpdateNodeStateHandler.class); private static final Duration LIFECYCLE_STOP_TIMEOUT = new Duration(30, SECONDS); private final ScheduledExecutorService shutdownHandler = newSingleThreadScheduledExecutor(threadsNamed("shutdown-handler-%s")); @@ -52,12 +63,15 @@ public class GracefulShutdownHandler private final boolean isCoordinator; private final ShutdownAction shutdownAction; private final Duration gracePeriod; + private final ScheduledExecutorService executor = newSingleThreadScheduledExecutor( + threadsNamed("decommission-handler-%s")); + private NodeState currState = NodeState.ACTIVE; @GuardedBy("this") private boolean shutdownRequested; @Inject - public GracefulShutdownHandler( + public UpdateNodeStateHandler( SqlTaskManager sqlTaskManager, ServerConfig serverConfig, ShutdownAction shutdownAction, @@ -70,6 +84,62 @@ public GracefulShutdownHandler( this.gracePeriod = serverConfig.getGracePeriod(); } + public NodeState getServerState() + { + return currState; + } + + public synchronized Response updateState(NodeState state) + throws WebApplicationException + { + requireNonNull(state, "state is null"); + log.info(String.format("Entre updateState %s -> %s", currState, state)); + + // Supported state transitions: + // 1. ? -> ? + // 2. * -> SHUTTING_DOWN + // 3. ACTIVE -> DECOMMISSIONING + // 4. DECOMMISSIONING, DECOMMISSIONED -> ACTIVE + + if (currState == state || (state == DECOMMISSIONING && currState == DECOMMISSIONED)) { + return Response.ok().build(); + } + + // Prefer using a switch instead of a chained if-else for enums + switch (state) { + case SHUTTING_DOWN: + requestShutdown(); + currState = SHUTTING_DOWN; + return Response.ok().build(); + case DECOMMISSIONING: + if (currState == ACTIVE) { + requestDecommission(); + currState = DECOMMISSIONING; + return Response.ok().build(); + } + break; + case ACTIVE: + if (currState == DECOMMISSIONING || currState == DECOMMISSIONED) { + currState = ACTIVE; + return Response.ok().build(); + } + break; + case INACTIVE: + break; + default: + return Response.status(BAD_REQUEST).type(TEXT_PLAIN) + .entity(format("Invalid state %s", state)) + .build(); + } + + // Bad request once here. + throw new WebApplicationException(Response + .status(BAD_REQUEST) + .type(MediaType.TEXT_PLAIN) + .entity(format("Invalid state transition from %s to %s", currState, state)) + .build()); + } + public synchronized void requestShutdown() { log.info("Shutdown requested"); @@ -89,33 +159,7 @@ public synchronized void requestShutdown() private void shutdown() { - List activeTasks = getActiveTasks(); - - // At this point no new tasks should be scheduled by coordinator on this worker node. - // Wait for all remaining tasks to finish. - while (activeTasks.size() > 0) { - CountDownLatch countDownLatch = new CountDownLatch(activeTasks.size()); - - for (TaskInfo taskInfo : activeTasks) { - sqlTaskManager.addStateChangeListener(taskInfo.getTaskStatus().getTaskId(), newState -> { - if (newState.isDone()) { - countDownLatch.countDown(); - } - }); - } - - log.info("Waiting for all tasks to finish"); - - try { - countDownLatch.await(); - } - catch (InterruptedException e) { - log.warn("Interrupted while waiting for all tasks to finish"); - currentThread().interrupt(); - } - - activeTasks = getActiveTasks(); - } + waitActiveTasksToFinish(sqlTaskManager); // wait for another grace period for all task states to be observed by the coordinator sleepUninterruptibly(gracePeriod.toMillis(), MILLISECONDS); @@ -143,7 +187,37 @@ private void shutdown() shutdownAction.onShutdown(); } - private List getActiveTasks() + static void waitActiveTasksToFinish(SqlTaskManager sqlTaskManager) + { + // At this point no new tasks should be scheduled by coordinator on this worker node. + // Wait for all remaining tasks to finish. + while (true) { + List activeTasks = getActiveTasks(sqlTaskManager); + log.info("Waiting for " + activeTasks.size() + " active tasks to finish"); + if (activeTasks.isEmpty()) { + break; + } + CountDownLatch countDownLatch = new CountDownLatch(activeTasks.size()); + + for (TaskInfo taskInfo : activeTasks) { + sqlTaskManager.addStateChangeListener(taskInfo.getTaskStatus().getTaskId(), newState -> { + if (newState.isDone()) { + countDownLatch.countDown(); + } + }); + } + + try { + countDownLatch.await(); + } + catch (InterruptedException e) { + log.warn("Interrupted while waiting for all tasks to finish"); + currentThread().interrupt(); + } + } + } + + private static List getActiveTasks(SqlTaskManager sqlTaskManager) { return sqlTaskManager.getAllTaskInfo() .stream() @@ -151,8 +225,35 @@ private List getActiveTasks() .collect(toImmutableList()); } - public synchronized boolean isShutdownRequested() + public synchronized void requestDecommission() { - return shutdownRequested; + log.info("enter requestDecommission " + getServerState()); + if (isCoordinator) { + throw new UnsupportedOperationException("Cannot decommission coordinator"); + } + + // The decommission is normally initiated by the coordinator. + // Here we wait a short grace period of 10 seconds for coordinator to no longer + // assign new tasks to this worker node, before wait active tasks to finish. + executor.schedule(new Runnable() { + @Override + public void run() + { + waitActiveTasksToFinish(sqlTaskManager); + log.info("complete waitActiveTasksToFinish " + getServerState()); + NodeState state = onDecommissioned(); + log.info("onDecommissioned " + state); + } + }, 10000, MILLISECONDS); + } + + // callback used by decommissionHandler + NodeState onDecommissioned() + { + log.info("onDecommissioned " + (currState == null ? "null" : currState)); + if (currState == NodeState.DECOMMISSIONING) { + currState = NodeState.DECOMMISSIONED; + } + return currState; } } diff --git a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java b/core/trino-main/src/main/java/io/trino/server/UpdateNodeStateModule.java similarity index 89% rename from core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java rename to core/trino-main/src/main/java/io/trino/server/UpdateNodeStateModule.java index cdb83a90e1f7..8e924c8c2a8d 100644 --- a/core/trino-main/src/main/java/io/trino/server/GracefulShutdownModule.java +++ b/core/trino-main/src/main/java/io/trino/server/UpdateNodeStateModule.java @@ -17,13 +17,13 @@ import com.google.inject.Scopes; import io.airlift.configuration.AbstractConfigurationAwareModule; -public class GracefulShutdownModule +public class UpdateNodeStateModule extends AbstractConfigurationAwareModule { @Override protected void setup(Binder binder) { binder.bind(ShutdownAction.class).to(DefaultShutdownAction.class).in(Scopes.SINGLETON); - binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); + binder.bind(UpdateNodeStateHandler.class).in(Scopes.SINGLETON); } } diff --git a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java index f33bf471986d..42c25a181a64 100644 --- a/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java +++ b/core/trino-main/src/main/java/io/trino/server/testing/TestingTrinoServer.java @@ -65,12 +65,12 @@ import io.trino.security.AccessControlConfig; import io.trino.security.AccessControlManager; import io.trino.security.GroupProviderManager; -import io.trino.server.GracefulShutdownHandler; import io.trino.server.PluginManager; import io.trino.server.Server; import io.trino.server.ServerMainModule; import io.trino.server.SessionPropertyDefaults; import io.trino.server.ShutdownAction; +import io.trino.server.UpdateNodeStateHandler; import io.trino.server.security.CertificateAuthenticatorManager; import io.trino.server.security.ServerSecurityModule; import io.trino.spi.ErrorType; @@ -168,7 +168,7 @@ public static Builder builder() private final DispatchManager dispatchManager; private final SqlQueryManager queryManager; private final SqlTaskManager taskManager; - private final GracefulShutdownHandler gracefulShutdownHandler; + private final UpdateNodeStateHandler updateNodeStateHandler; private final ShutdownAction shutdownAction; private final MBeanServer mBeanServer; private final boolean coordinator; @@ -273,7 +273,7 @@ private TestingTrinoServer( binder.bind(GroupProvider.class).to(TestingGroupProvider.class).in(Scopes.SINGLETON); binder.bind(AccessControl.class).to(AccessControlManager.class).in(Scopes.SINGLETON); binder.bind(ShutdownAction.class).to(TestShutdownAction.class).in(Scopes.SINGLETON); - binder.bind(GracefulShutdownHandler.class).in(Scopes.SINGLETON); + binder.bind(UpdateNodeStateHandler.class).in(Scopes.SINGLETON); binder.bind(ProcedureTester.class).in(Scopes.SINGLETON); binder.bind(ExchangeManagerRegistry.class).in(Scopes.SINGLETON); }); @@ -352,7 +352,7 @@ private TestingTrinoServer( localMemoryManager = injector.getInstance(LocalMemoryManager.class); nodeManager = injector.getInstance(InternalNodeManager.class); serviceSelectorManager = injector.getInstance(ServiceSelectorManager.class); - gracefulShutdownHandler = injector.getInstance(GracefulShutdownHandler.class); + updateNodeStateHandler = injector.getInstance(UpdateNodeStateHandler.class); taskManager = injector.getInstance(SqlTaskManager.class); shutdownAction = injector.getInstance(ShutdownAction.class); mBeanServer = injector.getInstance(MBeanServer.class); @@ -589,9 +589,9 @@ public MBeanServer getMbeanServer() return mBeanServer; } - public GracefulShutdownHandler getGracefulShutdownHandler() + public UpdateNodeStateHandler getUpdateNodeStateHandler() { - return gracefulShutdownHandler; + return updateNodeStateHandler; } public SqlTaskManager getTaskManager() diff --git a/core/trino-main/src/main/java/io/trino/server/ui/ClusterStatsResource.java b/core/trino-main/src/main/java/io/trino/server/ui/ClusterStatsResource.java index 83b0c2434b75..eb65224055c0 100644 --- a/core/trino-main/src/main/java/io/trino/server/ui/ClusterStatsResource.java +++ b/core/trino-main/src/main/java/io/trino/server/ui/ClusterStatsResource.java @@ -63,6 +63,8 @@ public ClusterStats getClusterStats() long activeNodes = nodeManager.getNodes(NodeState.ACTIVE).stream() .filter(node -> isIncludeCoordinator || !node.isCoordinator()) .count(); + long decommissioningNodes = nodeManager.getNodes(NodeState.DECOMMISSIONING).stream().count(); + long decommissionedNodes = nodeManager.getNodes(NodeState.DECOMMISSIONED).stream().count(); long activeCoordinators = nodeManager.getNodes(NodeState.ACTIVE).stream() .filter(InternalNode::isCoordinator) @@ -105,6 +107,8 @@ else if (query.getState() == QueryState.RUNNING) { queuedQueries, activeCoordinators, activeNodes, + decommissioningNodes, + decommissionedNodes, runningDrivers, totalAvailableProcessors, memoryReservation, @@ -121,6 +125,8 @@ public static class ClusterStats private final long activeCoordinators; private final long activeWorkers; + private final long decommissioningWorkers; + private final long decommissionedWorkers; private final long runningDrivers; private final long totalAvailableProcessors; @@ -138,6 +144,8 @@ public ClusterStats( @JsonProperty("queuedQueries") long queuedQueries, @JsonProperty("activeCoordinators") long activeCoordinators, @JsonProperty("activeWorkers") long activeWorkers, + @JsonProperty("decommissioningWorkers") long decommissioningWorkers, + @JsonProperty("decommissionedWorkers") long decommissionedWorkers, @JsonProperty("runningDrivers") long runningDrivers, @JsonProperty("totalAvailableProcessors") long totalAvailableProcessors, @JsonProperty("reservedMemory") double reservedMemory, @@ -150,6 +158,8 @@ public ClusterStats( this.queuedQueries = queuedQueries; this.activeCoordinators = activeCoordinators; this.activeWorkers = activeWorkers; + this.decommissioningWorkers = decommissioningWorkers; + this.decommissionedWorkers = decommissionedWorkers; this.runningDrivers = runningDrivers; this.totalAvailableProcessors = totalAvailableProcessors; this.reservedMemory = reservedMemory; @@ -188,6 +198,18 @@ public long getActiveWorkers() return activeWorkers; } + @JsonProperty + public long getDecommissioningWorkers() + { + return decommissioningWorkers; + } + + @JsonProperty + public long getDecommissionedWorkers() + { + return decommissionedWorkers; + } + @JsonProperty public long getRunningDrivers() { diff --git a/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java b/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java index 9a949f3e7c8e..3867ada07f5e 100644 --- a/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java +++ b/testing/trino-tests/src/test/java/io/trino/tests/TestGracefulShutdown.java @@ -98,7 +98,7 @@ public void testShutdown() MILLISECONDS.sleep(500); } - worker.getGracefulShutdownHandler().requestShutdown(); + worker.getUpdateNodeStateHandler().requestShutdown(); Futures.allAsList(queryFutures).get(); @@ -124,7 +124,7 @@ public void testCoordinatorShutdown() .filter(TestingTrinoServer::isCoordinator) .collect(onlyElement()); - assertThatThrownBy(coordinator.getGracefulShutdownHandler()::requestShutdown) + assertThatThrownBy(coordinator.getUpdateNodeStateHandler()::requestShutdown) .isInstanceOf(UnsupportedOperationException.class) .hasMessage("Cannot shutdown coordinator"); }