Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ public class ArrowFlightConfig
private String flightClientSSLKey;
private boolean arrowFlightServerSslEnabled;
private Integer arrowFlightPort;
private boolean caseSensitiveNameMatchingEnabled;

public String getFlightServerName()
{
Expand Down Expand Up @@ -111,4 +112,18 @@ public ArrowFlightConfig setFlightClientSSLKey(String flightClientSSLKey)
this.flightClientSSLKey = flightClientSSLKey;
return this;
}

public boolean isCaseSensitiveNameMatching()
{
return caseSensitiveNameMatchingEnabled;
}

@Config("case-sensitive-name-matching")
@ConfigDescription("Enable case-sensitive matching of schema, table names across the connector. " +
"When disabled, names are matched case-insensitively using lowercase normalization.")
public ArrowFlightConfig setCaseSensitiveNameMatching(boolean caseSensitiveNameMatchingEnabled)
{
this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled;
return this;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -43,19 +43,22 @@

import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
import static com.google.common.base.Preconditions.checkArgument;
import static java.util.Locale.ROOT;
import static java.util.Objects.requireNonNull;

public class ArrowMetadata
implements ConnectorMetadata
{
private final BaseArrowFlightClientHandler clientHandler;
private final ArrowBlockBuilder arrowBlockBuilder;
private final ArrowFlightConfig arrowFlightConfig;

@Inject
public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder)
public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder, ArrowFlightConfig arrowFlightConfig)
{
this.clientHandler = requireNonNull(clientHandler, "clientHandler is null");
this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
this.arrowFlightConfig = requireNonNull(arrowFlightConfig, "arrowFlightConfig is null");
}

@Override
Expand Down Expand Up @@ -192,6 +195,12 @@ public Map<SchemaTableName, List<ColumnMetadata>> listTableColumns(ConnectorSess
return columns.build();
}

@Override
public String normalizeIdentifier(ConnectorSession session, String identifier)
{
return arrowFlightConfig.isCaseSensitiveNameMatching() ? identifier : identifier.toLowerCase(ROOT);
}

private Type getPrestoTypeFromArrowField(Field field)
{
return arrowBlockBuilder.getPrestoTypeFromArrowField(field);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ public static void main(String[] args)

RootAllocator allocator = new RootAllocator(Long.MAX_VALUE);
Location serverLocation = Location.forGrpcTls("localhost", 9443);
FlightServer.Builder serverBuilder = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator));
FlightServer.Builder serverBuilder = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator, false));

File serverCert = new File("src/test/resources/certs/server.crt");
File serverKey = new File("src/test/resources/certs/server.key");
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@

/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.facebook.plugin.arrow;

import com.facebook.airlift.log.Logger;
import com.facebook.plugin.arrow.testingServer.TestingArrowProducer;
import com.facebook.presto.testing.MaterializedResult;
import com.facebook.presto.testing.QueryRunner;
import com.facebook.presto.tests.AbstractTestQueryFramework;
import com.facebook.presto.tests.DistributedQueryRunner;
import com.google.common.collect.ImmutableMap;
import org.apache.arrow.flight.FlightServer;
import org.apache.arrow.flight.Location;
import org.apache.arrow.memory.RootAllocator;
import org.testng.annotations.AfterClass;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;

import java.io.File;
import java.util.Map;
import java.util.Optional;

import static com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin.ARROW_FLIGHT_CONNECTOR;
import static com.facebook.presto.common.type.IntegerType.INTEGER;
import static com.facebook.presto.common.type.VarcharType.createVarcharType;
import static com.facebook.presto.testing.MaterializedResult.resultBuilder;
import static com.facebook.presto.tests.QueryAssertions.assertContains;
import static org.testng.Assert.assertTrue;

@Test
public class TestArrowFlightIntegrationMixedCase
extends AbstractTestQueryFramework
{
private static final Logger logger = Logger.get(TestArrowFlightIntegrationMixedCase.class);
private static final String ARROW_FLIGHT_MIXED_CATALOG = "arrow_mixed_catalog";
private int serverPort;
private RootAllocator allocator;
private FlightServer server;
private DistributedQueryRunner arrowFlightQueryRunner;

@BeforeClass
public void setup()
throws Exception
{
arrowFlightQueryRunner = getDistributedQueryRunner();
arrowFlightQueryRunner.createCatalog(ARROW_FLIGHT_MIXED_CATALOG, ARROW_FLIGHT_CONNECTOR, getCatalogProperties());
File certChainFile = new File("src/test/resources/certs/server.crt");
File privateKeyFile = new File("src/test/resources/certs/server.key");

allocator = new RootAllocator(Long.MAX_VALUE);
Location location = Location.forGrpcTls("localhost", serverPort);
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, true))
.useTls(certChainFile, privateKeyFile)
.build();

