diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml
index c8dec648623d5..9d03b35ce8ba0 100644
--- a/.github/workflows/arrow-flight-tests.yml
+++ b/.github/workflows/arrow-flight-tests.yml
@@ -39,7 +39,9 @@ jobs:
matrix:
java: ['17']
modules:
- - :presto-base-arrow-flight # Only run tests for the `presto-base-arrow-flight` module
+ - :presto-common-arrow
+ - :presto-base-arrow-flight
+ - :presto-flight-shim
timeout-minutes: 80
concurrency:
diff --git a/pom.xml b/pom.xml
index ea19425f3a7ec..9bba335e86a7d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -241,6 +241,7 @@
presto-sql-helpers/presto-sql-invoked-functions-plugin
presto-sql-helpers/presto-native-sql-invoked-functions-plugin
presto-lance
+ presto-flight-shim
diff --git a/presto-common-arrow/pom.xml b/presto-common-arrow/pom.xml
index c2507c9f136d4..c7cec4d0aa299 100644
--- a/presto-common-arrow/pom.xml
+++ b/presto-common-arrow/pom.xml
@@ -17,6 +17,17 @@
+
+ org.apache.arrow
+ arrow-memory-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
org.apache.arrow
arrow-vector
@@ -56,10 +67,62 @@
jakarta.inject
jakarta.inject-api
+
+
+
+ com.facebook.presto
+ presto-main-base
+ test
+
+
+
+ com.facebook.airlift
+ log
+ test
+
+
+
+ joda-time
+ joda-time
+ test
+
+
+
+ org.jdbi
+ jdbi3-core
+ test
+
+
+
+ org.testng
+ testng
+ test
+
+
+
+ org.apache.arrow
+ arrow-memory-netty
+ test
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ org.apache.arrow:arrow-memory-core
+
+
+
org.basepom.maven
duplicate-finder-maven-plugin
@@ -84,4 +147,19 @@
+
+
+
+ java17
+
+ 17
+
+
+
+ -Xss10M
+ --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED
+
+
+
+
diff --git a/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java
new file mode 100644
index 0000000000000..8aec16ad66243
--- /dev/null
+++ b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.Page;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.facebook.presto.spi.ConnectorPageSource;
+import com.google.common.collect.ImmutableList;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.AllocationHelper;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+
+import java.io.Closeable;
+import java.io.IOException;
+import java.util.List;
+
+import static com.facebook.plugin.arrow.BlockArrowWriter.createArrowWriters;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowBatchSource
+ implements Closeable
+{
+ private final VectorSchemaRoot root;
+ private final List writers;
+ private final int maxRowsPerBatch;
+ private final ConnectorPageSource pageSource;
+
+ private Page currentPage;
+ private int currentPosition;
+
+ public ArrowBatchSource(BufferAllocator allocator, List columns, ConnectorPageSource pageSource, int maxRowsPerBatch)
+ {
+ this.pageSource = requireNonNull(pageSource, "pageSource is null");
+ this.maxRowsPerBatch = maxRowsPerBatch;
+ ImmutableList.Builder writerBuilder = ImmutableList.builder();
+ this.root = createArrowWriters(allocator, columns, writerBuilder);
+ this.writers = writerBuilder.build();
+ }
+
+ public VectorSchemaRoot getVectorSchemaRoot()
+ {
+ return root;
+ }
+
+ /**
+ * Loads the next record batch from the source.
+ * Returns false if there are no more batches from the source.
+ */
+ public boolean nextBatch()
+ {
+ // Release previous buffers
+ root.clear();
+
+ // Reserve capacity for next batch
+ allocateVectorCapacity(root, maxRowsPerBatch);
+
+ int batchRowIndex = 0;
+ while (batchRowIndex < maxRowsPerBatch) {
+ if (currentPage == null || currentPosition >= currentPage.getPositionCount()) {
+ if (pageSource.isFinished()) {
+ break;
+ }
+
+ currentPage = pageSource.getNextPage();
+ currentPosition = 0;
+
+ if (currentPage == null || currentPage.getPositionCount() == 0) {
+ continue;
+ }
+ }
+
+ for (int column = 0; column < writers.size(); column++) {
+ Block block = currentPage.getBlock(column);
+ BlockArrowWriter.ArrowVectorWriter writer = writers.get(column);
+
+ if (block.isNull(currentPosition)) {
+ writer.writeNull(batchRowIndex);
+ }
+ else {
+ writer.writeBlock(batchRowIndex, block, currentPosition);
+ }
+ }
+ currentPosition++;
+ batchRowIndex++;
+ }
+ root.setRowCount(batchRowIndex);
+ return batchRowIndex > 0;
+ }
+
+ @Override
+ public void close()
+ throws IOException
+ {
+ root.close();
+ pageSource.close();
+ }
+
+ private static void allocateVectorCapacity(VectorSchemaRoot root, int capacity)
+ {
+ for (ValueVector vector : root.getFieldVectors()) {
+ vector.setInitialCapacity(capacity);
+ AllocationHelper.allocateNew(vector, capacity);
+ }
+ }
+}
diff --git a/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java
new file mode 100644
index 0000000000000..ae49b4a1aeecc
--- /dev/null
+++ b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java
@@ -0,0 +1,893 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.IntArrayBlock;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.common.type.BigintType;
+import com.facebook.presto.common.type.BooleanType;
+import com.facebook.presto.common.type.CharType;
+import com.facebook.presto.common.type.DateType;
+import com.facebook.presto.common.type.DecimalType;
+import com.facebook.presto.common.type.DoubleType;
+import com.facebook.presto.common.type.IntegerType;
+import com.facebook.presto.common.type.MapType;
+import com.facebook.presto.common.type.RealType;
+import com.facebook.presto.common.type.RowType;
+import com.facebook.presto.common.type.SmallintType;
+import com.facebook.presto.common.type.TimeType;
+import com.facebook.presto.common.type.TimestampType;
+import com.facebook.presto.common.type.TimestampWithTimeZoneType;
+import com.facebook.presto.common.type.TinyintType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.VarbinaryType;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.google.common.collect.ImmutableList;
+import io.airlift.slice.Slice;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.BaseVariableWidthVector;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeMilliVector;
+import org.apache.arrow.vector.TimeStampVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.BaseRepeatedValueVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.math.MathContext;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR;
+import static com.facebook.presto.common.type.Decimals.decodeUnscaledValue;
+import static com.facebook.presto.common.type.StandardTypes.ARRAY;
+import static com.facebook.presto.common.type.StandardTypes.MAP;
+import static com.facebook.presto.common.type.StandardTypes.ROW;
+import static com.facebook.presto.common.type.Varchars.isVarcharType;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static com.google.common.collect.ImmutableMap.toImmutableMap;
+import static java.lang.Float.intBitsToFloat;
+import static java.lang.Math.toIntExact;
+import static java.lang.String.format;
+
+public class BlockArrowWriter
+{
+ private BlockArrowWriter()
+ {
+ }
+
+ public static Field prestoToArrowField(ColumnMetadata column)
+ {
+ Field field;
+ Map metadata = column.getProperties().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, Object::toString));
+ if (column.getType().getTypeSignature().getBase().equals(ARRAY)) {
+ ArrayType arrayType = (ArrayType) column.getType();
+ Field childField = prestoToArrowField(ColumnMetadata.builder().setName(BaseRepeatedValueVector.DATA_VECTOR_NAME).setType(arrayType.getElementType()).build());
+ field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.List.INSTANCE, null, metadata), ImmutableList.of(childField));
+ }
+ else if (column.getType().getTypeSignature().getBase().equals(MAP)) {
+ MapType mapType = (MapType) column.getType();
+ // NOTE: Arrow key type must be non-nullable
+ Field keyField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.KEY_NAME).setType(mapType.getKeyType()).setNullable(false).build());
+ Field valueField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.VALUE_NAME).setType(mapType.getValueType()).build());
+ Field entriesField = new Field(MapVector.DATA_VECTOR_NAME, FieldType.notNullable(ArrowType.Struct.INSTANCE), ImmutableList.of(keyField, valueField));
+ field = new Field(column.getName(), new FieldType(column.isNullable(), new ArrowType.Map(false), null, metadata), ImmutableList.of(entriesField));
+ }
+ else if (column.getType().getTypeSignature().getBase().equals(ROW)) {
+ RowType rowType = (RowType) column.getType();
+ List rowFields = rowType.getFields();
+
+ AtomicInteger childCount = new AtomicInteger();
+ List childFields = rowFields.stream().map(f -> prestoToArrowField(
+ ColumnMetadata.builder().setName(f.getName().orElse(format("$child%s$", childCount.incrementAndGet()))).setType(f.getType()).build()))
+ .collect(toImmutableList());
+
+ field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.Struct.INSTANCE, null, metadata), childFields);
+ }
+ else {
+ ArrowType arrowType = prestoToArrowType(column.getType());
+ field = new Field(column.getName(), new FieldType(column.isNullable(), arrowType, null, metadata), ImmutableList.of());
+ }
+
+ return field;
+ }
+
+ public static ArrowType prestoToArrowType(Type type)
+ {
+ if (type.equals(BooleanType.BOOLEAN)) {
+ return ArrowType.Bool.INSTANCE;
+ }
+ else if (type.equals(TinyintType.TINYINT)) {
+ return Types.MinorType.TINYINT.getType();
+ }
+ else if (type.equals(SmallintType.SMALLINT)) {
+ return Types.MinorType.SMALLINT.getType();
+ }
+ else if (type.equals(IntegerType.INTEGER)) {
+ return Types.MinorType.INT.getType();
+ }
+ else if (type.equals(BigintType.BIGINT)) {
+ return Types.MinorType.BIGINT.getType();
+ }
+ else if (type.equals(RealType.REAL)) {
+ return Types.MinorType.FLOAT4.getType();
+ }
+ else if (type.equals(DoubleType.DOUBLE)) {
+ return Types.MinorType.FLOAT8.getType();
+ }
+ else if (type instanceof DecimalType) {
+ DecimalType decimalType = (DecimalType) type;
+ return new ArrowType.Decimal(decimalType.getPrecision(), decimalType.getScale(), 128);
+ }
+ else if (type.equals(VarbinaryType.VARBINARY)) {
+ return ArrowType.Binary.INSTANCE;
+ }
+ else if (type instanceof CharType) {
+ return ArrowType.Utf8.INSTANCE;
+ }
+ else if (isVarcharType(type)) {
+ return ArrowType.Utf8.INSTANCE;
+ }
+ else if (type instanceof DateType) {
+ return Types.MinorType.DATEDAY.getType();
+ }
+ else if (type instanceof TimeType) {
+ return Types.MinorType.TIMEMILLI.getType();
+ }
+ else if (type instanceof TimestampType) {
+ return Types.MinorType.TIMESTAMPMILLI.getType();
+ }
+ else if (type instanceof TimestampWithTimeZoneType) {
+ // Read as plain timestamp, timezone not supplied with type
+ return Types.MinorType.TIMESTAMPMILLI.getType();
+ }
+ throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unsupported type: " + type);
+ }
+
+ public static VectorSchemaRoot createArrowWriters(BufferAllocator allocator, List columns, ImmutableList.Builder writerBuilder)
+ {
+ List fields = columns.stream().map(BlockArrowWriter::prestoToArrowField).collect(Collectors.toList());
+ Schema schema = new Schema(fields);
+ VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator);
+
+ final List vectors = root.getFieldVectors();
+ checkArgument(vectors.size() == columns.size(), "Unexpected list of vectors: %s", schema);
+ for (int i = 0; i < vectors.size(); i++) {
+ ColumnMetadata columnMetadata = columns.get(i);
+ writerBuilder.add(createArrowWriter(vectors.get(i), columnMetadata.getType()));
+ }
+
+ return root;
+ }
+
+ private static ArrowVectorWriter createArrowWriter(FieldVector vector, Type type)
+ {
+ Class> javaType = type.getJavaType();
+
+ switch (vector.getMinorType()) {
+ case BIT:
+ checkArgument(javaType == boolean.class, "Unexpected type for BitVector: %s", type);
+ return new ArrowBitWriter((BitVector) vector, new BlockPrimitiveGetter(type));
+ case TINYINT:
+ checkArgument(javaType == long.class, "Unexpected type for TinyIntVector: %s", type);
+ return new ArrowTinyIntWriter((TinyIntVector) vector, new BlockPrimitiveGetter(type));
+ case SMALLINT:
+ checkArgument(javaType == long.class, "Unexpected type for SmallIntVector: %s", type);
+ return new ArrowSmallIntWriter((SmallIntVector) vector, new BlockPrimitiveGetter(type));
+ case INT:
+ checkArgument(javaType == long.class, "Unexpected type for IntVector: %s", type);
+ return new ArrowIntWriter((IntVector) vector, new BlockPrimitiveGetter(type));
+ case BIGINT:
+ checkArgument(javaType == long.class, "Unexpected type for BigIntVector: %s", type);
+ return new ArrowLongWriter((BigIntVector) vector, new BlockPrimitiveGetter(type));
+ case FLOAT4:
+ checkArgument(javaType == long.class, "Unexpected type for Float4Vector: %s", type);
+ return new ArrowRealWriter((Float4Vector) vector, new BlockPrimitiveGetter(type));
+ case FLOAT8:
+ checkArgument(javaType == double.class, "Unexpected type for Float8Vector: %s", type);
+ return new ArrowDoubleWriter((Float8Vector) vector, new BlockPrimitiveGetter(type));
+ case DECIMAL:
+ checkArgument(type instanceof DecimalType, "Expected DecimalType but got %", type);
+ return new ArrowDecimalWriter((DecimalVector) vector, createDecimalBlockGetter((DecimalType) type, javaType));
+ case VARBINARY:
+ case VARCHAR:
+ checkArgument(javaType == Slice.class, "Unexpected type for BaseVariableWidthVector: %s", type);
+ return new ArrowVariableWidthWriter((BaseVariableWidthVector) vector, new BlockSliceGetter(type));
+ case DATEDAY:
+ checkArgument(javaType == long.class, "Unexpected type for DateDayVector: %s", type);
+ return new ArrowDateWriter((DateDayVector) vector, new BlockPrimitiveGetter(type));
+ case TIMEMILLI:
+ checkArgument(javaType == long.class, "Unexpected type for TimeMilliVector: %s", type);
+ return new ArrowTimeWriter((TimeMilliVector) vector, new BlockPrimitiveGetter(type));
+ case TIMESTAMPMILLI:
+ checkArgument(javaType == long.class, "Unexpected type for TimeStampVector: %s", type);
+ return new ArrowTimeStampWriter((TimeStampVector) vector, new BlockPrimitiveGetter(type));
+ case LIST:
+ checkArgument(type instanceof ArrayType, "Unexpected type for ListVector: %s", type);
+ return new ArrowListWriter((ListVector) vector, new BlockArrayGetter((ArrayType) type));
+ case MAP:
+ checkArgument(type instanceof MapType, "Unexpected type for MapVector: %s", type);
+ return new ArrowMapWriter((MapVector) vector, new BlockMapGetter((MapType) type));
+ case STRUCT:
+ checkArgument(type instanceof RowType, "Unexpected type for StructVector: %s", type);
+ return new ArrowStructWriter((StructVector) vector, new BlockRowGetter((RowType) type));
+ default:
+ throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unsupported Arrow type: " + vector.getMinorType().name());
+ }
+ }
+
+ private static BlockValueGetter createDecimalBlockGetter(DecimalType decimalType, Class> javaType)
+ {
+ if (javaType == long.class) {
+ return new BlockShortDecimalGetter(decimalType);
+ }
+ else if (javaType == Slice.class) {
+ return new BlockLongDecimalGetter(decimalType);
+ }
+ throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected type for DecimalVector: " + decimalType);
+ }
+
+ private abstract static class BlockValueGetter
+ {
+ public boolean getBoolean(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public int getInt(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public long getLong(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public float getFloat(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public double getDouble(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public BigDecimal getDecimal(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public byte[] getBytes(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+
+ public Block getChildBlock(Block block, int position)
+ {
+ throw new UnsupportedOperationException(getClass().getName());
+ }
+ }
+
+ private static class BlockPrimitiveGetter
+ extends BlockValueGetter
+ {
+ protected Type type;
+
+ public BlockPrimitiveGetter(Type type)
+ {
+ this.type = type;
+ }
+
+ @Override
+ public boolean getBoolean(Block block, int position)
+ {
+ return type.getBoolean(block, position);
+ }
+
+ @Override
+ public int getInt(Block block, int position)
+ {
+ return (int) getLong(block, position);
+ }
+
+ @Override
+ public long getLong(Block block, int position)
+ {
+ if (block instanceof IntArrayBlock) {
+ return block.toLong(position);
+ }
+
+ return type.getLong(block, position);
+ }
+
+ @Override
+ public float getFloat(Block block, int position)
+ {
+ long value = getLong(block, position);
+ return intBitsToFloat(toIntExact(value));
+ }
+
+ @Override
+ public double getDouble(Block block, int position)
+ {
+ return type.getDouble(block, position);
+ }
+ }
+
+ private static class BlockShortDecimalGetter
+ extends BlockValueGetter
+ {
+ protected DecimalType type;
+
+ public BlockShortDecimalGetter(DecimalType type)
+ {
+ this.type = type;
+ }
+
+ @Override
+ public BigDecimal getDecimal(Block block, int position)
+ {
+ long value = type.getLong(block, position);
+ BigInteger unscaledValue = BigInteger.valueOf(value);
+ return new BigDecimal(unscaledValue, type.getScale(), new MathContext(type.getPrecision()));
+ }
+ }
+
+ private static class BlockLongDecimalGetter
+ extends BlockValueGetter
+ {
+ protected DecimalType type;
+
+ public BlockLongDecimalGetter(DecimalType type)
+ {
+ this.type = type;
+ }
+
+ @Override
+ public BigDecimal getDecimal(Block block, int position)
+ {
+ Slice value = type.getSlice(block, position);
+ BigInteger unscaledValue = decodeUnscaledValue(value);
+ return new BigDecimal(unscaledValue, type.getScale(), new MathContext(type.getPrecision()));
+ }
+ }
+
+ private static class BlockSliceGetter
+ extends BlockValueGetter
+ {
+ protected Type type;
+
+ public BlockSliceGetter(Type type)
+ {
+ this.type = type;
+ }
+
+ @Override
+ public byte[] getBytes(Block block, int position)
+ {
+ Slice value = type.getSlice(block, position);
+ return value.getBytes(0, value.length());
+ }
+ }
+
+ private static class BlockArrayGetter
+ extends BlockValueGetter
+ {
+ protected ArrayType type;
+
+ public BlockArrayGetter(ArrayType type)
+ {
+ this.type = type;
+ }
+
+ public ArrayType getType()
+ {
+ return type;
+ }
+
+ @Override
+ public Block getChildBlock(Block block, int position)
+ {
+ return type.getObject(block, position);
+ }
+ }
+
+ private static class BlockMapGetter
+ extends BlockValueGetter
+ {
+ protected MapType type;
+
+ public BlockMapGetter(MapType type)
+ {
+ this.type = type;
+ }
+
+ public MapType getType()
+ {
+ return type;
+ }
+
+ @Override
+ public Block getChildBlock(Block block, int position)
+ {
+ return type.getObject(block, position);
+ }
+ }
+
+ private static class BlockRowGetter
+ extends BlockValueGetter
+ {
+ protected RowType type;
+
+ public BlockRowGetter(RowType type)
+ {
+ this.type = type;
+ }
+
+ public RowType getType()
+ {
+ return type;
+ }
+
+ @Override
+ public Block getChildBlock(Block block, int position)
+ {
+ return type.getObject(block, position);
+ }
+ }
+
+ /**
+ * Writers for specific Arrow vector types to set values from a Block.
+ * Fixed-width vectors are pre-allocated, but setSafe should be used
+ * for all other vectors.
+ */
+ public abstract static class ArrowVectorWriter
+ {
+ public abstract void writeNull(int index);
+
+ public abstract void writeBlock(int index, Block block, int position);
+ }
+
+ private static class ArrowBitWriter
+ extends ArrowVectorWriter
+ {
+ private final BitVector vector;
+ protected final BlockValueGetter blockGetter;
+
+ public ArrowBitWriter(BitVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getBoolean(block, position) ? 1 : 0);
+ }
+ }
+
+ private static class ArrowTinyIntWriter
+ extends ArrowVectorWriter
+ {
+ private final TinyIntVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowTinyIntWriter(TinyIntVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getInt(block, position));
+ }
+ }
+
+ private static class ArrowSmallIntWriter
+ extends ArrowVectorWriter
+ {
+ private final SmallIntVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowSmallIntWriter(SmallIntVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getInt(block, position));
+ }
+ }
+
+ private static class ArrowIntWriter
+ extends ArrowVectorWriter
+ {
+ private final IntVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowIntWriter(IntVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getInt(block, position));
+ }
+ }
+
+ private static class ArrowLongWriter
+ extends ArrowVectorWriter
+ {
+ private final BigIntVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowLongWriter(BigIntVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getLong(block, position));
+ }
+ }
+
+ private static class ArrowRealWriter
+ extends ArrowVectorWriter
+ {
+ private final Float4Vector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowRealWriter(Float4Vector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getFloat(block, position));
+ }
+ }
+
+ private static class ArrowDoubleWriter
+ extends ArrowVectorWriter
+ {
+ private final Float8Vector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowDoubleWriter(Float8Vector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getDouble(block, position));
+ }
+ }
+
+ private static class ArrowDecimalWriter
+ extends ArrowVectorWriter
+ {
+ private final DecimalVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowDecimalWriter(DecimalVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getDecimal(block, position));
+ }
+ }
+
+ private static class ArrowVariableWidthWriter
+ extends ArrowVectorWriter
+ {
+ private final BaseVariableWidthVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowVariableWidthWriter(BaseVariableWidthVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.setSafe(index, blockGetter.getBytes(block, position));
+ }
+ }
+
+ private static class ArrowDateWriter
+ extends ArrowVectorWriter
+ {
+ private final DateDayVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowDateWriter(DateDayVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ // Value is supplied as days since epoch UTC
+ vector.set(index, blockGetter.getInt(block, position));
+ }
+ }
+
+ private static class ArrowTimeWriter
+ extends ArrowVectorWriter
+ {
+ private final TimeMilliVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowTimeWriter(TimeMilliVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getInt(block, position));
+ }
+ }
+
+ private static class ArrowTimeStampWriter
+ extends ArrowVectorWriter
+ {
+ private final TimeStampVector vector;
+ private final BlockValueGetter blockGetter;
+
+ public ArrowTimeStampWriter(TimeStampVector vector, BlockValueGetter blockGetter)
+ {
+ this.vector = vector;
+ this.blockGetter = blockGetter;
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ vector.set(index, blockGetter.getLong(block, position));
+ }
+ }
+
+ private static class ArrowListWriter
+ extends ArrowVectorWriter
+ {
+ private final ListVector vector;
+ private final BlockArrayGetter blockArrayGetter;
+ private final ArrowVectorWriter childWriter;
+
+ public ArrowListWriter(ListVector vector, BlockArrayGetter blockArrayGetter)
+ {
+ this.vector = vector;
+ this.blockArrayGetter = blockArrayGetter;
+ this.childWriter = createArrowWriter(vector.getDataVector(), blockArrayGetter.getType().getElementType());
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ int dataIndex = vector.startNewValue(index);
+ Block elementBlock = blockArrayGetter.getChildBlock(block, position);
+ for (int i = 0; i < elementBlock.getPositionCount(); ++i) {
+ childWriter.writeBlock(dataIndex + i, elementBlock, i);
+ }
+ vector.endValue(index, elementBlock.getPositionCount());
+ }
+ }
+
+ private static class ArrowMapWriter
+ extends ArrowVectorWriter
+ {
+ private final MapVector vector;
+ private final StructVector structVector;
+ private final BlockMapGetter blockMapGetter;
+ private final ArrowVectorWriter keyWriter;
+ private final ArrowVectorWriter valueWriter;
+
+ public ArrowMapWriter(MapVector vector, BlockMapGetter blockMapGetter)
+ {
+ this.vector = vector;
+ this.structVector = (StructVector) vector.getDataVector();
+ this.blockMapGetter = blockMapGetter;
+ this.keyWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(0), blockMapGetter.getType().getKeyType());
+ this.valueWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(1), blockMapGetter.getType().getValueType());
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ Block singleMapBlock = blockMapGetter.getChildBlock(block, position);
+ int dataIndex = vector.startNewValue(index);
+ int numPairs = singleMapBlock.getPositionCount() / 2;
+ for (int i = 0; i < numPairs; ++i) {
+ keyWriter.writeBlock(dataIndex + i, singleMapBlock, i * 2);
+ valueWriter.writeBlock(dataIndex + i, singleMapBlock, (i * 2) + 1);
+ structVector.setIndexDefined(dataIndex + i);
+ }
+ vector.endValue(index, numPairs);
+ }
+ }
+
+ private static class ArrowStructWriter
+ extends ArrowVectorWriter
+ {
+ private final StructVector vector;
+ private final BlockRowGetter blockRowGetter;
+ private final List childWriters;
+
+ public ArrowStructWriter(StructVector vector, BlockRowGetter blockRowGetter)
+ {
+ this.vector = vector;
+ this.blockRowGetter = blockRowGetter;
+ ImmutableList.Builder writerBuilder = ImmutableList.builder();
+ List fieldVectors = vector.getChildrenFromFields();
+ for (int i = 0; i < fieldVectors.size(); i++) {
+ Type childType = blockRowGetter.getType().getTypeParameters().get(i);
+ writerBuilder.add(createArrowWriter(fieldVectors.get(i), childType));
+ }
+ this.childWriters = writerBuilder.build();
+ }
+
+ @Override
+ public void writeNull(int index)
+ {
+ vector.setNull(index);
+ }
+
+ @Override
+ public void writeBlock(int index, Block block, int position)
+ {
+ Block singleRowBlock = blockRowGetter.getChildBlock(block, position);
+ for (int i = 0; i < childWriters.size(); ++i) {
+ childWriters.get(i).writeBlock(index, singleRowBlock, i);
+ }
+ vector.setIndexDefined(index);
+ }
+ }
+}
diff --git a/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java
new file mode 100644
index 0000000000000..a922c9f00c0a9
--- /dev/null
+++ b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java
@@ -0,0 +1,452 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.Page;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.facebook.presto.spi.ConnectorPageSource;
+import com.google.common.collect.ImmutableList;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeMilliVector;
+import org.apache.arrow.vector.TimeStampMilliVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.complex.impl.UnionListWriter;
+import org.apache.arrow.vector.complex.impl.UnionMapWriter;
+import org.apache.arrow.vector.types.Types;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.util.Text;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.time.LocalDateTime;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.concurrent.TimeUnit;
+
+import static com.facebook.presto.testing.TestingEnvironment.FUNCTION_AND_TYPE_MANAGER;
+import static com.facebook.presto.util.DateTimeUtils.parseTimestampWithoutTimeZone;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertFalse;
+import static org.testng.Assert.assertTrue;
+
+public class TestArrowBatchSource
+{
+ private static final int MAX_ROWS_PER_BATCH = 1000;
+ private BufferAllocator allocator;
+ private TestArrowBlockBuilder arrowBlockBuilder;
+
+ @BeforeClass
+ public void setUp()
+ {
+ // Initialize the Arrow allocator
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ arrowBlockBuilder = new TestArrowBlockBuilder(FUNCTION_AND_TYPE_MANAGER);
+ }
+
+ @AfterClass
+ public void tearDown()
+ {
+ allocator.close();
+ }
+
+ @Test
+ public void testPrimitiveTypes()
+ throws IOException
+ {
+ try (BitVector bitVector = new BitVector("bitVector", allocator);
+ TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator);
+ SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator);
+ IntVector intVector = new IntVector("intVector", allocator);
+ BigIntVector longVector = new BigIntVector("bigIntVector", allocator);
+ Float4Vector floatVector = new Float4Vector("floatVector", allocator);
+ Float8Vector doubleVector = new Float8Vector("doubleVector", allocator);
+ VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(bitVector, tinyIntVector, smallIntVector, intVector, longVector, floatVector, doubleVector))) {
+ final int numPages = 5;
+ final int numValues = 10;
+ intVector.allocateNew(numValues);
+
+ for (int i = 0; i < numValues; i++) {
+ bitVector.setSafe(i, i % 2);
+ tinyIntVector.setSafe(i, i);
+ smallIntVector.setSafe(i, i);
+ intVector.setSafe(i, i);
+ longVector.setSafe(i, i);
+ floatVector.setSafe(i, i * 1.1f);
+ doubleVector.setSafe(i, i * 1.1);
+ }
+
+ // Add null values
+ bitVector.setNull(2);
+ tinyIntVector.setNull(3);
+ smallIntVector.setNull(4);
+ intVector.setNull(5);
+ longVector.setNull(6);
+ floatVector.setNull(7);
+ doubleVector.setNull(8);
+
+ expectedRoot.setRowCount(numValues);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 5);
+
+ // Set for 1 batch per page
+ int batchCount = 0;
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, numValues)) {
+ while (arrowBatchSource.nextBatch()) {
+ assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot()));
+ ++batchCount;
+ }
+ }
+ assertEquals(batchCount, numPages);
+ }
+ }
+
+ @Test
+ public void testDateTimeTypes()
+ throws IOException
+ {
+ try (IntVector intVector = new IntVector("id", allocator);
+ DateDayVector dateVector = new DateDayVector("date", allocator);
+ TimeMilliVector timeVector = new TimeMilliVector("time", allocator);
+ TimeStampMilliVector timestampVector = new TimeStampMilliVector("timestamp", allocator);
+ VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, dateVector, timeVector, timestampVector))) {
+ List values = ImmutableList.of(
+ "1970-01-01T00:00:00",
+ "2024-01-01T01:01:01",
+ "2024-01-02T12:00:00",
+ "2112-12-31T23:58:00",
+ "1968-07-05T08:15:12.345");
+
+ for (int i = 0; i < values.size(); i++) {
+ intVector.setSafe(i, i);
+ LocalDateTime dateTime = LocalDateTime.parse(values.get(i));
+ // First vector value is explicitly set to 0 to ensure no issues with parsing
+ dateVector.setSafe(i, i == 0 ? 0 : (int) dateTime.toLocalDate().toEpochDay());
+ timeVector.setSafe(i, i == 0 ? 0 : (int) TimeUnit.NANOSECONDS.toMillis(dateTime.toLocalTime().toNanoOfDay()));
+ timestampVector.setSafe(i, i == 0 ? 0 : parseTimestampWithoutTimeZone(values.get(i).replace("T", " ")));
+ }
+
+ // Add null values at last row
+ dateVector.setNull(values.size());
+ timeVector.setNull(values.size());
+ timestampVector.setNull(values.size());
+
+ expectedRoot.setRowCount(values.size() + 1);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1);
+
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) {
+ assertTrue(arrowBatchSource.nextBatch());
+ assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot()));
+ assertFalse(arrowBatchSource.nextBatch());
+ }
+ }
+ }
+
+ @Test
+ public void testVarCharType()
+ throws IOException
+ {
+ final int numValues = 100;
+ final int maxRowsPerBatch = 33;
+
+ try (IntVector intVector = new IntVector("id", allocator);
+ VarCharVector stringVector = new VarCharVector("varchar", allocator);
+ VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, stringVector))) {
+ final String stringData = "abcdefghijklmnopqrstuvwxyz";
+ for (int i = 0; i < numValues; i++) {
+ intVector.setSafe(i, i);
+ String value = stringData.substring(0, i % stringData.length());
+ stringVector.setSafe(i, new Text(value));
+ }
+
+ // Add null values
+ stringVector.setNull(2);
+
+ expectedRoot.setRowCount(numValues);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1);
+
+ int actualRowCount = 0;
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, maxRowsPerBatch)) {
+ while (arrowBatchSource.nextBatch()) {
+ try (VectorSchemaRoot expectedSlice = expectedRoot.slice(actualRowCount, Math.min(numValues - actualRowCount, maxRowsPerBatch))) {
+ assertTrue(expectedSlice.equals(arrowBatchSource.getVectorSchemaRoot()));
+ }
+ actualRowCount += arrowBatchSource.getVectorSchemaRoot().getRowCount();
+ }
+ assertEquals(actualRowCount, numValues);
+ }
+ }
+ }
+
+ @Test
+ public void testArrayType()
+ throws IOException
+ {
+ try (IntVector intVector = new IntVector("id", allocator);
+ ListVector listVectorInt = ListVector.empty("array-int", allocator);
+ ListVector listVectorVarchar = ListVector.empty("array-varchar", allocator)) {
+ // Add the element vectors
+ listVectorInt.addOrGetVector(FieldType.nullable(Types.MinorType.INT.getType()));
+ listVectorVarchar.addOrGetVector(FieldType.nullable(Types.MinorType.VARCHAR.getType()));
+ listVectorInt.allocateNew();
+ listVectorVarchar.allocateNew();
+
+ final int numValues = 10;
+ final String stringData = "abcdefghijklmnopqrstuvwxyz";
+ final UnionListWriter writerInt = listVectorInt.getWriter();
+ final UnionListWriter writerVarchar = listVectorVarchar.getWriter();
+ for (int i = 0; i < numValues; i++) {
+ intVector.setSafe(i, i);
+ writerInt.setPosition(i);
+ // Need to set nulls during write for lists
+ if (i == 5) {
+ writerInt.writeNull();
+ writerVarchar.writeNull();
+ }
+ else {
+ writerInt.startList();
+ writerVarchar.startList();
+ for (int j = 0; j < i + i; j++) {
+ writerInt.integer().writeInt(i * j);
+ String stringValue = stringData.substring(0, i % stringData.length());
+ writerVarchar.writeVarChar(new Text(stringValue));
+ }
+ }
+ writerInt.endList();
+ writerVarchar.endList();
+ }
+
+ try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, listVectorInt, listVectorVarchar))) {
+ expectedRoot.setRowCount(numValues);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1);
+
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) {
+ assertTrue(arrowBatchSource.nextBatch());
+ assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot()));
+ assertFalse(arrowBatchSource.nextBatch());
+ }
+ }
+ }
+ }
+
+ @Test
+ void testMapType()
+ throws IOException
+ {
+ try (IntVector intVector = new IntVector("id", allocator);
+ MapVector mapLongVector = MapVector.empty("map-long-long", allocator, false);
+ MapVector mapVarcharVector = MapVector.empty("map-long-varchar", allocator, false)) {
+ UnionMapWriter mapLongWriter = mapLongVector.getWriter();
+ UnionMapWriter mapVarcharWriter = mapVarcharVector.getWriter();
+ mapLongWriter.allocate();
+ mapVarcharWriter.allocate();
+
+ final int numValues = 10;
+ final String stringData = "abcdefghijklmnopqrstuvwxyz";
+ for (int i = 0; i < numValues; i++) {
+ intVector.setSafe(i, i);
+ mapLongWriter.setPosition(i);
+ mapLongWriter.startMap();
+
+ mapVarcharWriter.setPosition(i);
+ mapVarcharWriter.startMap();
+
+ for (int j = 0; j < i; j++) {
+ mapLongWriter.startEntry();
+ mapLongWriter.key().bigInt().writeBigInt(j);
+ mapLongWriter.value().bigInt().writeBigInt(i * j);
+ mapLongWriter.endEntry();
+ mapVarcharWriter.startEntry();
+ mapVarcharWriter.key().bigInt().writeBigInt(j * j);
+ String stringValue = stringData.substring(0, i % stringData.length());
+ mapVarcharWriter.value().varChar().writeVarChar(new Text(stringValue));
+ mapVarcharWriter.endEntry();
+ }
+ mapLongWriter.endMap();
+ mapVarcharWriter.endMap();
+ }
+
+ try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, mapLongVector, mapVarcharVector))) {
+ expectedRoot.setRowCount(numValues);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1);
+
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) {
+ assertTrue(arrowBatchSource.nextBatch());
+ assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot()));
+ assertFalse(arrowBatchSource.nextBatch());
+ }
+ }
+ }
+ }
+
+ @Test
+ void testRowType()
+ throws IOException
+ {
+ try (IntVector intVector = new IntVector("id", allocator);
+ StructVector structVector = StructVector.empty("struct", allocator)) {
+ final BigIntVector childLongVector
+ = structVector.addOrGet("long", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class);
+ final VarCharVector childVarcharVector
+ = structVector.addOrGet("varchar", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class);
+ childLongVector.allocateNew();
+ childVarcharVector.allocateNew();
+
+ final int numValues = 10;
+ final String stringData = "abcdefghijklmnopqrstuvwxyz";
+ for (int i = 0; i < numValues; i++) {
+ intVector.setSafe(i, i);
+ childLongVector.setSafe(i, i * i);
+ String stringValue = stringData.substring(0, i % stringData.length());
+ childVarcharVector.setSafe(i, new Text(stringValue));
+ structVector.setIndexDefined(i);
+ }
+
+ try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, structVector))) {
+ expectedRoot.setRowCount(numValues);
+
+ TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1);
+
+ try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) {
+ assertTrue(arrowBatchSource.nextBatch());
+ assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot()));
+ assertFalse(arrowBatchSource.nextBatch());
+ }
+ }
+ }
+ }
+
+ private static class TestArrowBlockBuilder
+ extends ArrowBlockBuilder
+ {
+ public TestArrowBlockBuilder(TypeManager typeManager)
+ {
+ super(typeManager);
+ }
+
+ public Type getPrestoType(Field field)
+ {
+ return getPrestoTypeFromArrowField(field);
+ }
+ }
+
+ private static class TestArrowPageSource
+ implements ConnectorPageSource
+ {
+ private final int rowCount;
+ private final List blocks;
+ private final List columns;
+ private int pagesRemaining;
+
+ private TestArrowPageSource(int rowCount, List blocks, List columns, int numPages)
+ {
+ this.rowCount = rowCount;
+ this.blocks = blocks;
+ this.columns = columns;
+ this.pagesRemaining = numPages;
+ }
+
+ public static TestArrowPageSource create(TestArrowBlockBuilder arrowBlockBuilder, VectorSchemaRoot root, int numPages)
+ {
+ List blocks = new ArrayList<>();
+ List columns = new ArrayList<>();
+ for (FieldVector vector : root.getFieldVectors()) {
+ ColumnMetadata column = ColumnMetadata.builder()
+ .setName(vector.getName())
+ .setType(arrowBlockBuilder.getPrestoType(vector.getField()))
+ .setNullable(vector.getField().isNullable()).build();
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, column.getType(), null);
+ blocks.add(block);
+ columns.add(column);
+ }
+
+ return new TestArrowPageSource(root.getRowCount(), blocks, columns, numPages);
+ }
+
+ public List getColumns()
+ {
+ return columns;
+ }
+
+ @Override
+ public long getCompletedBytes()
+ {
+ return 0;
+ }
+
+ @Override
+ public long getCompletedPositions()
+ {
+ return 0;
+ }
+
+ @Override
+ public long getReadTimeNanos()
+ {
+ return 0;
+ }
+
+ @Override
+ public boolean isFinished()
+ {
+ return pagesRemaining == 0;
+ }
+
+ @Override
+ public long getSystemMemoryUsage()
+ {
+ return 0;
+ }
+
+ @Override
+ public Page getNextPage()
+ {
+ if (pagesRemaining == 0) {
+ return null;
+ }
+ --pagesRemaining;
+ return new Page(rowCount, blocks.toArray(new Block[0]));
+ }
+
+ @Override
+ public void close()
+ {
+ }
+ }
+}
diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java
similarity index 100%
rename from presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java
rename to presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java
diff --git a/presto-flight-shim/etc/flightshim.properties b/presto-flight-shim/etc/flightshim.properties
new file mode 100644
index 0000000000000..a86cb845c4685
--- /dev/null
+++ b/presto-flight-shim/etc/flightshim.properties
@@ -0,0 +1,10 @@
+flight-shim.server=localhost
+flight-shim.server.port=9999
+flight-shim.server-ssl-certificate-file=src/test/resources/certs/server.crt
+flight-shim.server-ssl-key-file=src/test/resources/certs/server.key
+
+plugin.bundles=\
+ ../presto-oracle/pom.xml,\
+ ../presto-postgresql/pom.xml,\
+ ../presto-mysql/pom.xml,\
+ ../presto-singlestore/pom.xml
diff --git a/presto-flight-shim/etc/log.properties b/presto-flight-shim/etc/log.properties
new file mode 100644
index 0000000000000..9b50e185c74ee
--- /dev/null
+++ b/presto-flight-shim/etc/log.properties
@@ -0,0 +1,13 @@
+#
+# WARNING
+# ^^^^^^^
+# This configuration file is for development only and should NOT be used be
+# used in production. For example configuration, see the Presto documentation.
+#
+com.=INFO
+com.facebook.presto.execution=DEBUG
+com.facebook.presto=INFO
+com.sun.jersey.guice.spi.container.GuiceComponentProviderFactory=WARN
+com.ning.http.client=WARN
+com.facebook.presto.server.PluginManager=DEBUG
+com.facebook.presto.flightshim=DEBUG
diff --git a/presto-flight-shim/pom.xml b/presto-flight-shim/pom.xml
new file mode 100644
index 0000000000000..83cf59ec6704e
--- /dev/null
+++ b/presto-flight-shim/pom.xml
@@ -0,0 +1,425 @@
+
+
+ 4.0.0
+
+
+ com.facebook.presto
+ presto-root
+ 0.297-SNAPSHOT
+
+
+ presto-flight-shim
+ presto-flight-shim
+ Presto - Shim with Apache Arrow Flight server for Presto native federation
+
+
+ ${project.parent.basedir}
+ -Xss10M
+ true
+
+
+
+
+
+ javax.annotation
+ javax.annotation-api
+ 1.3.2
+
+
+
+
+
+
+ com.facebook.presto
+ presto-common-arrow
+
+
+
+ org.apache.arrow
+ arrow-memory-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+
+ org.apache.arrow
+ arrow-memory-netty
+ runtime
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+
+ org.apache.arrow
+ arrow-vector
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jsr310
+
+
+
+
+
+ org.apache.arrow
+ flight-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+
+ com.facebook.airlift
+ bootstrap
+
+
+
+ com.facebook.airlift
+ log
+
+
+
+ com.google.guava
+ guava
+
+
+
+ jakarta.inject
+ jakarta.inject-api
+
+
+
+ javax.inject
+ javax.inject
+
+
+
+ com.facebook.presto
+ presto-spi
+
+
+
+ com.facebook.airlift
+ json
+
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+
+ com.facebook.presto
+ presto-common
+
+
+
+ com.fasterxml.jackson.core
+ jackson-core
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+ jakarta.annotation
+ jakarta.annotation-api
+
+
+
+ jakarta.validation
+ jakarta.validation-api
+
+
+
+ com.google.inject
+ guice
+
+
+
+ com.google.http-client
+ google-http-client
+ 1.47.1
+
+
+
+ com.facebook.airlift
+ configuration
+
+
+
+ com.facebook.airlift
+ concurrent
+
+
+
+ com.facebook.presto
+ presto-main-base
+
+
+
+ com.facebook.presto
+ presto-base-jdbc
+
+
+
+ com.facebook.presto
+ presto-analyzer
+
+
+
+ org.weakref
+ jmxutils
+
+
+
+ io.airlift.resolver
+ resolver
+
+
+
+
+
+ org.jdbi
+ jdbi3-core
+ test
+
+
+
+ org.testng
+ testng
+ test
+
+
+
+ io.airlift.tpch
+ tpch
+ test
+
+
+
+ com.facebook.presto
+ presto-tpch
+ test
+
+
+
+ com.facebook.presto
+ presto-testng-services
+ test
+
+
+
+ com.facebook.airlift
+ testing
+ test
+
+
+
+ com.facebook.presto
+ presto-tests
+ test
+
+
+
+ com.facebook.presto
+ presto-main-base
+ test-jar
+ test
+
+
+
+ com.facebook.presto
+ presto-base-arrow-flight
+ test
+
+
+
+ com.h2database
+ h2
+ test
+
+
+
+ org.testcontainers
+ testcontainers
+ test
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.jetbrains
+ annotations
+
+
+
+
+
+ org.testcontainers
+ testcontainers-postgresql
+ test
+
+
+
+ com.facebook.presto
+ presto-postgresql
+ ${project.version}
+ test
+
+
+
+ com.facebook.presto
+ presto-postgresql
+ ${project.version}
+ test-jar
+ test
+
+
+
+ org.testcontainers
+ testcontainers-mysql
+ test
+
+
+
+ com.facebook.presto
+ presto-mysql
+ ${project.version}
+ test
+
+
+
+ com.facebook.presto
+ presto-mysql
+ ${project.version}
+ test-jar
+ test
+
+
+
+ com.facebook.presto
+ presto-oracle
+ ${project.version}
+ test
+
+
+
+ com.facebook.presto
+ presto-oracle
+ ${project.version}
+ test-jar
+ test
+
+
+
+ org.testcontainers
+ oracle-xe
+ test
+ 1.14.3
+
+
+ org.slf4j
+ slf4j-api
+
+
+ org.jetbrains
+ annotations
+
+
+
+
+
+ com.facebook.presto
+ presto-singlestore
+ ${project.version}
+ test
+
+
+
+ com.facebook.presto
+ presto-singlestore
+ ${project.version}
+ test-jar
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ com.fasterxml.jackson.core:jackson-databind
+ com.facebook.airlift:log-manager
+ javax.inject:javax.inject
+ com.facebook.presto:presto-base-jdbc
+ com.google.errorprone:error_prone_annotations
+
+
+
+
+ org.basepom.maven
+ duplicate-finder-maven-plugin
+ 1.2.1
+
+
+ module-info
+ META-INF.versions.9.module-info
+
+
+ arrow-git.properties
+ about.html
+
+
+
+
+
+ check
+
+
+
+
+
+
+
+
+
+ java17
+
+ 17
+
+
+
+ -Xss10M
+ --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED
+
+
+
+
+
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java
new file mode 100644
index 0000000000000..9aeff050e2077
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java
@@ -0,0 +1,150 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.configuration.Config;
+import com.facebook.airlift.configuration.ConfigDescription;
+import jakarta.validation.constraints.Max;
+import jakarta.validation.constraints.Min;
+
+public class FlightShimConfig
+{
+ public static final String CONFIG_PREFIX = "flight-shim";
+ private static final int MAX_ROWS_PER_BATCH_DEFAULT = 10000;
+ private static final int READ_THREAD_POOL_SIZE = 16;
+ private String serverName;
+ private Integer serverPort;
+ private String serverSSLCertificateFile;
+ private String serverSSLKeyFile;
+ private String clientSSLCertificateFile;
+ private String clientSSLKeyFile;
+ private boolean serverSslEnabled = true;
+ private int maxRowsPerBatch = MAX_ROWS_PER_BATCH_DEFAULT;
+ private int readThreadPoolSize = READ_THREAD_POOL_SIZE;
+
+ public String getServerName()
+ {
+ return serverName;
+ }
+
+ @Config("server")
+ public FlightShimConfig setServerName(String serverName)
+ {
+ this.serverName = serverName;
+ return this;
+ }
+
+ public Integer getServerPort()
+ {
+ return serverPort;
+ }
+
+ @Config("server.port")
+ public FlightShimConfig setServerPort(Integer serverPort)
+ {
+ this.serverPort = serverPort;
+ return this;
+ }
+
+ public boolean getServerSslEnabled()
+ {
+ return serverSslEnabled;
+ }
+
+ @Config("server-ssl-enabled")
+ public FlightShimConfig setServerSslEnabled(boolean serverSslEnabled)
+ {
+ this.serverSslEnabled = serverSslEnabled;
+ return this;
+ }
+
+ public String getServerSSLCertificateFile()
+ {
+ return serverSSLCertificateFile;
+ }
+
+ @Config("server-ssl-certificate-file")
+ public FlightShimConfig setServerSSLCertificateFile(String serverSSLCertificateFile)
+ {
+ this.serverSSLCertificateFile = serverSSLCertificateFile;
+ return this;
+ }
+
+ public String getServerSSLKeyFile()
+ {
+ return serverSSLKeyFile;
+ }
+
+ @Config("server-ssl-key-file")
+ public FlightShimConfig setServerSSLKeyFile(String serverSSLKeyFile)
+ {
+ this.serverSSLKeyFile = serverSSLKeyFile;
+ return this;
+ }
+
+ public String getClientSSLCertificateFile()
+ {
+ return clientSSLCertificateFile;
+ }
+
+ @ConfigDescription("Path to the client SSL certificate used for mTLS authentication with the Flight server")
+ @Config("client-ssl-certificate-file")
+ public FlightShimConfig setClientSSLCertificateFile(String clientSSLCertificateFile)
+ {
+ this.clientSSLCertificateFile = clientSSLCertificateFile;
+ return this;
+ }
+
+ public String getClientSSLKeyFile()
+ {
+ return clientSSLKeyFile;
+ }
+
+ @ConfigDescription("Path to the client SSL key used for mTLS authentication with the Flight server")
+ @Config("client-ssl-key-file")
+ public FlightShimConfig setClientSSLKeyFile(String clientSSLKeyFile)
+ {
+ this.clientSSLKeyFile = clientSSLKeyFile;
+ return this;
+ }
+
+ public int getMaxRowsPerBatch()
+ {
+ return maxRowsPerBatch;
+ }
+
+ @Config("max-rows-per-batch")
+ @Min(1)
+ @Max(1000000)
+ @ConfigDescription("Sets the maximum number of rows an Arrow record batch will have before sending to the client")
+ public FlightShimConfig setMaxRowsPerBatch(int maxRowsPerBatch)
+ {
+ this.maxRowsPerBatch = maxRowsPerBatch;
+ return this;
+ }
+
+ @Config("thread-pool-size")
+ @Min(1)
+ @ConfigDescription("Size of thread pool to used to handle read requests")
+ public FlightShimConfig setReadThreadPoolSize(int readThreadPoolSize)
+ {
+ this.readThreadPoolSize = readThreadPoolSize;
+ return this;
+ }
+
+ public int getReadThreadPoolSize()
+ {
+ return this.readThreadPoolSize;
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java
new file mode 100644
index 0000000000000..9d272d91e1d69
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java
@@ -0,0 +1,190 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.configuration.AbstractConfigurationAwareModule;
+import com.facebook.airlift.json.JsonObjectMapperProvider;
+import com.facebook.presto.SystemSessionProperties;
+import com.facebook.presto.block.BlockJsonSerde;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockEncoding;
+import com.facebook.presto.common.block.BlockEncodingManager;
+import com.facebook.presto.common.block.BlockEncodingSerde;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.connector.ConnectorManager;
+import com.facebook.presto.cost.HistoryBasedOptimizationConfig;
+import com.facebook.presto.execution.QueryManagerConfig;
+import com.facebook.presto.execution.TaskManagerConfig;
+import com.facebook.presto.execution.scheduler.NodeSchedulerConfig;
+import com.facebook.presto.execution.warnings.WarningCollectorConfig;
+import com.facebook.presto.memory.MemoryManagerConfig;
+import com.facebook.presto.memory.NodeMemoryConfig;
+import com.facebook.presto.metadata.AnalyzePropertyManager;
+import com.facebook.presto.metadata.BuiltInProcedureRegistry;
+import com.facebook.presto.metadata.ColumnPropertyManager;
+import com.facebook.presto.metadata.FunctionAndTypeManager;
+import com.facebook.presto.metadata.HandleJsonModule;
+import com.facebook.presto.metadata.InMemoryNodeManager;
+import com.facebook.presto.metadata.InternalNodeManager;
+import com.facebook.presto.metadata.MaterializedViewPropertyManager;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.metadata.MetadataManager;
+import com.facebook.presto.metadata.SchemaPropertyManager;
+import com.facebook.presto.metadata.SessionPropertyManager;
+import com.facebook.presto.metadata.SessionPropertyProviderConfig;
+import com.facebook.presto.metadata.StaticCatalogStoreConfig;
+import com.facebook.presto.metadata.TableFunctionRegistry;
+import com.facebook.presto.metadata.TablePropertyManager;
+import com.facebook.presto.nodeManager.PluginNodeManager;
+import com.facebook.presto.server.PluginManagerConfig;
+import com.facebook.presto.server.security.SecurityConfig;
+import com.facebook.presto.sessionpropertyproviders.NativeWorkerSessionPropertyProvider;
+import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.analyzer.ViewDefinition;
+import com.facebook.presto.spi.procedure.ProcedureRegistry;
+import com.facebook.presto.spi.session.WorkerSessionPropertyProvider;
+import com.facebook.presto.spiller.NodeSpillConfig;
+import com.facebook.presto.sql.SqlEnvironmentConfig;
+import com.facebook.presto.sql.analyzer.FeaturesConfig;
+import com.facebook.presto.sql.analyzer.FunctionsConfig;
+import com.facebook.presto.sql.analyzer.JavaFeaturesConfig;
+import com.facebook.presto.sql.planner.CompilerConfig;
+import com.facebook.presto.tracing.TracingConfig;
+import com.facebook.presto.transaction.NoOpTransactionManager;
+import com.facebook.presto.transaction.TransactionManager;
+import com.facebook.presto.transaction.TransactionManagerConfig;
+import com.facebook.presto.type.TypeDeserializer;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.google.inject.Binder;
+import com.google.inject.Provides;
+import com.google.inject.Scopes;
+import com.google.inject.multibindings.MapBinder;
+import jakarta.inject.Singleton;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.ThreadPoolExecutor;
+import java.util.concurrent.TimeUnit;
+
+import static com.facebook.airlift.concurrent.Threads.threadsNamed;
+import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
+import static com.facebook.airlift.json.JsonBinder.jsonBinder;
+import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder;
+import static com.google.inject.multibindings.MapBinder.newMapBinder;
+import static com.google.inject.multibindings.Multibinder.newSetBinder;
+import static org.weakref.jmx.guice.ExportBinder.newExporter;
+
+public class FlightShimModule
+ extends AbstractConfigurationAwareModule
+{
+ @Override
+ protected void setup(Binder binder)
+ {
+ binder.bind(ConnectorManager.class).toProvider(() -> null);
+ binder.bind(FlightShimPluginManager.class).in(Scopes.SINGLETON);
+ binder.bind(BufferAllocator.class).to(RootAllocator.class).in(Scopes.SINGLETON);
+ binder.bind(FlightShimProducer.class).in(Scopes.SINGLETON);
+
+ binder.bind(FlightShimServerExecutionMBean.class).in(Scopes.SINGLETON);
+ newExporter(binder).export(FlightShimServerExecutionMBean.class).withGeneratedName();
+
+ configBinder(binder).bindConfig(FlightShimConfig.class, FlightShimConfig.CONFIG_PREFIX);
+ configBinder(binder).bindConfig(PluginManagerConfig.class);
+ configBinder(binder).bindConfig(StaticCatalogStoreConfig.class);
+
+ // configs
+ configBinder(binder).bindConfig(QueryManagerConfig.class);
+ configBinder(binder).bindConfig(TaskManagerConfig.class);
+ configBinder(binder).bindConfig(NodeSchedulerConfig.class);
+ configBinder(binder).bindConfig(WarningCollectorConfig.class);
+ configBinder(binder).bindConfig(MemoryManagerConfig.class);
+ configBinder(binder).bindConfig(NodeMemoryConfig.class);
+ configBinder(binder).bindConfig(SessionPropertyProviderConfig.class);
+ configBinder(binder).bindConfig(SecurityConfig.class);
+ configBinder(binder).bindConfig(NodeSpillConfig.class);
+ configBinder(binder).bindConfig(SqlEnvironmentConfig.class);
+ configBinder(binder).bindConfig(CompilerConfig.class);
+ configBinder(binder).bindConfig(TracingConfig.class);
+
+ // json codecs
+ jsonCodecBinder(binder).bindJsonCodec(ViewDefinition.class);
+
+ // Worker session property providers
+ MapBinder mapBinder =
+ newMapBinder(binder, String.class, WorkerSessionPropertyProvider.class);
+ mapBinder.addBinding("native-worker").to(NativeWorkerSessionPropertyProvider.class).in(Scopes.SINGLETON);
+
+ // history statistics
+ configBinder(binder).bindConfig(HistoryBasedOptimizationConfig.class);
+
+ // property managers
+ binder.bind(SystemSessionProperties.class).in(Scopes.SINGLETON);
+ binder.bind(SessionPropertyManager.class).in(Scopes.SINGLETON);
+ binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON);
+ binder.bind(TablePropertyManager.class).in(Scopes.SINGLETON);
+ binder.bind(MaterializedViewPropertyManager.class).in(Scopes.SINGLETON);
+ binder.bind(ColumnPropertyManager.class).in(Scopes.SINGLETON);
+ binder.bind(AnalyzePropertyManager.class).in(Scopes.SINGLETON);
+
+ // transaction manager
+ configBinder(binder).bindConfig(TransactionManagerConfig.class);
+ // Install no-op transaction manager on workers, since only coordinators manage transactions.
+ binder.bind(TransactionManager.class).to(NoOpTransactionManager.class).in(Scopes.SINGLETON);
+
+ // metadata
+ binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON);
+ binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON);
+ binder.bind(MetadataManager.class).in(Scopes.SINGLETON);
+ binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON);
+ binder.bind(ProcedureRegistry.class).to(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON);
+
+ // type
+ binder.bind(TypeManager.class).to(FunctionAndTypeManager.class).in(Scopes.SINGLETON);
+ jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class);
+ newSetBinder(binder, Type.class);
+ binder.bind(TypeDeserializer.class).in(Scopes.SINGLETON);
+
+ // block encodings
+ binder.bind(BlockEncodingManager.class).in(Scopes.SINGLETON);
+ binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON);
+ newSetBinder(binder, BlockEncoding.class);
+ jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class);
+ jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class);
+
+ // handle resolver
+ binder.install(new HandleJsonModule());
+ binder.bind(ObjectMapper.class).toProvider(JsonObjectMapperProvider.class);
+
+ // features config
+ configBinder(binder).bindConfig(FeaturesConfig.class);
+ configBinder(binder).bindConfig(FunctionsConfig.class);
+ configBinder(binder).bindConfig(JavaFeaturesConfig.class);
+
+ // Node manager binding
+ binder.bind(InternalNodeManager.class).to(InMemoryNodeManager.class).in(Scopes.SINGLETON);
+ binder.bind(PluginNodeManager.class).in(Scopes.SINGLETON);
+ binder.bind(NodeManager.class).to(PluginNodeManager.class).in(Scopes.SINGLETON);
+ }
+
+ @Provides
+ @Singleton
+ @ForFlightShimServer
+ public static ExecutorService createFlightShimServerExecutor(FlightShimConfig config)
+ {
+ return new ThreadPoolExecutor(0, config.getReadThreadPoolSize(), 1L, TimeUnit.MINUTES, new SynchronousQueue<>(), threadsNamed("flight-shim-%s"), new ThreadPoolExecutor.CallerRunsPolicy());
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java
new file mode 100644
index 0000000000000..0494deb853322
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java
@@ -0,0 +1,469 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.json.JsonCodecFactory;
+import com.facebook.airlift.json.JsonObjectMapperProvider;
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.GroupByHashPageIndexerFactory;
+import com.facebook.presto.PagesIndexPageSorter;
+import com.facebook.presto.block.BlockJsonSerde;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockEncodingManager;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.connector.ConnectorContextInstance;
+import com.facebook.presto.cost.ConnectorFilterStatsCalculatorService;
+import com.facebook.presto.cost.FilterStatsCalculator;
+import com.facebook.presto.cost.ScalarStatsCalculator;
+import com.facebook.presto.cost.StatsNormalizer;
+import com.facebook.presto.metadata.InMemoryNodeManager;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.metadata.StaticCatalogStoreConfig;
+import com.facebook.presto.nodeManager.PluginNodeManager;
+import com.facebook.presto.operator.PagesIndex;
+import com.facebook.presto.server.PluginInstaller;
+import com.facebook.presto.server.PluginManagerConfig;
+import com.facebook.presto.server.PluginManagerUtil;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.facebook.presto.spi.ConnectorHandleResolver;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.ConnectorSplit;
+import com.facebook.presto.spi.ConnectorTableHandle;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.facebook.presto.spi.CoordinatorPlugin;
+import com.facebook.presto.spi.Plugin;
+import com.facebook.presto.spi.PrestoException;
+import com.facebook.presto.spi.classloader.ThreadContextClassLoader;
+import com.facebook.presto.spi.connector.Connector;
+import com.facebook.presto.spi.connector.ConnectorContext;
+import com.facebook.presto.spi.connector.ConnectorFactory;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+import com.facebook.presto.spi.procedure.ProcedureRegistry;
+import com.facebook.presto.spi.relation.DeterminismEvaluator;
+import com.facebook.presto.spi.relation.DomainTranslator;
+import com.facebook.presto.spi.relation.ExpressionOptimizer;
+import com.facebook.presto.spi.relation.ExpressionOptimizerProvider;
+import com.facebook.presto.spi.relation.PredicateCompiler;
+import com.facebook.presto.spi.relation.RowExpression;
+import com.facebook.presto.spi.relation.RowExpressionService;
+import com.facebook.presto.sql.gen.JoinCompiler;
+import com.facebook.presto.sql.gen.RowExpressionPredicateCompiler;
+import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter;
+import com.facebook.presto.sql.relational.FunctionResolution;
+import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator;
+import com.facebook.presto.sql.relational.RowExpressionDomainTranslator;
+import com.facebook.presto.sql.relational.RowExpressionOptimizer;
+import com.facebook.presto.type.TypeDeserializer;
+import com.fasterxml.jackson.core.JsonParser;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.JsonDeserializer;
+import com.google.common.annotations.VisibleForTesting;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.io.Files;
+import com.google.inject.Inject;
+import io.airlift.resolver.ArtifactResolver;
+import jakarta.annotation.PreDestroy;
+
+import java.io.File;
+import java.io.IOException;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+import static com.facebook.presto.server.PluginManagerUtil.SPI_PACKAGES;
+import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR;
+import static com.facebook.presto.util.PropertiesUtil.loadProperties;
+import static com.google.api.client.util.Preconditions.checkState;
+import static com.google.common.base.MoreObjects.firstNonNull;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+public class FlightShimPluginManager
+{
+ private static final Logger log = Logger.get(FlightShimPluginManager.class);
+ private static final String SERVICES_FILE = "META-INF/services/" + Plugin.class.getName();
+ private final Map connectorFactories = new ConcurrentHashMap<>();
+ private final Map connectors = new ConcurrentHashMap<>();
+ private final File installedPluginsDir;
+ private final List plugins;
+ private final ArtifactResolver resolver;
+ private final AtomicBoolean pluginsLoading = new AtomicBoolean();
+ private final AtomicBoolean pluginsLoaded = new AtomicBoolean();
+ private final PluginInstaller pluginInstaller;
+ private final File catalogConfigurationDir;
+ private final Set disabledCatalogs;
+ private final AtomicBoolean catalogsLoading = new AtomicBoolean();
+ private final AtomicBoolean catalogsLoaded = new AtomicBoolean();
+ private final Map catalogPropertiesMap = new ConcurrentHashMap<>();
+ private final AtomicBoolean stopped = new AtomicBoolean();
+ private final Metadata metadata;
+ private final TypeDeserializer typeDeserializer;
+ private final BlockEncodingManager blockEncodingManager;
+ private final ProcedureRegistry procedureRegistry;
+
+ @Inject
+ public FlightShimPluginManager(
+ PluginManagerConfig pluginManagerConfig,
+ StaticCatalogStoreConfig catalogStoreConfig,
+ Metadata metadata,
+ TypeDeserializer typeDeserializer,
+ BlockEncodingManager blockEncodingManager,
+ ProcedureRegistry procedureRegistry)
+ {
+ requireNonNull(pluginManagerConfig, "pluginManagerConfig is null");
+ requireNonNull(catalogStoreConfig, "catalogStoreConfig is null");
+ this.metadata = requireNonNull(metadata, "metadata is null");
+ this.typeDeserializer = requireNonNull(typeDeserializer, "typeDeserializer is null");
+ this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null");
+ this.installedPluginsDir = pluginManagerConfig.getInstalledPluginsDir();
+ if (pluginManagerConfig.getPlugins() == null) {
+ this.plugins = ImmutableList.of();
+ }
+ else {
+ this.plugins = ImmutableList.copyOf(pluginManagerConfig.getPlugins());
+ }
+ this.resolver = new ArtifactResolver(pluginManagerConfig.getMavenLocalRepository(), pluginManagerConfig.getMavenRemoteRepository());
+ this.pluginInstaller = new FlightServerPluginInstaller();
+ this.catalogConfigurationDir = catalogStoreConfig.getCatalogConfigurationDir();
+ this.disabledCatalogs = ImmutableSet.copyOf(firstNonNull(catalogStoreConfig.getDisabledCatalogs(), ImmutableList.of()));
+ this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null");
+ }
+
+ @PreDestroy
+ public synchronized void stop()
+ {
+ if (stopped.getAndSet(true)) {
+ return;
+ }
+
+ for (Map.Entry entry : connectors.entrySet()) {
+ Connector connector = entry.getValue().getConnector();
+ try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(connector.getClass().getClassLoader())) {
+ connector.shutdown();
+ }
+ catch (Throwable t) {
+ log.error(t, "Error shutting down connector: %s", entry.getKey());
+ }
+ }
+ }
+
+ public void loadPlugins()
+ throws Exception
+ {
+ PluginManagerUtil.loadPlugins(
+ pluginsLoading,
+ pluginsLoaded,
+ installedPluginsDir,
+ plugins,
+ null,
+ resolver,
+ SPI_PACKAGES,
+ null,
+ SERVICES_FILE,
+ pluginInstaller,
+ getClass().getClassLoader());
+ }
+
+ public void loadCatalogs()
+ throws Exception
+ {
+ if (!catalogsLoading.compareAndSet(false, true)) {
+ return;
+ }
+
+ for (File file : listFiles(catalogConfigurationDir)) {
+ if (file.isFile() && file.getName().endsWith(".properties")) {
+ loadCatalog(file);
+ }
+ }
+
+ catalogsLoaded.set(true);
+ }
+
+ private void loadCatalog(File file)
+ throws Exception
+ {
+ String catalogName = Files.getNameWithoutExtension(file.getName());
+
+ log.info("-- Loading catalog properties %s --", file);
+ Map properties = loadProperties(file);
+ checkState(properties.containsKey("connector.name"), "Catalog configuration %s does not contain connector.name", file.getAbsoluteFile());
+
+ loadCatalog(catalogName, properties);
+ }
+
+ private void loadCatalog(String catalogName, Map properties)
+ {
+ if (disabledCatalogs.contains(catalogName)) {
+ log.info("Skipping disabled catalog %s", catalogName);
+ return;
+ }
+
+ log.info("-- Loading catalog %s --", catalogName);
+
+ String connectorName = null;
+ ImmutableMap.Builder connectorProperties = ImmutableMap.builder();
+ for (Map.Entry entry : properties.entrySet()) {
+ if (entry.getKey().equals("connector.name")) {
+ connectorName = entry.getValue();
+ }
+ else {
+ connectorProperties.put(entry.getKey(), entry.getValue());
+ }
+ }
+
+ checkState(connectorName != null, "Configuration for catalog %s does not contain connector.name", catalogName);
+
+ catalogPropertiesMap.put(catalogName, new CatalogPropertiesHolder(connectorName, connectorProperties.build()));
+
+ log.info("-- Added catalog %s using connector %s --", catalogName, connectorName);
+ }
+
+ @VisibleForTesting
+ void setCatalogProperties(String catalogName, String connectorName, Map properties)
+ {
+ catalogPropertiesMap.put(catalogName, new CatalogPropertiesHolder(connectorName, ImmutableMap.copyOf(properties)));
+ }
+
+ private static List listFiles(File installedPluginsDir)
+ {
+ if (installedPluginsDir != null && installedPluginsDir.isDirectory()) {
+ File[] files = installedPluginsDir.listFiles();
+ if (files != null) {
+ return ImmutableList.copyOf(files);
+ }
+ }
+ return ImmutableList.of();
+ }
+
+ private void registerPlugin(Plugin plugin)
+ {
+ for (ConnectorFactory factory : plugin.getConnectorFactories()) {
+ log.info("Registering connector %s", factory.getName());
+ connectorFactories.put(factory.getName(), factory);
+ }
+ }
+
+ public ConnectorHolder getConnector(String connectorId)
+ {
+ log.debug("FlightShimPluginManager getting connector: %s", connectorId);
+ CatalogPropertiesHolder catalogPropertiesHolder = catalogPropertiesMap.get(connectorId);
+ if (catalogPropertiesHolder == null) {
+ throw new IllegalArgumentException("Properties not loaded for " + connectorId);
+ }
+ final ImmutableMap config = catalogPropertiesHolder.getCatalogProperties();
+
+ // Create connector instances from factories as needed
+ return connectors.computeIfAbsent(catalogPropertiesHolder.getConnectorName(), name -> {
+ log.debug("Loading connector: %s", connectorId);
+ ConnectorFactory factory = connectorFactories.get(name);
+ requireNonNull(factory, format("No connector factory for %s", connectorId));
+
+ final RowExpressionDomainTranslator domainTranslator = new RowExpressionDomainTranslator(metadata);
+ final PredicateCompiler predicateCompiler = new RowExpressionPredicateCompiler(metadata);
+ final DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager());
+ final RowExpressionService rowExpressionService = new RowExpressionService()
+ {
+ @Override
+ public DomainTranslator getDomainTranslator()
+ {
+ return domainTranslator;
+ }
+
+ @Override
+ public ExpressionOptimizer getExpressionOptimizer(ConnectorSession session)
+ {
+ return new RowExpressionOptimizer(metadata);
+ }
+
+ @Override
+ public PredicateCompiler getPredicateCompiler()
+ {
+ return predicateCompiler;
+ }
+
+ @Override
+ public DeterminismEvaluator getDeterminismEvaluator()
+ {
+ return determinismEvaluator;
+ }
+
+ @Override
+ public String formatRowExpression(ConnectorSession session, RowExpression expression)
+ {
+ return new RowExpressionFormatter(metadata.getFunctionAndTypeManager()).formatRowExpression(session, expression);
+ }
+ };
+
+ final ExpressionOptimizerProvider expressionOptimizerProvider = (ConnectorSession session) -> new RowExpressionOptimizer(metadata);
+
+ ConnectorContext context = new ConnectorContextInstance(
+ new PluginNodeManager(new InMemoryNodeManager(), "flightconnector"),
+ metadata.getFunctionAndTypeManager(),
+ procedureRegistry,
+ metadata.getFunctionAndTypeManager(),
+ new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()),
+ new PagesIndexPageSorter(new PagesIndex.TestingFactory(false)),
+ new GroupByHashPageIndexerFactory(new JoinCompiler(metadata)),
+ rowExpressionService,
+ new ConnectorFilterStatsCalculatorService(new FilterStatsCalculator(metadata, new ScalarStatsCalculator(metadata, expressionOptimizerProvider), new StatsNormalizer())),
+ blockEncodingManager,
+ () -> false);
+
+ try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) {
+ ConnectorHolder holder = new ConnectorHolder(factory.create(name, config, context), factory.getHandleResolver(), typeDeserializer, blockEncodingManager);
+ log.debug("Finished loading connector: %s", connectorId);
+ return holder;
+ }
+ });
+ }
+
+ private class FlightServerPluginInstaller
+ implements PluginInstaller
+ {
+ @Override
+ public void installPlugin(Plugin plugin)
+ {
+ registerPlugin(plugin);
+ }
+
+ @Override
+ public void installCoordinatorPlugin(CoordinatorPlugin plugin) {}
+ }
+
+ private static class CatalogPropertiesHolder
+ {
+ private final String connectorName;
+ private final ImmutableMap catalogProperties;
+
+ CatalogPropertiesHolder(String connectorName, ImmutableMap catalogProperties)
+ {
+ this.connectorName = connectorName;
+ this.catalogProperties = catalogProperties;
+ }
+
+ String getConnectorName()
+ {
+ return connectorName;
+ }
+
+ ImmutableMap getCatalogProperties()
+ {
+ return catalogProperties;
+ }
+ }
+
+ public static class ConnectorHolder
+ {
+ private final Connector connector;
+ private final JsonCodec extends ConnectorSplit> codecSplit;
+ private final JsonCodec extends ColumnHandle> codecColumnHandle;
+ private final JsonCodec extends ConnectorTableHandle> codecTableHandle;
+ private final JsonCodec extends ConnectorTableLayoutHandle> codecTableLayoutHandle;
+ private final JsonCodec extends ConnectorTransactionHandle> codecTransactionHandle;
+ private final Method getColumnMetadataMethod;
+
+ ConnectorHolder(Connector connector, ConnectorHandleResolver resolver, TypeDeserializer typeDeserializer, BlockEncodingManager blockEncodingManager)
+ {
+ this.connector = connector;
+
+ JsonObjectMapperProvider provider = new JsonObjectMapperProvider();
+ JsonDeserializer> columnDeserializer = new JsonDeserializer()
+ {
+ @Override
+ public ColumnHandle deserialize(JsonParser p, DeserializationContext ctxt)
+ throws IOException
+ {
+ return p.readValueAs(resolver.getColumnHandleClass());
+ }
+ };
+ BlockJsonSerde.Deserializer blockDeserializer = new BlockJsonSerde.Deserializer(blockEncodingManager);
+ provider.setJsonDeserializers(ImmutableMap.of(
+ Type.class, typeDeserializer,
+ ColumnHandle.class, columnDeserializer,
+ Block.class, blockDeserializer));
+ JsonCodecFactory jsonCodecFactory = new JsonCodecFactory(provider);
+
+ this.codecSplit = jsonCodecFactory.jsonCodec(resolver.getSplitClass());
+ this.codecColumnHandle = jsonCodecFactory.jsonCodec(resolver.getColumnHandleClass());
+ this.codecTableHandle = jsonCodecFactory.jsonCodec(resolver.getTableHandleClass());
+ this.codecTableLayoutHandle = jsonCodecFactory.jsonCodec(resolver.getTableLayoutHandleClass());
+ this.codecTransactionHandle = jsonCodecFactory.jsonCodec(resolver.getTransactionHandleClass());
+ this.getColumnMetadataMethod = reflectGetColumnMetadata(resolver);
+ }
+
+ Connector getConnector()
+ {
+ return connector;
+ }
+
+ JsonCodec extends ConnectorSplit> getCodecSplit()
+ {
+ return codecSplit;
+ }
+
+ JsonCodec extends ColumnHandle> getCodecColumnHandle()
+ {
+ return codecColumnHandle;
+ }
+
+ JsonCodec extends ConnectorTableHandle> getCodecTableHandle()
+ {
+ return codecTableHandle;
+ }
+
+ JsonCodec extends ConnectorTableLayoutHandle> getCodecTableLayoutHandle()
+ {
+ return codecTableLayoutHandle;
+ }
+
+ JsonCodec extends ConnectorTransactionHandle> getCodecTransactionHandle()
+ {
+ return codecTransactionHandle;
+ }
+
+ ColumnMetadata getColumnMetadata(ColumnHandle handle)
+ {
+ try {
+ return (ColumnMetadata) getColumnMetadataMethod.invoke(handle);
+ }
+ catch (InvocationTargetException | IllegalAccessException e) {
+ throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unable to invoke method for getColumnMetadata", e);
+ }
+ }
+
+ private static Method reflectGetColumnMetadata(ConnectorHandleResolver resolver)
+ {
+ try {
+ return resolver.getColumnHandleClass().getMethod("getColumnMetadata");
+ }
+ catch (NoSuchMethodException e) {
+ try {
+ return resolver.getColumnHandleClass().getMethod("toColumnMetadata");
+ }
+ catch (NoSuchMethodException e2) {
+ throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unable to get column metadata from ColumnHandle", e);
+ }
+ }
+ }
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java
new file mode 100644
index 0000000000000..79fefac418c96
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java
@@ -0,0 +1,203 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.log.Logger;
+import com.facebook.plugin.arrow.ArrowBatchSource;
+import com.facebook.presto.Session;
+import com.facebook.presto.common.RuntimeStats;
+import com.facebook.presto.execution.QueryIdGenerator;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.facebook.presto.spi.ConnectorId;
+import com.facebook.presto.spi.ConnectorPageSource;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.ConnectorSplit;
+import com.facebook.presto.spi.ConnectorTableHandle;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.facebook.presto.spi.SplitContext;
+import com.facebook.presto.spi.TableHandle;
+import com.facebook.presto.spi.connector.Connector;
+import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
+import com.facebook.presto.spi.connector.ConnectorRecordSetProvider;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+import com.facebook.presto.spi.security.Identity;
+import com.facebook.presto.spi.transaction.IsolationLevel;
+import com.facebook.presto.split.RecordPageSourceProvider;
+import org.apache.arrow.flight.BackpressureStrategy;
+import org.apache.arrow.flight.CallStatus;
+import org.apache.arrow.flight.NoOpFlightProducer;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+
+import javax.inject.Inject;
+
+import java.io.Closeable;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.ExecutorService;
+
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
+import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager;
+import static com.facebook.presto.testing.TestingSession.DEFAULT_TIME_ZONE_KEY;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.lang.String.format;
+import static java.util.Collections.unmodifiableList;
+import static java.util.Locale.ENGLISH;
+import static java.util.Objects.requireNonNull;
+
+public class FlightShimProducer
+ extends NoOpFlightProducer
+ implements Closeable
+{
+ private static final Logger log = Logger.get(FlightShimProducer.class);
+ private static final JsonCodec REQUEST_JSON_CODEC = jsonCodec(FlightShimRequest.class);
+ private static final int CLIENT_POLL_TIME = 5000; // Backpressure poll time ms
+ private final BufferAllocator allocator;
+ private final FlightShimPluginManager pluginManager;
+ private final FlightShimConfig config;
+ private final ExecutorService shimExecutor;
+
+ @Inject
+ public FlightShimProducer(BufferAllocator allocator, FlightShimPluginManager pluginManager, FlightShimConfig config, @ForFlightShimServer ExecutorService shimExecutor)
+ {
+ this.allocator = allocator.newChildAllocator("flight-shim", 0, Long.MAX_VALUE);
+ this.pluginManager = requireNonNull(pluginManager, "pluginManager is null");
+ this.config = requireNonNull(config, "config is null");
+ this.shimExecutor = requireNonNull(shimExecutor, "shimExecutor is null");
+ }
+
+ @Override
+ public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener)
+ {
+ log.debug("Received GetStream request");
+ shimExecutor.submit(() -> runGetStreamAsync(context, ticket, listener));
+ }
+
+ private void runGetStreamAsync(CallContext context, Ticket ticket, ServerStreamListener listener)
+ {
+ log.debug("Starting GetStream processing");
+ int columnCount = 0;
+ int rowCount = 0;
+ int batchCount = 0;
+ try {
+ final BackpressureStrategy backpressureStrategy = new BackpressureStrategy.CallbackBackpressureStrategy();
+ backpressureStrategy.register(listener);
+
+ FlightShimRequest request = REQUEST_JSON_CODEC.fromJson(ticket.getBytes());
+ log.debug("Request for connector: %s", request.getConnectorId());
+
+ FlightShimPluginManager.ConnectorHolder connectorHolder = pluginManager.getConnector(request.getConnectorId());
+ requireNonNull(connectorHolder, format("Requested connector not loaded: %s", request.getConnectorId()));
+
+ Connector connector = connectorHolder.getConnector();
+ ConnectorSplit split = connectorHolder.getCodecSplit().fromJson(request.getSplitBytes());
+
+ List extends ColumnHandle> columnHandles = request.getColumnHandlesBytes().stream().map(
+ columnHandleBytes -> connectorHolder.getCodecColumnHandle().fromJson(columnHandleBytes)
+ ).collect(toImmutableList());
+
+ List columnsMetadata = columnHandles.stream()
+ .map(connectorHolder::getColumnMetadata).collect(toImmutableList());
+
+ ConnectorTableHandle connectorTableHandle = connectorHolder.getCodecTableHandle().fromJson(request.getTableHandleBytes());
+ ConnectorTransactionHandle transactionHandle = connectorHolder.getCodecTransactionHandle().fromJson(request.getTransactionHandleBytes());
+ Optional connectorTableLayoutHandle =
+ request.getTableLayoutHandleBytes().map( tableLayoutHandleBytes -> connectorHolder.getCodecTableLayoutHandle().fromJson(tableLayoutHandleBytes));
+ TableHandle tableHandle = new TableHandle(new ConnectorId(request.getConnectorId()), connectorTableHandle, transactionHandle, connectorTableLayoutHandle);
+
+ // Create a dummy session to load the connector
+ QueryIdGenerator queryIdGenerator = new QueryIdGenerator();
+ Session session = Session.builder(createTestingSessionPropertyManager())
+ .setQueryId(queryIdGenerator.createNextQueryId())
+ .setIdentity(new Identity("user", Optional.empty()))
+ .setTimeZoneKey(DEFAULT_TIME_ZONE_KEY)
+ .setLocale(ENGLISH).build();
+ ConnectorId connectorId = new ConnectorId(request.getConnectorId());
+ ConnectorSession connectorSession = session.toConnectorSession(connectorId);
+
+ ConnectorPageSourceProvider connectorPageSourceProvider = getConnectorPageSourceProvider(connector, connectorId);
+ ConnectorPageSource connectorPageSource = connectorPageSourceProvider.createPageSource(
+ transactionHandle,
+ connectorSession,
+ split,
+ null,
+ unmodifiableList(columnHandles),
+ new SplitContext(false),
+ new RuntimeStats());
+
+ try (ArrowBatchSource batchSource = new ArrowBatchSource(allocator, columnsMetadata, connectorPageSource, config.getMaxRowsPerBatch())) {
+ listener.setUseZeroCopy(true);
+ listener.start(batchSource.getVectorSchemaRoot());
+ columnCount = batchSource.getVectorSchemaRoot().getFieldVectors().size();
+ while (batchSource.nextBatch()) {
+ BackpressureStrategy.WaitResult waitResult;
+ while ((waitResult = backpressureStrategy.waitForListener(CLIENT_POLL_TIME)) == BackpressureStrategy.WaitResult.TIMEOUT) {
+ log.debug(format("Waiting for client to read from connector %s", request.getConnectorId()));
+ }
+ if (waitResult != BackpressureStrategy.WaitResult.READY) {
+ log.info(format("Read stopped from connector %s due to client wait result: %s", request.getConnectorId(), waitResult));
+ break;
+ }
+ rowCount += batchSource.getVectorSchemaRoot().getRowCount();
+ batchCount++;
+ listener.putNext();
+ }
+ listener.completed();
+ }
+ }
+ catch (Throwable t) {
+ final String message = "Error getting connector flight stream";
+ log.error(t, message);
+ listener.error(CallStatus.INTERNAL.withCause(t).withDescription(format("%s [%s]", message, t)).toRuntimeException());
+ }
+ finally {
+ log.debug(format("Processing GetStream completed [columns=%d, rows=%d, batches=%d]", columnCount, rowCount, batchCount));
+ }
+ }
+
+ private ConnectorPageSourceProvider getConnectorPageSourceProvider(Connector connector, ConnectorId connectorId)
+ {
+ ConnectorPageSourceProvider connectorPageSourceProvider = null;
+ try {
+ connectorPageSourceProvider = connector.getPageSourceProvider();
+ requireNonNull(connectorPageSourceProvider, format("Connector %s returned a null page source provider", connectorId));
+ }
+ catch (UnsupportedOperationException ignored) {
+ }
+
+ if (connectorPageSourceProvider == null) {
+ ConnectorRecordSetProvider connectorRecordSetProvider = null;
+ try {
+ connectorRecordSetProvider = connector.getRecordSetProvider();
+ requireNonNull(connectorRecordSetProvider, format("Connector %s returned a null record set provider", connectorId));
+ }
+ catch (UnsupportedOperationException ignored) {
+ }
+ checkState(connectorRecordSetProvider != null, "Connector %s has neither a PageSource or RecordSet provider", connectorId);
+ connectorPageSourceProvider = new RecordPageSourceProvider(connectorRecordSetProvider);
+ }
+ return connectorPageSourceProvider;
+ }
+
+ @Override
+ public void close()
+ {
+ shimExecutor.shutdownNow();
+ pluginManager.stop();
+ allocator.close();
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java
new file mode 100644
index 0000000000000..33f6753515b9e
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+import com.google.common.collect.ImmutableList;
+
+import java.util.List;
+import java.util.Optional;
+
+import static java.util.Objects.requireNonNull;
+
+public class FlightShimRequest
+{
+ private final String connectorId;
+ private final byte[] splitBytes;
+ private final List columnHandlesBytes;
+ private final byte[] tableHandleBytes;
+ private final Optional tableLayoutHandleBytes;
+ private final byte[] transactionHandleBytes;
+
+ @JsonCreator
+ public FlightShimRequest(
+ @JsonProperty("connectorId") String connectorId,
+ @JsonProperty("splitBytes") byte[] splitBytes,
+ @JsonProperty("columnHandlesBytes") List columnHandlesBytes,
+ @JsonProperty("tableHandleBytes") byte[] tableHandleBytes,
+ @JsonProperty("tableLayoutHandleBytes") Optional tableLayoutHandleBytes,
+ @JsonProperty("transactionHandleBytes") byte[] transactionHandleBytes)
+ {
+ this.connectorId = requireNonNull(connectorId, "connectorId is null");
+ this.splitBytes = requireNonNull(splitBytes, "splitBytes is null");
+ this.columnHandlesBytes = ImmutableList.copyOf(requireNonNull(columnHandlesBytes, "columnHandlesBytes is null"));
+ this.tableHandleBytes = requireNonNull(tableHandleBytes, "tableHandleBytes is null");
+ this.tableLayoutHandleBytes = tableLayoutHandleBytes;
+ this.transactionHandleBytes = requireNonNull(transactionHandleBytes, "transactionHandleBytes is null");
+ }
+
+ @JsonProperty
+ public String getConnectorId()
+ {
+ return connectorId;
+ }
+
+ @JsonProperty
+ public byte[] getSplitBytes()
+ {
+ return splitBytes;
+ }
+
+ @JsonProperty
+ public List getColumnHandlesBytes()
+ {
+ return columnHandlesBytes;
+ }
+
+ @JsonProperty
+ public byte[] getTableHandleBytes()
+ {
+ return tableHandleBytes;
+ }
+
+ @JsonProperty
+ public Optional getTableLayoutHandleBytes()
+ {
+ return tableLayoutHandleBytes;
+ }
+
+ @JsonProperty
+ public byte[] getTransactionHandleBytes()
+ {
+ return transactionHandleBytes;
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java
new file mode 100644
index 0000000000000..581c5fa3a9867
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.bootstrap.Bootstrap;
+import com.facebook.airlift.json.JsonModule;
+import com.facebook.airlift.log.Logger;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.inject.Injector;
+import com.google.inject.Key;
+import com.google.inject.Module;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.grpc.ContextPropagatingExecutorService;
+import org.apache.arrow.memory.BufferAllocator;
+
+import java.io.File;
+import java.util.Map;
+import java.util.concurrent.ExecutorService;
+
+import static java.lang.String.format;
+
+public class FlightShimServer
+{
+ private FlightShimServer()
+ {
+ }
+
+ public static Injector initialize(Map config, Module... extraModules)
+ {
+ Bootstrap app = new Bootstrap(ImmutableList.builder()
+ .add(new FlightShimModule())
+ .add(new JsonModule())
+ .add(extraModules)
+ .build());
+
+ if (config != null && !config.isEmpty()) {
+ // Required config was provided instead of vm option -Dconfig=
+ app.setRequiredConfigurationProperties(config);
+ }
+
+ return app.initialize();
+ }
+
+ public static FlightServer start(Injector injector, FlightServer.Builder builder)
+ throws Exception
+ {
+ FlightShimPluginManager pluginManager = injector.getInstance(FlightShimPluginManager.class);
+ pluginManager.loadPlugins();
+ pluginManager.loadCatalogs();
+
+ builder.allocator(injector.getInstance(BufferAllocator.class));
+ FlightShimConfig config = injector.getInstance(FlightShimConfig.class);
+
+ if (config.getServerName() == null || config.getServerPort() == null) {
+ throw new IllegalArgumentException("Required configuration 'flight-shim.server' and 'flight-shim.server.port' not set");
+ }
+
+ if (config.getServerSslEnabled()) {
+ builder.location(Location.forGrpcTls(config.getServerName(), config.getServerPort()));
+ if (config.getServerSSLCertificateFile() == null || config.getServerSSLKeyFile() == null) {
+ throw new IllegalArgumentException("'flight-shim.server-ssl-enabled' is enabled but 'flight-shim.server-ssl-certificate-file' or 'flight-shim.server-ssl-key-file' not set");
+ }
+ File certChainFile = new File(config.getServerSSLCertificateFile());
+ File privateKeyFile = new File(config.getServerSSLKeyFile());
+ builder.useTls(certChainFile, privateKeyFile);
+
+ // Check if client cert is provided for mTLS
+ if (config.getClientSSLCertificateFile() != null) {
+ File clientCertFile = new File(config.getClientSSLCertificateFile());
+ builder.useMTlsClientVerification(clientCertFile);
+ }
+ }
+ else {
+ builder.location(Location.forGrpcInsecure(config.getServerName(), config.getServerPort()));
+ }
+
+ ExecutorService executor = injector.getInstance(Key.get(ExecutorService.class, ForFlightShimServer.class));
+ builder.executor(new ContextPropagatingExecutorService(executor));
+
+ FlightShimProducer producer = injector.getInstance(FlightShimProducer.class);
+ builder.producer(producer);
+
+ FlightServer server = builder.build();
+ server.start();
+
+ return server;
+ }
+
+ public static void main(String[] args)
+ {
+ Logger log = Logger.get(FlightShimModule.class);
+
+ final Map config;
+ if (System.getProperty("config") == null) {
+ log.info("FlightShim server using default config, override with -Dconfig=");
+ ImmutableMap.Builder configBuilder = ImmutableMap.builder();
+ configBuilder.put("flight-shim.server", "localhost");
+ configBuilder.put("flight-shim.server.port", String.valueOf(9443));
+ configBuilder.put("flight-shim.server-ssl-certificate-file", "src/test/resources/certs/server.crt");
+ configBuilder.put("flight-shim.server-ssl-key-file", "src/test/resources/certs/server.key");
+ config = configBuilder.build();
+ }
+ else {
+ log.info("FlightShim server using config from: " + System.getProperty("config"));
+ config = ImmutableMap.of();
+ }
+ Injector injector = initialize(config);
+
+ log.info("FlightShim server initializing");
+ try (FlightServer server = start(injector, FlightServer.builder());
+ FlightShimProducer producer = injector.getInstance(FlightShimProducer.class)) {
+ log.info(format("======== FlightShim Server started on port: %s ========", server.getPort()));
+ server.awaitTermination();
+ }
+ catch (Throwable t) {
+ log.error(t);
+ System.exit(1);
+ }
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java
new file mode 100644
index 0000000000000..4cc2cfafb10bd
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java
@@ -0,0 +1,41 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean;
+import org.weakref.jmx.Managed;
+import org.weakref.jmx.Nested;
+
+import javax.inject.Inject;
+
+import java.util.concurrent.ExecutorService;
+import java.util.concurrent.ThreadPoolExecutor;
+
+public class FlightShimServerExecutionMBean
+{
+ private final ThreadPoolExecutorMBean executorMBean;
+
+ @Inject
+ public FlightShimServerExecutionMBean(@ForFlightShimServer ExecutorService executor)
+ {
+ this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor);
+ }
+
+ @Managed
+ @Nested
+ public ThreadPoolExecutorMBean getExecutor()
+ {
+ return executorMBean;
+ }
+}
diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java
new file mode 100644
index 0000000000000..f100bc2359e4a
--- /dev/null
+++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import jakarta.inject.Qualifier;
+
+import java.lang.annotation.Retention;
+import java.lang.annotation.Target;
+
+import static java.lang.annotation.ElementType.FIELD;
+import static java.lang.annotation.ElementType.METHOD;
+import static java.lang.annotation.ElementType.PARAMETER;
+import static java.lang.annotation.RetentionPolicy.RUNTIME;
+
+@Retention(RUNTIME)
+@Target({FIELD, PARAMETER, METHOD})
+@Qualifier
+public @interface ForFlightShimServer
+{
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java
new file mode 100644
index 0000000000000..621ca3fac0798
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java
@@ -0,0 +1,346 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.common.type.CharType;
+import com.facebook.presto.common.type.DecimalType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.VarcharType;
+import com.facebook.presto.plugin.jdbc.JdbcColumnHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTableHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTransactionHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTypeHandle;
+import com.facebook.presto.spi.SchemaTableName;
+import com.facebook.presto.tpch.TpchColumnHandle;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.testng.annotations.Test;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.sql.Types;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
+import static com.facebook.airlift.testing.Assertions.assertGreaterThan;
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.BooleanType.BOOLEAN;
+import static com.facebook.presto.common.type.Chars.isCharType;
+import static com.facebook.presto.common.type.DateType.DATE;
+import static com.facebook.presto.common.type.DoubleType.DOUBLE;
+import static com.facebook.presto.common.type.IntegerType.INTEGER;
+import static com.facebook.presto.common.type.RealType.REAL;
+import static com.facebook.presto.common.type.SmallintType.SMALLINT;
+import static com.facebook.presto.common.type.TimeType.TIME;
+import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE;
+import static com.facebook.presto.common.type.TimestampType.TIMESTAMP;
+import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE;
+import static com.facebook.presto.common.type.TinyintType.TINYINT;
+import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
+import static com.facebook.presto.common.type.Varchars.isVarcharType;
+import static com.facebook.presto.util.ResourceFileUtils.getResourceFile;
+import static java.lang.String.format;
+
+@Test(singleThreaded = true)
+public abstract class AbstractTestFlightShimJdbcPlugins
+ extends AbstractTestFlightShimPlugins
+{
+ public static final JsonCodec COLUMN_HANDLE_JSON_CODEC = jsonCodec(JdbcColumnHandle.class);
+ public static final JsonCodec TABLE_HANDLE_JSON_CODEC = jsonCodec(JdbcTableHandle.class);
+ public static final JsonCodec TRANSACTION_HANDLE_JSON_CODEC = jsonCodec(JdbcTransactionHandle.class);
+
+ protected abstract String getConnectionUrl();
+
+ @Override
+ protected Map getConnectorProperties()
+ {
+ Map connectorProperties = new HashMap<>();
+ connectorProperties.putIfAbsent("connection-url", getConnectionUrl());
+ return connectorProperties;
+ }
+
+ @Test
+ void testJdbcSplitWithTupleDomain() throws Exception
+ {
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequestWithTupleDomain()));
+
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ }
+ }
+
+ assertGreaterThan(rowCount, 0);
+ }
+ }
+
+ @Test
+ void testJdbcSplitWithAdditionalPredicate() throws Exception
+ {
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequestWithAdditionalPredicate()));
+
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ }
+ }
+
+ assertGreaterThan(rowCount, 0);
+ }
+ }
+
+ @Test
+ public void testSelectColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN));
+ }
+
+ @Test
+ public void testDateColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, SHIPDATE_COLUMN));
+ }
+
+ @Test
+ public void testFloatingPointColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, QUANTITY_COLUMN, EXTENDEDPRICE_COLUMN));
+ }
+
+ private JdbcColumnHandle convertToJdbcColumnHandle(TpchColumnHandle columnHandle)
+ {
+ Type type = columnHandle.getType();
+ return new JdbcColumnHandle(
+ getConnectorId(),
+ columnHandle.getColumnName(),
+ new JdbcTypeHandle(jdbcDataType(type), type.getDisplayName(), columnSize(type), 0),
+ type,
+ false,
+ Optional.empty());
+ }
+
+ @Override
+ protected FlightShimRequest createTpchTableRequest(int partNumber, int totalParts, List columnHandles)
+ {
+ Preconditions.checkArgument(totalParts == 1, "JDBC request must be of a single partition");
+ String split = createJdbcSplit(getConnectorId(), "tpch", TPCH_TABLE);
+ byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8);
+
+ ImmutableList.Builder columnBuilder = ImmutableList.builder();
+ for (TpchColumnHandle columnHandle : columnHandles) {
+ columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(convertToJdbcColumnHandle(columnHandle)));
+ }
+
+ JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE);
+ byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle);
+ byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle());
+
+ return new FlightShimRequest(getConnectorId(), splitBytes, columnBuilder.build(), tableHandleBytes, Optional.empty(), transactionHandleBytes);
+ }
+
+ protected FlightShimRequest createTpchTableRequestWithTupleDomain() throws Exception
+ {
+ JdbcColumnHandle orderKeyHandle = convertToJdbcColumnHandle(getOrderKeyColumn());
+ byte[] splitBytes = Files.readAllBytes(getResourceFile("split_tuple_domain.json").toPath());
+
+ List columnHandles = ImmutableList.of(orderKeyHandle);
+ ImmutableList.Builder columnBuilder = ImmutableList.builder();
+ for (JdbcColumnHandle columnHandle : columnHandles) {
+ columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(columnHandle));
+ }
+
+ JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE);
+ byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle);
+ byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle());
+
+ return new FlightShimRequest(
+ getConnectorId(),
+ splitBytes,
+ columnBuilder.build(),
+ tableHandleBytes,
+ Optional.empty(),
+ transactionHandleBytes);
+ }
+
+ protected FlightShimRequest createTpchTableRequestWithAdditionalPredicate()
+ throws IOException
+ {
+ // Query is: "SELECT orderkey FROM orders WHERE orderkey IN (1, 2, 3)"
+ JdbcColumnHandle orderKeyHandle = convertToJdbcColumnHandle(getOrderKeyColumn());
+ byte[] splitBytes = Files.readAllBytes(getResourceFile("split_additional_predicate.json").toPath());
+
+ List columnHandles = ImmutableList.of(orderKeyHandle);
+ ImmutableList.Builder columnBuilder = ImmutableList.builder();
+ for (JdbcColumnHandle columnHandle : columnHandles) {
+ columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(columnHandle));
+ }
+
+ JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE);
+ byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle);
+ byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle());
+
+ return new FlightShimRequest(
+ getConnectorId(),
+ splitBytes,
+ columnBuilder.build(),
+ tableHandleBytes,
+ Optional.empty(),
+ transactionHandleBytes);
+ }
+
+ protected static String removeDatabaseFromJdbcUrl(String jdbcUrl)
+ {
+ return jdbcUrl.replaceFirst("/[^/?]+([?]|$)", "/$1");
+ }
+
+ protected static String addDatabaseCredentialsToJdbcUrl(String jdbcUrl, String username, String password)
+ {
+ return jdbcUrl + (jdbcUrl.contains("?") ? "&" : "?") +
+ "user=" + username + "&password=" + password;
+ }
+
+ protected static String createJdbcSplit(String connectorId, String schemaName, String tableName)
+ {
+ return format("{\n" +
+ " \"connectorId\" : \"%s\",\n" +
+ " \"schemaName\" : \"%s\",\n" +
+ " \"tableName\" : \"%s\",\n" +
+ " \"tupleDomain\" : {\n" +
+ " \"columnDomains\" : [ ]\n" +
+ " }\n" +
+ "}", connectorId, schemaName, tableName);
+ }
+
+ static int jdbcDataType(Type type)
+ {
+ if (type.equals(BOOLEAN)) {
+ return Types.BOOLEAN;
+ }
+ if (type.equals(BIGINT)) {
+ return Types.BIGINT;
+ }
+ if (type.equals(INTEGER)) {
+ return Types.INTEGER;
+ }
+ if (type.equals(SMALLINT)) {
+ return Types.SMALLINT;
+ }
+ if (type.equals(TINYINT)) {
+ return Types.TINYINT;
+ }
+ if (type.equals(REAL)) {
+ return Types.REAL;
+ }
+ if (type.equals(DOUBLE)) {
+ return Types.DOUBLE;
+ }
+ if (type instanceof DecimalType) {
+ return Types.DECIMAL;
+ }
+ if (isVarcharType(type)) {
+ return Types.VARCHAR;
+ }
+ if (isCharType(type)) {
+ return Types.CHAR;
+ }
+ if (type.equals(VARBINARY)) {
+ return Types.VARBINARY;
+ }
+ if (type.equals(TIME)) {
+ return Types.TIME;
+ }
+ if (type.equals(TIME_WITH_TIME_ZONE)) {
+ return Types.TIME_WITH_TIMEZONE;
+ }
+ if (type.equals(TIMESTAMP)) {
+ return Types.TIMESTAMP;
+ }
+ if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
+ return Types.TIMESTAMP_WITH_TIMEZONE;
+ }
+ if (type.equals(DATE)) {
+ return Types.DATE;
+ }
+ if (type instanceof ArrayType) {
+ return Types.ARRAY;
+ }
+ return Types.JAVA_OBJECT;
+ }
+
+ static int columnSize(Type type)
+ {
+ if (type.equals(BIGINT)) {
+ return 19; // 2**63-1
+ }
+ if (type.equals(INTEGER)) {
+ return 10; // 2**31-1
+ }
+ if (type.equals(SMALLINT)) {
+ return 5; // 2**15-1
+ }
+ if (type.equals(TINYINT)) {
+ return 3; // 2**7-1
+ }
+ if (type instanceof DecimalType) {
+ return ((DecimalType) type).getPrecision();
+ }
+ if (type.equals(REAL)) {
+ return 24; // IEEE 754
+ }
+ if (type.equals(DOUBLE)) {
+ return 53; // IEEE 754
+ }
+ if (isVarcharType(type)) {
+ return ((VarcharType) type).getLength();
+ }
+ if (isCharType(type)) {
+ return ((CharType) type).getLength();
+ }
+ if (type.equals(VARBINARY)) {
+ return Integer.MAX_VALUE;
+ }
+ if (type.equals(TIME)) {
+ return 8; // 00:00:00
+ }
+ if (type.equals(TIME_WITH_TIME_ZONE)) {
+ return 8 + 6; // 00:00:00+00:00
+ }
+ if (type.equals(DATE)) {
+ return 14; // +5881580-07-11 (2**31-1 days)
+ }
+ if (type.equals(TIMESTAMP)) {
+ return 15 + 8;
+ }
+ if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) {
+ return 15 + 8 + 6;
+ }
+ return 0;
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java
new file mode 100644
index 0000000000000..2341ef131cb0c
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java
@@ -0,0 +1,544 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.plugin.arrow.ArrowBlockBuilder;
+import com.facebook.presto.Session;
+import com.facebook.presto.common.Page;
+import com.facebook.presto.common.QualifiedObjectName;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.type.BigintType;
+import com.facebook.presto.common.type.DateType;
+import com.facebook.presto.common.type.DoubleType;
+import com.facebook.presto.common.type.IntegerType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.common.type.VarcharType;
+import com.facebook.presto.cost.StatsCalculator;
+import com.facebook.presto.metadata.Metadata;
+import com.facebook.presto.spi.Plugin;
+import com.facebook.presto.spi.eventlistener.EventListener;
+import com.facebook.presto.split.PageSourceManager;
+import com.facebook.presto.split.SplitManager;
+import com.facebook.presto.sql.expressions.ExpressionOptimizerManager;
+import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager;
+import com.facebook.presto.sql.planner.NodePartitioningManager;
+import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager;
+import com.facebook.presto.testing.MaterializedResult;
+import com.facebook.presto.testing.QueryRunner;
+import com.facebook.presto.testing.TestingAccessControlManager;
+import com.facebook.presto.tests.AbstractTestQueryFramework;
+import com.facebook.presto.tpch.TpchColumnHandle;
+import com.facebook.presto.tpch.TpchTableHandle;
+import com.facebook.presto.tpch.TpchTableLayoutHandle;
+import com.facebook.presto.tpch.TpchTransactionHandle;
+import com.facebook.presto.transaction.TransactionManager;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.google.inject.Injector;
+import org.apache.arrow.flight.CallOption;
+import org.apache.arrow.flight.CallOptions;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.FieldVector;
+import org.intellij.lang.annotations.Language;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import java.io.ByteArrayInputStream;
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.ServerSocket;
+import java.nio.charset.StandardCharsets;
+import java.nio.file.Files;
+import java.nio.file.Paths;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Locale;
+import java.util.Map;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.locks.Lock;
+import java.util.stream.Collectors;
+
+import static com.facebook.airlift.json.JsonCodec.jsonCodec;
+import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR;
+import static java.lang.String.format;
+
+@Test(singleThreaded = true)
+public abstract class AbstractTestFlightShimPlugins
+ extends AbstractTestQueryFramework
+{
+ public static final JsonCodec REQUEST_JSON_CODEC = jsonCodec(FlightShimRequest.class);
+ public static final JsonCodec TPCH_COLUMN_JSON_CODEC = jsonCodec(TpchColumnHandle.class);
+ public static final JsonCodec TPCH_TABLE_HANDLE_JSON_CODEC = jsonCodec(TpchTableHandle.class);
+ public static final JsonCodec TPCH_TABLE_LAYOUT_HANDLE_JSON_CODEC = jsonCodec(TpchTableLayoutHandle.class);
+ public static final JsonCodec TPCH_TRANSACTION_HANDLE_JSON_CODEC = jsonCodec(TpchTransactionHandle.class);
+ public static final String TPCH_TABLE = "lineitem";
+ public static final String ORDERKEY_COLUMN = "orderkey";
+ public static final String LINENUMBER_COLUMN = "linenumber";
+ public static final String LINESTATUS_COLUMN = "linestatus";
+ public static final String EXTENDEDPRICE_COLUMN = "extendedprice";
+ public static final String QUANTITY_COLUMN = "quantity";
+ public static final String SHIPDATE_COLUMN = "shipdate";
+ public static final String SHIPINSTRUCT_COLUMN = "shipinstruct";
+ protected final List closables = new ArrayList<>();
+ protected static final CallOption CALL_OPTIONS = CallOptions.timeout(300, TimeUnit.SECONDS);
+ protected BufferAllocator allocator;
+ protected FlightServer server;
+ private ArrowBlockBuilder blockBuilder;
+
+ protected abstract String getConnectorId();
+
+ protected abstract String getPluginBundles();
+
+ @BeforeClass
+ public void setup()
+ throws Exception
+ {
+ ImmutableMap.Builder configBuilder = ImmutableMap.builder();
+ configBuilder.put("flight-shim.server", "localhost");
+ configBuilder.put("flight-shim.server.port", String.valueOf(findUnusedPort()));
+ configBuilder.put("flight-shim.server-ssl-certificate-file", "src/test/resources/certs/server.crt");
+ configBuilder.put("flight-shim.server-ssl-key-file", "src/test/resources/certs/server.key");
+ configBuilder.put("plugin.bundles", getPluginBundles());
+
+ // Allow for 3 batches using testing tpch db
+ configBuilder.put("flight-shim.max-rows-per-batch", String.valueOf(500));
+
+ Injector injector = FlightShimServer.initialize(configBuilder.build());
+
+ server = FlightShimServer.start(injector, FlightServer.builder());
+ allocator = injector.getInstance(BufferAllocator.class);
+
+ // Set test properties after catalogs have been loaded
+ FlightShimPluginManager pluginManager = injector.getInstance(FlightShimPluginManager.class);
+ Map connectorProperties = getConnectorProperties();
+ pluginManager.setCatalogProperties(getConnectorId(), getConnectorId(), connectorProperties);
+
+ TypeManager typeManager = injector.getInstance(TypeManager.class);
+ blockBuilder = new ArrowBlockBuilder(typeManager);
+
+ // Make sure these resources close properly
+ closables.add(server);
+ closables.add(injector.getInstance(FlightShimProducer.class));
+ closables.add(allocator);
+ }
+
+ @Override
+ @AfterClass(alwaysRun = true)
+ public void close()
+ throws Exception
+ {
+ super.close();
+ for (AutoCloseable closeable : closables) {
+ closeable.close();
+ }
+ closables.clear();
+ }
+
+ protected Map getConnectorProperties()
+ {
+ return ImmutableMap.of();
+ }
+
+ protected int getExpectedTotalParts()
+ {
+ return 1;
+ }
+
+ @Override
+ protected QueryRunner createQueryRunner()
+ {
+ return new FlightShimQueryRunner(getExpectedTotalParts());
+ }
+
+ /*
+ @Test
+ void testWithMtls() throws Exception
+ {
+ InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/server.crt")));
+ InputStream clientCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/client.crt")));
+ InputStream clientKey = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/client.key")));
+
+ Location location = Location.forGrpcTls("localhost", server.getPort());
+
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = FlightClient.builder(bufferAllocator, location).useTls().clientCertificate(clientCertificate, clientKey).trustedCertificates(trustedCertificate).build()) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(ImmutableList.of(getOrderKeyColumn()))));
+
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ }
+ }
+
+ assertGreaterThan(rowCount, 0);
+ }
+ }
+
+ @Test
+ public void testSelectColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN));
+ }
+
+ @Test
+ public void testDateColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, SHIPDATE_COLUMN));
+ }
+
+ @Test
+ public void testFloatingPointColumns()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, QUANTITY_COLUMN, EXTENDEDPRICE_COLUMN));
+ }
+
+ protected void assertSelectQueryFromColumns(List tpchColumnNames)
+ {
+ @Language("SQL") String query = format("SELECT %s FROM %s", String.join(",", tpchColumnNames), TPCH_TABLE);
+ assertQuery(query);
+ }*/
+
+ protected List getHandlesFromSelectQuery(String sql)
+ {
+ String sqlLower = sql.toLowerCase(Locale.ENGLISH);
+ int start = sqlLower.indexOf("select");
+ if (start < 0) {
+ throw new RuntimeException("Expected 'SELECT' in query: " + sql);
+ }
+ start += "select".length();
+ int stop = sqlLower.indexOf("from");
+ if (stop < start) {
+ throw new RuntimeException("Expected 'FROM' in query: " + sql);
+ }
+ String columnsString = sql.substring(start, stop);
+ List columns = Arrays.stream(columnsString.split(",")).map(String::trim).collect(Collectors.toList());
+
+ if (columns.isEmpty()) {
+ throw new RuntimeException("No columns found in query: " + sql);
+ }
+
+ ImmutableList.Builder columnHandlesBuilder = ImmutableList.builder();
+ for (String column : columns) {
+ switch (column) {
+ case ORDERKEY_COLUMN:
+ columnHandlesBuilder.add(getOrderKeyColumn());
+ break;
+ case LINENUMBER_COLUMN:
+ columnHandlesBuilder.add(getLineNumberColumn());
+ break;
+ case QUANTITY_COLUMN:
+ columnHandlesBuilder.add(getQuantityColumn());
+ break;
+ case EXTENDEDPRICE_COLUMN:
+ columnHandlesBuilder.add(getExtendedPriceColumn());
+ break;
+ case LINESTATUS_COLUMN:
+ columnHandlesBuilder.add(getLineStatusColumn());
+ break;
+ case SHIPDATE_COLUMN:
+ columnHandlesBuilder.add(getShipDateColumn());
+ break;
+ case SHIPINSTRUCT_COLUMN:
+ columnHandlesBuilder.add(getShipInstructColumn());
+ break;
+ default:
+ throw new RuntimeException("Unknown column handle for: " + column);
+ }
+ }
+
+ return columnHandlesBuilder.build();
+ }
+
+ protected void assertSelectQueryFromColumns(List tpchColumnNames)
+ {
+ @Language("SQL") String query = format("SELECT %s FROM %s", String.join(",", tpchColumnNames), TPCH_TABLE);
+ assertQuery(query);
+ }
+
+ protected class FlightShimQueryRunner
+ implements QueryRunner
+ {
+ private int totalParts;
+
+ public FlightShimQueryRunner(int totalParts)
+ {
+ this.totalParts = totalParts;
+ }
+
+ public FlightShimQueryRunner()
+ {
+ this(1);
+ }
+
+ @Override
+ public Session getDefaultSession()
+ {
+ return ((QueryRunner) getExpectedQueryRunner()).getDefaultSession();
+ }
+
+ protected int getTotalParts()
+ {
+ return totalParts;
+ }
+
+ protected FlightShimRequest getRequestForPart(int partNumber, List columnHandles)
+ {
+ return createTpchTableRequest(partNumber, getTotalParts(), columnHandles);
+ }
+
+ @Override
+ public MaterializedResult execute(String sql)
+ {
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ List columnHandles = getHandlesFromSelectQuery(sql);
+ List pages = new ArrayList<>();
+
+ for (int i = 0; i < getTotalParts(); i++) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(getRequestForPart(i, columnHandles)));
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ List blocks = new ArrayList<>();
+ for (TpchColumnHandle columnHandle : columnHandles) {
+ FieldVector vector = stream.getRoot().getVector(columnHandle.getColumnName());
+ Block block = blockBuilder.buildBlockFromFieldVector(vector, columnHandle.getType(), null);
+ blocks.add(block);
+ }
+ pages.add(new Page(stream.getRoot().getRowCount(), blocks.toArray(new Block[0])));
+ }
+ }
+ }
+
+ List types = columnHandles.stream().map(TpchColumnHandle::getType).collect(Collectors.toList());
+ MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(getSession(), types);
+ resultBuilder.pages(pages);
+ return resultBuilder.build();
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ @Override
+ public int getNodeCount()
+ {
+ return 0;
+ }
+
+ @Override
+ public TransactionManager getTransactionManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public Metadata getMetadata()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public SplitManager getSplitManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PageSourceManager getPageSourceManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public NodePartitioningManager getNodePartitioningManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ConnectorPlanOptimizerManager getPlanOptimizerManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public PlanCheckerProviderManager getPlanCheckerProviderManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public StatsCalculator getStatsCalculator()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public List getEventListeners()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public TestingAccessControlManager getAccessControl()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public ExpressionOptimizerManager getExpressionManager()
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public MaterializedResult execute(Session session, String sql)
+ {
+ return execute(sql);
+ }
+
+ @Override
+ public List listTables(Session session, String catalog, String schema)
+ {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public boolean tableExists(Session session, String table)
+ {
+ return false;
+ }
+
+ @Override
+ public void installPlugin(Plugin plugin)
+ {
+ }
+
+ @Override
+ public void createCatalog(String catalogName, String connectorName, Map properties)
+ {
+ }
+
+ @Override
+ public void loadFunctionNamespaceManager(String functionNamespaceManagerName, String catalogName, Map properties)
+ {
+ }
+
+ @Override
+ public Lock getExclusiveLock()
+ {
+ return null;
+ }
+
+ @Override
+ public void close()
+ {
+ }
+ }
+
+ private int findUnusedPort()
+ throws IOException
+ {
+ try (ServerSocket socket = new ServerSocket(0)) {
+ return socket.getLocalPort();
+ }
+ }
+
+ protected TpchColumnHandle getOrderKeyColumn()
+ {
+ return new TpchColumnHandle(ORDERKEY_COLUMN, BigintType.BIGINT);
+ }
+
+ protected TpchColumnHandle getLineNumberColumn()
+ {
+ return new TpchColumnHandle(LINENUMBER_COLUMN, IntegerType.INTEGER);
+ }
+
+ protected TpchColumnHandle getLineStatusColumn()
+ {
+ return new TpchColumnHandle(LINESTATUS_COLUMN, VarcharType.createVarcharType(32));
+ }
+
+ protected TpchColumnHandle getQuantityColumn()
+ {
+ return new TpchColumnHandle(QUANTITY_COLUMN, DoubleType.DOUBLE);
+ }
+
+ protected TpchColumnHandle getExtendedPriceColumn()
+ {
+ return new TpchColumnHandle(EXTENDEDPRICE_COLUMN, DoubleType.DOUBLE);
+ }
+
+ protected TpchColumnHandle getShipDateColumn()
+ {
+ return new TpchColumnHandle(SHIPDATE_COLUMN, DateType.DATE);
+ }
+
+ protected TpchColumnHandle getShipInstructColumn()
+ {
+ return new TpchColumnHandle(SHIPINSTRUCT_COLUMN, VarcharType.createVarcharType(64));
+ }
+
+ protected FlightShimRequest createTpchTableRequest(int partNumber, int totalParts, List columnHandles)
+ {
+ String split = createTpchSplit(TPCH_TABLE, partNumber, totalParts);
+ byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8);
+
+ ImmutableList.Builder columnBuilder = ImmutableList.builder();
+ for (TpchColumnHandle columnHandle : columnHandles) {
+ columnBuilder.add(TPCH_COLUMN_JSON_CODEC.toJsonBytes(columnHandle));
+ }
+
+ TpchTableHandle tableHandle = new TpchTableHandle(TPCH_TABLE, 1.0);
+ byte[] tableHandleBytes = TPCH_TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle);
+
+ byte[] transactionHandleBytes = TPCH_TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(TpchTransactionHandle.INSTANCE);
+
+ return new FlightShimRequest(getConnectorId(), splitBytes, columnBuilder.build(), tableHandleBytes, Optional.empty(), transactionHandleBytes);
+ }
+
+ protected static String createTpchSplit(String tableName, int partNumber, int totalParts)
+ {
+ return format("{\n" +
+ " \"tableHandle\" : {\n" +
+ " \"tableName\" : \"%s\",\n" +
+ " \"scaleFactor\" : %.2f\n" +
+ " },\n" +
+ " \"partNumber\" : %d,\n" +
+ " \"totalParts\" : %d,\n" +
+ " \"addresses\" : [ \"127.0.0.1:9999\" ],\n" +
+ " \"predicate\" : {\n" +
+ " \"columnDomains\" : [ ]\n" +
+ " }\n" +
+ "}", tableName, TINY_SCALE_FACTOR, partNumber, totalParts);
+ }
+
+ protected static FlightClient createFlightClient(BufferAllocator allocator, int serverPort) throws IOException
+ {
+ InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/server.crt")));
+ Location location = Location.forGrpcTls("localhost", serverPort);
+ return FlightClient.builder(allocator, location).useTls().trustedCertificates(trustedCertificate).build();
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java
new file mode 100644
index 0000000000000..631c3bf7c5f48
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.google.common.collect.ImmutableMap;
+import org.testng.annotations.Test;
+
+import java.util.Map;
+
+import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping;
+import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults;
+import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults;
+
+public class TestFlightShimConfig
+{
+ @Test
+ public void testDefaults()
+ {
+ assertRecordedDefaults(recordDefaults(FlightShimConfig.class)
+ .setServerName(null)
+ .setServerPort(null)
+ .setServerSslEnabled(true)
+ .setServerSSLCertificateFile(null)
+ .setServerSSLKeyFile(null)
+ .setReadThreadPoolSize(16)
+ .setClientSSLCertificateFile(null)
+ .setClientSSLKeyFile(null)
+ .setMaxRowsPerBatch(10000));
+ }
+
+ @Test
+ public void testExplicitPropertyMappings()
+ {
+ Map properties = new ImmutableMap.Builder()
+ .put("server", "localhost")
+ .put("server.port", "9432")
+ .put("server-ssl-enabled", "false")
+ .put("server-ssl-certificate-file", "/some/path/server.cert")
+ .put("server-ssl-key-file", "/some/path/server.key")
+ .put("client-ssl-certificate-file", "/some/path/client.cert")
+ .put("client-ssl-key-file", "/some/path/client.key")
+ .put("thread-pool-size", "8")
+ .put("max-rows-per-batch", "1000")
+ .build();
+
+ FlightShimConfig expected = new FlightShimConfig()
+ .setServerName("localhost")
+ .setServerPort(9432)
+ .setServerSslEnabled(false)
+ .setServerSSLCertificateFile("/some/path/server.cert")
+ .setServerSSLKeyFile("/some/path/server.key")
+ .setClientSSLCertificateFile("/some/path/client.cert")
+ .setClientSSLKeyFile("/some/path/client.key")
+ .setReadThreadPoolSize(8)
+ .setMaxRowsPerBatch(1000);
+
+ assertFullMapping(properties, expected);
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java
new file mode 100644
index 0000000000000..c6d9d9eedaefd
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java
@@ -0,0 +1,69 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.google.common.collect.ImmutableMap;
+import io.airlift.tpch.TpchTable;
+import org.testcontainers.mysql.MySQLContainer;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner;
+
+@Test
+public class TestFlightShimMySqlPlugin
+ extends AbstractTestFlightShimJdbcPlugins
+{
+ private static final String DATABASE_USERNAME = "testuser";
+ private static final String DATABASE_PASSWORD = "testpass";
+ private final MySQLContainer mysqlContainer;
+
+ public TestFlightShimMySqlPlugin()
+ {
+ this.mysqlContainer = new MySQLContainer("mysql:8.0")
+ .withDatabaseName("tpch")
+ .withUsername(DATABASE_USERNAME)
+ .withPassword(DATABASE_PASSWORD);
+ mysqlContainer.start();
+ closables.add(mysqlContainer);
+ }
+
+ @Override
+ protected String getConnectorId()
+ {
+ return "mysql";
+ }
+
+ @Override
+ protected String getConnectionUrl()
+ {
+ return addDatabaseCredentialsToJdbcUrl(
+ removeDatabaseFromJdbcUrl(mysqlContainer.getJdbcUrl()),
+ DATABASE_USERNAME,
+ DATABASE_PASSWORD);
+ }
+
+ @Override
+ protected String getPluginBundles()
+ {
+ return "../presto-mysql/pom.xml";
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of(), TpchTable.getTables());
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java
new file mode 100644
index 0000000000000..af22d32c5eeb5
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java
@@ -0,0 +1,66 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.google.common.collect.ImmutableMap;
+import io.airlift.tpch.TpchTable;
+import org.testcontainers.postgresql.PostgreSQLContainer;
+import org.testng.annotations.Test;
+
+import static com.facebook.presto.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner;
+
+@Test
+public class TestFlightShimPostgresPlugin
+ extends AbstractTestFlightShimJdbcPlugins
+{
+ private static final String DATABASE_USERNAME = "testuser";
+ private static final String DATABASE_PASSWORD = "testpass";
+ private final PostgreSQLContainer postgresContainer;
+
+ public TestFlightShimPostgresPlugin()
+ {
+ this.postgresContainer = new PostgreSQLContainer("postgres:14")
+ .withDatabaseName("tpch")
+ .withUsername(DATABASE_USERNAME)
+ .withPassword(DATABASE_PASSWORD);
+ postgresContainer.start();
+ closables.add(postgresContainer);
+ }
+
+ @Override
+ protected String getConnectorId()
+ {
+ return "postgresql";
+ }
+
+ @Override
+ protected String getConnectionUrl()
+ {
+ return addDatabaseCredentialsToJdbcUrl(postgresContainer.getJdbcUrl(), DATABASE_USERNAME, DATABASE_PASSWORD);
+ }
+
+ @Override
+ protected String getPluginBundles()
+ {
+ return "../presto-postgresql/pom.xml";
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return createPostgreSqlQueryRunner(postgresContainer.getJdbcUrl(), ImmutableMap.of(), TpchTable.getTables());
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java
new file mode 100644
index 0000000000000..636f7f2006584
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java
@@ -0,0 +1,94 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.presto.common.type.BigintType;
+import com.facebook.presto.common.type.VarcharType;
+import com.facebook.presto.plugin.jdbc.JdbcColumnHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTableHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTransactionHandle;
+import com.facebook.presto.plugin.jdbc.JdbcTypeHandle;
+import com.facebook.presto.spi.SchemaTableName;
+import com.google.common.collect.ImmutableList;
+import org.testng.annotations.Test;
+
+import java.nio.charset.StandardCharsets;
+import java.sql.Types;
+import java.util.Optional;
+
+import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.COLUMN_HANDLE_JSON_CODEC;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.TABLE_HANDLE_JSON_CODEC;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.TRANSACTION_HANDLE_JSON_CODEC;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.createJdbcSplit;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.LINESTATUS_COLUMN;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.ORDERKEY_COLUMN;
+import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.TPCH_TABLE;
+import static org.testng.Assert.assertEquals;
+import static org.testng.internal.junit.ArrayAsserts.assertArrayEquals;
+
+public class TestFlightShimRequest
+{
+ @Test
+ public void testJsonRoundTrip()
+ {
+ FlightShimRequest expected = createTpchCustomerRequest();
+ String json = AbstractTestFlightShimPlugins.REQUEST_JSON_CODEC.toJson(expected);
+ FlightShimRequest copy = AbstractTestFlightShimPlugins.REQUEST_JSON_CODEC.fromJson(json);
+ assertEquals(copy.getConnectorId(), expected.getConnectorId());
+ assertEquals(copy.getSplitBytes(), expected.getSplitBytes());
+ assertArrayEquals(copy.getColumnHandlesBytes().toArray(), expected.getColumnHandlesBytes().toArray());
+ assertEquals(copy.getTableHandleBytes(), expected.getTableHandleBytes());
+ assertEquals(copy.getTableLayoutHandleBytes(), expected.getTableLayoutHandleBytes());
+ assertEquals(copy.getTransactionHandleBytes(), expected.getTransactionHandleBytes());
+ }
+
+ FlightShimRequest createTpchCustomerRequest()
+ {
+ String split = createJdbcSplit("postgresql", "tpch", TPCH_TABLE);
+ byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8);
+
+ JdbcColumnHandle custkeyHandle = new JdbcColumnHandle(
+ "postgresql",
+ ORDERKEY_COLUMN,
+ new JdbcTypeHandle(Types.BIGINT, "bigint", 8, 0),
+ BigintType.BIGINT,
+ false,
+ Optional.empty());
+ byte[] custkeyBytes = COLUMN_HANDLE_JSON_CODEC.toJsonBytes(custkeyHandle);
+
+ JdbcColumnHandle nameHandle = new JdbcColumnHandle(
+ "postgresql",
+ LINESTATUS_COLUMN,
+ new JdbcTypeHandle(Types.VARCHAR, "varchar", 32, 0),
+ VarcharType.createVarcharType(32),
+ false,
+ Optional.empty());
+ byte[] nameBytes = COLUMN_HANDLE_JSON_CODEC.toJsonBytes(nameHandle);
+
+ JdbcTableHandle tableHandle = new JdbcTableHandle("postgresql", new SchemaTableName("tpch", TPCH_TABLE), "postgresql", "tpch", TPCH_TABLE);
+ byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle);
+
+ JdbcTransactionHandle transactionHandle = new JdbcTransactionHandle();
+ byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(transactionHandle);
+
+ return new FlightShimRequest(
+ "postgresql",
+ splitBytes,
+ ImmutableList.of(custkeyBytes, nameBytes),
+ tableHandleBytes,
+ Optional.empty(),
+ transactionHandleBytes
+ );
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java
new file mode 100644
index 0000000000000..41751a7501c1f
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java
@@ -0,0 +1,59 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.presto.plugin.singlestore.DockerizedSingleStoreServer;
+import com.facebook.presto.plugin.singlestore.SingleStoreQueryRunner;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.google.common.collect.ImmutableMap;
+import io.airlift.tpch.TpchTable;
+import org.testng.annotations.Test;
+
+@Test
+public class TestFlightShimSingleStorePlugin
+ extends AbstractTestFlightShimJdbcPlugins
+{
+ private final DockerizedSingleStoreServer singleStoreServer;
+
+ public TestFlightShimSingleStorePlugin()
+ {
+ this.singleStoreServer = new DockerizedSingleStoreServer();
+ closables.add(singleStoreServer);
+ }
+
+ @Override
+ protected String getConnectorId()
+ {
+ return "singlestore";
+ }
+
+ @Override
+ protected String getConnectionUrl()
+ {
+ return singleStoreServer.getJdbcUrl();
+ }
+
+ @Override
+ protected String getPluginBundles()
+ {
+ return "../presto-singlestore/pom.xml";
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ throws Exception
+ {
+ return SingleStoreQueryRunner.createSingleStoreQueryRunner(singleStoreServer, ImmutableMap.of(), TpchTable.getTables());
+ }
+}
diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java
new file mode 100644
index 0000000000000..b5d097c004076
--- /dev/null
+++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.presto.flightshim;
+
+import com.facebook.presto.Session;
+import com.facebook.presto.testing.ExpectedQueryRunner;
+import com.facebook.presto.testing.LocalQueryRunner;
+import com.facebook.presto.tpch.TpchConnectorFactory;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.flight.Ticket;
+import org.apache.arrow.memory.BufferAllocator;
+import org.testng.annotations.Test;
+
+import java.util.concurrent.CancellationException;
+
+import static com.facebook.airlift.testing.Assertions.assertGreaterThan;
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+
+@Test(singleThreaded = true)
+public class TestFlightShimTpchPlugin
+ extends AbstractTestFlightShimPlugins
+{
+ public static final int SPLITS_PER_NODE = 4;
+
+ protected String getConnectorId()
+ {
+ return "tpch";
+ }
+
+ protected String getPluginBundles()
+ {
+ return "../presto-tpch/pom.xml";
+ }
+
+ @Override
+ protected ExpectedQueryRunner createExpectedQueryRunner()
+ {
+ Session session = testSessionBuilder()
+ .setCatalog("tpch")
+ .setSchema("tiny")
+ .build();
+ return new LocalQueryRunner(session);
+ }
+
+ @Override
+ protected int getExpectedTotalParts()
+ {
+ return SPLITS_PER_NODE;
+ }
+
+ @Override
+ protected void createTables()
+ {
+ ((LocalQueryRunner) getExpectedQueryRunner()).createCatalog("tpch", new TpchConnectorFactory(SPLITS_PER_NODE), ImmutableMap.of());
+ }
+
+ @Test
+ public void testConnectorGetStream() throws Exception
+ {
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn()))));
+
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ }
+ }
+
+ assertGreaterThan(rowCount, 0);
+ }
+ }
+
+ @Test
+ public void testStopStreamAtLimit() throws Exception
+ {
+ int rowLimit = 500;
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn()))));
+
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ if (rowCount >= rowLimit) {
+ break;
+ }
+ }
+ }
+
+ assertEquals(rowCount, rowLimit);
+ }
+ }
+
+ @Test
+ public void testCancelStream() throws Exception
+ {
+ String cancelMessage = "READ COMPLETE";
+ try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE);
+ FlightClient client = createFlightClient(bufferAllocator, server.getPort())) {
+ Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn()))));
+
+ // Cancel stream explicitly
+ int rowCount = 0;
+ try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) {
+ while (stream.next()) {
+ rowCount += stream.getRoot().getRowCount();
+ if (rowCount >= 500) {
+ stream.cancel("Cancel", new CancellationException(cancelMessage));
+ break;
+ }
+ }
+
+ // Drain any remaining messages to properly release messages
+ try {
+ do {
+ Thread.sleep(100);
+ }
+ while (stream.next());
+ }
+ catch (final Exception e) {
+ assertNotNull(e.getCause());
+ assertEquals(e.getCause().getMessage(), cancelMessage);
+ }
+ }
+
+ assertGreaterThan(rowCount, 0);
+ }
+ }
+
+ @Test
+ public void testReadFromMultipleSplits()
+ {
+ assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN, SHIPINSTRUCT_COLUMN));
+ }
+}
diff --git a/presto-flight-shim/src/test/resources/certs/ca.crt b/presto-flight-shim/src/test/resources/certs/ca.crt
new file mode 100644
index 0000000000000..3d3f516aa0148
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/certs/ca.crt
@@ -0,0 +1,33 @@
+-----BEGIN CERTIFICATE-----
+MIIFozCCA4ugAwIBAgIUTnZhzGzxAzaTJ5DDQnmBp8jvBUMwDQYJKoZIhvcNAQEL
+BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5
+MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv
+b3RDQTAgFw0yNTA1MjIxNjA2MDNaGA8yMTI1MDQyODE2MDYwM1owYDELMAkGA1UE
+BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN
+eU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJvb3RDQTCCAiIwDQYJ
+KoZIhvcNAQEBBQADggIPADCCAgoCggIBALmNRxPZt9vvFvIQcdCM3xcAOiTepahR
+zsK86AwzXbwcWwOw+rQP0iXiMGjP9wJy0C1hU+j7F+gu5+d5j3DK9/iTlIqByjN+
+gffsZDSBBZBozJxARBWKaEFKRKbhCmr+v92ZN1KMYhhwc5W3SRp6y3kvsWYM4/JE
+pRzYUy0505g3p7PnpQzZZSSzV33+9fn5ggafVB9FSJqOuBKUBTITt+Rf1pNV2ZGC
+f2X39KOIIRz2Hz1L6hdqeMQN9V5KQ7C6oPfl+ho97JHaZirtHr1XZJ6YQRy7Mldp
+jOPj4MMKqi8m+HhyCSznIDNiMp0OACX3CZ5RmKRYfnunOuw+fvI28HyOXMgfWKa2
+o/2/YlAzNrD50hJPwQMTlKWJm2gY2n5x2FWT6/8aXeCnsJALK6Dj6Ax0wxkAFfIo
+76REHlXX2fIiz0cciYwYtvwhjp5efqX22B7LDkhu7fJ42yUd6g3crbmGM8OOoXgL
+w3MWx30FatTyDT8un2ZvVDJEADW3+WWtyrZWHFMVFrVnN7Di4MAuWsRIVtuO6PEV
+6pPS55NBmvKzWAoYBmpH103GlIxvZ6CCTeKqFbcrI77smrIO5CLmzl5yjb/urmy4
+GLYHM+EPB5pJKQO8g6fyM369mhEjBxt/RJGv9E0Jw7KmnHn5V2qSn4Yi41dUp7Cp
+pQVIaYlJGn65AgMBAAGjUzBRMB0GA1UdDgQWBBRjt9KoJO8CO/Cy/xTJztrbdUlv
+9DAfBgNVHSMEGDAWgBRjt9KoJO8CO/Cy/xTJztrbdUlv9DAPBgNVHRMBAf8EBTAD
+AQH/MA0GCSqGSIb3DQEBCwUAA4ICAQApizBpSaZDYjBhOxD6+B5em18OIf5ZIjHk
+iAhuppR+HaqyAHJgCO707dZVmFz9ECUZI9mvKcmj+h0Wh/mK4cSiDunFB9yUr67U
+wV5F2/u/JAAq6VsbdrDiZPUwET8U9ai14LMEgPPF+Zif+wnopRav5lbiPoJVUjqr
+wVoP2AHijIP46YCWwXqOTJMC79ccUMBZeDwF4bOquIADLmEnANp6fiMI+eE6OLFs
+fDtjFqybRUZqzewv2lpzH2ZYEYk4bIk76TGkOYtrwJ+iQj77ZZFSBW5zkry/zaG3
+/5Ufjv65T9Zr1jmIMigcmCHwNsCLOYzIKjRaiuLGs4B9s8SGEauTRhP0dG5ndWPw
+50NeSNJr37MHdKky44WAFlAk9BAKlghOaC5m2RyMof8DwYKPEe5epe1wBotiPqSX
+doaZvch6wkuo8xvFKqH6rBTWJLMwuFt7m3XrGqGYlE+1gvuEfn7ZGzG00sl218mZ
+MfsLqJfft92ARC1/qJvUFr5mM6SV4eQeTl1tAtv6Xfczr+3/iqc5gbeG25dXQclO
+y1qIKthAoXFq6rAZ+bvfASiVV1OQS76nWSiYYS1dDPQJ/g4aawOkUYr9OjS5HNQ+
+rNAcLB+I0oaZQzZ85098qAVAJ76eFATb4ieDK6m0j6Fq5ddwrlYzxEEr7TscfHW2
+zPCtinUe0w==
+-----END CERTIFICATE-----
diff --git a/presto-flight-shim/src/test/resources/certs/client.crt b/presto-flight-shim/src/test/resources/certs/client.crt
new file mode 100644
index 0000000000000..0e2bc9ff7f662
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/certs/client.crt
@@ -0,0 +1,27 @@
+-----BEGIN CERTIFICATE-----
+MIIEkDCCAnigAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQYwDQYJKoZIhvcNAQEL
+BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5
+MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv
+b3RDQTAgFw0yNTA1MjIxNjA4MzFaGA8yMTI1MDQyODE2MDgzMVowXjELMAkGA1UE
+BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN
+eU9yZzEPMA0GA1UECwwGTXlVbml0MQ8wDQYDVQQDDAZjbGllbnQwggEiMA0GCSqG
+SIb3DQEBAQUAA4IBDwAwggEKAoIBAQC8DK/RdBF6I+k+DMGrjhMMBCnpPNwJtzJU
+uXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIMzS4HBU/CHY26eR2ItiWq
+DoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7DJC2xTTsyEjBzD+MnIB7T
+ja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7BeqGuFS8bEgtLwwCzRxPK
+kG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5MulEIjuxII/dTuoo+/uirN0
+wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPoHpWZsYIHAgMBAAGjQjBA
+MB0GA1UdDgQWBBRW/6j5rNwnwTBAXY5igXm3ETzv/DAfBgNVHSMEGDAWgBRjt9Ko
+JO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAVLpfZDgkL1Dz/+Fq
+vSl2IoxOFNd2DTa6yM8/1wpvMVTA024lp0ttyoz0o1621hTRexcXTimZqUNAtPV6
+Gwmb2ACLN4XtLk9QT9XjDtWKzPxCJ+ze7rrhj1jYqv9yUebdkJoMKfcbwYi0gtpt
+HlaJqNoKgzZxOCGhTtdS3ypb9nDCyx3fmFk5mIYfzEszoMmqNL006ANlJ0IKkFZj
+vUkkFyMGLerInmTDRjDLkUCkNKaJUYjZhf/FNwVtc1A9a/bDJMEYVog+CY0dpXKb
+1IGaXzB4ewhuQKuhb/LCZT1pNm/cGCY2cRGFy9EVAuZ5FV0ajh1HYwdpGDzucoUD
+UaTHIK40E2/kZorJa37Xyn7Lekgun6YpfOudBkKg5mlV6qoL9W/lTZPjMHs2ufvW
+/A5S4okR4JmhC44TMgAv90MU9yEP90OkzW6egatBShWySJ3Bn5W+ebQSwJ38wgTy
+e6j5jWh7xiiPC4TJbSXVMGEfJw/c2hx4R/83MqBhVLPfoapaUCDUWniv6n7zl5ML
+k7WIZzXSK212/H+eVXFJ1Gq6zztOPkN9QGgr+dbsCzdLWPJzZDmI7+lgT2EKxIV/
+VMtHVOe2bLkePTNA2+vXQhs0p7JDtzyATAyMdhJwPljt8X+HJxEoAP4Dk6PcK3Bd
+E92yOuW2jl6FzgqKqtVQvxIiBgc=
+-----END CERTIFICATE-----
diff --git a/presto-flight-shim/src/test/resources/certs/client.key b/presto-flight-shim/src/test/resources/certs/client.key
new file mode 100644
index 0000000000000..72434bc294aff
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/certs/client.key
@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC8DK/RdBF6I+k+
+DMGrjhMMBCnpPNwJtzJUuXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIM
+zS4HBU/CHY26eR2ItiWqDoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7D
+JC2xTTsyEjBzD+MnIB7Tja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7B
+eqGuFS8bEgtLwwCzRxPKkG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5Mul
+EIjuxII/dTuoo+/uirN0wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPo
+HpWZsYIHAgMBAAECggEAViw6JXFa0O3D5HtUBJmGgOsniYoqCwm4NrsGNLuHb2ME
+rSpTwNNGJtqpDcQdEtVXfY1muO9xjuznPJaJkQ4ODpYcbGcz8YIGoHck+XHJjHsp
+2VIeNFFsbsFzWZqzfYHrj/rMjpVJJx90tlfN2IHbroZHTXLqVPOTLL6wvZZ6P9XF
+zqpWABKOaEDbenhSFFZeF5KR3NG9HSTm4YLuekumkH+QgrveDfDwXG4hAHqg836o
+OF3NPaij6VlSR18nuyW0wMs/Ceu13P+GALqHmz98pFyVgHWQFryL9IccvJQDyEnt
+saeG4IAVlJbZDGTnRgANLhpwBr7XhMG1aK+wmOMRgQKBgQDkcatiATlr9L8gfnHb
+6pmX//AZLdXuQLXfuTvu638Brhm770noLgfIC+HIp5kCHxT2Xj5Vn+MSnYD6R6Wh
+chApRKJUdsuz1iOq23YJjvsSLWCGpl9IxR7WY27uGOPIjQcOd1PRbkCq9AgUJwyn
+ryca3sbYh/XQOWGLbJNIQs/S/QKBgQDSu6PVeMaS3276KblvGIvvaSAQDQWxXcC+
+sA4CBmvjzx3xx5GAox/w7tcKmK/KQxNhaYy6N7xLc1YUJ9FbnT2PZQJhtP2d2Gat
+Zre/+Qa+u84cR5hj9EI+B8FjW7D/psEj16KjHCds/SET6ngPM+RdB4N9daVFCurt
+p0f717yiUwKBgBTJDun06I+dDkLbnmp/FwiQff0cgYmTE7lOdliPzteNSsQhypy4
+i3a1Ng72yOI7h8G+43cQ/C02bYTYPgbJhRTsLMT4piIvysECBORrwQZvYIf/3U2W
+ue6Rz4cUdq1Jv6meS98TZAjp+U40G1+qfSlhub/75u7SOcDg2SnLAnPVAoGBAIOO
+EmRE5qpwA+b2P0Ykq89E8Hg0uPYWEiq427XV7mqkNQxoSuRkcZ9Ga0a5NRzurN2m
+N+1UuB7eHMGubdtkmTa4lzkJ9T4iB09/DX0x6E0QD0bGR1M2/FefHdJ6PlAK+Q34
+Ixbyj4ZRq+G0AUl0Wr7c3vBmjktA2pKMWLrW3nLzAoGBAKTl7qX6CD42gAJuT5Hp
+rrXqlppVIyRvuXzXtX/Xq81IUHlBgS/t9HPyqDzmTKfxD8540kI+15bWPDHSJxiQ
+ccqPaKyXhBXstDwGmlPKVzJUxk0dz5NHs+8gItUDOg78pM3siXN7vW9XBCH7mCDA
+4zet/C0YCAiFVT+ipMoXy8Nc
+-----END PRIVATE KEY-----
diff --git a/presto-flight-shim/src/test/resources/certs/server.crt b/presto-flight-shim/src/test/resources/certs/server.crt
new file mode 100644
index 0000000000000..8fb80c1bba51d
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/certs/server.crt
@@ -0,0 +1,27 @@
+-----BEGIN CERTIFICATE-----
+MIIEkzCCAnugAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQgwDQYJKoZIhvcNAQEL
+BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5
+MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv
+b3RDQTAgFw0yNTA1MjMwMTQzNTBaGA8yMTI1MDQyOTAxNDM1MFowYTELMAkGA1UE
+BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN
+eU9yZzEPMA0GA1UECwwGTXlVbml0MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEiMA0G
+CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCtkJ+r1F8+YOVuwWLxbGVsJKw3BESh
+tCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h0xVlHES7eTAD+NPl
+WHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqAyPZjUjhkAZPgrT1/
+8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMOkSYNsanb+JdEryev
+goju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDygZoO8Qr/YrGvXXYqt
+0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3QMRWGpsvAgMBAAGj
+QjBAMB0GA1UdDgQWBBRf0XdDhjNSIZQsvSFbe+hdsydllzAfBgNVHSMEGDAWgBRj
+t9KoJO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAec7y1Odyg3x/
+Uj0jfZWYNE0BuR114UVwEYhxFi9tRAGxjTlsl6ATCSYBWU+fwUkC29C2r3bu59fp
+/KvYPRrwGPyOtXwHR2cmwJ7QrUhlIwPipWO4Kal3/EKWvrV9rzdnOd2QYqEMS6f0
+UixGLcT5p5KmEsH3W8Y9Uk94g/z12ZgdGeKKyY7hWnu1d47b2TS4oiRx+d6AacAD
+1BzJRUhDjS2Vfe2cnpOqBHJWyCT1BxsfxKAc3rLa6JznbulHQPCE5WWBolHb8Tob
+Yf32sIJydOcWU+zJ9VsGuglQ8dQInMemW8k5y48ACqHb00lAoucGJ3Izy9tJiBeU
+C8TcmRxQSjoMGhFGNIjBAvXak0UzobUKE3YyABBbUdWLofs0N325K1aSga4vXmlO
+OPzP4FWMLZyDicVMUGLP9jcyeb/gFMzjoU2En59gRNVDNNo01Lyj9MhGHF/jvV7p
+Zf782GvXMT/NSPtXqmABSV/Svy70vXogeQxTii9YOZePKWQnFhcEwin/7d9bf4d/
+nUqtzDFb5FiDeFA+H9FyeNaeub3OtvsZUacAVDCT1t9/8uShjJo34v8WbJegepcY
+Tpdkm4x0DWvv/QT3JBqC0wprsmBVzuqTJ/jBxYye02bds1bM7xePAuVOwuSb6Azg
+gLrweiDMakgmZSkGgUonnjblYbHGTHw=
+-----END CERTIFICATE-----
diff --git a/presto-flight-shim/src/test/resources/certs/server.key b/presto-flight-shim/src/test/resources/certs/server.key
new file mode 100644
index 0000000000000..0e49bebd60beb
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/certs/server.key
@@ -0,0 +1,28 @@
+-----BEGIN PRIVATE KEY-----
+MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCtkJ+r1F8+YOVu
+wWLxbGVsJKw3BEShtCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h
+0xVlHES7eTAD+NPlWHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqA
+yPZjUjhkAZPgrT1/8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMO
+kSYNsanb+JdEryevgoju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDyg
+ZoO8Qr/YrGvXXYqt0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3
+QMRWGpsvAgMBAAECggEAIK9C+lNAallJ62z8inU8tjxDuAqUOBVbJZRVcPbIr1zn
+HmLlpyd4Sghhh7CjYYoPuHDtTQxIcBNwlDBxb3x+zwUXzy+tC5v5j7DN01qex2Ew
+XTDSAEN3Ra2r1S+/1hSztVd0oXDozFxKk+UETRjfKKoJZH6LPcy7MOLFR5EEuJ8L
+0kvGdEtuNCmZ1vPBwqR3IKQS9NsB1IdTtK0g2LdtVzM3U6F173CrAx51qNeAL30j
+Np+I0rfm7vYVco6nDQXJB86hzwwBnLMzmZR2E0z+JStQCjQtEJN9wp+NBnViMb8C
+mZl0K/PH3ZKNEs1Aw/TsRpPu6Fc+sN6iIs2oOGiKfQKBgQDa0Wpfj+SHflLfmrRU
+PplGNjWdJiyuXROqX18iNE8nAD0eRqAFdzj9yU1IW49KCzuHInEl2pP9yrDZTWXB
+Bht4C+Vk13mrBE3Sc1LDrks5EhDLaaolLgx1B+JN1X2DpfuzO8WHrXR11PCzFTAp
+yDSVd451CFFXMseS1V9UxCy3lQKBgQDLDrLX/0hGhG+a5RUaAE+hZk+tU9RyjYm6
+/5lIoDjDwA9Yst69JCTHDApkdZ6IrjPDZrxkAQR6QwsGo+zRGkHV2wCoqR/RxcT5
+RBcbe/8xL86ZKwnhAheP6ssgZeK5zOG1iLol319kXXuo6NueN+YlocmsppRvAOq7
+/qMnhzXGswKBgQCpke2wHo9HnNJWK8ohGt2mtm232ZR4jvKlbgEIPac1Hw89/hcW
+BT0qFqyILUQOakP4Re2PGyLiYwfHbh4zhisVTYq4Ke9EYzJ3qxzxPYlXsbNIHxtW
+cqf+rVxnWtFIiwFR9TjvGrEMezcIYJwRVO/DAIJqGUcHnvdfx3B3/Qp2PQKBgQCk
+y7UR37kEog8BotHRXFdEIgigHtzYa05QWYhJjN8E3yaVUfW7g03lzTvR9DNJsjeI
+aiSS9NBxeV/Fb9yOh8TOjwKl3zxXvy3xLvWh9KxTev0tCeTmnBALWP6puIadTE4S
+Snjoq7R7e/MUToeOjMdX20oVuMvWmuPm1u4K8o0OSQKBgQCec+QLllYXk22A8/e/
+f5HhSYr161lEFFmuzKhuuy+esyCQU/KZmxQH0UqnsL3Ww4ofq42lteqyUJnriHsx
+QP5FTIMKH8W+Xels1i6jCC+MVXAXraAF27dOlmKxWMN7mnElZ/7lQKmBq64wil35
+sfcJA4FDxVM2Amv4KRo/w1C/zQ==
+-----END PRIVATE KEY-----
diff --git a/presto-flight-shim/src/test/resources/split_additional_predicate.json b/presto-flight-shim/src/test/resources/split_additional_predicate.json
new file mode 100644
index 0000000000000..d4dac67bb731c
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/split_additional_predicate.json
@@ -0,0 +1,91 @@
+{
+"connectorId" : "postgresql",
+"schemaName" : "tpch",
+"tableName" : "orders",
+"tupleDomain" : {
+ "columnDomains" : [ {
+ "column" : {
+ "@type" : "postgresql",
+ "connectorId" : "postgresql",
+ "columnName" : "orderkey",
+ "jdbcTypeHandle" : {
+ "jdbcType" : -5,
+ "jdbcTypeName" : "int8",
+ "columnSize" : 19,
+ "decimalDigits" : 0
+ },
+ "columnType" : "bigint",
+ "nullable" : true
+ },
+ "domain" : {
+ "values" : {
+ "@type" : "sortable",
+ "type" : "bigint",
+ "ranges" : [ {
+ "low" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA",
+ "bound" : "EXACTLY"
+ },
+ "high" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA",
+ "bound" : "EXACTLY"
+ }
+ }, {
+ "low" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA",
+ "bound" : "EXACTLY"
+ },
+ "high" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA",
+ "bound" : "EXACTLY"
+ }
+ }, {
+ "low" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA",
+ "bound" : "EXACTLY"
+ },
+ "high" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA",
+ "bound" : "EXACTLY"
+ }
+ } ]
+ },
+ "nullAllowed" : false
+ }
+ } ]
+},
+"additionalProperty" : {
+ "translatedString" : "((orderkey) IN ((?) , (?) , (?)))",
+ "boundConstantValues" : [ {
+ "@type" : "constant",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA",
+ "type" : "bigint",
+ "sourceLocation" : {
+ "line" : 1,
+ "column" : 22
+ }
+ }, {
+ "@type" : "constant",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA",
+ "type" : "bigint",
+ "sourceLocation" : {
+ "line" : 1,
+ "column" : 22
+ }
+ }, {
+ "@type" : "constant",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA",
+ "type" : "bigint",
+ "sourceLocation" : {
+ "line" : 1,
+ "column" : 22
+ }
+ } ]
+}
+}
diff --git a/presto-flight-shim/src/test/resources/split_tuple_domain.json b/presto-flight-shim/src/test/resources/split_tuple_domain.json
new file mode 100644
index 0000000000000..b7fc0082fa403
--- /dev/null
+++ b/presto-flight-shim/src/test/resources/split_tuple_domain.json
@@ -0,0 +1,40 @@
+{
+ "connectorId" : "postgresql",
+ "schemaName" : "tpch",
+ "tableName" : "lineitem",
+ "tupleDomain" : {
+ "columnDomains" : [ {
+ "column" : {
+ "connectorId" : "postgresql",
+ "columnName" : "orderkey",
+ "jdbcTypeHandle" : {
+ "jdbcType" : -5,
+ "jdbcTypeName" : "bigint",
+ "columnSize" : 8,
+ "decimalDigits" : 0
+ },
+ "columnType" : "bigint",
+ "nullable" : false
+ },
+ "domain" : {
+ "values" : {
+ "@type" : "sortable",
+ "type" : "bigint",
+ "ranges" : [ {
+ "low" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA",
+ "bound" : "EXACTLY"
+ },
+ "high" : {
+ "type" : "bigint",
+ "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA",
+ "bound" : "EXACTLY"
+ }
+ } ]
+ },
+ "nullAllowed" : false
+ }
+ } ]
+ }
+}
diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java
index 979df181ddb5e..015e31a3a9293 100644
--- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java
+++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java
@@ -16,6 +16,7 @@
import com.facebook.presto.common.Subfield;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ColumnMetadata;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.collect.ImmutableList;
@@ -62,6 +63,14 @@ public Type getType()
return type;
}
+ public ColumnMetadata getColumnMetadata()
+ {
+ return ColumnMetadata.builder()
+ .setName(columnName)
+ .setType(type)
+ .build();
+ }
+
@Override
public String toString()
{