diff --git a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RemoteTableNameCacheKey.java b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RemoteTableNameCacheKey.java index 9676255df20f..e8f298be158c 100644 --- a/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RemoteTableNameCacheKey.java +++ b/plugin/trino-base-jdbc/src/main/java/io/trino/plugin/jdbc/RemoteTableNameCacheKey.java @@ -18,12 +18,12 @@ import static com.google.common.base.MoreObjects.toStringHelper; import static java.util.Objects.requireNonNull; -final class RemoteTableNameCacheKey +public final class RemoteTableNameCacheKey { private final JdbcIdentity identity; private final String schema; - RemoteTableNameCacheKey(JdbcIdentity identity, String schema) + public RemoteTableNameCacheKey(JdbcIdentity identity, String schema) { this.identity = requireNonNull(identity, "identity is null"); this.schema = requireNonNull(schema, "schema is null"); diff --git a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java index 6f8f4f5ad90a..b32585f26cd9 100644 --- a/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java +++ b/plugin/trino-druid/src/main/java/io/trino/plugin/druid/DruidJdbcClient.java @@ -13,12 +13,14 @@ */ package io.trino.plugin.druid; +import com.google.common.base.CharMatcher; import com.google.common.collect.ImmutableList; import io.trino.plugin.jdbc.BaseJdbcClient; import io.trino.plugin.jdbc.BaseJdbcConfig; import io.trino.plugin.jdbc.ColumnMapping; import io.trino.plugin.jdbc.ConnectionFactory; import io.trino.plugin.jdbc.JdbcColumnHandle; +import io.trino.plugin.jdbc.JdbcIdentity; import io.trino.plugin.jdbc.JdbcNamedRelationHandle; import io.trino.plugin.jdbc.JdbcOutputTableHandle; import io.trino.plugin.jdbc.JdbcSplit; @@ -26,6 +28,7 @@ import io.trino.plugin.jdbc.JdbcTypeHandle; import io.trino.plugin.jdbc.PreparedQuery; import io.trino.plugin.jdbc.RemoteTableName; +import io.trino.plugin.jdbc.RemoteTableNameCacheKey; import io.trino.plugin.jdbc.WriteFunction; import io.trino.plugin.jdbc.WriteMapping; import io.trino.spi.TrinoException; @@ -51,12 +54,15 @@ import java.util.function.BiFunction; import java.util.stream.Collectors; +import static com.google.common.base.MoreObjects.firstNonNull; import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Verify.verify; import static com.google.common.collect.Iterables.getOnlyElement; import static io.trino.plugin.jdbc.JdbcErrorCode.JDBC_ERROR; import static io.trino.plugin.jdbc.StandardColumnMappings.defaultVarcharColumnMapping; import static io.trino.plugin.jdbc.StandardColumnMappings.varcharColumnMapping; import static io.trino.spi.type.VarcharType.createUnboundedVarcharType; +import static java.util.Objects.requireNonNull; public class DruidJdbcClient extends BaseJdbcClient @@ -85,9 +91,10 @@ protected Collection listSchemas(Connection connection) public Optional getTableHandle(ConnectorSession session, SchemaTableName schemaTableName) { try (Connection connection = connectionFactory.openConnection(session)) { - String jdbcSchemaName = schemaTableName.getSchemaName(); - String jdbcTableName = schemaTableName.getTableName(); - try (ResultSet resultSet = getTables(connection, Optional.of(jdbcSchemaName), Optional.of(jdbcTableName))) { + JdbcIdentity identity = JdbcIdentity.from(session); + String remoteSchema = toRemoteSchemaName(identity, connection, schemaTableName.getSchemaName()); + String remoteTable = toRemoteTableName(identity, connection, remoteSchema, schemaTableName.getTableName()); + try (ResultSet resultSet = getTables(connection, Optional.of(remoteSchema), Optional.of(remoteTable))) { List tableHandles = new ArrayList<>(); while (resultSet.next()) { tableHandles.add(new JdbcTableHandle( @@ -99,14 +106,15 @@ public Optional getTableHandle(ConnectorSession session, Schema if (tableHandles.isEmpty()) { return Optional.empty(); } + return Optional.of( getOnlyElement( tableHandles .stream() .filter( jdbcTableHandle -> - Objects.equals(jdbcTableHandle.getSchemaName(), schemaTableName.getSchemaName()) - && Objects.equals(jdbcTableHandle.getTableName(), schemaTableName.getTableName())) + Objects.equals(jdbcTableHandle.getSchemaName(), remoteSchema) + && Objects.equals(jdbcTableHandle.getTableName(), remoteTable)) .collect(Collectors.toList()))); } } @@ -136,6 +144,68 @@ protected ResultSet getTables(Connection connection, Optional schemaName null); } + @Override + protected String toRemoteSchemaName(JdbcIdentity identity, Connection connection, String schemaName) + { + requireNonNull(schemaName, "schemaName is null"); + verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(schemaName), "Expected schema name from internal metadata to be lowercase: %s", schemaName); + + if (caseInsensitiveNameMatching) { + try { + Map mapping = remoteSchemaNames.getIfPresent(identity); + if (mapping != null && !mapping.containsKey(schemaName)) { + // This might be a schema that has just been created. Force reload. + mapping = null; + } + if (mapping == null) { + mapping = listSchemasByLowerCase(connection); + remoteSchemaNames.put(identity, mapping); + } + String remoteSchema = mapping.get(schemaName); + if (remoteSchema != null) { + return remoteSchema; + } + } + catch (RuntimeException e) { + throw new TrinoException(JDBC_ERROR, "Failed to find remote schema name: " + firstNonNull(e.getMessage(), e), e); + } + } + + return schemaName; + } + + @Override + protected String toRemoteTableName(JdbcIdentity identity, Connection connection, String remoteSchema, String tableName) + { + requireNonNull(remoteSchema, "remoteSchema is null"); + requireNonNull(tableName, "tableName is null"); + verify(CharMatcher.forPredicate(Character::isUpperCase).matchesNoneOf(tableName), "Expected table name from internal metadata to be lowercase: %s", tableName); + + if (caseInsensitiveNameMatching) { + try { + RemoteTableNameCacheKey cacheKey = new RemoteTableNameCacheKey(identity, remoteSchema); + Map mapping = remoteTableNames.getIfPresent(cacheKey); + if (mapping != null && !mapping.containsKey(tableName)) { + // This might be a table that has just been created. Force reload. + mapping = null; + } + if (mapping == null) { + mapping = listTablesByLowerCase(connection, remoteSchema); + remoteTableNames.put(cacheKey, mapping); + } + String remoteTable = mapping.get(tableName); + if (remoteTable != null) { + return remoteTable; + } + } + catch (RuntimeException e) { + throw new TrinoException(JDBC_ERROR, "Failed to find remote table name: " + firstNonNull(e.getMessage(), e), e); + } + } + + return tableName; + } + @Override public Optional toColumnMapping(ConnectorSession session, Connection connection, JdbcTypeHandle typeHandle) { diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/DruidQueryRunner.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/DruidQueryRunner.java index 961538d8f4fe..58ce4722d996 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/DruidQueryRunner.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/DruidQueryRunner.java @@ -39,7 +39,10 @@ public class DruidQueryRunner { private DruidQueryRunner() {} - public static DistributedQueryRunner createDruidQueryRunnerTpch(TestingDruidServer testingDruidServer, Map extraProperties) + public static DistributedQueryRunner createDruidQueryRunnerTpch( + TestingDruidServer testingDruidServer, + Map extraProperties, + Map connectorProperties) throws Exception { DistributedQueryRunner queryRunner = null; @@ -50,7 +53,7 @@ public static DistributedQueryRunner createDruidQueryRunnerTpch(TestingDruidServ queryRunner.installPlugin(new TpchPlugin()); queryRunner.createCatalog("tpch", "tpch"); - Map connectorProperties = new HashMap<>(); + connectorProperties = new HashMap<>(ImmutableMap.copyOf(connectorProperties)); connectorProperties.putIfAbsent("connection-url", testingDruidServer.getJdbcUrl()); queryRunner.installPlugin(new DruidJdbcPlugin()); queryRunner.createCatalog("druid", "druid", connectorProperties); @@ -62,6 +65,15 @@ public static DistributedQueryRunner createDruidQueryRunnerTpch(TestingDruidServ } } + public static void copyAndIngestTpchData(MaterializedResult rows, TestingDruidServer testingDruidServer, + String sourceDatasource, String targetDatasource) + throws IOException, InterruptedException + { + String tsvFileLocation = format("%s/%s.tsv", testingDruidServer.getHostWorkingDirectory(), targetDatasource); + writeDataAsTsv(rows, tsvFileLocation); + testingDruidServer.ingestData(targetDatasource, getIngestionSpecFileName(sourceDatasource), tsvFileLocation); + } + public static void copyAndIngestTpchData(MaterializedResult rows, TestingDruidServer testingDruidServer, String druidDatasource) throws IOException, InterruptedException { @@ -109,7 +121,7 @@ public static void main(String[] args) DistributedQueryRunner queryRunner = createDruidQueryRunnerTpch( new TestingDruidServer(), - ImmutableMap.of("http-server.http.port", "8080")); + ImmutableMap.of("http-server.http.port", "8080"), ImmutableMap.of()); Logger log = Logger.get(DruidQueryRunner.class); log.info("======== SERVER STARTED ========"); diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMatch.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMatch.java new file mode 100644 index 000000000000..15cb43c3e15d --- /dev/null +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidCaseInsensitiveMatch.java @@ -0,0 +1,90 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.trino.plugin.druid; + +import com.google.common.collect.ImmutableMap; +import io.trino.testing.AbstractTestQueryFramework; +import io.trino.testing.DistributedQueryRunner; +import io.trino.testing.MaterializedResult; +import io.trino.testing.QueryRunner; +import io.trino.testing.assertions.Assert; +import org.testng.annotations.Test; + +import java.io.IOException; + +import static io.trino.plugin.druid.BaseDruidIntegrationSmokeTest.SELECT_FROM_ORDERS; +import static io.trino.plugin.druid.BaseDruidIntegrationSmokeTest.SELECT_FROM_REGION; +import static io.trino.plugin.druid.DruidQueryRunner.copyAndIngestTpchData; +import static io.trino.spi.type.VarcharType.VARCHAR; +import static org.assertj.core.api.Assertions.assertThat; + +@Test(singleThreaded = true) +public class TestDruidCaseInsensitiveMatch + extends AbstractTestQueryFramework +{ + private TestingDruidServer druidServer; + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + druidServer = new TestingDruidServer(); + closeAfterClass(() -> { + druidServer.close(); + druidServer = null; + }); + DistributedQueryRunner queryRunner = DruidQueryRunner.createDruidQueryRunnerTpch( + druidServer, ImmutableMap.of(), ImmutableMap.of("case-insensitive-name-matching", "true")); + copyAndIngestTpchData(queryRunner.execute(SELECT_FROM_ORDERS + " LIMIT 10"), this.druidServer, "orders", "CamelCase"); + return queryRunner; + } + + @Test + public void testNonLowerCaseTableName() + { + MaterializedResult expectedColumns = MaterializedResult.resultBuilder(getQueryRunner().getDefaultSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("__time", "timestamp(3)", "", "") + .row("clerk", "varchar", "", "") // String columns are reported only as varchar + .row("comment", "varchar", "", "") + .row("custkey", "bigint", "", "") // Long columns are reported as bigint + .row("orderdate", "varchar", "", "") + .row("orderkey", "bigint", "", "") + .row("orderpriority", "varchar", "", "") + .row("orderstatus", "varchar", "", "") + .row("shippriority", "bigint", "", "") // Druid doesn't support int type + .row("totalprice", "double", "", "") + .build(); + MaterializedResult actualColumns = computeActual("DESCRIBE " + "CamelCase"); + Assert.assertEquals(actualColumns, expectedColumns); + MaterializedResult materializedRows = computeActual("SELECT * FROM druid.druid.CAMELCASE"); + Assert.assertEquals(materializedRows.getRowCount(), 10); + MaterializedResult materializedRows1 = computeActual("SELECT * FROM druid.CamelCase"); + MaterializedResult materializedRows2 = computeActual("SELECT * FROM druid.camelcase"); + assertThat(materializedRows.equals(materializedRows1)); + assertThat(materializedRows.equals(materializedRows2)); + } + + @Test + public void testTableNameClash() + throws IOException, InterruptedException + { + try { + //ingesting data with already existing table name in lowercase which should fail + copyAndIngestTpchData(getQueryRunner().execute(SELECT_FROM_REGION + " LIMIT 10"), this.druidServer, "region", "camelcase"); + } + catch (AssertionError e) { + Assert.assertEquals(e.getMessage(), "Datasource camelcase not loaded expected [true] but found [false]"); + } + } +} diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTest.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTest.java index 81c5a23d3c00..5986ca70e4e2 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTest.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTest.java @@ -32,7 +32,7 @@ protected QueryRunner createQueryRunner() throws Exception { this.druidServer = new TestingDruidServer(); - QueryRunner runner = DruidQueryRunner.createDruidQueryRunnerTpch(druidServer, ImmutableMap.of()); + QueryRunner runner = DruidQueryRunner.createDruidQueryRunnerTpch(druidServer, ImmutableMap.of(), ImmutableMap.of()); copyAndIngestTpchData(runner.execute(SELECT_FROM_ORDERS), this.druidServer, ORDERS.getTableName()); copyAndIngestTpchData(runner.execute(SELECT_FROM_LINEITEM), this.druidServer, LINE_ITEM.getTableName()); copyAndIngestTpchData(runner.execute(SELECT_FROM_NATION), this.druidServer, NATION.getTableName()); diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTestLatest.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTestLatest.java index 8459a01193d8..a721c02e352a 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTestLatest.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestDruidIntegrationSmokeTestLatest.java @@ -34,7 +34,7 @@ protected QueryRunner createQueryRunner() throws Exception { this.druidServer = new TestingDruidServer(LATEST_DRUID_DOCKER_IMAGE); - QueryRunner runner = DruidQueryRunner.createDruidQueryRunnerTpch(druidServer, ImmutableMap.of()); + QueryRunner runner = DruidQueryRunner.createDruidQueryRunnerTpch(druidServer, ImmutableMap.of(), ImmutableMap.of()); copyAndIngestTpchData(runner.execute(SELECT_FROM_ORDERS), this.druidServer, ORDERS.getTableName()); copyAndIngestTpchData(runner.execute(SELECT_FROM_LINEITEM), this.druidServer, LINE_ITEM.getTableName()); copyAndIngestTpchData(runner.execute(SELECT_FROM_NATION), this.druidServer, NATION.getTableName()); diff --git a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestingDruidServer.java b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestingDruidServer.java index 84df7f7b7028..7e6431f8c452 100644 --- a/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestingDruidServer.java +++ b/plugin/trino-druid/src/test/java/io/trino/plugin/druid/TestingDruidServer.java @@ -223,7 +223,7 @@ void ingestData(String datasource, String indexTaskFile, String dataFilePath) middleManager.withCopyFileToContainer(forHostPath(dataFilePath), getMiddleManagerContainerPathForDataFile(dataFilePath)); String indexTask = Resources.toString(getResource(indexTaskFile), Charset.defaultCharset()); - + indexTask = getReplacedIndexTask(datasource, indexTask); Request.Builder requestBuilder = new Request.Builder(); requestBuilder.addHeader("content-type", "application/json;charset=utf-8") .url("http://localhost:" + getCoordinatorOverlordPort() + "/druid/indexer/v1/task") @@ -234,6 +234,13 @@ void ingestData(String datasource, String indexTaskFile, String dataFilePath) } } + private String getReplacedIndexTask(String targetDataSource, String indexTask) + { + indexTask = indexTask.replaceAll("dataSource\":.*,", "dataSource\": \"" + targetDataSource + "\","); + indexTask = indexTask.replaceAll("filter\":.*", "filter\": \"" + targetDataSource + ".tsv\""); + return indexTask; + } + private boolean checkDatasourceAvailable(String datasource) throws IOException, InterruptedException {