server.start();
logger.info("Server listening on port %s", server.getPort());
}

private Map<String, String> getCatalogProperties()
{
ImmutableMap.Builder<String, String> catalogProperties = ImmutableMap.<String, String>builder()
.put("arrow-flight.server.port", String.valueOf(serverPort))
.put("arrow-flight.server", "localhost")
.put("arrow-flight.server-ssl-enabled", "true")
.put("arrow-flight.server-ssl-certificate", "src/test/resources/certs/server.crt")
.put("case-sensitive-name-matching", "true");
return catalogProperties.build();
}

@AfterClass(alwaysRun = true)
public void close()
throws InterruptedException
{
arrowFlightQueryRunner.close();
server.close();
allocator.close();
}

@Override
protected QueryRunner createQueryRunner()
throws Exception
{
serverPort = ArrowFlightQueryRunner.findUnusedPort();
return ArrowFlightQueryRunner.createQueryRunner(serverPort, ImmutableMap.of(), ImmutableMap.of(), Optional.empty(), Optional.empty());
}

@Test
public void testShowSchemas()
{
MaterializedResult actualRow = computeActual("SHOW schemas FROM arrow_mixed_catalog");
MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50))
.row("Tpch_Mx")
.row("tpch_mx")
.build();

assertContains(actualRow, expectedRow);
}

@Test
public void testShowTables()
{
MaterializedResult actualRow = computeActual("SHOW TABLES FROM arrow_mixed_catalog.tpch_mx");
MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50))
.row("MXTEST")
.row("mxtest")
.build();

assertContains(actualRow, expectedRow);
}

@Test
public void testShowColumns()
{
MaterializedResult actualRow = computeActual("SHOW columns FROM arrow_mixed_catalog.tpch_mx.mxtest");
MaterializedResult expectedRow = resultBuilder(getSession(), createVarcharType(50))
.row("ID", "integer", "", "", Long.valueOf(10), null, null)
.row("NAME", "varchar(50)", "", "", null, null, Long.valueOf(50))
.row("name", "varchar(50)", "", "", null, null, Long.valueOf(50))
.row("Address", "varchar(50)", "", "", null, null, Long.valueOf(50))
.build();

assertContains(actualRow, expectedRow);
}

@Test
public void testSelect()
{
MaterializedResult actualRow = computeActual("SELECT * from arrow_mixed_catalog.tpch_mx.mxtest");
MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createVarcharType(50), createVarcharType(50))
.row(1, "TOM", "test", "kochi")
.row(2, "MARY", "test", "kochi")
.build();
assertTrue(actualRow.equals(expectedRow));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ public void setup()

allocator = new RootAllocator(Long.MAX_VALUE);
Location location = Location.forGrpcTls("127.0.0.1", serverPort);
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator))
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false))
.useTls(certChainFile, privateKeyFile)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ private void setup()
allocator = new RootAllocator(Long.MAX_VALUE);

Location location = Location.forGrpcTls("localhost", serverPort);
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator))
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false))
.useTls(certChainFile, privateKeyFile)
.useMTlsClientVerification(caCertFile)
.build();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void setup()
arrowFlightQueryRunner = getDistributedQueryRunner();
allocator = new RootAllocator(Long.MAX_VALUE);
Location location = Location.forGrpcTls("localhost", serverPort);
FlightServer.Builder serverBuilder = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator));
FlightServer.Builder serverBuilder = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false));

File serverCert = new File("src/test/resources/certs/server.crt");
File serverKey = new File("src/test/resources/certs/server.key");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public void setup()

allocator = new RootAllocator(Long.MAX_VALUE);
Location location = Location.forGrpcTls("localhost", serverPort);
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator))
server = FlightServer.builder(allocator, location, new TestingArrowProducer(allocator, false))
.useTls(certChainFile, privateKeyFile)
.build();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,14 +42,15 @@
import java.util.concurrent.TimeUnit;

import static com.facebook.presto.common.Utils.checkArgument;
import static java.util.Locale.ENGLISH;
import static java.util.Locale.ROOT;
import static java.util.Objects.requireNonNull;

