diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml index c8dec648623d5..9d03b35ce8ba0 100644 --- a/.github/workflows/arrow-flight-tests.yml +++ b/.github/workflows/arrow-flight-tests.yml @@ -39,7 +39,9 @@ jobs: matrix: java: ['17'] modules: - - :presto-base-arrow-flight # Only run tests for the `presto-base-arrow-flight` module + - :presto-common-arrow + - :presto-base-arrow-flight + - :presto-flight-shim timeout-minutes: 80 concurrency: diff --git a/pom.xml b/pom.xml index ea19425f3a7ec..9bba335e86a7d 100644 --- a/pom.xml +++ b/pom.xml @@ -241,6 +241,7 @@ presto-sql-helpers/presto-sql-invoked-functions-plugin presto-sql-helpers/presto-native-sql-invoked-functions-plugin presto-lance + presto-flight-shim diff --git a/presto-common-arrow/pom.xml b/presto-common-arrow/pom.xml index c2507c9f136d4..c7cec4d0aa299 100644 --- a/presto-common-arrow/pom.xml +++ b/presto-common-arrow/pom.xml @@ -17,6 +17,17 @@ + + org.apache.arrow + arrow-memory-core + + + org.slf4j + slf4j-api + + + + org.apache.arrow arrow-vector @@ -56,10 +67,62 @@ jakarta.inject jakarta.inject-api + + + + com.facebook.presto + presto-main-base + test + + + + com.facebook.airlift + log + test + + + + joda-time + joda-time + test + + + + org.jdbi + jdbi3-core + test + + + + org.testng + testng + test + + + + org.apache.arrow + arrow-memory-netty + test + + + org.slf4j + slf4j-api + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + org.apache.arrow:arrow-memory-core + + + org.basepom.maven duplicate-finder-maven-plugin @@ -84,4 +147,19 @@ + + + + java17 + + 17 + + + + -Xss10M + --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + diff --git a/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java new file mode 100644 index 0000000000000..8aec16ad66243 --- /dev/null +++ b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/ArrowBatchSource.java @@ -0,0 +1,118 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorPageSource; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.AllocationHelper; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VectorSchemaRoot; + +import java.io.Closeable; +import java.io.IOException; +import java.util.List; + +import static com.facebook.plugin.arrow.BlockArrowWriter.createArrowWriters; +import static java.util.Objects.requireNonNull; + +public class ArrowBatchSource + implements Closeable +{ + private final VectorSchemaRoot root; + private final List writers; + private final int maxRowsPerBatch; + private final ConnectorPageSource pageSource; + + private Page currentPage; + private int currentPosition; + + public ArrowBatchSource(BufferAllocator allocator, List columns, ConnectorPageSource pageSource, int maxRowsPerBatch) + { + this.pageSource = requireNonNull(pageSource, "pageSource is null"); + this.maxRowsPerBatch = maxRowsPerBatch; + ImmutableList.Builder writerBuilder = ImmutableList.builder(); + this.root = createArrowWriters(allocator, columns, writerBuilder); + this.writers = writerBuilder.build(); + } + + public VectorSchemaRoot getVectorSchemaRoot() + { + return root; + } + + /** + * Loads the next record batch from the source. + * Returns false if there are no more batches from the source. + */ + public boolean nextBatch() + { + // Release previous buffers + root.clear(); + + // Reserve capacity for next batch + allocateVectorCapacity(root, maxRowsPerBatch); + + int batchRowIndex = 0; + while (batchRowIndex < maxRowsPerBatch) { + if (currentPage == null || currentPosition >= currentPage.getPositionCount()) { + if (pageSource.isFinished()) { + break; + } + + currentPage = pageSource.getNextPage(); + currentPosition = 0; + + if (currentPage == null || currentPage.getPositionCount() == 0) { + continue; + } + } + + for (int column = 0; column < writers.size(); column++) { + Block block = currentPage.getBlock(column); + BlockArrowWriter.ArrowVectorWriter writer = writers.get(column); + + if (block.isNull(currentPosition)) { + writer.writeNull(batchRowIndex); + } + else { + writer.writeBlock(batchRowIndex, block, currentPosition); + } + } + currentPosition++; + batchRowIndex++; + } + root.setRowCount(batchRowIndex); + return batchRowIndex > 0; + } + + @Override + public void close() + throws IOException + { + root.close(); + pageSource.close(); + } + + private static void allocateVectorCapacity(VectorSchemaRoot root, int capacity) + { + for (ValueVector vector : root.getFieldVectors()) { + vector.setInitialCapacity(capacity); + AllocationHelper.allocateNew(vector, capacity); + } + } +} diff --git a/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java new file mode 100644 index 0000000000000..ae49b4a1aeecc --- /dev/null +++ b/presto-common-arrow/src/main/java/com/facebook/plugin/arrow/BlockArrowWriter.java @@ -0,0 +1,893 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.IntArrayBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TimestampWithTimeZoneType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.spi.ColumnMetadata; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.BaseVariableWidthVector; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.BaseRepeatedValueVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.math.MathContext; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Collectors; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR; +import static com.facebook.presto.common.type.Decimals.decodeUnscaledValue; +import static com.facebook.presto.common.type.StandardTypes.ARRAY; +import static com.facebook.presto.common.type.StandardTypes.MAP; +import static com.facebook.presto.common.type.StandardTypes.ROW; +import static com.facebook.presto.common.type.Varchars.isVarcharType; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static com.google.common.collect.ImmutableMap.toImmutableMap; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; + +public class BlockArrowWriter +{ + private BlockArrowWriter() + { + } + + public static Field prestoToArrowField(ColumnMetadata column) + { + Field field; + Map metadata = column.getProperties().entrySet().stream().collect(toImmutableMap(Map.Entry::getKey, Object::toString)); + if (column.getType().getTypeSignature().getBase().equals(ARRAY)) { + ArrayType arrayType = (ArrayType) column.getType(); + Field childField = prestoToArrowField(ColumnMetadata.builder().setName(BaseRepeatedValueVector.DATA_VECTOR_NAME).setType(arrayType.getElementType()).build()); + field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.List.INSTANCE, null, metadata), ImmutableList.of(childField)); + } + else if (column.getType().getTypeSignature().getBase().equals(MAP)) { + MapType mapType = (MapType) column.getType(); + // NOTE: Arrow key type must be non-nullable + Field keyField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.KEY_NAME).setType(mapType.getKeyType()).setNullable(false).build()); + Field valueField = prestoToArrowField(ColumnMetadata.builder().setName(MapVector.VALUE_NAME).setType(mapType.getValueType()).build()); + Field entriesField = new Field(MapVector.DATA_VECTOR_NAME, FieldType.notNullable(ArrowType.Struct.INSTANCE), ImmutableList.of(keyField, valueField)); + field = new Field(column.getName(), new FieldType(column.isNullable(), new ArrowType.Map(false), null, metadata), ImmutableList.of(entriesField)); + } + else if (column.getType().getTypeSignature().getBase().equals(ROW)) { + RowType rowType = (RowType) column.getType(); + List rowFields = rowType.getFields(); + + AtomicInteger childCount = new AtomicInteger(); + List childFields = rowFields.stream().map(f -> prestoToArrowField( + ColumnMetadata.builder().setName(f.getName().orElse(format("$child%s$", childCount.incrementAndGet()))).setType(f.getType()).build())) + .collect(toImmutableList()); + + field = new Field(column.getName(), new FieldType(column.isNullable(), ArrowType.Struct.INSTANCE, null, metadata), childFields); + } + else { + ArrowType arrowType = prestoToArrowType(column.getType()); + field = new Field(column.getName(), new FieldType(column.isNullable(), arrowType, null, metadata), ImmutableList.of()); + } + + return field; + } + + public static ArrowType prestoToArrowType(Type type) + { + if (type.equals(BooleanType.BOOLEAN)) { + return ArrowType.Bool.INSTANCE; + } + else if (type.equals(TinyintType.TINYINT)) { + return Types.MinorType.TINYINT.getType(); + } + else if (type.equals(SmallintType.SMALLINT)) { + return Types.MinorType.SMALLINT.getType(); + } + else if (type.equals(IntegerType.INTEGER)) { + return Types.MinorType.INT.getType(); + } + else if (type.equals(BigintType.BIGINT)) { + return Types.MinorType.BIGINT.getType(); + } + else if (type.equals(RealType.REAL)) { + return Types.MinorType.FLOAT4.getType(); + } + else if (type.equals(DoubleType.DOUBLE)) { + return Types.MinorType.FLOAT8.getType(); + } + else if (type instanceof DecimalType) { + DecimalType decimalType = (DecimalType) type; + return new ArrowType.Decimal(decimalType.getPrecision(), decimalType.getScale(), 128); + } + else if (type.equals(VarbinaryType.VARBINARY)) { + return ArrowType.Binary.INSTANCE; + } + else if (type instanceof CharType) { + return ArrowType.Utf8.INSTANCE; + } + else if (isVarcharType(type)) { + return ArrowType.Utf8.INSTANCE; + } + else if (type instanceof DateType) { + return Types.MinorType.DATEDAY.getType(); + } + else if (type instanceof TimeType) { + return Types.MinorType.TIMEMILLI.getType(); + } + else if (type instanceof TimestampType) { + return Types.MinorType.TIMESTAMPMILLI.getType(); + } + else if (type instanceof TimestampWithTimeZoneType) { + // Read as plain timestamp, timezone not supplied with type + return Types.MinorType.TIMESTAMPMILLI.getType(); + } + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unsupported type: " + type); + } + + public static VectorSchemaRoot createArrowWriters(BufferAllocator allocator, List columns, ImmutableList.Builder writerBuilder) + { + List fields = columns.stream().map(BlockArrowWriter::prestoToArrowField).collect(Collectors.toList()); + Schema schema = new Schema(fields); + VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator); + + final List vectors = root.getFieldVectors(); + checkArgument(vectors.size() == columns.size(), "Unexpected list of vectors: %s", schema); + for (int i = 0; i < vectors.size(); i++) { + ColumnMetadata columnMetadata = columns.get(i); + writerBuilder.add(createArrowWriter(vectors.get(i), columnMetadata.getType())); + } + + return root; + } + + private static ArrowVectorWriter createArrowWriter(FieldVector vector, Type type) + { + Class javaType = type.getJavaType(); + + switch (vector.getMinorType()) { + case BIT: + checkArgument(javaType == boolean.class, "Unexpected type for BitVector: %s", type); + return new ArrowBitWriter((BitVector) vector, new BlockPrimitiveGetter(type)); + case TINYINT: + checkArgument(javaType == long.class, "Unexpected type for TinyIntVector: %s", type); + return new ArrowTinyIntWriter((TinyIntVector) vector, new BlockPrimitiveGetter(type)); + case SMALLINT: + checkArgument(javaType == long.class, "Unexpected type for SmallIntVector: %s", type); + return new ArrowSmallIntWriter((SmallIntVector) vector, new BlockPrimitiveGetter(type)); + case INT: + checkArgument(javaType == long.class, "Unexpected type for IntVector: %s", type); + return new ArrowIntWriter((IntVector) vector, new BlockPrimitiveGetter(type)); + case BIGINT: + checkArgument(javaType == long.class, "Unexpected type for BigIntVector: %s", type); + return new ArrowLongWriter((BigIntVector) vector, new BlockPrimitiveGetter(type)); + case FLOAT4: + checkArgument(javaType == long.class, "Unexpected type for Float4Vector: %s", type); + return new ArrowRealWriter((Float4Vector) vector, new BlockPrimitiveGetter(type)); + case FLOAT8: + checkArgument(javaType == double.class, "Unexpected type for Float8Vector: %s", type); + return new ArrowDoubleWriter((Float8Vector) vector, new BlockPrimitiveGetter(type)); + case DECIMAL: + checkArgument(type instanceof DecimalType, "Expected DecimalType but got %", type); + return new ArrowDecimalWriter((DecimalVector) vector, createDecimalBlockGetter((DecimalType) type, javaType)); + case VARBINARY: + case VARCHAR: + checkArgument(javaType == Slice.class, "Unexpected type for BaseVariableWidthVector: %s", type); + return new ArrowVariableWidthWriter((BaseVariableWidthVector) vector, new BlockSliceGetter(type)); + case DATEDAY: + checkArgument(javaType == long.class, "Unexpected type for DateDayVector: %s", type); + return new ArrowDateWriter((DateDayVector) vector, new BlockPrimitiveGetter(type)); + case TIMEMILLI: + checkArgument(javaType == long.class, "Unexpected type for TimeMilliVector: %s", type); + return new ArrowTimeWriter((TimeMilliVector) vector, new BlockPrimitiveGetter(type)); + case TIMESTAMPMILLI: + checkArgument(javaType == long.class, "Unexpected type for TimeStampVector: %s", type); + return new ArrowTimeStampWriter((TimeStampVector) vector, new BlockPrimitiveGetter(type)); + case LIST: + checkArgument(type instanceof ArrayType, "Unexpected type for ListVector: %s", type); + return new ArrowListWriter((ListVector) vector, new BlockArrayGetter((ArrayType) type)); + case MAP: + checkArgument(type instanceof MapType, "Unexpected type for MapVector: %s", type); + return new ArrowMapWriter((MapVector) vector, new BlockMapGetter((MapType) type)); + case STRUCT: + checkArgument(type instanceof RowType, "Unexpected type for StructVector: %s", type); + return new ArrowStructWriter((StructVector) vector, new BlockRowGetter((RowType) type)); + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unsupported Arrow type: " + vector.getMinorType().name()); + } + } + + private static BlockValueGetter createDecimalBlockGetter(DecimalType decimalType, Class javaType) + { + if (javaType == long.class) { + return new BlockShortDecimalGetter(decimalType); + } + else if (javaType == Slice.class) { + return new BlockLongDecimalGetter(decimalType); + } + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected type for DecimalVector: " + decimalType); + } + + private abstract static class BlockValueGetter + { + public boolean getBoolean(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public int getInt(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public long getLong(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public float getFloat(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public double getDouble(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public BigDecimal getDecimal(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public byte[] getBytes(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + + public Block getChildBlock(Block block, int position) + { + throw new UnsupportedOperationException(getClass().getName()); + } + } + + private static class BlockPrimitiveGetter + extends BlockValueGetter + { + protected Type type; + + public BlockPrimitiveGetter(Type type) + { + this.type = type; + } + + @Override + public boolean getBoolean(Block block, int position) + { + return type.getBoolean(block, position); + } + + @Override + public int getInt(Block block, int position) + { + return (int) getLong(block, position); + } + + @Override + public long getLong(Block block, int position) + { + if (block instanceof IntArrayBlock) { + return block.toLong(position); + } + + return type.getLong(block, position); + } + + @Override + public float getFloat(Block block, int position) + { + long value = getLong(block, position); + return intBitsToFloat(toIntExact(value)); + } + + @Override + public double getDouble(Block block, int position) + { + return type.getDouble(block, position); + } + } + + private static class BlockShortDecimalGetter + extends BlockValueGetter + { + protected DecimalType type; + + public BlockShortDecimalGetter(DecimalType type) + { + this.type = type; + } + + @Override + public BigDecimal getDecimal(Block block, int position) + { + long value = type.getLong(block, position); + BigInteger unscaledValue = BigInteger.valueOf(value); + return new BigDecimal(unscaledValue, type.getScale(), new MathContext(type.getPrecision())); + } + } + + private static class BlockLongDecimalGetter + extends BlockValueGetter + { + protected DecimalType type; + + public BlockLongDecimalGetter(DecimalType type) + { + this.type = type; + } + + @Override + public BigDecimal getDecimal(Block block, int position) + { + Slice value = type.getSlice(block, position); + BigInteger unscaledValue = decodeUnscaledValue(value); + return new BigDecimal(unscaledValue, type.getScale(), new MathContext(type.getPrecision())); + } + } + + private static class BlockSliceGetter + extends BlockValueGetter + { + protected Type type; + + public BlockSliceGetter(Type type) + { + this.type = type; + } + + @Override + public byte[] getBytes(Block block, int position) + { + Slice value = type.getSlice(block, position); + return value.getBytes(0, value.length()); + } + } + + private static class BlockArrayGetter + extends BlockValueGetter + { + protected ArrayType type; + + public BlockArrayGetter(ArrayType type) + { + this.type = type; + } + + public ArrayType getType() + { + return type; + } + + @Override + public Block getChildBlock(Block block, int position) + { + return type.getObject(block, position); + } + } + + private static class BlockMapGetter + extends BlockValueGetter + { + protected MapType type; + + public BlockMapGetter(MapType type) + { + this.type = type; + } + + public MapType getType() + { + return type; + } + + @Override + public Block getChildBlock(Block block, int position) + { + return type.getObject(block, position); + } + } + + private static class BlockRowGetter + extends BlockValueGetter + { + protected RowType type; + + public BlockRowGetter(RowType type) + { + this.type = type; + } + + public RowType getType() + { + return type; + } + + @Override + public Block getChildBlock(Block block, int position) + { + return type.getObject(block, position); + } + } + + /** + * Writers for specific Arrow vector types to set values from a Block. + * Fixed-width vectors are pre-allocated, but setSafe should be used + * for all other vectors. + */ + public abstract static class ArrowVectorWriter + { + public abstract void writeNull(int index); + + public abstract void writeBlock(int index, Block block, int position); + } + + private static class ArrowBitWriter + extends ArrowVectorWriter + { + private final BitVector vector; + protected final BlockValueGetter blockGetter; + + public ArrowBitWriter(BitVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getBoolean(block, position) ? 1 : 0); + } + } + + private static class ArrowTinyIntWriter + extends ArrowVectorWriter + { + private final TinyIntVector vector; + private final BlockValueGetter blockGetter; + + public ArrowTinyIntWriter(TinyIntVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getInt(block, position)); + } + } + + private static class ArrowSmallIntWriter + extends ArrowVectorWriter + { + private final SmallIntVector vector; + private final BlockValueGetter blockGetter; + + public ArrowSmallIntWriter(SmallIntVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getInt(block, position)); + } + } + + private static class ArrowIntWriter + extends ArrowVectorWriter + { + private final IntVector vector; + private final BlockValueGetter blockGetter; + + public ArrowIntWriter(IntVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getInt(block, position)); + } + } + + private static class ArrowLongWriter + extends ArrowVectorWriter + { + private final BigIntVector vector; + private final BlockValueGetter blockGetter; + + public ArrowLongWriter(BigIntVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getLong(block, position)); + } + } + + private static class ArrowRealWriter + extends ArrowVectorWriter + { + private final Float4Vector vector; + private final BlockValueGetter blockGetter; + + public ArrowRealWriter(Float4Vector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getFloat(block, position)); + } + } + + private static class ArrowDoubleWriter + extends ArrowVectorWriter + { + private final Float8Vector vector; + private final BlockValueGetter blockGetter; + + public ArrowDoubleWriter(Float8Vector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getDouble(block, position)); + } + } + + private static class ArrowDecimalWriter + extends ArrowVectorWriter + { + private final DecimalVector vector; + private final BlockValueGetter blockGetter; + + public ArrowDecimalWriter(DecimalVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getDecimal(block, position)); + } + } + + private static class ArrowVariableWidthWriter + extends ArrowVectorWriter + { + private final BaseVariableWidthVector vector; + private final BlockValueGetter blockGetter; + + public ArrowVariableWidthWriter(BaseVariableWidthVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.setSafe(index, blockGetter.getBytes(block, position)); + } + } + + private static class ArrowDateWriter + extends ArrowVectorWriter + { + private final DateDayVector vector; + private final BlockValueGetter blockGetter; + + public ArrowDateWriter(DateDayVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + // Value is supplied as days since epoch UTC + vector.set(index, blockGetter.getInt(block, position)); + } + } + + private static class ArrowTimeWriter + extends ArrowVectorWriter + { + private final TimeMilliVector vector; + private final BlockValueGetter blockGetter; + + public ArrowTimeWriter(TimeMilliVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getInt(block, position)); + } + } + + private static class ArrowTimeStampWriter + extends ArrowVectorWriter + { + private final TimeStampVector vector; + private final BlockValueGetter blockGetter; + + public ArrowTimeStampWriter(TimeStampVector vector, BlockValueGetter blockGetter) + { + this.vector = vector; + this.blockGetter = blockGetter; + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + vector.set(index, blockGetter.getLong(block, position)); + } + } + + private static class ArrowListWriter + extends ArrowVectorWriter + { + private final ListVector vector; + private final BlockArrayGetter blockArrayGetter; + private final ArrowVectorWriter childWriter; + + public ArrowListWriter(ListVector vector, BlockArrayGetter blockArrayGetter) + { + this.vector = vector; + this.blockArrayGetter = blockArrayGetter; + this.childWriter = createArrowWriter(vector.getDataVector(), blockArrayGetter.getType().getElementType()); + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + int dataIndex = vector.startNewValue(index); + Block elementBlock = blockArrayGetter.getChildBlock(block, position); + for (int i = 0; i < elementBlock.getPositionCount(); ++i) { + childWriter.writeBlock(dataIndex + i, elementBlock, i); + } + vector.endValue(index, elementBlock.getPositionCount()); + } + } + + private static class ArrowMapWriter + extends ArrowVectorWriter + { + private final MapVector vector; + private final StructVector structVector; + private final BlockMapGetter blockMapGetter; + private final ArrowVectorWriter keyWriter; + private final ArrowVectorWriter valueWriter; + + public ArrowMapWriter(MapVector vector, BlockMapGetter blockMapGetter) + { + this.vector = vector; + this.structVector = (StructVector) vector.getDataVector(); + this.blockMapGetter = blockMapGetter; + this.keyWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(0), blockMapGetter.getType().getKeyType()); + this.valueWriter = createArrowWriter((FieldVector) structVector.getChildByOrdinal(1), blockMapGetter.getType().getValueType()); + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + Block singleMapBlock = blockMapGetter.getChildBlock(block, position); + int dataIndex = vector.startNewValue(index); + int numPairs = singleMapBlock.getPositionCount() / 2; + for (int i = 0; i < numPairs; ++i) { + keyWriter.writeBlock(dataIndex + i, singleMapBlock, i * 2); + valueWriter.writeBlock(dataIndex + i, singleMapBlock, (i * 2) + 1); + structVector.setIndexDefined(dataIndex + i); + } + vector.endValue(index, numPairs); + } + } + + private static class ArrowStructWriter + extends ArrowVectorWriter + { + private final StructVector vector; + private final BlockRowGetter blockRowGetter; + private final List childWriters; + + public ArrowStructWriter(StructVector vector, BlockRowGetter blockRowGetter) + { + this.vector = vector; + this.blockRowGetter = blockRowGetter; + ImmutableList.Builder writerBuilder = ImmutableList.builder(); + List fieldVectors = vector.getChildrenFromFields(); + for (int i = 0; i < fieldVectors.size(); i++) { + Type childType = blockRowGetter.getType().getTypeParameters().get(i); + writerBuilder.add(createArrowWriter(fieldVectors.get(i), childType)); + } + this.childWriters = writerBuilder.build(); + } + + @Override + public void writeNull(int index) + { + vector.setNull(index); + } + + @Override + public void writeBlock(int index, Block block, int position) + { + Block singleRowBlock = blockRowGetter.getChildBlock(block, position); + for (int i = 0; i < childWriters.size(); ++i) { + childWriters.get(i).writeBlock(index, singleRowBlock, i); + } + vector.setIndexDefined(index); + } + } +} diff --git a/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java new file mode 100644 index 0000000000000..a922c9f00c0a9 --- /dev/null +++ b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBatchSource.java @@ -0,0 +1,452 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorPageSource; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.Text; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.time.LocalDateTime; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.testing.TestingEnvironment.FUNCTION_AND_TYPE_MANAGER; +import static com.facebook.presto.util.DateTimeUtils.parseTimestampWithoutTimeZone; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +public class TestArrowBatchSource +{ + private static final int MAX_ROWS_PER_BATCH = 1000; + private BufferAllocator allocator; + private TestArrowBlockBuilder arrowBlockBuilder; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + arrowBlockBuilder = new TestArrowBlockBuilder(FUNCTION_AND_TYPE_MANAGER); + } + + @AfterClass + public void tearDown() + { + allocator.close(); + } + + @Test + public void testPrimitiveTypes() + throws IOException + { + try (BitVector bitVector = new BitVector("bitVector", allocator); + TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator); + SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator); + IntVector intVector = new IntVector("intVector", allocator); + BigIntVector longVector = new BigIntVector("bigIntVector", allocator); + Float4Vector floatVector = new Float4Vector("floatVector", allocator); + Float8Vector doubleVector = new Float8Vector("doubleVector", allocator); + VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(bitVector, tinyIntVector, smallIntVector, intVector, longVector, floatVector, doubleVector))) { + final int numPages = 5; + final int numValues = 10; + intVector.allocateNew(numValues); + + for (int i = 0; i < numValues; i++) { + bitVector.setSafe(i, i % 2); + tinyIntVector.setSafe(i, i); + smallIntVector.setSafe(i, i); + intVector.setSafe(i, i); + longVector.setSafe(i, i); + floatVector.setSafe(i, i * 1.1f); + doubleVector.setSafe(i, i * 1.1); + } + + // Add null values + bitVector.setNull(2); + tinyIntVector.setNull(3); + smallIntVector.setNull(4); + intVector.setNull(5); + longVector.setNull(6); + floatVector.setNull(7); + doubleVector.setNull(8); + + expectedRoot.setRowCount(numValues); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 5); + + // Set for 1 batch per page + int batchCount = 0; + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, numValues)) { + while (arrowBatchSource.nextBatch()) { + assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot())); + ++batchCount; + } + } + assertEquals(batchCount, numPages); + } + } + + @Test + public void testDateTimeTypes() + throws IOException + { + try (IntVector intVector = new IntVector("id", allocator); + DateDayVector dateVector = new DateDayVector("date", allocator); + TimeMilliVector timeVector = new TimeMilliVector("time", allocator); + TimeStampMilliVector timestampVector = new TimeStampMilliVector("timestamp", allocator); + VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, dateVector, timeVector, timestampVector))) { + List values = ImmutableList.of( + "1970-01-01T00:00:00", + "2024-01-01T01:01:01", + "2024-01-02T12:00:00", + "2112-12-31T23:58:00", + "1968-07-05T08:15:12.345"); + + for (int i = 0; i < values.size(); i++) { + intVector.setSafe(i, i); + LocalDateTime dateTime = LocalDateTime.parse(values.get(i)); + // First vector value is explicitly set to 0 to ensure no issues with parsing + dateVector.setSafe(i, i == 0 ? 0 : (int) dateTime.toLocalDate().toEpochDay()); + timeVector.setSafe(i, i == 0 ? 0 : (int) TimeUnit.NANOSECONDS.toMillis(dateTime.toLocalTime().toNanoOfDay())); + timestampVector.setSafe(i, i == 0 ? 0 : parseTimestampWithoutTimeZone(values.get(i).replace("T", " "))); + } + + // Add null values at last row + dateVector.setNull(values.size()); + timeVector.setNull(values.size()); + timestampVector.setNull(values.size()); + + expectedRoot.setRowCount(values.size() + 1); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1); + + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) { + assertTrue(arrowBatchSource.nextBatch()); + assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot())); + assertFalse(arrowBatchSource.nextBatch()); + } + } + } + + @Test + public void testVarCharType() + throws IOException + { + final int numValues = 100; + final int maxRowsPerBatch = 33; + + try (IntVector intVector = new IntVector("id", allocator); + VarCharVector stringVector = new VarCharVector("varchar", allocator); + VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, stringVector))) { + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + String value = stringData.substring(0, i % stringData.length()); + stringVector.setSafe(i, new Text(value)); + } + + // Add null values + stringVector.setNull(2); + + expectedRoot.setRowCount(numValues); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1); + + int actualRowCount = 0; + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, maxRowsPerBatch)) { + while (arrowBatchSource.nextBatch()) { + try (VectorSchemaRoot expectedSlice = expectedRoot.slice(actualRowCount, Math.min(numValues - actualRowCount, maxRowsPerBatch))) { + assertTrue(expectedSlice.equals(arrowBatchSource.getVectorSchemaRoot())); + } + actualRowCount += arrowBatchSource.getVectorSchemaRoot().getRowCount(); + } + assertEquals(actualRowCount, numValues); + } + } + } + + @Test + public void testArrayType() + throws IOException + { + try (IntVector intVector = new IntVector("id", allocator); + ListVector listVectorInt = ListVector.empty("array-int", allocator); + ListVector listVectorVarchar = ListVector.empty("array-varchar", allocator)) { + // Add the element vectors + listVectorInt.addOrGetVector(FieldType.nullable(Types.MinorType.INT.getType())); + listVectorVarchar.addOrGetVector(FieldType.nullable(Types.MinorType.VARCHAR.getType())); + listVectorInt.allocateNew(); + listVectorVarchar.allocateNew(); + + final int numValues = 10; + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + final UnionListWriter writerInt = listVectorInt.getWriter(); + final UnionListWriter writerVarchar = listVectorVarchar.getWriter(); + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + writerInt.setPosition(i); + // Need to set nulls during write for lists + if (i == 5) { + writerInt.writeNull(); + writerVarchar.writeNull(); + } + else { + writerInt.startList(); + writerVarchar.startList(); + for (int j = 0; j < i + i; j++) { + writerInt.integer().writeInt(i * j); + String stringValue = stringData.substring(0, i % stringData.length()); + writerVarchar.writeVarChar(new Text(stringValue)); + } + } + writerInt.endList(); + writerVarchar.endList(); + } + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, listVectorInt, listVectorVarchar))) { + expectedRoot.setRowCount(numValues); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1); + + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) { + assertTrue(arrowBatchSource.nextBatch()); + assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot())); + assertFalse(arrowBatchSource.nextBatch()); + } + } + } + } + + @Test + void testMapType() + throws IOException + { + try (IntVector intVector = new IntVector("id", allocator); + MapVector mapLongVector = MapVector.empty("map-long-long", allocator, false); + MapVector mapVarcharVector = MapVector.empty("map-long-varchar", allocator, false)) { + UnionMapWriter mapLongWriter = mapLongVector.getWriter(); + UnionMapWriter mapVarcharWriter = mapVarcharVector.getWriter(); + mapLongWriter.allocate(); + mapVarcharWriter.allocate(); + + final int numValues = 10; + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + mapLongWriter.setPosition(i); + mapLongWriter.startMap(); + + mapVarcharWriter.setPosition(i); + mapVarcharWriter.startMap(); + + for (int j = 0; j < i; j++) { + mapLongWriter.startEntry(); + mapLongWriter.key().bigInt().writeBigInt(j); + mapLongWriter.value().bigInt().writeBigInt(i * j); + mapLongWriter.endEntry(); + mapVarcharWriter.startEntry(); + mapVarcharWriter.key().bigInt().writeBigInt(j * j); + String stringValue = stringData.substring(0, i % stringData.length()); + mapVarcharWriter.value().varChar().writeVarChar(new Text(stringValue)); + mapVarcharWriter.endEntry(); + } + mapLongWriter.endMap(); + mapVarcharWriter.endMap(); + } + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, mapLongVector, mapVarcharVector))) { + expectedRoot.setRowCount(numValues); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1); + + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) { + assertTrue(arrowBatchSource.nextBatch()); + assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot())); + assertFalse(arrowBatchSource.nextBatch()); + } + } + } + } + + @Test + void testRowType() + throws IOException + { + try (IntVector intVector = new IntVector("id", allocator); + StructVector structVector = StructVector.empty("struct", allocator)) { + final BigIntVector childLongVector + = structVector.addOrGet("long", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + final VarCharVector childVarcharVector + = structVector.addOrGet("varchar", FieldType.nullable(ArrowType.Utf8.INSTANCE), VarCharVector.class); + childLongVector.allocateNew(); + childVarcharVector.allocateNew(); + + final int numValues = 10; + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + childLongVector.setSafe(i, i * i); + String stringValue = stringData.substring(0, i % stringData.length()); + childVarcharVector.setSafe(i, new Text(stringValue)); + structVector.setIndexDefined(i); + } + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, structVector))) { + expectedRoot.setRowCount(numValues); + + TestArrowPageSource pageSource = TestArrowPageSource.create(arrowBlockBuilder, expectedRoot, 1); + + try (ArrowBatchSource arrowBatchSource = new ArrowBatchSource(allocator, pageSource.getColumns(), pageSource, MAX_ROWS_PER_BATCH)) { + assertTrue(arrowBatchSource.nextBatch()); + assertTrue(expectedRoot.equals(arrowBatchSource.getVectorSchemaRoot())); + assertFalse(arrowBatchSource.nextBatch()); + } + } + } + } + + private static class TestArrowBlockBuilder + extends ArrowBlockBuilder + { + public TestArrowBlockBuilder(TypeManager typeManager) + { + super(typeManager); + } + + public Type getPrestoType(Field field) + { + return getPrestoTypeFromArrowField(field); + } + } + + private static class TestArrowPageSource + implements ConnectorPageSource + { + private final int rowCount; + private final List blocks; + private final List columns; + private int pagesRemaining; + + private TestArrowPageSource(int rowCount, List blocks, List columns, int numPages) + { + this.rowCount = rowCount; + this.blocks = blocks; + this.columns = columns; + this.pagesRemaining = numPages; + } + + public static TestArrowPageSource create(TestArrowBlockBuilder arrowBlockBuilder, VectorSchemaRoot root, int numPages) + { + List blocks = new ArrayList<>(); + List columns = new ArrayList<>(); + for (FieldVector vector : root.getFieldVectors()) { + ColumnMetadata column = ColumnMetadata.builder() + .setName(vector.getName()) + .setType(arrowBlockBuilder.getPrestoType(vector.getField())) + .setNullable(vector.getField().isNullable()).build(); + Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, column.getType(), null); + blocks.add(block); + columns.add(column); + } + + return new TestArrowPageSource(root.getRowCount(), blocks, columns, numPages); + } + + public List getColumns() + { + return columns; + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getCompletedPositions() + { + return 0; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return pagesRemaining == 0; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public Page getNextPage() + { + if (pagesRemaining == 0) { + return null; + } + --pagesRemaining; + return new Page(rowCount, blocks.toArray(new Block[0])); + } + + @Override + public void close() + { + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java b/presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java similarity index 100% rename from presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java rename to presto-common-arrow/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java diff --git a/presto-flight-shim/etc/flightshim.properties b/presto-flight-shim/etc/flightshim.properties new file mode 100644 index 0000000000000..a86cb845c4685 --- /dev/null +++ b/presto-flight-shim/etc/flightshim.properties @@ -0,0 +1,10 @@ +flight-shim.server=localhost +flight-shim.server.port=9999 +flight-shim.server-ssl-certificate-file=src/test/resources/certs/server.crt +flight-shim.server-ssl-key-file=src/test/resources/certs/server.key + +plugin.bundles=\ + ../presto-oracle/pom.xml,\ + ../presto-postgresql/pom.xml,\ + ../presto-mysql/pom.xml,\ + ../presto-singlestore/pom.xml diff --git a/presto-flight-shim/etc/log.properties b/presto-flight-shim/etc/log.properties new file mode 100644 index 0000000000000..9b50e185c74ee --- /dev/null +++ b/presto-flight-shim/etc/log.properties @@ -0,0 +1,13 @@ +# +# WARNING +# ^^^^^^^ +# This configuration file is for development only and should NOT be used be +# used in production. For example configuration, see the Presto documentation. +# +com.=INFO +com.facebook.presto.execution=DEBUG +com.facebook.presto=INFO +com.sun.jersey.guice.spi.container.GuiceComponentProviderFactory=WARN +com.ning.http.client=WARN +com.facebook.presto.server.PluginManager=DEBUG +com.facebook.presto.flightshim=DEBUG diff --git a/presto-flight-shim/pom.xml b/presto-flight-shim/pom.xml new file mode 100644 index 0000000000000..83cf59ec6704e --- /dev/null +++ b/presto-flight-shim/pom.xml @@ -0,0 +1,425 @@ + + + 4.0.0 + + + com.facebook.presto + presto-root + 0.297-SNAPSHOT + + + presto-flight-shim + presto-flight-shim + Presto - Shim with Apache Arrow Flight server for Presto native federation + + + ${project.parent.basedir} + -Xss10M + true + + + + + + javax.annotation + javax.annotation-api + 1.3.2 + + + + + + + com.facebook.presto + presto-common-arrow + + + + org.apache.arrow + arrow-memory-core + + + org.slf4j + slf4j-api + + + + + + org.apache.arrow + arrow-memory-netty + runtime + + + org.slf4j + slf4j-api + + + + + + org.apache.arrow + arrow-vector + + + org.slf4j + slf4j-api + + + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + + org.apache.arrow + flight-core + + + org.slf4j + slf4j-api + + + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + log + + + + com.google.guava + guava + + + + jakarta.inject + jakarta.inject-api + + + + javax.inject + javax.inject + + + + com.facebook.presto + presto-spi + + + + com.facebook.airlift + json + + + + com.fasterxml.jackson.core + jackson-annotations + + + + com.facebook.presto + presto-common + + + + com.fasterxml.jackson.core + jackson-core + + + + com.fasterxml.jackson.core + jackson-databind + + + + jakarta.annotation + jakarta.annotation-api + + + + jakarta.validation + jakarta.validation-api + + + + com.google.inject + guice + + + + com.google.http-client + google-http-client + 1.47.1 + + + + com.facebook.airlift + configuration + + + + com.facebook.airlift + concurrent + + + + com.facebook.presto + presto-main-base + + + + com.facebook.presto + presto-base-jdbc + + + + com.facebook.presto + presto-analyzer + + + + org.weakref + jmxutils + + + + io.airlift.resolver + resolver + + + + + + org.jdbi + jdbi3-core + test + + + + org.testng + testng + test + + + + io.airlift.tpch + tpch + test + + + + com.facebook.presto + presto-tpch + test + + + + com.facebook.presto + presto-testng-services + test + + + + com.facebook.airlift + testing + test + + + + com.facebook.presto + presto-tests + test + + + + com.facebook.presto + presto-main-base + test-jar + test + + + + com.facebook.presto + presto-base-arrow-flight + test + + + + com.h2database + h2 + test + + + + org.testcontainers + testcontainers + test + + + org.slf4j + slf4j-api + + + org.jetbrains + annotations + + + + + + org.testcontainers + testcontainers-postgresql + test + + + + com.facebook.presto + presto-postgresql + ${project.version} + test + + + + com.facebook.presto + presto-postgresql + ${project.version} + test-jar + test + + + + org.testcontainers + testcontainers-mysql + test + + + + com.facebook.presto + presto-mysql + ${project.version} + test + + + + com.facebook.presto + presto-mysql + ${project.version} + test-jar + test + + + + com.facebook.presto + presto-oracle + ${project.version} + test + + + + com.facebook.presto + presto-oracle + ${project.version} + test-jar + test + + + + org.testcontainers + oracle-xe + test + 1.14.3 + + + org.slf4j + slf4j-api + + + org.jetbrains + annotations + + + + + + com.facebook.presto + presto-singlestore + ${project.version} + test + + + + com.facebook.presto + presto-singlestore + ${project.version} + test-jar + test + + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + org.apache.maven.plugins + maven-dependency-plugin + + + com.fasterxml.jackson.core:jackson-databind + com.facebook.airlift:log-manager + javax.inject:javax.inject + com.facebook.presto:presto-base-jdbc + com.google.errorprone:error_prone_annotations + + + + + org.basepom.maven + duplicate-finder-maven-plugin + 1.2.1 + + + module-info + META-INF.versions.9.module-info + + + arrow-git.properties + about.html + + + + + + check + + + + + + + + + + java17 + + 17 + + + + -Xss10M + --add-opens=java.base/java.nio=org.apache.arrow.memory.core,ALL-UNNAMED + + + + + diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java new file mode 100644 index 0000000000000..9aeff050e2077 --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimConfig.java @@ -0,0 +1,150 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.configuration.Config; +import com.facebook.airlift.configuration.ConfigDescription; +import jakarta.validation.constraints.Max; +import jakarta.validation.constraints.Min; + +public class FlightShimConfig +{ + public static final String CONFIG_PREFIX = "flight-shim"; + private static final int MAX_ROWS_PER_BATCH_DEFAULT = 10000; + private static final int READ_THREAD_POOL_SIZE = 16; + private String serverName; + private Integer serverPort; + private String serverSSLCertificateFile; + private String serverSSLKeyFile; + private String clientSSLCertificateFile; + private String clientSSLKeyFile; + private boolean serverSslEnabled = true; + private int maxRowsPerBatch = MAX_ROWS_PER_BATCH_DEFAULT; + private int readThreadPoolSize = READ_THREAD_POOL_SIZE; + + public String getServerName() + { + return serverName; + } + + @Config("server") + public FlightShimConfig setServerName(String serverName) + { + this.serverName = serverName; + return this; + } + + public Integer getServerPort() + { + return serverPort; + } + + @Config("server.port") + public FlightShimConfig setServerPort(Integer serverPort) + { + this.serverPort = serverPort; + return this; + } + + public boolean getServerSslEnabled() + { + return serverSslEnabled; + } + + @Config("server-ssl-enabled") + public FlightShimConfig setServerSslEnabled(boolean serverSslEnabled) + { + this.serverSslEnabled = serverSslEnabled; + return this; + } + + public String getServerSSLCertificateFile() + { + return serverSSLCertificateFile; + } + + @Config("server-ssl-certificate-file") + public FlightShimConfig setServerSSLCertificateFile(String serverSSLCertificateFile) + { + this.serverSSLCertificateFile = serverSSLCertificateFile; + return this; + } + + public String getServerSSLKeyFile() + { + return serverSSLKeyFile; + } + + @Config("server-ssl-key-file") + public FlightShimConfig setServerSSLKeyFile(String serverSSLKeyFile) + { + this.serverSSLKeyFile = serverSSLKeyFile; + return this; + } + + public String getClientSSLCertificateFile() + { + return clientSSLCertificateFile; + } + + @ConfigDescription("Path to the client SSL certificate used for mTLS authentication with the Flight server") + @Config("client-ssl-certificate-file") + public FlightShimConfig setClientSSLCertificateFile(String clientSSLCertificateFile) + { + this.clientSSLCertificateFile = clientSSLCertificateFile; + return this; + } + + public String getClientSSLKeyFile() + { + return clientSSLKeyFile; + } + + @ConfigDescription("Path to the client SSL key used for mTLS authentication with the Flight server") + @Config("client-ssl-key-file") + public FlightShimConfig setClientSSLKeyFile(String clientSSLKeyFile) + { + this.clientSSLKeyFile = clientSSLKeyFile; + return this; + } + + public int getMaxRowsPerBatch() + { + return maxRowsPerBatch; + } + + @Config("max-rows-per-batch") + @Min(1) + @Max(1000000) + @ConfigDescription("Sets the maximum number of rows an Arrow record batch will have before sending to the client") + public FlightShimConfig setMaxRowsPerBatch(int maxRowsPerBatch) + { + this.maxRowsPerBatch = maxRowsPerBatch; + return this; + } + + @Config("thread-pool-size") + @Min(1) + @ConfigDescription("Size of thread pool to used to handle read requests") + public FlightShimConfig setReadThreadPoolSize(int readThreadPoolSize) + { + this.readThreadPoolSize = readThreadPoolSize; + return this; + } + + public int getReadThreadPoolSize() + { + return this.readThreadPoolSize; + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java new file mode 100644 index 0000000000000..9d272d91e1d69 --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimModule.java @@ -0,0 +1,190 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.configuration.AbstractConfigurationAwareModule; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.SystemSessionProperties; +import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncoding; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.block.BlockEncodingSerde; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.connector.ConnectorManager; +import com.facebook.presto.cost.HistoryBasedOptimizationConfig; +import com.facebook.presto.execution.QueryManagerConfig; +import com.facebook.presto.execution.TaskManagerConfig; +import com.facebook.presto.execution.scheduler.NodeSchedulerConfig; +import com.facebook.presto.execution.warnings.WarningCollectorConfig; +import com.facebook.presto.memory.MemoryManagerConfig; +import com.facebook.presto.memory.NodeMemoryConfig; +import com.facebook.presto.metadata.AnalyzePropertyManager; +import com.facebook.presto.metadata.BuiltInProcedureRegistry; +import com.facebook.presto.metadata.ColumnPropertyManager; +import com.facebook.presto.metadata.FunctionAndTypeManager; +import com.facebook.presto.metadata.HandleJsonModule; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.InternalNodeManager; +import com.facebook.presto.metadata.MaterializedViewPropertyManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.MetadataManager; +import com.facebook.presto.metadata.SchemaPropertyManager; +import com.facebook.presto.metadata.SessionPropertyManager; +import com.facebook.presto.metadata.SessionPropertyProviderConfig; +import com.facebook.presto.metadata.StaticCatalogStoreConfig; +import com.facebook.presto.metadata.TableFunctionRegistry; +import com.facebook.presto.metadata.TablePropertyManager; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.server.PluginManagerConfig; +import com.facebook.presto.server.security.SecurityConfig; +import com.facebook.presto.sessionpropertyproviders.NativeWorkerSessionPropertyProvider; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.analyzer.ViewDefinition; +import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.session.WorkerSessionPropertyProvider; +import com.facebook.presto.spiller.NodeSpillConfig; +import com.facebook.presto.sql.SqlEnvironmentConfig; +import com.facebook.presto.sql.analyzer.FeaturesConfig; +import com.facebook.presto.sql.analyzer.FunctionsConfig; +import com.facebook.presto.sql.analyzer.JavaFeaturesConfig; +import com.facebook.presto.sql.planner.CompilerConfig; +import com.facebook.presto.tracing.TracingConfig; +import com.facebook.presto.transaction.NoOpTransactionManager; +import com.facebook.presto.transaction.TransactionManager; +import com.facebook.presto.transaction.TransactionManagerConfig; +import com.facebook.presto.type.TypeDeserializer; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.inject.Binder; +import com.google.inject.Provides; +import com.google.inject.Scopes; +import com.google.inject.multibindings.MapBinder; +import jakarta.inject.Singleton; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.SynchronousQueue; +import java.util.concurrent.ThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.concurrent.Threads.threadsNamed; +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static com.facebook.airlift.json.JsonBinder.jsonBinder; +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; +import static com.google.inject.multibindings.MapBinder.newMapBinder; +import static com.google.inject.multibindings.Multibinder.newSetBinder; +import static org.weakref.jmx.guice.ExportBinder.newExporter; + +public class FlightShimModule + extends AbstractConfigurationAwareModule +{ + @Override + protected void setup(Binder binder) + { + binder.bind(ConnectorManager.class).toProvider(() -> null); + binder.bind(FlightShimPluginManager.class).in(Scopes.SINGLETON); + binder.bind(BufferAllocator.class).to(RootAllocator.class).in(Scopes.SINGLETON); + binder.bind(FlightShimProducer.class).in(Scopes.SINGLETON); + + binder.bind(FlightShimServerExecutionMBean.class).in(Scopes.SINGLETON); + newExporter(binder).export(FlightShimServerExecutionMBean.class).withGeneratedName(); + + configBinder(binder).bindConfig(FlightShimConfig.class, FlightShimConfig.CONFIG_PREFIX); + configBinder(binder).bindConfig(PluginManagerConfig.class); + configBinder(binder).bindConfig(StaticCatalogStoreConfig.class); + + // configs + configBinder(binder).bindConfig(QueryManagerConfig.class); + configBinder(binder).bindConfig(TaskManagerConfig.class); + configBinder(binder).bindConfig(NodeSchedulerConfig.class); + configBinder(binder).bindConfig(WarningCollectorConfig.class); + configBinder(binder).bindConfig(MemoryManagerConfig.class); + configBinder(binder).bindConfig(NodeMemoryConfig.class); + configBinder(binder).bindConfig(SessionPropertyProviderConfig.class); + configBinder(binder).bindConfig(SecurityConfig.class); + configBinder(binder).bindConfig(NodeSpillConfig.class); + configBinder(binder).bindConfig(SqlEnvironmentConfig.class); + configBinder(binder).bindConfig(CompilerConfig.class); + configBinder(binder).bindConfig(TracingConfig.class); + + // json codecs + jsonCodecBinder(binder).bindJsonCodec(ViewDefinition.class); + + // Worker session property providers + MapBinder mapBinder = + newMapBinder(binder, String.class, WorkerSessionPropertyProvider.class); + mapBinder.addBinding("native-worker").to(NativeWorkerSessionPropertyProvider.class).in(Scopes.SINGLETON); + + // history statistics + configBinder(binder).bindConfig(HistoryBasedOptimizationConfig.class); + + // property managers + binder.bind(SystemSessionProperties.class).in(Scopes.SINGLETON); + binder.bind(SessionPropertyManager.class).in(Scopes.SINGLETON); + binder.bind(SchemaPropertyManager.class).in(Scopes.SINGLETON); + binder.bind(TablePropertyManager.class).in(Scopes.SINGLETON); + binder.bind(MaterializedViewPropertyManager.class).in(Scopes.SINGLETON); + binder.bind(ColumnPropertyManager.class).in(Scopes.SINGLETON); + binder.bind(AnalyzePropertyManager.class).in(Scopes.SINGLETON); + + // transaction manager + configBinder(binder).bindConfig(TransactionManagerConfig.class); + // Install no-op transaction manager on workers, since only coordinators manage transactions. + binder.bind(TransactionManager.class).to(NoOpTransactionManager.class).in(Scopes.SINGLETON); + + // metadata + binder.bind(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + binder.bind(TableFunctionRegistry.class).in(Scopes.SINGLETON); + binder.bind(MetadataManager.class).in(Scopes.SINGLETON); + binder.bind(Metadata.class).to(MetadataManager.class).in(Scopes.SINGLETON); + binder.bind(ProcedureRegistry.class).to(BuiltInProcedureRegistry.class).in(Scopes.SINGLETON); + + // type + binder.bind(TypeManager.class).to(FunctionAndTypeManager.class).in(Scopes.SINGLETON); + jsonBinder(binder).addDeserializerBinding(Type.class).to(TypeDeserializer.class); + newSetBinder(binder, Type.class); + binder.bind(TypeDeserializer.class).in(Scopes.SINGLETON); + + // block encodings + binder.bind(BlockEncodingManager.class).in(Scopes.SINGLETON); + binder.bind(BlockEncodingSerde.class).to(BlockEncodingManager.class).in(Scopes.SINGLETON); + newSetBinder(binder, BlockEncoding.class); + jsonBinder(binder).addSerializerBinding(Block.class).to(BlockJsonSerde.Serializer.class); + jsonBinder(binder).addDeserializerBinding(Block.class).to(BlockJsonSerde.Deserializer.class); + + // handle resolver + binder.install(new HandleJsonModule()); + binder.bind(ObjectMapper.class).toProvider(JsonObjectMapperProvider.class); + + // features config + configBinder(binder).bindConfig(FeaturesConfig.class); + configBinder(binder).bindConfig(FunctionsConfig.class); + configBinder(binder).bindConfig(JavaFeaturesConfig.class); + + // Node manager binding + binder.bind(InternalNodeManager.class).to(InMemoryNodeManager.class).in(Scopes.SINGLETON); + binder.bind(PluginNodeManager.class).in(Scopes.SINGLETON); + binder.bind(NodeManager.class).to(PluginNodeManager.class).in(Scopes.SINGLETON); + } + + @Provides + @Singleton + @ForFlightShimServer + public static ExecutorService createFlightShimServerExecutor(FlightShimConfig config) + { + return new ThreadPoolExecutor(0, config.getReadThreadPoolSize(), 1L, TimeUnit.MINUTES, new SynchronousQueue<>(), threadsNamed("flight-shim-%s"), new ThreadPoolExecutor.CallerRunsPolicy()); + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java new file mode 100644 index 0000000000000..0494deb853322 --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimPluginManager.java @@ -0,0 +1,469 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.airlift.log.Logger; +import com.facebook.presto.GroupByHashPageIndexerFactory; +import com.facebook.presto.PagesIndexPageSorter; +import com.facebook.presto.block.BlockJsonSerde; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockEncodingManager; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.connector.ConnectorContextInstance; +import com.facebook.presto.cost.ConnectorFilterStatsCalculatorService; +import com.facebook.presto.cost.FilterStatsCalculator; +import com.facebook.presto.cost.ScalarStatsCalculator; +import com.facebook.presto.cost.StatsNormalizer; +import com.facebook.presto.metadata.InMemoryNodeManager; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.metadata.StaticCatalogStoreConfig; +import com.facebook.presto.nodeManager.PluginNodeManager; +import com.facebook.presto.operator.PagesIndex; +import com.facebook.presto.server.PluginInstaller; +import com.facebook.presto.server.PluginManagerConfig; +import com.facebook.presto.server.PluginManagerUtil; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.CoordinatorPlugin; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.PrestoException; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.procedure.ProcedureRegistry; +import com.facebook.presto.spi.relation.DeterminismEvaluator; +import com.facebook.presto.spi.relation.DomainTranslator; +import com.facebook.presto.spi.relation.ExpressionOptimizer; +import com.facebook.presto.spi.relation.ExpressionOptimizerProvider; +import com.facebook.presto.spi.relation.PredicateCompiler; +import com.facebook.presto.spi.relation.RowExpression; +import com.facebook.presto.spi.relation.RowExpressionService; +import com.facebook.presto.sql.gen.JoinCompiler; +import com.facebook.presto.sql.gen.RowExpressionPredicateCompiler; +import com.facebook.presto.sql.planner.planPrinter.RowExpressionFormatter; +import com.facebook.presto.sql.relational.FunctionResolution; +import com.facebook.presto.sql.relational.RowExpressionDeterminismEvaluator; +import com.facebook.presto.sql.relational.RowExpressionDomainTranslator; +import com.facebook.presto.sql.relational.RowExpressionOptimizer; +import com.facebook.presto.type.TypeDeserializer; +import com.fasterxml.jackson.core.JsonParser; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.JsonDeserializer; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; +import com.google.common.io.Files; +import com.google.inject.Inject; +import io.airlift.resolver.ArtifactResolver; +import jakarta.annotation.PreDestroy; + +import java.io.File; +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; + +import static com.facebook.presto.server.PluginManagerUtil.SPI_PACKAGES; +import static com.facebook.presto.spi.StandardErrorCode.GENERIC_INTERNAL_ERROR; +import static com.facebook.presto.util.PropertiesUtil.loadProperties; +import static com.google.api.client.util.Preconditions.checkState; +import static com.google.common.base.MoreObjects.firstNonNull; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class FlightShimPluginManager +{ + private static final Logger log = Logger.get(FlightShimPluginManager.class); + private static final String SERVICES_FILE = "META-INF/services/" + Plugin.class.getName(); + private final Map connectorFactories = new ConcurrentHashMap<>(); + private final Map connectors = new ConcurrentHashMap<>(); + private final File installedPluginsDir; + private final List plugins; + private final ArtifactResolver resolver; + private final AtomicBoolean pluginsLoading = new AtomicBoolean(); + private final AtomicBoolean pluginsLoaded = new AtomicBoolean(); + private final PluginInstaller pluginInstaller; + private final File catalogConfigurationDir; + private final Set disabledCatalogs; + private final AtomicBoolean catalogsLoading = new AtomicBoolean(); + private final AtomicBoolean catalogsLoaded = new AtomicBoolean(); + private final Map catalogPropertiesMap = new ConcurrentHashMap<>(); + private final AtomicBoolean stopped = new AtomicBoolean(); + private final Metadata metadata; + private final TypeDeserializer typeDeserializer; + private final BlockEncodingManager blockEncodingManager; + private final ProcedureRegistry procedureRegistry; + + @Inject + public FlightShimPluginManager( + PluginManagerConfig pluginManagerConfig, + StaticCatalogStoreConfig catalogStoreConfig, + Metadata metadata, + TypeDeserializer typeDeserializer, + BlockEncodingManager blockEncodingManager, + ProcedureRegistry procedureRegistry) + { + requireNonNull(pluginManagerConfig, "pluginManagerConfig is null"); + requireNonNull(catalogStoreConfig, "catalogStoreConfig is null"); + this.metadata = requireNonNull(metadata, "metadata is null"); + this.typeDeserializer = requireNonNull(typeDeserializer, "typeDeserializer is null"); + this.blockEncodingManager = requireNonNull(blockEncodingManager, "blockEncodingManager is null"); + this.installedPluginsDir = pluginManagerConfig.getInstalledPluginsDir(); + if (pluginManagerConfig.getPlugins() == null) { + this.plugins = ImmutableList.of(); + } + else { + this.plugins = ImmutableList.copyOf(pluginManagerConfig.getPlugins()); + } + this.resolver = new ArtifactResolver(pluginManagerConfig.getMavenLocalRepository(), pluginManagerConfig.getMavenRemoteRepository()); + this.pluginInstaller = new FlightServerPluginInstaller(); + this.catalogConfigurationDir = catalogStoreConfig.getCatalogConfigurationDir(); + this.disabledCatalogs = ImmutableSet.copyOf(firstNonNull(catalogStoreConfig.getDisabledCatalogs(), ImmutableList.of())); + this.procedureRegistry = requireNonNull(procedureRegistry, "procedureRegistry is null"); + } + + @PreDestroy + public synchronized void stop() + { + if (stopped.getAndSet(true)) { + return; + } + + for (Map.Entry entry : connectors.entrySet()) { + Connector connector = entry.getValue().getConnector(); + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(connector.getClass().getClassLoader())) { + connector.shutdown(); + } + catch (Throwable t) { + log.error(t, "Error shutting down connector: %s", entry.getKey()); + } + } + } + + public void loadPlugins() + throws Exception + { + PluginManagerUtil.loadPlugins( + pluginsLoading, + pluginsLoaded, + installedPluginsDir, + plugins, + null, + resolver, + SPI_PACKAGES, + null, + SERVICES_FILE, + pluginInstaller, + getClass().getClassLoader()); + } + + public void loadCatalogs() + throws Exception + { + if (!catalogsLoading.compareAndSet(false, true)) { + return; + } + + for (File file : listFiles(catalogConfigurationDir)) { + if (file.isFile() && file.getName().endsWith(".properties")) { + loadCatalog(file); + } + } + + catalogsLoaded.set(true); + } + + private void loadCatalog(File file) + throws Exception + { + String catalogName = Files.getNameWithoutExtension(file.getName()); + + log.info("-- Loading catalog properties %s --", file); + Map properties = loadProperties(file); + checkState(properties.containsKey("connector.name"), "Catalog configuration %s does not contain connector.name", file.getAbsoluteFile()); + + loadCatalog(catalogName, properties); + } + + private void loadCatalog(String catalogName, Map properties) + { + if (disabledCatalogs.contains(catalogName)) { + log.info("Skipping disabled catalog %s", catalogName); + return; + } + + log.info("-- Loading catalog %s --", catalogName); + + String connectorName = null; + ImmutableMap.Builder connectorProperties = ImmutableMap.builder(); + for (Map.Entry entry : properties.entrySet()) { + if (entry.getKey().equals("connector.name")) { + connectorName = entry.getValue(); + } + else { + connectorProperties.put(entry.getKey(), entry.getValue()); + } + } + + checkState(connectorName != null, "Configuration for catalog %s does not contain connector.name", catalogName); + + catalogPropertiesMap.put(catalogName, new CatalogPropertiesHolder(connectorName, connectorProperties.build())); + + log.info("-- Added catalog %s using connector %s --", catalogName, connectorName); + } + + @VisibleForTesting + void setCatalogProperties(String catalogName, String connectorName, Map properties) + { + catalogPropertiesMap.put(catalogName, new CatalogPropertiesHolder(connectorName, ImmutableMap.copyOf(properties))); + } + + private static List listFiles(File installedPluginsDir) + { + if (installedPluginsDir != null && installedPluginsDir.isDirectory()) { + File[] files = installedPluginsDir.listFiles(); + if (files != null) { + return ImmutableList.copyOf(files); + } + } + return ImmutableList.of(); + } + + private void registerPlugin(Plugin plugin) + { + for (ConnectorFactory factory : plugin.getConnectorFactories()) { + log.info("Registering connector %s", factory.getName()); + connectorFactories.put(factory.getName(), factory); + } + } + + public ConnectorHolder getConnector(String connectorId) + { + log.debug("FlightShimPluginManager getting connector: %s", connectorId); + CatalogPropertiesHolder catalogPropertiesHolder = catalogPropertiesMap.get(connectorId); + if (catalogPropertiesHolder == null) { + throw new IllegalArgumentException("Properties not loaded for " + connectorId); + } + final ImmutableMap config = catalogPropertiesHolder.getCatalogProperties(); + + // Create connector instances from factories as needed + return connectors.computeIfAbsent(catalogPropertiesHolder.getConnectorName(), name -> { + log.debug("Loading connector: %s", connectorId); + ConnectorFactory factory = connectorFactories.get(name); + requireNonNull(factory, format("No connector factory for %s", connectorId)); + + final RowExpressionDomainTranslator domainTranslator = new RowExpressionDomainTranslator(metadata); + final PredicateCompiler predicateCompiler = new RowExpressionPredicateCompiler(metadata); + final DeterminismEvaluator determinismEvaluator = new RowExpressionDeterminismEvaluator(metadata.getFunctionAndTypeManager()); + final RowExpressionService rowExpressionService = new RowExpressionService() + { + @Override + public DomainTranslator getDomainTranslator() + { + return domainTranslator; + } + + @Override + public ExpressionOptimizer getExpressionOptimizer(ConnectorSession session) + { + return new RowExpressionOptimizer(metadata); + } + + @Override + public PredicateCompiler getPredicateCompiler() + { + return predicateCompiler; + } + + @Override + public DeterminismEvaluator getDeterminismEvaluator() + { + return determinismEvaluator; + } + + @Override + public String formatRowExpression(ConnectorSession session, RowExpression expression) + { + return new RowExpressionFormatter(metadata.getFunctionAndTypeManager()).formatRowExpression(session, expression); + } + }; + + final ExpressionOptimizerProvider expressionOptimizerProvider = (ConnectorSession session) -> new RowExpressionOptimizer(metadata); + + ConnectorContext context = new ConnectorContextInstance( + new PluginNodeManager(new InMemoryNodeManager(), "flightconnector"), + metadata.getFunctionAndTypeManager(), + procedureRegistry, + metadata.getFunctionAndTypeManager(), + new FunctionResolution(metadata.getFunctionAndTypeManager().getFunctionAndTypeResolver()), + new PagesIndexPageSorter(new PagesIndex.TestingFactory(false)), + new GroupByHashPageIndexerFactory(new JoinCompiler(metadata)), + rowExpressionService, + new ConnectorFilterStatsCalculatorService(new FilterStatsCalculator(metadata, new ScalarStatsCalculator(metadata, expressionOptimizerProvider), new StatsNormalizer())), + blockEncodingManager, + () -> false); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(factory.getClass().getClassLoader())) { + ConnectorHolder holder = new ConnectorHolder(factory.create(name, config, context), factory.getHandleResolver(), typeDeserializer, blockEncodingManager); + log.debug("Finished loading connector: %s", connectorId); + return holder; + } + }); + } + + private class FlightServerPluginInstaller + implements PluginInstaller + { + @Override + public void installPlugin(Plugin plugin) + { + registerPlugin(plugin); + } + + @Override + public void installCoordinatorPlugin(CoordinatorPlugin plugin) {} + } + + private static class CatalogPropertiesHolder + { + private final String connectorName; + private final ImmutableMap catalogProperties; + + CatalogPropertiesHolder(String connectorName, ImmutableMap catalogProperties) + { + this.connectorName = connectorName; + this.catalogProperties = catalogProperties; + } + + String getConnectorName() + { + return connectorName; + } + + ImmutableMap getCatalogProperties() + { + return catalogProperties; + } + } + + public static class ConnectorHolder + { + private final Connector connector; + private final JsonCodec codecSplit; + private final JsonCodec codecColumnHandle; + private final JsonCodec codecTableHandle; + private final JsonCodec codecTableLayoutHandle; + private final JsonCodec codecTransactionHandle; + private final Method getColumnMetadataMethod; + + ConnectorHolder(Connector connector, ConnectorHandleResolver resolver, TypeDeserializer typeDeserializer, BlockEncodingManager blockEncodingManager) + { + this.connector = connector; + + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + JsonDeserializer columnDeserializer = new JsonDeserializer() + { + @Override + public ColumnHandle deserialize(JsonParser p, DeserializationContext ctxt) + throws IOException + { + return p.readValueAs(resolver.getColumnHandleClass()); + } + }; + BlockJsonSerde.Deserializer blockDeserializer = new BlockJsonSerde.Deserializer(blockEncodingManager); + provider.setJsonDeserializers(ImmutableMap.of( + Type.class, typeDeserializer, + ColumnHandle.class, columnDeserializer, + Block.class, blockDeserializer)); + JsonCodecFactory jsonCodecFactory = new JsonCodecFactory(provider); + + this.codecSplit = jsonCodecFactory.jsonCodec(resolver.getSplitClass()); + this.codecColumnHandle = jsonCodecFactory.jsonCodec(resolver.getColumnHandleClass()); + this.codecTableHandle = jsonCodecFactory.jsonCodec(resolver.getTableHandleClass()); + this.codecTableLayoutHandle = jsonCodecFactory.jsonCodec(resolver.getTableLayoutHandleClass()); + this.codecTransactionHandle = jsonCodecFactory.jsonCodec(resolver.getTransactionHandleClass()); + this.getColumnMetadataMethod = reflectGetColumnMetadata(resolver); + } + + Connector getConnector() + { + return connector; + } + + JsonCodec getCodecSplit() + { + return codecSplit; + } + + JsonCodec getCodecColumnHandle() + { + return codecColumnHandle; + } + + JsonCodec getCodecTableHandle() + { + return codecTableHandle; + } + + JsonCodec getCodecTableLayoutHandle() + { + return codecTableLayoutHandle; + } + + JsonCodec getCodecTransactionHandle() + { + return codecTransactionHandle; + } + + ColumnMetadata getColumnMetadata(ColumnHandle handle) + { + try { + return (ColumnMetadata) getColumnMetadataMethod.invoke(handle); + } + catch (InvocationTargetException | IllegalAccessException e) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unable to invoke method for getColumnMetadata", e); + } + } + + private static Method reflectGetColumnMetadata(ConnectorHandleResolver resolver) + { + try { + return resolver.getColumnHandleClass().getMethod("getColumnMetadata"); + } + catch (NoSuchMethodException e) { + try { + return resolver.getColumnHandleClass().getMethod("toColumnMetadata"); + } + catch (NoSuchMethodException e2) { + throw new PrestoException(GENERIC_INTERNAL_ERROR, "Unable to get column metadata from ColumnHandle", e); + } + } + } + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java new file mode 100644 index 0000000000000..79fefac418c96 --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimProducer.java @@ -0,0 +1,203 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.ArrowBatchSource; +import com.facebook.presto.Session; +import com.facebook.presto.common.RuntimeStats; +import com.facebook.presto.execution.QueryIdGenerator; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorId; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.TableHandle; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorRecordSetProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.security.Identity; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.facebook.presto.split.RecordPageSourceProvider; +import org.apache.arrow.flight.BackpressureStrategy; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; + +import javax.inject.Inject; + +import java.io.Closeable; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.metadata.SessionPropertyManager.createTestingSessionPropertyManager; +import static com.facebook.presto.testing.TestingSession.DEFAULT_TIME_ZONE_KEY; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static java.util.Collections.unmodifiableList; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class FlightShimProducer + extends NoOpFlightProducer + implements Closeable +{ + private static final Logger log = Logger.get(FlightShimProducer.class); + private static final JsonCodec REQUEST_JSON_CODEC = jsonCodec(FlightShimRequest.class); + private static final int CLIENT_POLL_TIME = 5000; // Backpressure poll time ms + private final BufferAllocator allocator; + private final FlightShimPluginManager pluginManager; + private final FlightShimConfig config; + private final ExecutorService shimExecutor; + + @Inject + public FlightShimProducer(BufferAllocator allocator, FlightShimPluginManager pluginManager, FlightShimConfig config, @ForFlightShimServer ExecutorService shimExecutor) + { + this.allocator = allocator.newChildAllocator("flight-shim", 0, Long.MAX_VALUE); + this.pluginManager = requireNonNull(pluginManager, "pluginManager is null"); + this.config = requireNonNull(config, "config is null"); + this.shimExecutor = requireNonNull(shimExecutor, "shimExecutor is null"); + } + + @Override + public void getStream(CallContext context, Ticket ticket, ServerStreamListener listener) + { + log.debug("Received GetStream request"); + shimExecutor.submit(() -> runGetStreamAsync(context, ticket, listener)); + } + + private void runGetStreamAsync(CallContext context, Ticket ticket, ServerStreamListener listener) + { + log.debug("Starting GetStream processing"); + int columnCount = 0; + int rowCount = 0; + int batchCount = 0; + try { + final BackpressureStrategy backpressureStrategy = new BackpressureStrategy.CallbackBackpressureStrategy(); + backpressureStrategy.register(listener); + + FlightShimRequest request = REQUEST_JSON_CODEC.fromJson(ticket.getBytes()); + log.debug("Request for connector: %s", request.getConnectorId()); + + FlightShimPluginManager.ConnectorHolder connectorHolder = pluginManager.getConnector(request.getConnectorId()); + requireNonNull(connectorHolder, format("Requested connector not loaded: %s", request.getConnectorId())); + + Connector connector = connectorHolder.getConnector(); + ConnectorSplit split = connectorHolder.getCodecSplit().fromJson(request.getSplitBytes()); + + List columnHandles = request.getColumnHandlesBytes().stream().map( + columnHandleBytes -> connectorHolder.getCodecColumnHandle().fromJson(columnHandleBytes) + ).collect(toImmutableList()); + + List columnsMetadata = columnHandles.stream() + .map(connectorHolder::getColumnMetadata).collect(toImmutableList()); + + ConnectorTableHandle connectorTableHandle = connectorHolder.getCodecTableHandle().fromJson(request.getTableHandleBytes()); + ConnectorTransactionHandle transactionHandle = connectorHolder.getCodecTransactionHandle().fromJson(request.getTransactionHandleBytes()); + Optional connectorTableLayoutHandle = + request.getTableLayoutHandleBytes().map( tableLayoutHandleBytes -> connectorHolder.getCodecTableLayoutHandle().fromJson(tableLayoutHandleBytes)); + TableHandle tableHandle = new TableHandle(new ConnectorId(request.getConnectorId()), connectorTableHandle, transactionHandle, connectorTableLayoutHandle); + + // Create a dummy session to load the connector + QueryIdGenerator queryIdGenerator = new QueryIdGenerator(); + Session session = Session.builder(createTestingSessionPropertyManager()) + .setQueryId(queryIdGenerator.createNextQueryId()) + .setIdentity(new Identity("user", Optional.empty())) + .setTimeZoneKey(DEFAULT_TIME_ZONE_KEY) + .setLocale(ENGLISH).build(); + ConnectorId connectorId = new ConnectorId(request.getConnectorId()); + ConnectorSession connectorSession = session.toConnectorSession(connectorId); + + ConnectorPageSourceProvider connectorPageSourceProvider = getConnectorPageSourceProvider(connector, connectorId); + ConnectorPageSource connectorPageSource = connectorPageSourceProvider.createPageSource( + transactionHandle, + connectorSession, + split, + null, + unmodifiableList(columnHandles), + new SplitContext(false), + new RuntimeStats()); + + try (ArrowBatchSource batchSource = new ArrowBatchSource(allocator, columnsMetadata, connectorPageSource, config.getMaxRowsPerBatch())) { + listener.setUseZeroCopy(true); + listener.start(batchSource.getVectorSchemaRoot()); + columnCount = batchSource.getVectorSchemaRoot().getFieldVectors().size(); + while (batchSource.nextBatch()) { + BackpressureStrategy.WaitResult waitResult; + while ((waitResult = backpressureStrategy.waitForListener(CLIENT_POLL_TIME)) == BackpressureStrategy.WaitResult.TIMEOUT) { + log.debug(format("Waiting for client to read from connector %s", request.getConnectorId())); + } + if (waitResult != BackpressureStrategy.WaitResult.READY) { + log.info(format("Read stopped from connector %s due to client wait result: %s", request.getConnectorId(), waitResult)); + break; + } + rowCount += batchSource.getVectorSchemaRoot().getRowCount(); + batchCount++; + listener.putNext(); + } + listener.completed(); + } + } + catch (Throwable t) { + final String message = "Error getting connector flight stream"; + log.error(t, message); + listener.error(CallStatus.INTERNAL.withCause(t).withDescription(format("%s [%s]", message, t)).toRuntimeException()); + } + finally { + log.debug(format("Processing GetStream completed [columns=%d, rows=%d, batches=%d]", columnCount, rowCount, batchCount)); + } + } + + private ConnectorPageSourceProvider getConnectorPageSourceProvider(Connector connector, ConnectorId connectorId) + { + ConnectorPageSourceProvider connectorPageSourceProvider = null; + try { + connectorPageSourceProvider = connector.getPageSourceProvider(); + requireNonNull(connectorPageSourceProvider, format("Connector %s returned a null page source provider", connectorId)); + } + catch (UnsupportedOperationException ignored) { + } + + if (connectorPageSourceProvider == null) { + ConnectorRecordSetProvider connectorRecordSetProvider = null; + try { + connectorRecordSetProvider = connector.getRecordSetProvider(); + requireNonNull(connectorRecordSetProvider, format("Connector %s returned a null record set provider", connectorId)); + } + catch (UnsupportedOperationException ignored) { + } + checkState(connectorRecordSetProvider != null, "Connector %s has neither a PageSource or RecordSet provider", connectorId); + connectorPageSourceProvider = new RecordPageSourceProvider(connectorRecordSetProvider); + } + return connectorPageSourceProvider; + } + + @Override + public void close() + { + shimExecutor.shutdownNow(); + pluginManager.stop(); + allocator.close(); + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java new file mode 100644 index 0000000000000..33f6753515b9e --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimRequest.java @@ -0,0 +1,86 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; +import java.util.Optional; + +import static java.util.Objects.requireNonNull; + +public class FlightShimRequest +{ + private final String connectorId; + private final byte[] splitBytes; + private final List columnHandlesBytes; + private final byte[] tableHandleBytes; + private final Optional tableLayoutHandleBytes; + private final byte[] transactionHandleBytes; + + @JsonCreator + public FlightShimRequest( + @JsonProperty("connectorId") String connectorId, + @JsonProperty("splitBytes") byte[] splitBytes, + @JsonProperty("columnHandlesBytes") List columnHandlesBytes, + @JsonProperty("tableHandleBytes") byte[] tableHandleBytes, + @JsonProperty("tableLayoutHandleBytes") Optional tableLayoutHandleBytes, + @JsonProperty("transactionHandleBytes") byte[] transactionHandleBytes) + { + this.connectorId = requireNonNull(connectorId, "connectorId is null"); + this.splitBytes = requireNonNull(splitBytes, "splitBytes is null"); + this.columnHandlesBytes = ImmutableList.copyOf(requireNonNull(columnHandlesBytes, "columnHandlesBytes is null")); + this.tableHandleBytes = requireNonNull(tableHandleBytes, "tableHandleBytes is null"); + this.tableLayoutHandleBytes = tableLayoutHandleBytes; + this.transactionHandleBytes = requireNonNull(transactionHandleBytes, "transactionHandleBytes is null"); + } + + @JsonProperty + public String getConnectorId() + { + return connectorId; + } + + @JsonProperty + public byte[] getSplitBytes() + { + return splitBytes; + } + + @JsonProperty + public List getColumnHandlesBytes() + { + return columnHandlesBytes; + } + + @JsonProperty + public byte[] getTableHandleBytes() + { + return tableHandleBytes; + } + + @JsonProperty + public Optional getTableLayoutHandleBytes() + { + return tableLayoutHandleBytes; + } + + @JsonProperty + public byte[] getTransactionHandleBytes() + { + return transactionHandleBytes; + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java new file mode 100644 index 0000000000000..581c5fa3a9867 --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServer.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.airlift.json.JsonModule; +import com.facebook.airlift.log.Logger; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.grpc.ContextPropagatingExecutorService; +import org.apache.arrow.memory.BufferAllocator; + +import java.io.File; +import java.util.Map; +import java.util.concurrent.ExecutorService; + +import static java.lang.String.format; + +public class FlightShimServer +{ + private FlightShimServer() + { + } + + public static Injector initialize(Map config, Module... extraModules) + { + Bootstrap app = new Bootstrap(ImmutableList.builder() + .add(new FlightShimModule()) + .add(new JsonModule()) + .add(extraModules) + .build()); + + if (config != null && !config.isEmpty()) { + // Required config was provided instead of vm option -Dconfig= + app.setRequiredConfigurationProperties(config); + } + + return app.initialize(); + } + + public static FlightServer start(Injector injector, FlightServer.Builder builder) + throws Exception + { + FlightShimPluginManager pluginManager = injector.getInstance(FlightShimPluginManager.class); + pluginManager.loadPlugins(); + pluginManager.loadCatalogs(); + + builder.allocator(injector.getInstance(BufferAllocator.class)); + FlightShimConfig config = injector.getInstance(FlightShimConfig.class); + + if (config.getServerName() == null || config.getServerPort() == null) { + throw new IllegalArgumentException("Required configuration 'flight-shim.server' and 'flight-shim.server.port' not set"); + } + + if (config.getServerSslEnabled()) { + builder.location(Location.forGrpcTls(config.getServerName(), config.getServerPort())); + if (config.getServerSSLCertificateFile() == null || config.getServerSSLKeyFile() == null) { + throw new IllegalArgumentException("'flight-shim.server-ssl-enabled' is enabled but 'flight-shim.server-ssl-certificate-file' or 'flight-shim.server-ssl-key-file' not set"); + } + File certChainFile = new File(config.getServerSSLCertificateFile()); + File privateKeyFile = new File(config.getServerSSLKeyFile()); + builder.useTls(certChainFile, privateKeyFile); + + // Check if client cert is provided for mTLS + if (config.getClientSSLCertificateFile() != null) { + File clientCertFile = new File(config.getClientSSLCertificateFile()); + builder.useMTlsClientVerification(clientCertFile); + } + } + else { + builder.location(Location.forGrpcInsecure(config.getServerName(), config.getServerPort())); + } + + ExecutorService executor = injector.getInstance(Key.get(ExecutorService.class, ForFlightShimServer.class)); + builder.executor(new ContextPropagatingExecutorService(executor)); + + FlightShimProducer producer = injector.getInstance(FlightShimProducer.class); + builder.producer(producer); + + FlightServer server = builder.build(); + server.start(); + + return server; + } + + public static void main(String[] args) + { + Logger log = Logger.get(FlightShimModule.class); + + final Map config; + if (System.getProperty("config") == null) { + log.info("FlightShim server using default config, override with -Dconfig="); + ImmutableMap.Builder configBuilder = ImmutableMap.builder(); + configBuilder.put("flight-shim.server", "localhost"); + configBuilder.put("flight-shim.server.port", String.valueOf(9443)); + configBuilder.put("flight-shim.server-ssl-certificate-file", "src/test/resources/certs/server.crt"); + configBuilder.put("flight-shim.server-ssl-key-file", "src/test/resources/certs/server.key"); + config = configBuilder.build(); + } + else { + log.info("FlightShim server using config from: " + System.getProperty("config")); + config = ImmutableMap.of(); + } + Injector injector = initialize(config); + + log.info("FlightShim server initializing"); + try (FlightServer server = start(injector, FlightServer.builder()); + FlightShimProducer producer = injector.getInstance(FlightShimProducer.class)) { + log.info(format("======== FlightShim Server started on port: %s ========", server.getPort())); + server.awaitTermination(); + } + catch (Throwable t) { + log.error(t); + System.exit(1); + } + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java new file mode 100644 index 0000000000000..4cc2cfafb10bd --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/FlightShimServerExecutionMBean.java @@ -0,0 +1,41 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.concurrent.ThreadPoolExecutorMBean; +import org.weakref.jmx.Managed; +import org.weakref.jmx.Nested; + +import javax.inject.Inject; + +import java.util.concurrent.ExecutorService; +import java.util.concurrent.ThreadPoolExecutor; + +public class FlightShimServerExecutionMBean +{ + private final ThreadPoolExecutorMBean executorMBean; + + @Inject + public FlightShimServerExecutionMBean(@ForFlightShimServer ExecutorService executor) + { + this.executorMBean = new ThreadPoolExecutorMBean((ThreadPoolExecutor) executor); + } + + @Managed + @Nested + public ThreadPoolExecutorMBean getExecutor() + { + return executorMBean; + } +} diff --git a/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java new file mode 100644 index 0000000000000..f100bc2359e4a --- /dev/null +++ b/presto-flight-shim/src/main/java/com/facebook/presto/flightshim/ForFlightShimServer.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import jakarta.inject.Qualifier; + +import java.lang.annotation.Retention; +import java.lang.annotation.Target; + +import static java.lang.annotation.ElementType.FIELD; +import static java.lang.annotation.ElementType.METHOD; +import static java.lang.annotation.ElementType.PARAMETER; +import static java.lang.annotation.RetentionPolicy.RUNTIME; + +@Retention(RUNTIME) +@Target({FIELD, PARAMETER, METHOD}) +@Qualifier +public @interface ForFlightShimServer +{ +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java new file mode 100644 index 0000000000000..621ca3fac0798 --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimJdbcPlugins.java @@ -0,0 +1,346 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; +import com.facebook.presto.plugin.jdbc.JdbcTableHandle; +import com.facebook.presto.plugin.jdbc.JdbcTransactionHandle; +import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.testng.annotations.Test; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.sql.Types; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.airlift.testing.Assertions.assertGreaterThan; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.Chars.isCharType; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.RealType.REAL; +import static com.facebook.presto.common.type.SmallintType.SMALLINT; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimeWithTimeZoneType.TIME_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.TimestampWithTimeZoneType.TIMESTAMP_WITH_TIME_ZONE; +import static com.facebook.presto.common.type.TinyintType.TINYINT; +import static com.facebook.presto.common.type.VarbinaryType.VARBINARY; +import static com.facebook.presto.common.type.Varchars.isVarcharType; +import static com.facebook.presto.util.ResourceFileUtils.getResourceFile; +import static java.lang.String.format; + +@Test(singleThreaded = true) +public abstract class AbstractTestFlightShimJdbcPlugins + extends AbstractTestFlightShimPlugins +{ + public static final JsonCodec COLUMN_HANDLE_JSON_CODEC = jsonCodec(JdbcColumnHandle.class); + public static final JsonCodec TABLE_HANDLE_JSON_CODEC = jsonCodec(JdbcTableHandle.class); + public static final JsonCodec TRANSACTION_HANDLE_JSON_CODEC = jsonCodec(JdbcTransactionHandle.class); + + protected abstract String getConnectionUrl(); + + @Override + protected Map getConnectorProperties() + { + Map connectorProperties = new HashMap<>(); + connectorProperties.putIfAbsent("connection-url", getConnectionUrl()); + return connectorProperties; + } + + @Test + void testJdbcSplitWithTupleDomain() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequestWithTupleDomain())); + + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + } + } + + assertGreaterThan(rowCount, 0); + } + } + + @Test + void testJdbcSplitWithAdditionalPredicate() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequestWithAdditionalPredicate())); + + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + } + } + + assertGreaterThan(rowCount, 0); + } + } + + @Test + public void testSelectColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN)); + } + + @Test + public void testDateColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, SHIPDATE_COLUMN)); + } + + @Test + public void testFloatingPointColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, QUANTITY_COLUMN, EXTENDEDPRICE_COLUMN)); + } + + private JdbcColumnHandle convertToJdbcColumnHandle(TpchColumnHandle columnHandle) + { + Type type = columnHandle.getType(); + return new JdbcColumnHandle( + getConnectorId(), + columnHandle.getColumnName(), + new JdbcTypeHandle(jdbcDataType(type), type.getDisplayName(), columnSize(type), 0), + type, + false, + Optional.empty()); + } + + @Override + protected FlightShimRequest createTpchTableRequest(int partNumber, int totalParts, List columnHandles) + { + Preconditions.checkArgument(totalParts == 1, "JDBC request must be of a single partition"); + String split = createJdbcSplit(getConnectorId(), "tpch", TPCH_TABLE); + byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8); + + ImmutableList.Builder columnBuilder = ImmutableList.builder(); + for (TpchColumnHandle columnHandle : columnHandles) { + columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(convertToJdbcColumnHandle(columnHandle))); + } + + JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE); + byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle); + byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle()); + + return new FlightShimRequest(getConnectorId(), splitBytes, columnBuilder.build(), tableHandleBytes, Optional.empty(), transactionHandleBytes); + } + + protected FlightShimRequest createTpchTableRequestWithTupleDomain() throws Exception + { + JdbcColumnHandle orderKeyHandle = convertToJdbcColumnHandle(getOrderKeyColumn()); + byte[] splitBytes = Files.readAllBytes(getResourceFile("split_tuple_domain.json").toPath()); + + List columnHandles = ImmutableList.of(orderKeyHandle); + ImmutableList.Builder columnBuilder = ImmutableList.builder(); + for (JdbcColumnHandle columnHandle : columnHandles) { + columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(columnHandle)); + } + + JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE); + byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle); + byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle()); + + return new FlightShimRequest( + getConnectorId(), + splitBytes, + columnBuilder.build(), + tableHandleBytes, + Optional.empty(), + transactionHandleBytes); + } + + protected FlightShimRequest createTpchTableRequestWithAdditionalPredicate() + throws IOException + { + // Query is: "SELECT orderkey FROM orders WHERE orderkey IN (1, 2, 3)" + JdbcColumnHandle orderKeyHandle = convertToJdbcColumnHandle(getOrderKeyColumn()); + byte[] splitBytes = Files.readAllBytes(getResourceFile("split_additional_predicate.json").toPath()); + + List columnHandles = ImmutableList.of(orderKeyHandle); + ImmutableList.Builder columnBuilder = ImmutableList.builder(); + for (JdbcColumnHandle columnHandle : columnHandles) { + columnBuilder.add(COLUMN_HANDLE_JSON_CODEC.toJsonBytes(columnHandle)); + } + + JdbcTableHandle tableHandle = new JdbcTableHandle(getConnectorId(), new SchemaTableName("tpch", TPCH_TABLE), getConnectorId(), "tpch", TPCH_TABLE); + byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle); + byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(new JdbcTransactionHandle()); + + return new FlightShimRequest( + getConnectorId(), + splitBytes, + columnBuilder.build(), + tableHandleBytes, + Optional.empty(), + transactionHandleBytes); + } + + protected static String removeDatabaseFromJdbcUrl(String jdbcUrl) + { + return jdbcUrl.replaceFirst("/[^/?]+([?]|$)", "/$1"); + } + + protected static String addDatabaseCredentialsToJdbcUrl(String jdbcUrl, String username, String password) + { + return jdbcUrl + (jdbcUrl.contains("?") ? "&" : "?") + + "user=" + username + "&password=" + password; + } + + protected static String createJdbcSplit(String connectorId, String schemaName, String tableName) + { + return format("{\n" + + " \"connectorId\" : \"%s\",\n" + + " \"schemaName\" : \"%s\",\n" + + " \"tableName\" : \"%s\",\n" + + " \"tupleDomain\" : {\n" + + " \"columnDomains\" : [ ]\n" + + " }\n" + + "}", connectorId, schemaName, tableName); + } + + static int jdbcDataType(Type type) + { + if (type.equals(BOOLEAN)) { + return Types.BOOLEAN; + } + if (type.equals(BIGINT)) { + return Types.BIGINT; + } + if (type.equals(INTEGER)) { + return Types.INTEGER; + } + if (type.equals(SMALLINT)) { + return Types.SMALLINT; + } + if (type.equals(TINYINT)) { + return Types.TINYINT; + } + if (type.equals(REAL)) { + return Types.REAL; + } + if (type.equals(DOUBLE)) { + return Types.DOUBLE; + } + if (type instanceof DecimalType) { + return Types.DECIMAL; + } + if (isVarcharType(type)) { + return Types.VARCHAR; + } + if (isCharType(type)) { + return Types.CHAR; + } + if (type.equals(VARBINARY)) { + return Types.VARBINARY; + } + if (type.equals(TIME)) { + return Types.TIME; + } + if (type.equals(TIME_WITH_TIME_ZONE)) { + return Types.TIME_WITH_TIMEZONE; + } + if (type.equals(TIMESTAMP)) { + return Types.TIMESTAMP; + } + if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return Types.TIMESTAMP_WITH_TIMEZONE; + } + if (type.equals(DATE)) { + return Types.DATE; + } + if (type instanceof ArrayType) { + return Types.ARRAY; + } + return Types.JAVA_OBJECT; + } + + static int columnSize(Type type) + { + if (type.equals(BIGINT)) { + return 19; // 2**63-1 + } + if (type.equals(INTEGER)) { + return 10; // 2**31-1 + } + if (type.equals(SMALLINT)) { + return 5; // 2**15-1 + } + if (type.equals(TINYINT)) { + return 3; // 2**7-1 + } + if (type instanceof DecimalType) { + return ((DecimalType) type).getPrecision(); + } + if (type.equals(REAL)) { + return 24; // IEEE 754 + } + if (type.equals(DOUBLE)) { + return 53; // IEEE 754 + } + if (isVarcharType(type)) { + return ((VarcharType) type).getLength(); + } + if (isCharType(type)) { + return ((CharType) type).getLength(); + } + if (type.equals(VARBINARY)) { + return Integer.MAX_VALUE; + } + if (type.equals(TIME)) { + return 8; // 00:00:00 + } + if (type.equals(TIME_WITH_TIME_ZONE)) { + return 8 + 6; // 00:00:00+00:00 + } + if (type.equals(DATE)) { + return 14; // +5881580-07-11 (2**31-1 days) + } + if (type.equals(TIMESTAMP)) { + return 15 + 8; + } + if (type.equals(TIMESTAMP_WITH_TIME_ZONE)) { + return 15 + 8 + 6; + } + return 0; + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java new file mode 100644 index 0000000000000..2341ef131cb0c --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/AbstractTestFlightShimPlugins.java @@ -0,0 +1,544 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.presto.Session; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.QualifiedObjectName; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.cost.StatsCalculator; +import com.facebook.presto.metadata.Metadata; +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.eventlistener.EventListener; +import com.facebook.presto.split.PageSourceManager; +import com.facebook.presto.split.SplitManager; +import com.facebook.presto.sql.expressions.ExpressionOptimizerManager; +import com.facebook.presto.sql.planner.ConnectorPlanOptimizerManager; +import com.facebook.presto.sql.planner.NodePartitioningManager; +import com.facebook.presto.sql.planner.sanity.PlanCheckerProviderManager; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.testing.TestingAccessControlManager; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tpch.TpchColumnHandle; +import com.facebook.presto.tpch.TpchTableHandle; +import com.facebook.presto.tpch.TpchTableLayoutHandle; +import com.facebook.presto.tpch.TpchTransactionHandle; +import com.facebook.presto.transaction.TransactionManager; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.inject.Injector; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CallOptions; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.FieldVector; +import org.intellij.lang.annotations.Language; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.ServerSocket; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.Lock; +import java.util.stream.Collectors; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCALE_FACTOR; +import static java.lang.String.format; + +@Test(singleThreaded = true) +public abstract class AbstractTestFlightShimPlugins + extends AbstractTestQueryFramework +{ + public static final JsonCodec REQUEST_JSON_CODEC = jsonCodec(FlightShimRequest.class); + public static final JsonCodec TPCH_COLUMN_JSON_CODEC = jsonCodec(TpchColumnHandle.class); + public static final JsonCodec TPCH_TABLE_HANDLE_JSON_CODEC = jsonCodec(TpchTableHandle.class); + public static final JsonCodec TPCH_TABLE_LAYOUT_HANDLE_JSON_CODEC = jsonCodec(TpchTableLayoutHandle.class); + public static final JsonCodec TPCH_TRANSACTION_HANDLE_JSON_CODEC = jsonCodec(TpchTransactionHandle.class); + public static final String TPCH_TABLE = "lineitem"; + public static final String ORDERKEY_COLUMN = "orderkey"; + public static final String LINENUMBER_COLUMN = "linenumber"; + public static final String LINESTATUS_COLUMN = "linestatus"; + public static final String EXTENDEDPRICE_COLUMN = "extendedprice"; + public static final String QUANTITY_COLUMN = "quantity"; + public static final String SHIPDATE_COLUMN = "shipdate"; + public static final String SHIPINSTRUCT_COLUMN = "shipinstruct"; + protected final List closables = new ArrayList<>(); + protected static final CallOption CALL_OPTIONS = CallOptions.timeout(300, TimeUnit.SECONDS); + protected BufferAllocator allocator; + protected FlightServer server; + private ArrowBlockBuilder blockBuilder; + + protected abstract String getConnectorId(); + + protected abstract String getPluginBundles(); + + @BeforeClass + public void setup() + throws Exception + { + ImmutableMap.Builder configBuilder = ImmutableMap.builder(); + configBuilder.put("flight-shim.server", "localhost"); + configBuilder.put("flight-shim.server.port", String.valueOf(findUnusedPort())); + configBuilder.put("flight-shim.server-ssl-certificate-file", "src/test/resources/certs/server.crt"); + configBuilder.put("flight-shim.server-ssl-key-file", "src/test/resources/certs/server.key"); + configBuilder.put("plugin.bundles", getPluginBundles()); + + // Allow for 3 batches using testing tpch db + configBuilder.put("flight-shim.max-rows-per-batch", String.valueOf(500)); + + Injector injector = FlightShimServer.initialize(configBuilder.build()); + + server = FlightShimServer.start(injector, FlightServer.builder()); + allocator = injector.getInstance(BufferAllocator.class); + + // Set test properties after catalogs have been loaded + FlightShimPluginManager pluginManager = injector.getInstance(FlightShimPluginManager.class); + Map connectorProperties = getConnectorProperties(); + pluginManager.setCatalogProperties(getConnectorId(), getConnectorId(), connectorProperties); + + TypeManager typeManager = injector.getInstance(TypeManager.class); + blockBuilder = new ArrowBlockBuilder(typeManager); + + // Make sure these resources close properly + closables.add(server); + closables.add(injector.getInstance(FlightShimProducer.class)); + closables.add(allocator); + } + + @Override + @AfterClass(alwaysRun = true) + public void close() + throws Exception + { + super.close(); + for (AutoCloseable closeable : closables) { + closeable.close(); + } + closables.clear(); + } + + protected Map getConnectorProperties() + { + return ImmutableMap.of(); + } + + protected int getExpectedTotalParts() + { + return 1; + } + + @Override + protected QueryRunner createQueryRunner() + { + return new FlightShimQueryRunner(getExpectedTotalParts()); + } + + /* + @Test + void testWithMtls() throws Exception + { + InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/server.crt"))); + InputStream clientCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/client.crt"))); + InputStream clientKey = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/client.key"))); + + Location location = Location.forGrpcTls("localhost", server.getPort()); + + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = FlightClient.builder(bufferAllocator, location).useTls().clientCertificate(clientCertificate, clientKey).trustedCertificates(trustedCertificate).build()) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(ImmutableList.of(getOrderKeyColumn())))); + + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + } + } + + assertGreaterThan(rowCount, 0); + } + } + + @Test + public void testSelectColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN)); + } + + @Test + public void testDateColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, SHIPDATE_COLUMN)); + } + + @Test + public void testFloatingPointColumns() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, QUANTITY_COLUMN, EXTENDEDPRICE_COLUMN)); + } + + protected void assertSelectQueryFromColumns(List tpchColumnNames) + { + @Language("SQL") String query = format("SELECT %s FROM %s", String.join(",", tpchColumnNames), TPCH_TABLE); + assertQuery(query); + }*/ + + protected List getHandlesFromSelectQuery(String sql) + { + String sqlLower = sql.toLowerCase(Locale.ENGLISH); + int start = sqlLower.indexOf("select"); + if (start < 0) { + throw new RuntimeException("Expected 'SELECT' in query: " + sql); + } + start += "select".length(); + int stop = sqlLower.indexOf("from"); + if (stop < start) { + throw new RuntimeException("Expected 'FROM' in query: " + sql); + } + String columnsString = sql.substring(start, stop); + List columns = Arrays.stream(columnsString.split(",")).map(String::trim).collect(Collectors.toList()); + + if (columns.isEmpty()) { + throw new RuntimeException("No columns found in query: " + sql); + } + + ImmutableList.Builder columnHandlesBuilder = ImmutableList.builder(); + for (String column : columns) { + switch (column) { + case ORDERKEY_COLUMN: + columnHandlesBuilder.add(getOrderKeyColumn()); + break; + case LINENUMBER_COLUMN: + columnHandlesBuilder.add(getLineNumberColumn()); + break; + case QUANTITY_COLUMN: + columnHandlesBuilder.add(getQuantityColumn()); + break; + case EXTENDEDPRICE_COLUMN: + columnHandlesBuilder.add(getExtendedPriceColumn()); + break; + case LINESTATUS_COLUMN: + columnHandlesBuilder.add(getLineStatusColumn()); + break; + case SHIPDATE_COLUMN: + columnHandlesBuilder.add(getShipDateColumn()); + break; + case SHIPINSTRUCT_COLUMN: + columnHandlesBuilder.add(getShipInstructColumn()); + break; + default: + throw new RuntimeException("Unknown column handle for: " + column); + } + } + + return columnHandlesBuilder.build(); + } + + protected void assertSelectQueryFromColumns(List tpchColumnNames) + { + @Language("SQL") String query = format("SELECT %s FROM %s", String.join(",", tpchColumnNames), TPCH_TABLE); + assertQuery(query); + } + + protected class FlightShimQueryRunner + implements QueryRunner + { + private int totalParts; + + public FlightShimQueryRunner(int totalParts) + { + this.totalParts = totalParts; + } + + public FlightShimQueryRunner() + { + this(1); + } + + @Override + public Session getDefaultSession() + { + return ((QueryRunner) getExpectedQueryRunner()).getDefaultSession(); + } + + protected int getTotalParts() + { + return totalParts; + } + + protected FlightShimRequest getRequestForPart(int partNumber, List columnHandles) + { + return createTpchTableRequest(partNumber, getTotalParts(), columnHandles); + } + + @Override + public MaterializedResult execute(String sql) + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + List columnHandles = getHandlesFromSelectQuery(sql); + List pages = new ArrayList<>(); + + for (int i = 0; i < getTotalParts(); i++) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(getRequestForPart(i, columnHandles))); + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + List blocks = new ArrayList<>(); + for (TpchColumnHandle columnHandle : columnHandles) { + FieldVector vector = stream.getRoot().getVector(columnHandle.getColumnName()); + Block block = blockBuilder.buildBlockFromFieldVector(vector, columnHandle.getType(), null); + blocks.add(block); + } + pages.add(new Page(stream.getRoot().getRowCount(), blocks.toArray(new Block[0]))); + } + } + } + + List types = columnHandles.stream().map(TpchColumnHandle::getType).collect(Collectors.toList()); + MaterializedResult.Builder resultBuilder = MaterializedResult.resultBuilder(getSession(), types); + resultBuilder.pages(pages); + return resultBuilder.build(); + } + catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public int getNodeCount() + { + return 0; + } + + @Override + public TransactionManager getTransactionManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public Metadata getMetadata() + { + throw new UnsupportedOperationException(); + } + + @Override + public SplitManager getSplitManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public PageSourceManager getPageSourceManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public NodePartitioningManager getNodePartitioningManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public ConnectorPlanOptimizerManager getPlanOptimizerManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public PlanCheckerProviderManager getPlanCheckerProviderManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public StatsCalculator getStatsCalculator() + { + throw new UnsupportedOperationException(); + } + + @Override + public List getEventListeners() + { + throw new UnsupportedOperationException(); + } + + @Override + public TestingAccessControlManager getAccessControl() + { + throw new UnsupportedOperationException(); + } + + @Override + public ExpressionOptimizerManager getExpressionManager() + { + throw new UnsupportedOperationException(); + } + + @Override + public MaterializedResult execute(Session session, String sql) + { + return execute(sql); + } + + @Override + public List listTables(Session session, String catalog, String schema) + { + throw new UnsupportedOperationException(); + } + + @Override + public boolean tableExists(Session session, String table) + { + return false; + } + + @Override + public void installPlugin(Plugin plugin) + { + } + + @Override + public void createCatalog(String catalogName, String connectorName, Map properties) + { + } + + @Override + public void loadFunctionNamespaceManager(String functionNamespaceManagerName, String catalogName, Map properties) + { + } + + @Override + public Lock getExclusiveLock() + { + return null; + } + + @Override + public void close() + { + } + } + + private int findUnusedPort() + throws IOException + { + try (ServerSocket socket = new ServerSocket(0)) { + return socket.getLocalPort(); + } + } + + protected TpchColumnHandle getOrderKeyColumn() + { + return new TpchColumnHandle(ORDERKEY_COLUMN, BigintType.BIGINT); + } + + protected TpchColumnHandle getLineNumberColumn() + { + return new TpchColumnHandle(LINENUMBER_COLUMN, IntegerType.INTEGER); + } + + protected TpchColumnHandle getLineStatusColumn() + { + return new TpchColumnHandle(LINESTATUS_COLUMN, VarcharType.createVarcharType(32)); + } + + protected TpchColumnHandle getQuantityColumn() + { + return new TpchColumnHandle(QUANTITY_COLUMN, DoubleType.DOUBLE); + } + + protected TpchColumnHandle getExtendedPriceColumn() + { + return new TpchColumnHandle(EXTENDEDPRICE_COLUMN, DoubleType.DOUBLE); + } + + protected TpchColumnHandle getShipDateColumn() + { + return new TpchColumnHandle(SHIPDATE_COLUMN, DateType.DATE); + } + + protected TpchColumnHandle getShipInstructColumn() + { + return new TpchColumnHandle(SHIPINSTRUCT_COLUMN, VarcharType.createVarcharType(64)); + } + + protected FlightShimRequest createTpchTableRequest(int partNumber, int totalParts, List columnHandles) + { + String split = createTpchSplit(TPCH_TABLE, partNumber, totalParts); + byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8); + + ImmutableList.Builder columnBuilder = ImmutableList.builder(); + for (TpchColumnHandle columnHandle : columnHandles) { + columnBuilder.add(TPCH_COLUMN_JSON_CODEC.toJsonBytes(columnHandle)); + } + + TpchTableHandle tableHandle = new TpchTableHandle(TPCH_TABLE, 1.0); + byte[] tableHandleBytes = TPCH_TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle); + + byte[] transactionHandleBytes = TPCH_TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(TpchTransactionHandle.INSTANCE); + + return new FlightShimRequest(getConnectorId(), splitBytes, columnBuilder.build(), tableHandleBytes, Optional.empty(), transactionHandleBytes); + } + + protected static String createTpchSplit(String tableName, int partNumber, int totalParts) + { + return format("{\n" + + " \"tableHandle\" : {\n" + + " \"tableName\" : \"%s\",\n" + + " \"scaleFactor\" : %.2f\n" + + " },\n" + + " \"partNumber\" : %d,\n" + + " \"totalParts\" : %d,\n" + + " \"addresses\" : [ \"127.0.0.1:9999\" ],\n" + + " \"predicate\" : {\n" + + " \"columnDomains\" : [ ]\n" + + " }\n" + + "}", tableName, TINY_SCALE_FACTOR, partNumber, totalParts); + } + + protected static FlightClient createFlightClient(BufferAllocator allocator, int serverPort) throws IOException + { + InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/certs/server.crt"))); + Location location = Location.forGrpcTls("localhost", serverPort); + return FlightClient.builder(allocator, location).useTls().trustedCertificates(trustedCertificate).build(); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java new file mode 100644 index 0000000000000..631c3bf7c5f48 --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimConfig.java @@ -0,0 +1,70 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.google.common.collect.ImmutableMap; +import org.testng.annotations.Test; + +import java.util.Map; + +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertFullMapping; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.assertRecordedDefaults; +import static com.facebook.airlift.configuration.testing.ConfigAssertions.recordDefaults; + +public class TestFlightShimConfig +{ + @Test + public void testDefaults() + { + assertRecordedDefaults(recordDefaults(FlightShimConfig.class) + .setServerName(null) + .setServerPort(null) + .setServerSslEnabled(true) + .setServerSSLCertificateFile(null) + .setServerSSLKeyFile(null) + .setReadThreadPoolSize(16) + .setClientSSLCertificateFile(null) + .setClientSSLKeyFile(null) + .setMaxRowsPerBatch(10000)); + } + + @Test + public void testExplicitPropertyMappings() + { + Map properties = new ImmutableMap.Builder() + .put("server", "localhost") + .put("server.port", "9432") + .put("server-ssl-enabled", "false") + .put("server-ssl-certificate-file", "/some/path/server.cert") + .put("server-ssl-key-file", "/some/path/server.key") + .put("client-ssl-certificate-file", "/some/path/client.cert") + .put("client-ssl-key-file", "/some/path/client.key") + .put("thread-pool-size", "8") + .put("max-rows-per-batch", "1000") + .build(); + + FlightShimConfig expected = new FlightShimConfig() + .setServerName("localhost") + .setServerPort(9432) + .setServerSslEnabled(false) + .setServerSSLCertificateFile("/some/path/server.cert") + .setServerSSLKeyFile("/some/path/server.key") + .setClientSSLCertificateFile("/some/path/client.cert") + .setClientSSLKeyFile("/some/path/client.key") + .setReadThreadPoolSize(8) + .setMaxRowsPerBatch(1000); + + assertFullMapping(properties, expected); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java new file mode 100644 index 0000000000000..c6d9d9eedaefd --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimMySqlPlugin.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testcontainers.mysql.MySQLContainer; +import org.testng.annotations.Test; + +import static com.facebook.presto.plugin.mysql.MySqlQueryRunner.createMySqlQueryRunner; + +@Test +public class TestFlightShimMySqlPlugin + extends AbstractTestFlightShimJdbcPlugins +{ + private static final String DATABASE_USERNAME = "testuser"; + private static final String DATABASE_PASSWORD = "testpass"; + private final MySQLContainer mysqlContainer; + + public TestFlightShimMySqlPlugin() + { + this.mysqlContainer = new MySQLContainer("mysql:8.0") + .withDatabaseName("tpch") + .withUsername(DATABASE_USERNAME) + .withPassword(DATABASE_PASSWORD); + mysqlContainer.start(); + closables.add(mysqlContainer); + } + + @Override + protected String getConnectorId() + { + return "mysql"; + } + + @Override + protected String getConnectionUrl() + { + return addDatabaseCredentialsToJdbcUrl( + removeDatabaseFromJdbcUrl(mysqlContainer.getJdbcUrl()), + DATABASE_USERNAME, + DATABASE_PASSWORD); + } + + @Override + protected String getPluginBundles() + { + return "../presto-mysql/pom.xml"; + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() + throws Exception + { + return createMySqlQueryRunner(mysqlContainer.getJdbcUrl(), ImmutableMap.of(), TpchTable.getTables()); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java new file mode 100644 index 0000000000000..af22d32c5eeb5 --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimPostgresPlugin.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testcontainers.postgresql.PostgreSQLContainer; +import org.testng.annotations.Test; + +import static com.facebook.presto.plugin.postgresql.PostgreSqlQueryRunner.createPostgreSqlQueryRunner; + +@Test +public class TestFlightShimPostgresPlugin + extends AbstractTestFlightShimJdbcPlugins +{ + private static final String DATABASE_USERNAME = "testuser"; + private static final String DATABASE_PASSWORD = "testpass"; + private final PostgreSQLContainer postgresContainer; + + public TestFlightShimPostgresPlugin() + { + this.postgresContainer = new PostgreSQLContainer("postgres:14") + .withDatabaseName("tpch") + .withUsername(DATABASE_USERNAME) + .withPassword(DATABASE_PASSWORD); + postgresContainer.start(); + closables.add(postgresContainer); + } + + @Override + protected String getConnectorId() + { + return "postgresql"; + } + + @Override + protected String getConnectionUrl() + { + return addDatabaseCredentialsToJdbcUrl(postgresContainer.getJdbcUrl(), DATABASE_USERNAME, DATABASE_PASSWORD); + } + + @Override + protected String getPluginBundles() + { + return "../presto-postgresql/pom.xml"; + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() + throws Exception + { + return createPostgreSqlQueryRunner(postgresContainer.getJdbcUrl(), ImmutableMap.of(), TpchTable.getTables()); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java new file mode 100644 index 0000000000000..636f7f2006584 --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimRequest.java @@ -0,0 +1,94 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.plugin.jdbc.JdbcColumnHandle; +import com.facebook.presto.plugin.jdbc.JdbcTableHandle; +import com.facebook.presto.plugin.jdbc.JdbcTransactionHandle; +import com.facebook.presto.plugin.jdbc.JdbcTypeHandle; +import com.facebook.presto.spi.SchemaTableName; +import com.google.common.collect.ImmutableList; +import org.testng.annotations.Test; + +import java.nio.charset.StandardCharsets; +import java.sql.Types; +import java.util.Optional; + +import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.COLUMN_HANDLE_JSON_CODEC; +import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.TABLE_HANDLE_JSON_CODEC; +import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.TRANSACTION_HANDLE_JSON_CODEC; +import static com.facebook.presto.flightshim.AbstractTestFlightShimJdbcPlugins.createJdbcSplit; +import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.LINESTATUS_COLUMN; +import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.ORDERKEY_COLUMN; +import static com.facebook.presto.flightshim.AbstractTestFlightShimPlugins.TPCH_TABLE; +import static org.testng.Assert.assertEquals; +import static org.testng.internal.junit.ArrayAsserts.assertArrayEquals; + +public class TestFlightShimRequest +{ + @Test + public void testJsonRoundTrip() + { + FlightShimRequest expected = createTpchCustomerRequest(); + String json = AbstractTestFlightShimPlugins.REQUEST_JSON_CODEC.toJson(expected); + FlightShimRequest copy = AbstractTestFlightShimPlugins.REQUEST_JSON_CODEC.fromJson(json); + assertEquals(copy.getConnectorId(), expected.getConnectorId()); + assertEquals(copy.getSplitBytes(), expected.getSplitBytes()); + assertArrayEquals(copy.getColumnHandlesBytes().toArray(), expected.getColumnHandlesBytes().toArray()); + assertEquals(copy.getTableHandleBytes(), expected.getTableHandleBytes()); + assertEquals(copy.getTableLayoutHandleBytes(), expected.getTableLayoutHandleBytes()); + assertEquals(copy.getTransactionHandleBytes(), expected.getTransactionHandleBytes()); + } + + FlightShimRequest createTpchCustomerRequest() + { + String split = createJdbcSplit("postgresql", "tpch", TPCH_TABLE); + byte[] splitBytes = split.getBytes(StandardCharsets.UTF_8); + + JdbcColumnHandle custkeyHandle = new JdbcColumnHandle( + "postgresql", + ORDERKEY_COLUMN, + new JdbcTypeHandle(Types.BIGINT, "bigint", 8, 0), + BigintType.BIGINT, + false, + Optional.empty()); + byte[] custkeyBytes = COLUMN_HANDLE_JSON_CODEC.toJsonBytes(custkeyHandle); + + JdbcColumnHandle nameHandle = new JdbcColumnHandle( + "postgresql", + LINESTATUS_COLUMN, + new JdbcTypeHandle(Types.VARCHAR, "varchar", 32, 0), + VarcharType.createVarcharType(32), + false, + Optional.empty()); + byte[] nameBytes = COLUMN_HANDLE_JSON_CODEC.toJsonBytes(nameHandle); + + JdbcTableHandle tableHandle = new JdbcTableHandle("postgresql", new SchemaTableName("tpch", TPCH_TABLE), "postgresql", "tpch", TPCH_TABLE); + byte[] tableHandleBytes = TABLE_HANDLE_JSON_CODEC.toJsonBytes(tableHandle); + + JdbcTransactionHandle transactionHandle = new JdbcTransactionHandle(); + byte[] transactionHandleBytes = TRANSACTION_HANDLE_JSON_CODEC.toJsonBytes(transactionHandle); + + return new FlightShimRequest( + "postgresql", + splitBytes, + ImmutableList.of(custkeyBytes, nameBytes), + tableHandleBytes, + Optional.empty(), + transactionHandleBytes + ); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java new file mode 100644 index 0000000000000..41751a7501c1f --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimSingleStorePlugin.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.presto.plugin.singlestore.DockerizedSingleStoreServer; +import com.facebook.presto.plugin.singlestore.SingleStoreQueryRunner; +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.google.common.collect.ImmutableMap; +import io.airlift.tpch.TpchTable; +import org.testng.annotations.Test; + +@Test +public class TestFlightShimSingleStorePlugin + extends AbstractTestFlightShimJdbcPlugins +{ + private final DockerizedSingleStoreServer singleStoreServer; + + public TestFlightShimSingleStorePlugin() + { + this.singleStoreServer = new DockerizedSingleStoreServer(); + closables.add(singleStoreServer); + } + + @Override + protected String getConnectorId() + { + return "singlestore"; + } + + @Override + protected String getConnectionUrl() + { + return singleStoreServer.getJdbcUrl(); + } + + @Override + protected String getPluginBundles() + { + return "../presto-singlestore/pom.xml"; + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() + throws Exception + { + return SingleStoreQueryRunner.createSingleStoreQueryRunner(singleStoreServer, ImmutableMap.of(), TpchTable.getTables()); + } +} diff --git a/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java new file mode 100644 index 0000000000000..b5d097c004076 --- /dev/null +++ b/presto-flight-shim/src/test/java/com/facebook/presto/flightshim/TestFlightShimTpchPlugin.java @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.presto.flightshim; + +import com.facebook.presto.Session; +import com.facebook.presto.testing.ExpectedQueryRunner; +import com.facebook.presto.testing.LocalQueryRunner; +import com.facebook.presto.tpch.TpchConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.testng.annotations.Test; + +import java.util.concurrent.CancellationException; + +import static com.facebook.airlift.testing.Assertions.assertGreaterThan; +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +@Test(singleThreaded = true) +public class TestFlightShimTpchPlugin + extends AbstractTestFlightShimPlugins +{ + public static final int SPLITS_PER_NODE = 4; + + protected String getConnectorId() + { + return "tpch"; + } + + protected String getPluginBundles() + { + return "../presto-tpch/pom.xml"; + } + + @Override + protected ExpectedQueryRunner createExpectedQueryRunner() + { + Session session = testSessionBuilder() + .setCatalog("tpch") + .setSchema("tiny") + .build(); + return new LocalQueryRunner(session); + } + + @Override + protected int getExpectedTotalParts() + { + return SPLITS_PER_NODE; + } + + @Override + protected void createTables() + { + ((LocalQueryRunner) getExpectedQueryRunner()).createCatalog("tpch", new TpchConnectorFactory(SPLITS_PER_NODE), ImmutableMap.of()); + } + + @Test + public void testConnectorGetStream() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn())))); + + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + } + } + + assertGreaterThan(rowCount, 0); + } + } + + @Test + public void testStopStreamAtLimit() throws Exception + { + int rowLimit = 500; + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn())))); + + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + if (rowCount >= rowLimit) { + break; + } + } + } + + assertEquals(rowCount, rowLimit); + } + } + + @Test + public void testCancelStream() throws Exception + { + String cancelMessage = "READ COMPLETE"; + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("connector-test-client", 0, Long.MAX_VALUE); + FlightClient client = createFlightClient(bufferAllocator, server.getPort())) { + Ticket ticket = new Ticket(REQUEST_JSON_CODEC.toJsonBytes(createTpchTableRequest(0, 1, ImmutableList.of(getOrderKeyColumn())))); + + // Cancel stream explicitly + int rowCount = 0; + try (FlightStream stream = client.getStream(ticket, CALL_OPTIONS)) { + while (stream.next()) { + rowCount += stream.getRoot().getRowCount(); + if (rowCount >= 500) { + stream.cancel("Cancel", new CancellationException(cancelMessage)); + break; + } + } + + // Drain any remaining messages to properly release messages + try { + do { + Thread.sleep(100); + } + while (stream.next()); + } + catch (final Exception e) { + assertNotNull(e.getCause()); + assertEquals(e.getCause().getMessage(), cancelMessage); + } + } + + assertGreaterThan(rowCount, 0); + } + } + + @Test + public void testReadFromMultipleSplits() + { + assertSelectQueryFromColumns(ImmutableList.of(ORDERKEY_COLUMN, LINENUMBER_COLUMN, SHIPINSTRUCT_COLUMN)); + } +} diff --git a/presto-flight-shim/src/test/resources/certs/ca.crt b/presto-flight-shim/src/test/resources/certs/ca.crt new file mode 100644 index 0000000000000..3d3f516aa0148 --- /dev/null +++ b/presto-flight-shim/src/test/resources/certs/ca.crt @@ -0,0 +1,33 @@ +-----BEGIN CERTIFICATE----- +MIIFozCCA4ugAwIBAgIUTnZhzGzxAzaTJ5DDQnmBp8jvBUMwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjIxNjA2MDNaGA8yMTI1MDQyODE2MDYwM1owYDELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJvb3RDQTCCAiIwDQYJ +KoZIhvcNAQEBBQADggIPADCCAgoCggIBALmNRxPZt9vvFvIQcdCM3xcAOiTepahR +zsK86AwzXbwcWwOw+rQP0iXiMGjP9wJy0C1hU+j7F+gu5+d5j3DK9/iTlIqByjN+ +gffsZDSBBZBozJxARBWKaEFKRKbhCmr+v92ZN1KMYhhwc5W3SRp6y3kvsWYM4/JE +pRzYUy0505g3p7PnpQzZZSSzV33+9fn5ggafVB9FSJqOuBKUBTITt+Rf1pNV2ZGC +f2X39KOIIRz2Hz1L6hdqeMQN9V5KQ7C6oPfl+ho97JHaZirtHr1XZJ6YQRy7Mldp +jOPj4MMKqi8m+HhyCSznIDNiMp0OACX3CZ5RmKRYfnunOuw+fvI28HyOXMgfWKa2 +o/2/YlAzNrD50hJPwQMTlKWJm2gY2n5x2FWT6/8aXeCnsJALK6Dj6Ax0wxkAFfIo +76REHlXX2fIiz0cciYwYtvwhjp5efqX22B7LDkhu7fJ42yUd6g3crbmGM8OOoXgL +w3MWx30FatTyDT8un2ZvVDJEADW3+WWtyrZWHFMVFrVnN7Di4MAuWsRIVtuO6PEV +6pPS55NBmvKzWAoYBmpH103GlIxvZ6CCTeKqFbcrI77smrIO5CLmzl5yjb/urmy4 +GLYHM+EPB5pJKQO8g6fyM369mhEjBxt/RJGv9E0Jw7KmnHn5V2qSn4Yi41dUp7Cp +pQVIaYlJGn65AgMBAAGjUzBRMB0GA1UdDgQWBBRjt9KoJO8CO/Cy/xTJztrbdUlv +9DAfBgNVHSMEGDAWgBRjt9KoJO8CO/Cy/xTJztrbdUlv9DAPBgNVHRMBAf8EBTAD +AQH/MA0GCSqGSIb3DQEBCwUAA4ICAQApizBpSaZDYjBhOxD6+B5em18OIf5ZIjHk +iAhuppR+HaqyAHJgCO707dZVmFz9ECUZI9mvKcmj+h0Wh/mK4cSiDunFB9yUr67U +wV5F2/u/JAAq6VsbdrDiZPUwET8U9ai14LMEgPPF+Zif+wnopRav5lbiPoJVUjqr +wVoP2AHijIP46YCWwXqOTJMC79ccUMBZeDwF4bOquIADLmEnANp6fiMI+eE6OLFs +fDtjFqybRUZqzewv2lpzH2ZYEYk4bIk76TGkOYtrwJ+iQj77ZZFSBW5zkry/zaG3 +/5Ufjv65T9Zr1jmIMigcmCHwNsCLOYzIKjRaiuLGs4B9s8SGEauTRhP0dG5ndWPw +50NeSNJr37MHdKky44WAFlAk9BAKlghOaC5m2RyMof8DwYKPEe5epe1wBotiPqSX +doaZvch6wkuo8xvFKqH6rBTWJLMwuFt7m3XrGqGYlE+1gvuEfn7ZGzG00sl218mZ +MfsLqJfft92ARC1/qJvUFr5mM6SV4eQeTl1tAtv6Xfczr+3/iqc5gbeG25dXQclO +y1qIKthAoXFq6rAZ+bvfASiVV1OQS76nWSiYYS1dDPQJ/g4aawOkUYr9OjS5HNQ+ +rNAcLB+I0oaZQzZ85098qAVAJ76eFATb4ieDK6m0j6Fq5ddwrlYzxEEr7TscfHW2 +zPCtinUe0w== +-----END CERTIFICATE----- diff --git a/presto-flight-shim/src/test/resources/certs/client.crt b/presto-flight-shim/src/test/resources/certs/client.crt new file mode 100644 index 0000000000000..0e2bc9ff7f662 --- /dev/null +++ b/presto-flight-shim/src/test/resources/certs/client.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEkDCCAnigAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQYwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjIxNjA4MzFaGA8yMTI1MDQyODE2MDgzMVowXjELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MQ8wDQYDVQQDDAZjbGllbnQwggEiMA0GCSqG +SIb3DQEBAQUAA4IBDwAwggEKAoIBAQC8DK/RdBF6I+k+DMGrjhMMBCnpPNwJtzJU +uXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIMzS4HBU/CHY26eR2ItiWq +DoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7DJC2xTTsyEjBzD+MnIB7T +ja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7BeqGuFS8bEgtLwwCzRxPK +kG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5MulEIjuxII/dTuoo+/uirN0 +wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPoHpWZsYIHAgMBAAGjQjBA +MB0GA1UdDgQWBBRW/6j5rNwnwTBAXY5igXm3ETzv/DAfBgNVHSMEGDAWgBRjt9Ko +JO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAVLpfZDgkL1Dz/+Fq +vSl2IoxOFNd2DTa6yM8/1wpvMVTA024lp0ttyoz0o1621hTRexcXTimZqUNAtPV6 +Gwmb2ACLN4XtLk9QT9XjDtWKzPxCJ+ze7rrhj1jYqv9yUebdkJoMKfcbwYi0gtpt +HlaJqNoKgzZxOCGhTtdS3ypb9nDCyx3fmFk5mIYfzEszoMmqNL006ANlJ0IKkFZj +vUkkFyMGLerInmTDRjDLkUCkNKaJUYjZhf/FNwVtc1A9a/bDJMEYVog+CY0dpXKb +1IGaXzB4ewhuQKuhb/LCZT1pNm/cGCY2cRGFy9EVAuZ5FV0ajh1HYwdpGDzucoUD +UaTHIK40E2/kZorJa37Xyn7Lekgun6YpfOudBkKg5mlV6qoL9W/lTZPjMHs2ufvW +/A5S4okR4JmhC44TMgAv90MU9yEP90OkzW6egatBShWySJ3Bn5W+ebQSwJ38wgTy +e6j5jWh7xiiPC4TJbSXVMGEfJw/c2hx4R/83MqBhVLPfoapaUCDUWniv6n7zl5ML +k7WIZzXSK212/H+eVXFJ1Gq6zztOPkN9QGgr+dbsCzdLWPJzZDmI7+lgT2EKxIV/ +VMtHVOe2bLkePTNA2+vXQhs0p7JDtzyATAyMdhJwPljt8X+HJxEoAP4Dk6PcK3Bd +E92yOuW2jl6FzgqKqtVQvxIiBgc= +-----END CERTIFICATE----- diff --git a/presto-flight-shim/src/test/resources/certs/client.key b/presto-flight-shim/src/test/resources/certs/client.key new file mode 100644 index 0000000000000..72434bc294aff --- /dev/null +++ b/presto-flight-shim/src/test/resources/certs/client.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC8DK/RdBF6I+k+ +DMGrjhMMBCnpPNwJtzJUuXYcHFYEdBnHY/rpjk/fi+7jD8bppynCZPakrDX+5VIM +zS4HBU/CHY26eR2ItiWqDoDkPAlCdgeKIGNYYEvVSuUW5YQX6fuD8PfCpCP5zK7D +JC2xTTsyEjBzD+MnIB7Tja9/22Djo2Ib2l9BEBOD+k79caPFtSqDQVdS5JLJ/P7B +eqGuFS8bEgtLwwCzRxPKkG64rXb9F0IErGwjXi/70BA24EW0uGAzzeY5Pnlx5Mul +EIjuxII/dTuoo+/uirN0wxURKSzTMyzzoJqGX9Ng9+Z+VqWFrnqckcmFcD4T0sPo +HpWZsYIHAgMBAAECggEAViw6JXFa0O3D5HtUBJmGgOsniYoqCwm4NrsGNLuHb2ME +rSpTwNNGJtqpDcQdEtVXfY1muO9xjuznPJaJkQ4ODpYcbGcz8YIGoHck+XHJjHsp +2VIeNFFsbsFzWZqzfYHrj/rMjpVJJx90tlfN2IHbroZHTXLqVPOTLL6wvZZ6P9XF +zqpWABKOaEDbenhSFFZeF5KR3NG9HSTm4YLuekumkH+QgrveDfDwXG4hAHqg836o +OF3NPaij6VlSR18nuyW0wMs/Ceu13P+GALqHmz98pFyVgHWQFryL9IccvJQDyEnt +saeG4IAVlJbZDGTnRgANLhpwBr7XhMG1aK+wmOMRgQKBgQDkcatiATlr9L8gfnHb +6pmX//AZLdXuQLXfuTvu638Brhm770noLgfIC+HIp5kCHxT2Xj5Vn+MSnYD6R6Wh +chApRKJUdsuz1iOq23YJjvsSLWCGpl9IxR7WY27uGOPIjQcOd1PRbkCq9AgUJwyn +ryca3sbYh/XQOWGLbJNIQs/S/QKBgQDSu6PVeMaS3276KblvGIvvaSAQDQWxXcC+ +sA4CBmvjzx3xx5GAox/w7tcKmK/KQxNhaYy6N7xLc1YUJ9FbnT2PZQJhtP2d2Gat +Zre/+Qa+u84cR5hj9EI+B8FjW7D/psEj16KjHCds/SET6ngPM+RdB4N9daVFCurt +p0f717yiUwKBgBTJDun06I+dDkLbnmp/FwiQff0cgYmTE7lOdliPzteNSsQhypy4 +i3a1Ng72yOI7h8G+43cQ/C02bYTYPgbJhRTsLMT4piIvysECBORrwQZvYIf/3U2W +ue6Rz4cUdq1Jv6meS98TZAjp+U40G1+qfSlhub/75u7SOcDg2SnLAnPVAoGBAIOO +EmRE5qpwA+b2P0Ykq89E8Hg0uPYWEiq427XV7mqkNQxoSuRkcZ9Ga0a5NRzurN2m +N+1UuB7eHMGubdtkmTa4lzkJ9T4iB09/DX0x6E0QD0bGR1M2/FefHdJ6PlAK+Q34 +Ixbyj4ZRq+G0AUl0Wr7c3vBmjktA2pKMWLrW3nLzAoGBAKTl7qX6CD42gAJuT5Hp +rrXqlppVIyRvuXzXtX/Xq81IUHlBgS/t9HPyqDzmTKfxD8540kI+15bWPDHSJxiQ +ccqPaKyXhBXstDwGmlPKVzJUxk0dz5NHs+8gItUDOg78pM3siXN7vW9XBCH7mCDA +4zet/C0YCAiFVT+ipMoXy8Nc +-----END PRIVATE KEY----- diff --git a/presto-flight-shim/src/test/resources/certs/server.crt b/presto-flight-shim/src/test/resources/certs/server.crt new file mode 100644 index 0000000000000..8fb80c1bba51d --- /dev/null +++ b/presto-flight-shim/src/test/resources/certs/server.crt @@ -0,0 +1,27 @@ +-----BEGIN CERTIFICATE----- +MIIEkzCCAnugAwIBAgIUUIByH9V7DhUf5n3Qd7tPxpPixQgwDQYJKoZIhvcNAQEL +BQAwYDELMAkGA1UEBhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5 +MQ4wDAYDVQQKDAVNeU9yZzEPMA0GA1UECwwGTXlVbml0MREwDwYDVQQDDAhNeVJv +b3RDQTAgFw0yNTA1MjMwMTQzNTBaGA8yMTI1MDQyOTAxNDM1MFowYTELMAkGA1UE +BhMCVVMxDjAMBgNVBAgMBVN0YXRlMQ0wCwYDVQQHDARDaXR5MQ4wDAYDVQQKDAVN +eU9yZzEPMA0GA1UECwwGTXlVbml0MRIwEAYDVQQDDAlsb2NhbGhvc3QwggEiMA0G +CSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCtkJ+r1F8+YOVuwWLxbGVsJKw3BESh +tCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h0xVlHES7eTAD+NPl +WHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqAyPZjUjhkAZPgrT1/ +8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMOkSYNsanb+JdEryev +goju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDygZoO8Qr/YrGvXXYqt +0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3QMRWGpsvAgMBAAGj +QjBAMB0GA1UdDgQWBBRf0XdDhjNSIZQsvSFbe+hdsydllzAfBgNVHSMEGDAWgBRj +t9KoJO8CO/Cy/xTJztrbdUlv9DANBgkqhkiG9w0BAQsFAAOCAgEAec7y1Odyg3x/ +Uj0jfZWYNE0BuR114UVwEYhxFi9tRAGxjTlsl6ATCSYBWU+fwUkC29C2r3bu59fp +/KvYPRrwGPyOtXwHR2cmwJ7QrUhlIwPipWO4Kal3/EKWvrV9rzdnOd2QYqEMS6f0 +UixGLcT5p5KmEsH3W8Y9Uk94g/z12ZgdGeKKyY7hWnu1d47b2TS4oiRx+d6AacAD +1BzJRUhDjS2Vfe2cnpOqBHJWyCT1BxsfxKAc3rLa6JznbulHQPCE5WWBolHb8Tob +Yf32sIJydOcWU+zJ9VsGuglQ8dQInMemW8k5y48ACqHb00lAoucGJ3Izy9tJiBeU +C8TcmRxQSjoMGhFGNIjBAvXak0UzobUKE3YyABBbUdWLofs0N325K1aSga4vXmlO +OPzP4FWMLZyDicVMUGLP9jcyeb/gFMzjoU2En59gRNVDNNo01Lyj9MhGHF/jvV7p +Zf782GvXMT/NSPtXqmABSV/Svy70vXogeQxTii9YOZePKWQnFhcEwin/7d9bf4d/ +nUqtzDFb5FiDeFA+H9FyeNaeub3OtvsZUacAVDCT1t9/8uShjJo34v8WbJegepcY +Tpdkm4x0DWvv/QT3JBqC0wprsmBVzuqTJ/jBxYye02bds1bM7xePAuVOwuSb6Azg +gLrweiDMakgmZSkGgUonnjblYbHGTHw= +-----END CERTIFICATE----- diff --git a/presto-flight-shim/src/test/resources/certs/server.key b/presto-flight-shim/src/test/resources/certs/server.key new file mode 100644 index 0000000000000..0e49bebd60beb --- /dev/null +++ b/presto-flight-shim/src/test/resources/certs/server.key @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQCtkJ+r1F8+YOVu +wWLxbGVsJKw3BEShtCsU+IXHJeaNJdBr59B/4h5WM37wOnnecmyEZTh47FXXkb5h +0xVlHES7eTAD+NPlWHufGJ9PR1kvQyZ0fyNRFXLzUID/dl7atHBtlrqE5Bhg7xqA +yPZjUjhkAZPgrT1/8+gYmbWPbw3Ba3+XRupq3Kn+EVVJi7wk4cj8jf6g1aex6sMO +kSYNsanb+JdEryevgoju+EtHgCHL6cB0eJs8PfMWiibgWLE2pkI0bdbGjTNVDDyg +ZoO8Qr/YrGvXXYqt0D7IKSUiO8bnrvZh6ITPEcQ3ePQRGEpqh8ggKaVq3RVkC3t3 +QMRWGpsvAgMBAAECggEAIK9C+lNAallJ62z8inU8tjxDuAqUOBVbJZRVcPbIr1zn +HmLlpyd4Sghhh7CjYYoPuHDtTQxIcBNwlDBxb3x+zwUXzy+tC5v5j7DN01qex2Ew +XTDSAEN3Ra2r1S+/1hSztVd0oXDozFxKk+UETRjfKKoJZH6LPcy7MOLFR5EEuJ8L +0kvGdEtuNCmZ1vPBwqR3IKQS9NsB1IdTtK0g2LdtVzM3U6F173CrAx51qNeAL30j +Np+I0rfm7vYVco6nDQXJB86hzwwBnLMzmZR2E0z+JStQCjQtEJN9wp+NBnViMb8C +mZl0K/PH3ZKNEs1Aw/TsRpPu6Fc+sN6iIs2oOGiKfQKBgQDa0Wpfj+SHflLfmrRU +PplGNjWdJiyuXROqX18iNE8nAD0eRqAFdzj9yU1IW49KCzuHInEl2pP9yrDZTWXB +Bht4C+Vk13mrBE3Sc1LDrks5EhDLaaolLgx1B+JN1X2DpfuzO8WHrXR11PCzFTAp +yDSVd451CFFXMseS1V9UxCy3lQKBgQDLDrLX/0hGhG+a5RUaAE+hZk+tU9RyjYm6 +/5lIoDjDwA9Yst69JCTHDApkdZ6IrjPDZrxkAQR6QwsGo+zRGkHV2wCoqR/RxcT5 +RBcbe/8xL86ZKwnhAheP6ssgZeK5zOG1iLol319kXXuo6NueN+YlocmsppRvAOq7 +/qMnhzXGswKBgQCpke2wHo9HnNJWK8ohGt2mtm232ZR4jvKlbgEIPac1Hw89/hcW +BT0qFqyILUQOakP4Re2PGyLiYwfHbh4zhisVTYq4Ke9EYzJ3qxzxPYlXsbNIHxtW +cqf+rVxnWtFIiwFR9TjvGrEMezcIYJwRVO/DAIJqGUcHnvdfx3B3/Qp2PQKBgQCk +y7UR37kEog8BotHRXFdEIgigHtzYa05QWYhJjN8E3yaVUfW7g03lzTvR9DNJsjeI +aiSS9NBxeV/Fb9yOh8TOjwKl3zxXvy3xLvWh9KxTev0tCeTmnBALWP6puIadTE4S +Snjoq7R7e/MUToeOjMdX20oVuMvWmuPm1u4K8o0OSQKBgQCec+QLllYXk22A8/e/ +f5HhSYr161lEFFmuzKhuuy+esyCQU/KZmxQH0UqnsL3Ww4ofq42lteqyUJnriHsx +QP5FTIMKH8W+Xels1i6jCC+MVXAXraAF27dOlmKxWMN7mnElZ/7lQKmBq64wil35 +sfcJA4FDxVM2Amv4KRo/w1C/zQ== +-----END PRIVATE KEY----- diff --git a/presto-flight-shim/src/test/resources/split_additional_predicate.json b/presto-flight-shim/src/test/resources/split_additional_predicate.json new file mode 100644 index 0000000000000..d4dac67bb731c --- /dev/null +++ b/presto-flight-shim/src/test/resources/split_additional_predicate.json @@ -0,0 +1,91 @@ +{ +"connectorId" : "postgresql", +"schemaName" : "tpch", +"tableName" : "orders", +"tupleDomain" : { + "columnDomains" : [ { + "column" : { + "@type" : "postgresql", + "connectorId" : "postgresql", + "columnName" : "orderkey", + "jdbcTypeHandle" : { + "jdbcType" : -5, + "jdbcTypeName" : "int8", + "columnSize" : 19, + "decimalDigits" : 0 + }, + "columnType" : "bigint", + "nullable" : true + }, + "domain" : { + "values" : { + "@type" : "sortable", + "type" : "bigint", + "ranges" : [ { + "low" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA", + "bound" : "EXACTLY" + }, + "high" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA", + "bound" : "EXACTLY" + } + }, { + "low" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA", + "bound" : "EXACTLY" + }, + "high" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA", + "bound" : "EXACTLY" + } + }, { + "low" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA", + "bound" : "EXACTLY" + }, + "high" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA", + "bound" : "EXACTLY" + } + } ] + }, + "nullAllowed" : false + } + } ] +}, +"additionalProperty" : { + "translatedString" : "((orderkey) IN ((?) , (?) , (?)))", + "boundConstantValues" : [ { + "@type" : "constant", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAEAAAAAAAAA", + "type" : "bigint", + "sourceLocation" : { + "line" : 1, + "column" : 22 + } + }, { + "@type" : "constant", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAIAAAAAAAAA", + "type" : "bigint", + "sourceLocation" : { + "line" : 1, + "column" : 22 + } + }, { + "@type" : "constant", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA", + "type" : "bigint", + "sourceLocation" : { + "line" : 1, + "column" : 22 + } + } ] +} +} diff --git a/presto-flight-shim/src/test/resources/split_tuple_domain.json b/presto-flight-shim/src/test/resources/split_tuple_domain.json new file mode 100644 index 0000000000000..b7fc0082fa403 --- /dev/null +++ b/presto-flight-shim/src/test/resources/split_tuple_domain.json @@ -0,0 +1,40 @@ +{ + "connectorId" : "postgresql", + "schemaName" : "tpch", + "tableName" : "lineitem", + "tupleDomain" : { + "columnDomains" : [ { + "column" : { + "connectorId" : "postgresql", + "columnName" : "orderkey", + "jdbcTypeHandle" : { + "jdbcType" : -5, + "jdbcTypeName" : "bigint", + "columnSize" : 8, + "decimalDigits" : 0 + }, + "columnType" : "bigint", + "nullable" : false + }, + "domain" : { + "values" : { + "@type" : "sortable", + "type" : "bigint", + "ranges" : [ { + "low" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA", + "bound" : "EXACTLY" + }, + "high" : { + "type" : "bigint", + "valueBlock" : "CgAAAExPTkdfQVJSQVkBAAAAAAMAAAAAAAAA", + "bound" : "EXACTLY" + } + } ] + }, + "nullAllowed" : false + } + } ] + } +} diff --git a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java index 979df181ddb5e..015e31a3a9293 100644 --- a/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java +++ b/presto-tpch/src/main/java/com/facebook/presto/tpch/TpchColumnHandle.java @@ -16,6 +16,7 @@ import com.facebook.presto.common.Subfield; import com.facebook.presto.common.type.Type; import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonProperty; import com.google.common.collect.ImmutableList; @@ -62,6 +63,14 @@ public Type getType() return type; } + public ColumnMetadata getColumnMetadata() + { + return ColumnMetadata.builder() + .setName(columnName) + .setType(type) + .build(); + } + @Override public String toString() {