public class TestingArrowFlightClientHandler
extends BaseArrowFlightClientHandler
{
private final JsonCodec<TestingArrowFlightRequest> requestCodec;
private final JsonCodec<TestingArrowFlightResponse> responseCodec;
private boolean caseSensitiveNameMatchingEnabled;

@Inject
public TestingArrowFlightClientHandler(
Expand All @@ -61,6 +62,7 @@ public TestingArrowFlightClientHandler(
super(allocator, config);
this.requestCodec = requireNonNull(requestCodec, "requestCodec is null");
this.responseCodec = requireNonNull(responseCodec, "responseCodec is null");
this.caseSensitiveNameMatchingEnabled = config.isCaseSensitiveNameMatching();
}

@Override
Expand Down Expand Up @@ -102,7 +104,7 @@ public List<String> listSchemaNames(ConnectorSession session)
List<String> listSchemas = res;
List<String> names = new ArrayList<>();
for (String value : listSchemas) {
names.add(value.toLowerCase(ENGLISH));
names.add(normalizeIdentifier(value));
}
return ImmutableList.copyOf(names);
}
Expand Down Expand Up @@ -131,7 +133,7 @@ public List<SchemaTableName> listTables(ConnectorSession session, Optional<Strin
List<String> listTables = res;
List<SchemaTableName> tables = new ArrayList<>();
for (String value : listTables) {
tables.add(new SchemaTableName(schemaValue.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH)));
tables.add(new SchemaTableName(normalizeIdentifier(schemaValue), normalizeIdentifier(value)));
}

return tables;
Expand All @@ -149,4 +151,9 @@ public FlightDescriptor getFlightDescriptorForTableScan(ConnectorSession session
TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query);
return FlightDescriptor.command(requestCodec.toBytes(request));
}

private String normalizeIdentifier(String identifier)
{
return caseSensitiveNameMatchingEnabled ? identifier : identifier.toLowerCase(ROOT);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,15 +73,17 @@ public class TestingArrowProducer
private static final Logger logger = Logger.get(TestingArrowProducer.class);
private final JsonCodec<TestingArrowFlightRequest> requestCodec;
private final JsonCodec<TestingArrowFlightResponse> responseCodec;
private boolean caseSensitiveNameMatchingEnabled;

public TestingArrowProducer(BufferAllocator allocator) throws Exception
public TestingArrowProducer(BufferAllocator allocator, boolean caseSensitiveNameMatchingEnabled) throws Exception
{
this.allocator = allocator;
String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1";
TestingH2DatabaseSetup.setup(h2JdbcUrl);
this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", "");
this.requestCodec = jsonCodec(TestingArrowFlightRequest.class);
this.responseCodec = jsonCodec(TestingArrowFlightResponse.class);
this.caseSensitiveNameMatchingEnabled = caseSensitiveNameMatchingEnabled;
}

@Override
Expand All @@ -100,7 +102,7 @@ public void getStream(CallContext callContext, Ticket ticket, ServerStreamListen

logger.debug("Executing query: %s", query);

try (ResultSet resultSet = stmt.executeQuery(query.toUpperCase())) {
try (ResultSet resultSet = stmt.executeQuery(normalizeIdentifier(query))) {
JdbcToArrowConfig config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048)
.setCalendar(Calendar.getInstance(TimeZone.getDefault())).build();
Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), config);
Expand Down Expand Up @@ -158,8 +160,8 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight
List<Field> fields = new ArrayList<>();
if (tableName.isPresent()) {
String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " +
"WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " +
"AND TABLE_NAME='" + tableName.get().toUpperCase() + "'";
"WHERE TABLE_SCHEMA='" + normalizeIdentifier(schemaName) + "' " +
"AND TABLE_NAME='" + normalizeIdentifier(tableName.get()) + "'";

try (ResultSet rs = connection.createStatement().executeQuery(query)) {
while (rs.next()) {
Expand All @@ -182,7 +184,7 @@ public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flight
}
}
else if (selectStatement != null) {
selectStatement = selectStatement.toUpperCase();
selectStatement = normalizeIdentifier(selectStatement);
logger.debug("Executing SELECT query: %s", selectStatement);
try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) {
ResultSetMetaData metaData = rs.getMetaData();
Expand Down Expand Up @@ -232,7 +234,7 @@ public void doAction(CallContext callContext, Action action, StreamListener<Resu
query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA";
}
else {
query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName.get().toUpperCase() + "'";
query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + normalizeIdentifier(schemaName.get()) + "'";
}
ResultSet rs = connection.createStatement().executeQuery(query);
List<String> names = new ArrayList<>();
Expand Down Expand Up @@ -306,4 +308,9 @@ private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int s
throw new IllegalArgumentException("Unsupported SQL type: " + sqlType);
}
}

private String normalizeIdentifier(String identifier)
{
return caseSensitiveNameMatchingEnabled ? identifier : identifier.toUpperCase();
}
}
Loading
Loading