diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml
new file mode 100644
index 0000000000000..ee77c122536e1
--- /dev/null
+++ b/.github/workflows/arrow-flight-tests.yml
@@ -0,0 +1,82 @@
+name: arrow flight tests
+
+on:
+ pull_request:
+
+env:
+ CONTINUOUS_INTEGRATION: true
+ MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError"
+ MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError"
+ MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true"
+ MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end"
+ RETRY: .github/bin/retry
+
+jobs:
+ changes:
+ runs-on: ubuntu-latest
+ permissions:
+ pull-requests: read
+ outputs:
+ codechange: ${{ steps.filter.outputs.codechange }}
+ steps:
+ - uses: dorny/paths-filter@v2
+ id: filter
+ with:
+ filters: |
+ codechange:
+ - '!presto-docs/**'
+ test:
+ runs-on: ubuntu-latest
+ needs: changes
+ strategy:
+ fail-fast: false
+ matrix:
+ modules:
+ - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module
+
+ timeout-minutes: 80
+ concurrency:
+ group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }}
+ cancel-in-progress: true
+
+ steps:
+ # Checkout the code only if there are changes in the relevant files
+ - uses: actions/checkout@v4
+ if: needs.changes.outputs.codechange == 'true'
+ with:
+ show-progress: false
+
+ # Set up Java for the build environment
+ - uses: actions/setup-java@v2
+ if: needs.changes.outputs.codechange == 'true'
+ with:
+ distribution: 'temurin'
+ java-version: 8
+
+ # Cache Maven dependencies to speed up the build
+ - name: Cache local Maven repository
+ if: needs.changes.outputs.codechange == 'true'
+ id: cache-maven
+ uses: actions/cache@v2
+ with:
+ path: ~/.m2/repository
+ key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }}
+ restore-keys: |
+ ${{ runner.os }}-maven-2-
+
+ # Resolve Maven dependencies (if cache is not found)
+ - name: Populate Maven cache
+ if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true'
+ run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs
+
+ # Install dependencies for the target module
+ - name: Maven Install
+ if: needs.changes.outputs.codechange == 'true'
+ run: |
+ export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}"
+ ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }}
+
+ # Run Maven tests for the target module
+ - name: Maven Tests
+ if: needs.changes.outputs.codechange == 'true'
+ run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }}
diff --git a/.github/workflows/test-other-modules.yml b/.github/workflows/test-other-modules.yml
index 7b5ba858bdc4d..02cf887feecc6 100644
--- a/.github/workflows/test-other-modules.yml
+++ b/.github/workflows/test-other-modules.yml
@@ -87,4 +87,5 @@ jobs:
!presto-test-coverage,
!presto-iceberg,
!presto-singlestore,
- !presto-native-sidecar-plugin'
+ !presto-native-sidecar-plugin,
+ !presto-base-arrow-flight'
diff --git a/pom.xml b/pom.xml
index 4b2f740db855b..cf8e996d137f5 100644
--- a/pom.xml
+++ b/pom.xml
@@ -89,6 +89,7 @@
2.11.0
3.14.0
5.1.0
+ 17.0.0
+
+ com.fasterxml.jackson.datatype
+ jackson-datatype-jsr310
+
+
+
+
+
+ org.apache.arrow
+ flight-core
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+
+ com.facebook.airlift
+ bootstrap
+
+
+
+ com.facebook.airlift
+ log
+
+
+
+ com.google.guava
+ guava
+
+
+
+ javax.inject
+ javax.inject
+
+
+
+ com.facebook.presto
+ presto-spi
+
+
+
+ io.airlift
+ slice
+
+
+
+ com.facebook.airlift
+ log-manager
+
+
+
+ com.fasterxml.jackson.core
+ jackson-annotations
+
+
+
+ com.facebook.presto
+ presto-common
+
+
+
+ com.fasterxml.jackson.core
+ jackson-databind
+
+
+
+ com.google.code.findbugs
+ jsr305
+ true
+
+
+
+ com.google.inject
+ guice
+
+
+
+ com.facebook.airlift
+ configuration
+
+
+
+ joda-time
+ joda-time
+
+
+
+ org.jdbi
+ jdbi3-core
+
+
+
+
+ org.testng
+ testng
+ test
+
+
+
+ io.airlift.tpch
+ tpch
+ test
+
+
+
+ com.facebook.presto
+ presto-tpch
+ test
+
+
+
+ com.facebook.airlift
+ json
+ test
+
+
+
+ com.facebook.presto
+ presto-testng-services
+ test
+
+
+
+ com.facebook.airlift
+ testing
+ test
+
+
+
+ com.facebook.presto
+ presto-main
+ test
+
+
+
+ com.facebook.presto
+ presto-tests
+ test
+
+
+
+ com.h2database
+ h2
+ test
+
+
+
+ org.apache.arrow
+ arrow-jdbc
+ ${dep.arrow.version}
+ test
+
+
+ org.slf4j
+ slf4j-api
+
+
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-surefire-plugin
+
+ -Xss10M
+
+
+
+ org.apache.maven.plugins
+ maven-enforcer-plugin
+
+
+ org.apache.maven.plugins
+ maven-dependency-plugin
+
+
+ org.basepom.maven
+ duplicate-finder-maven-plugin
+ 1.2.1
+
+
+ module-info
+ META-INF.versions.9.module-info
+
+
+ arrow-git.properties
+ about.html
+
+
+
+
+
+ check
+
+
+
+
+
+
+
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java
new file mode 100644
index 0000000000000..3a5ce28cd76fd
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java
@@ -0,0 +1,722 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockBuilder;
+import com.facebook.presto.common.block.DictionaryBlock;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.common.type.BigintType;
+import com.facebook.presto.common.type.BooleanType;
+import com.facebook.presto.common.type.CharType;
+import com.facebook.presto.common.type.DateType;
+import com.facebook.presto.common.type.DecimalType;
+import com.facebook.presto.common.type.Decimals;
+import com.facebook.presto.common.type.DoubleType;
+import com.facebook.presto.common.type.IntegerType;
+import com.facebook.presto.common.type.MapType;
+import com.facebook.presto.common.type.RealType;
+import com.facebook.presto.common.type.RowType;
+import com.facebook.presto.common.type.SmallintType;
+import com.facebook.presto.common.type.TimeType;
+import com.facebook.presto.common.type.TimestampType;
+import com.facebook.presto.common.type.TinyintType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.common.type.VarbinaryType;
+import com.facebook.presto.common.type.VarcharType;
+import com.google.common.base.CharMatcher;
+import io.airlift.slice.Slice;
+import io.airlift.slice.Slices;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DateMilliVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.Float4Vector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.NullVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeMicroVector;
+import org.apache.arrow.vector.TimeMilliVector;
+import org.apache.arrow.vector.TimeSecVector;
+import org.apache.arrow.vector.TimeStampMicroVector;
+import org.apache.arrow.vector.TimeStampMilliTZVector;
+import org.apache.arrow.vector.TimeStampMilliVector;
+import org.apache.arrow.vector.TimeStampSecVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarBinaryVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.MapVector;
+import org.apache.arrow.vector.complex.StructVector;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.Field;
+
+import javax.inject.Inject;
+
+import java.math.BigDecimal;
+import java.nio.charset.StandardCharsets;
+import java.time.Duration;
+import java.time.LocalTime;
+import java.util.List;
+import java.util.Optional;
+import java.util.concurrent.TimeUnit;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR;
+import static com.facebook.presto.common.Utils.checkArgument;
+import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature;
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.lang.String.format;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowBlockBuilder
+{
+ private final TypeManager typeManager;
+
+ @Inject
+ public ArrowBlockBuilder(TypeManager typeManager)
+ {
+ this.typeManager = requireNonNull(typeManager, "typeManager is null");
+ }
+ public Block buildBlockFromFieldVector(FieldVector vector, Type type, DictionaryProvider dictionaryProvider)
+ {
+ // Use Arrow dictionary to create a DictionaryBlock
+ if (dictionaryProvider != null && vector.getField().getDictionary() != null) {
+ Dictionary dictionary = dictionaryProvider.lookup(vector.getField().getDictionary().getId());
+ if (dictionary != null) {
+ Type prestoType = getPrestoTypeFromArrowField(dictionary.getVector().getField());
+ BlockBuilder dictionaryBuilder = prestoType.createBlockBuilder(null, vector.getValueCount());
+ assignBlockFromValueVector(dictionary.getVector(), prestoType, dictionaryBuilder, 0, dictionary.getVector().getValueCount());
+ return buildDictionaryBlock(vector, dictionaryBuilder.build());
+ }
+ }
+
+ BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount());
+ assignBlockFromValueVector(vector, type, builder, 0, vector.getValueCount());
+ return builder.build();
+ }
+
+ protected Type getPrestoTypeFromArrowField(Field field)
+ {
+ switch (field.getType().getTypeID()) {
+ case Int:
+ ArrowType.Int intType = (ArrowType.Int) field.getType();
+ return getPrestoTypeForArrowIntType(intType);
+ case Binary:
+ case LargeBinary:
+ case FixedSizeBinary:
+ return VarbinaryType.VARBINARY;
+ case Date:
+ return DateType.DATE;
+ case Timestamp:
+ return TimestampType.TIMESTAMP;
+ case Utf8:
+ case LargeUtf8:
+ return VarcharType.VARCHAR;
+ case FloatingPoint:
+ ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType();
+ return getPrestoTypeForArrowFloatingPointType(floatingPoint);
+ case Decimal:
+ ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType();
+ return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale());
+ case Bool:
+ return BooleanType.BOOLEAN;
+ case Time:
+ return TimeType.TIME;
+ case List: {
+ List children = field.getChildren();
+ checkArgument(children.size() == 1, "Arrow List expected to have 1 child Field, got: " + children.size());
+ return new ArrayType(getPrestoTypeFromArrowField(field.getChildren().get(0)));
+ }
+ case Map: {
+ List children = field.getChildren();
+ checkArgument(children.size() == 1, "Arrow Map expected to have 1 child Field for entries, got: " + children.size());
+ Field entryField = children.get(0);
+ checkArgument(entryField.getChildren().size() == 2, "Arrow Map entries expected to have 2 child Fields, got: " + children.size());
+ Type keyType = getPrestoTypeFromArrowField(entryField.getChildren().get(0));
+ Type valueType = getPrestoTypeFromArrowField(entryField.getChildren().get(1));
+ return typeManager.getType(parseTypeSignature(format("map(%s,%s)", keyType.getTypeSignature(), valueType.getTypeSignature())));
+ }
+ case Struct: {
+ List children = field.getChildren().stream().map(child -> new RowType.Field(Optional.of(child.getName()), getPrestoTypeFromArrowField(child))).collect(toImmutableList());
+ return RowType.from(children);
+ }
+ default:
+ throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported.");
+ }
+ }
+
+ private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint)
+ {
+ switch (floatingPoint.getPrecision()) {
+ case SINGLE:
+ return RealType.REAL;
+ case DOUBLE:
+ return DoubleType.DOUBLE;
+ default:
+ throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected floating point precision: " + floatingPoint.getPrecision());
+ }
+ }
+
+ private Type getPrestoTypeForArrowIntType(ArrowType.Int intType)
+ {
+ switch (intType.getBitWidth()) {
+ case 64:
+ return BigintType.BIGINT;
+ case 32:
+ return IntegerType.INTEGER;
+ case 16:
+ return SmallintType.SMALLINT;
+ case 8:
+ return TinyintType.TINYINT;
+ default:
+ throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected bit width: " + intType.getBitWidth());
+ }
+ }
+
+ private DictionaryBlock buildDictionaryBlock(FieldVector fieldVector, Block dictionaryblock)
+ {
+ if (fieldVector instanceof IntVector) {
+ // Get the Arrow indices vector
+ IntVector indicesVector = (IntVector) fieldVector;
+ int[] ids = new int[indicesVector.getValueCount()];
+ for (int i = 0; i < indicesVector.getValueCount(); i++) {
+ ids[i] = indicesVector.get(i);
+ }
+ return new DictionaryBlock(ids.length, dictionaryblock, ids);
+ }
+ else if (fieldVector instanceof SmallIntVector) {
+ // Get the SmallInt indices vector
+ SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector;
+ int[] ids = new int[smallIntIndicesVector.getValueCount()];
+ for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) {
+ ids[i] = smallIntIndicesVector.get(i);
+ }
+ return new DictionaryBlock(ids.length, dictionaryblock, ids);
+ }
+ else if (fieldVector instanceof TinyIntVector) {
+ // Get the TinyInt indices vector
+ TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector;
+ int[] ids = new int[tinyIntIndicesVector.getValueCount()];
+ for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) {
+ ids[i] = tinyIntIndicesVector.get(i);
+ }
+ return new DictionaryBlock(ids.length, dictionaryblock, ids);
+ }
+ else {
+ // Handle the case where the FieldVector is of an unsupported type
+ throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass());
+ }
+ }
+
+ private void assignBlockFromValueVector(ValueVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (vector instanceof BitVector) {
+ assignBlockFromBitVector((BitVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TinyIntVector) {
+ assignBlockFromTinyIntVector((TinyIntVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof IntVector) {
+ assignBlockFromIntVector((IntVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof SmallIntVector) {
+ assignBlockFromSmallIntVector((SmallIntVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof BigIntVector) {
+ assignBlockFromBigIntVector((BigIntVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof DecimalVector) {
+ assignBlockFromDecimalVector((DecimalVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof NullVector) {
+ assignBlockFromNullVector((NullVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeStampMicroVector) {
+ assignBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeStampMilliVector) {
+ assignBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof Float4Vector) {
+ assignBlockFromFloat4Vector((Float4Vector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof Float8Vector) {
+ assignBlockFromFloat8Vector((Float8Vector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof VarCharVector) {
+ if (type instanceof CharType) {
+ assignCharTypeBlockFromVarcharVector((VarCharVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (type instanceof TimeType) {
+ assignTimeTypeBlockFromVarcharVector((VarCharVector) vector, type, builder, startIndex, endIndex);
+ }
+ else {
+ assignBlockFromVarCharVector((VarCharVector) vector, type, builder, startIndex, endIndex);
+ }
+ }
+ else if (vector instanceof VarBinaryVector) {
+ assignBlockFromVarBinaryVector((VarBinaryVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof DateDayVector) {
+ assignBlockFromDateDayVector((DateDayVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof DateMilliVector) {
+ assignBlockFromDateMilliVector((DateMilliVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeMilliVector) {
+ assignBlockFromTimeMilliVector((TimeMilliVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeSecVector) {
+ assignBlockFromTimeSecVector((TimeSecVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeStampSecVector) {
+ assignBlockFromTimeStampSecVector((TimeStampSecVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeMicroVector) {
+ assignBlockFromTimeMicroVector((TimeMicroVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof TimeStampMilliTZVector) {
+ assignBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof MapVector) {
+ // NOTE: MapVector is also instanceof ListVector, so check for Map first
+ assignBlockFromMapVector((MapVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof ListVector) {
+ assignBlockFromListVector((ListVector) vector, type, builder, startIndex, endIndex);
+ }
+ else if (vector instanceof StructVector) {
+ assignBlockFromStructVector((StructVector) vector, type, builder, startIndex, endIndex);
+ }
+ else {
+ throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass());
+ }
+ }
+
+ public void assignBlockFromBitVector(BitVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeBoolean(builder, vector.get(i) == 1);
+ }
+ }
+ }
+
+ public void assignBlockFromIntVector(IntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeLong(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromSmallIntVector(SmallIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeLong(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromTinyIntVector(TinyIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeLong(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromBigIntVector(BigIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeLong(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromDecimalVector(DecimalVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof DecimalType)) {
+ throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector");
+ }
+
+ DecimalType decimalType = (DecimalType) type;
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value
+ if (decimalType.isShort()) {
+ builder.writeLong(decimal.unscaledValue().longValue());
+ }
+ else {
+ Slice slice = Decimals.encodeScaledValue(decimal);
+ decimalType.writeSlice(builder, slice, 0, slice.length());
+ }
+ }
+ }
+ }
+
+ public void assignBlockFromNullVector(NullVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ builder.appendNull();
+ }
+ }
+
+ public void assignBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimestampType)) {
+ throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName());
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long micros = vector.get(i);
+ long millis = TimeUnit.MICROSECONDS.toMillis(micros);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimestampType)) {
+ throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName());
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long millis = vector.get(i);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromFloat8Vector(Float8Vector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeDouble(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromFloat4Vector(Float4Vector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ int intBits = Float.floatToIntBits(vector.get(i));
+ type.writeLong(builder, intBits);
+ }
+ }
+ }
+
+ public void assignBlockFromVarBinaryVector(VarBinaryVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ byte[] value = vector.get(i);
+ type.writeSlice(builder, Slices.wrappedBuffer(value));
+ }
+ }
+ }
+
+ public void assignBlockFromVarCharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof VarcharType)) {
+ throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName());
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ // Directly create a Slice from the raw byte array
+ byte[] rawBytes = vector.get(i);
+ Slice slice = Slices.wrappedBuffer(rawBytes);
+ // Write the Slice directly to the builder
+ type.writeSlice(builder, slice);
+ }
+ }
+ }
+
+ public void assignBlockFromDateDayVector(DateDayVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof DateType)) {
+ throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName());
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ type.writeLong(builder, vector.get(i));
+ }
+ }
+ }
+
+ public void assignBlockFromDateMilliVector(DateMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof DateType)) {
+ throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName());
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ DateType dateType = (DateType) type;
+ long days = TimeUnit.MILLISECONDS.toDays(vector.get(i));
+ dateType.writeLong(builder, days);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeSecVector(TimeSecVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimeType)) {
+ throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector");
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ int value = vector.get(i);
+ long millis = TimeUnit.SECONDS.toMillis(value);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeMilliVector(TimeMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimeType)) {
+ throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector");
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long millis = vector.get(i);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeMicroVector(TimeMicroVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimeType)) {
+ throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector");
+ }
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long value = vector.get(i);
+ long micro = TimeUnit.MICROSECONDS.toMillis(value);
+ type.writeLong(builder, micro);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimestampType)) {
+ throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector");
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long value = vector.get(i);
+ long millis = TimeUnit.SECONDS.toMillis(value);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignCharTypeBlockFromVarcharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ String value = new String(vector.get(i), StandardCharsets.UTF_8);
+ type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value)));
+ }
+ }
+ }
+
+ public void assignTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ String timeString = new String(vector.get(i), StandardCharsets.UTF_8);
+ LocalTime time = LocalTime.parse(timeString);
+ long millis = Duration.between(LocalTime.MIN, time).toMillis();
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof TimestampType)) {
+ throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector");
+ }
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ long millis = vector.get(i);
+ type.writeLong(builder, millis);
+ }
+ }
+ }
+
+ public void assignBlockFromListVector(ListVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof ArrayType)) {
+ throw new IllegalArgumentException("Type must be an ArrayType for ListVector");
+ }
+
+ ArrayType arrayType = (ArrayType) type;
+ Type elementType = arrayType.getElementType();
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ BlockBuilder elementBuilder = builder.beginBlockEntry();
+ assignBlockFromValueVector(
+ vector.getDataVector(), elementType, elementBuilder, vector.getElementStartIndex(i), vector.getElementEndIndex(i));
+ builder.closeEntry();
+ }
+ }
+ }
+
+ public void assignBlockFromMapVector(MapVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof MapType)) {
+ throw new IllegalArgumentException("Type must be a MapType for MapVector");
+ }
+
+ MapType mapType = (MapType) type;
+ StructVector entryVector = (StructVector) vector.getDataVector();
+ ValueVector keyVector = entryVector.getChildByOrdinal(0);
+ ValueVector valueVector = entryVector.getChildByOrdinal(1);
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ BlockBuilder entryBuilder = builder.beginBlockEntry();
+ int entryStart = vector.getElementStartIndex(i);
+ int entryEnd = vector.getElementEndIndex(i);
+ for (int entryIndex = entryStart; entryIndex < entryEnd; entryIndex++) {
+ assignBlockFromValueVector(keyVector, mapType.getKeyType(), entryBuilder, entryIndex, entryIndex + 1);
+ assignBlockFromValueVector(valueVector, mapType.getValueType(), entryBuilder, entryIndex, entryIndex + 1);
+ }
+ builder.closeEntry();
+ }
+ }
+ }
+
+ public void assignBlockFromStructVector(StructVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex)
+ {
+ if (!(type instanceof RowType)) {
+ throw new IllegalArgumentException("Type must be a RowType for StructVector");
+ }
+
+ RowType rowType = (RowType) type;
+
+ for (int i = startIndex; i < endIndex; i++) {
+ if (vector.isNull(i)) {
+ builder.appendNull();
+ }
+ else {
+ BlockBuilder childBuilder = builder.beginBlockEntry();
+ List childTypes = rowType.getTypeParameters();
+ for (int childIndex = 0; childIndex < childTypes.size(); childIndex++) {
+ Type childType = childTypes.get(childIndex);
+ ValueVector childVector = vector.getChildByOrdinal(childIndex);
+ assignBlockFromValueVector(childVector, childType, childBuilder, i, i + 1);
+ }
+ builder.closeEntry();
+ }
+ }
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java
new file mode 100644
index 0000000000000..6a1430b007822
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java
@@ -0,0 +1,61 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import static java.util.Objects.requireNonNull;
+
+public class ArrowColumnHandle
+ implements ColumnHandle
+{
+ private final String columnName;
+ private final Type columnType;
+
+ @JsonCreator
+ public ArrowColumnHandle(
+ @JsonProperty("columnName") String columnName,
+ @JsonProperty("columnType") Type columnType)
+ {
+ this.columnName = requireNonNull(columnName, "columnName is null");
+ this.columnType = requireNonNull(columnType, "columnType is null");
+ }
+
+ @JsonProperty
+ public String getColumnName()
+ {
+ return columnName;
+ }
+
+ @JsonProperty
+ public Type getColumnType()
+ {
+ return columnType;
+ }
+
+ public ColumnMetadata getColumnMetadata()
+ {
+ return new ColumnMetadata(columnName, columnType);
+ }
+
+ @Override
+ public String toString()
+ {
+ return columnName + ":" + columnType;
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java
new file mode 100644
index 0000000000000..3bae2b16c484b
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.connector.Connector;
+import com.facebook.presto.spi.connector.ConnectorMetadata;
+import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
+import com.facebook.presto.spi.connector.ConnectorSplitManager;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+import com.facebook.presto.spi.transaction.IsolationLevel;
+import com.google.inject.Inject;
+import org.apache.arrow.memory.BufferAllocator;
+
+import static java.util.Objects.requireNonNull;
+
+public class ArrowConnector
+ implements Connector
+{
+ private final ConnectorMetadata metadata;
+ private final ConnectorSplitManager splitManager;
+ private final ConnectorPageSourceProvider pageSourceProvider;
+ private final BufferAllocator connectorAllocator;
+
+ @Inject
+ public ArrowConnector(
+ ConnectorMetadata metadata,
+ ConnectorSplitManager splitManager,
+ ConnectorPageSourceProvider pageSourceProvider,
+ BufferAllocator connectorAllocator)
+ {
+ this.metadata = requireNonNull(metadata, "Metadata is null");
+ this.splitManager = requireNonNull(splitManager, "splitManager is null");
+ this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null");
+ this.connectorAllocator = requireNonNull(connectorAllocator, "connectorAllocator is null");
+ }
+
+ @Override
+ public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly)
+ {
+ return ArrowTransactionHandle.INSTANCE;
+ }
+
+ @Override
+ public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle)
+ {
+ return metadata;
+ }
+
+ @Override
+ public ConnectorSplitManager getSplitManager()
+ {
+ return splitManager;
+ }
+
+ @Override
+ public ConnectorPageSourceProvider getPageSourceProvider()
+ {
+ return pageSourceProvider;
+ }
+
+ @Override
+ public void shutdown()
+ {
+ connectorAllocator.close();
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java
new file mode 100644
index 0000000000000..ce95cccdd051e
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java
@@ -0,0 +1,104 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.bootstrap.Bootstrap;
+import com.facebook.presto.common.type.TypeManager;
+import com.facebook.presto.spi.ConnectorHandleResolver;
+import com.facebook.presto.spi.NodeManager;
+import com.facebook.presto.spi.classloader.ThreadContextClassLoader;
+import com.facebook.presto.spi.connector.Connector;
+import com.facebook.presto.spi.connector.ConnectorContext;
+import com.facebook.presto.spi.connector.ConnectorFactory;
+import com.facebook.presto.spi.function.FunctionMetadataManager;
+import com.facebook.presto.spi.function.StandardFunctionResolution;
+import com.facebook.presto.spi.relation.RowExpressionService;
+import com.google.common.collect.ImmutableList;
+import com.google.inject.ConfigurationException;
+import com.google.inject.Injector;
+import com.google.inject.Module;
+
+import java.util.List;
+import java.util.Map;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_INTERNAL_ERROR;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Strings.isNullOrEmpty;
+import static com.google.common.base.Throwables.throwIfUnchecked;
+import static com.google.inject.util.Modules.override;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowConnectorFactory
+ implements ConnectorFactory
+{
+ private final String name;
+ private final Module module;
+ private final ImmutableList extraModules;
+ private final ClassLoader classLoader;
+
+ public ArrowConnectorFactory(String name, Module module, List extraModules, ClassLoader classLoader)
+ {
+ checkArgument(!isNullOrEmpty(name), "name is null or empty");
+ this.name = name;
+ this.module = requireNonNull(module, "module is null");
+ this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null"));
+ this.classLoader = requireNonNull(classLoader, "classLoader is null");
+ }
+
+ @Override
+ public String getName()
+ {
+ return name;
+ }
+
+ @Override
+ public ConnectorHandleResolver getHandleResolver()
+ {
+ return new ArrowHandleResolver();
+ }
+
+ @Override
+ public Connector create(String catalogName, Map requiredConfig, ConnectorContext context)
+ {
+ requireNonNull(requiredConfig, "requiredConfig is null");
+
+ try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) {
+ Bootstrap app = new Bootstrap(ImmutableList.builder()
+ .add(binder -> {
+ binder.bind(TypeManager.class).toInstance(context.getTypeManager());
+ binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager());
+ binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution());
+ binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService());
+ binder.bind(NodeManager.class).toInstance(context.getNodeManager());
+ })
+ .add(override(new ArrowModule(catalogName)).with(module))
+ .addAll(extraModules)
+ .build());
+
+ Injector injector = app
+ .doNotInitializeLogging()
+ .setRequiredConfigurationProperties(requiredConfig)
+ .initialize();
+
+ return injector.getInstance(ArrowConnector.class);
+ }
+ catch (ConfigurationException ex) {
+ throw new ArrowException(ARROW_INTERNAL_ERROR, "The connector instance could not be created.", ex);
+ }
+ catch (Exception e) {
+ throwIfUnchecked(e);
+ throw new RuntimeException(e);
+ }
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java
new file mode 100644
index 0000000000000..dce08bac4ac24
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java
@@ -0,0 +1,53 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import java.util.Objects;
+
+import static java.util.Objects.requireNonNull;
+
+public class ArrowConnectorId
+{
+ private final String id;
+
+ public ArrowConnectorId(String id)
+ {
+ this.id = requireNonNull(id, "id is null");
+ }
+
+ @Override
+ public String toString()
+ {
+ return id;
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(id);
+ }
+
+ @Override
+ public boolean equals(Object obj)
+ {
+ if (this == obj) {
+ return true;
+ }
+ if ((obj == null) || (getClass() != obj.getClass())) {
+ return false;
+ }
+ ArrowConnectorId other = (ArrowConnectorId) obj;
+ return Objects.equals(this.id, other.id);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java
new file mode 100644
index 0000000000000..6b217e4f56cdf
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.ErrorCode;
+import com.facebook.presto.common.ErrorType;
+import com.facebook.presto.spi.ErrorCodeSupplier;
+
+import static com.facebook.presto.common.ErrorType.EXTERNAL;
+import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR;
+
+public enum ArrowErrorCode
+ implements ErrorCodeSupplier
+{
+ ARROW_FLIGHT_INFO_ERROR(0, EXTERNAL),
+ ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR),
+ ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL),
+ ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL),
+ ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL);
+
+ private final ErrorCode errorCode;
+
+ ArrowErrorCode(int code, ErrorType type)
+ {
+ errorCode = new ErrorCode(code + 0x0510_0000, name(), type);
+ }
+
+ @Override
+ public ErrorCode toErrorCode()
+ {
+ return errorCode;
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java
new file mode 100644
index 0000000000000..ba2c6edba589c
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ErrorCodeSupplier;
+import com.facebook.presto.spi.PrestoException;
+
+public class ArrowException
+ extends PrestoException
+{
+ public ArrowException(ErrorCodeSupplier errorCode, String message)
+ {
+ super(errorCode, message);
+ }
+
+ public ArrowException(ErrorCodeSupplier errorCode, String message, Throwable cause)
+ {
+ super(errorCode, message, cause);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java
new file mode 100644
index 0000000000000..00698846a85a3
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.configuration.Config;
+
+public class ArrowFlightConfig
+{
+ private String server;
+ private Boolean verifyServer;
+ private String flightServerSSLCertificate;
+ private Boolean arrowFlightServerSslEnabled;
+ private Integer arrowFlightPort;
+
+ public String getFlightServerName()
+ {
+ return server;
+ }
+
+ @Config("arrow-flight.server")
+ public ArrowFlightConfig setFlightServerName(String server)
+ {
+ this.server = server;
+ return this;
+ }
+
+ public Boolean getVerifyServer()
+ {
+ return verifyServer;
+ }
+
+ @Config("arrow-flight.server.verify")
+ public ArrowFlightConfig setVerifyServer(Boolean verifyServer)
+ {
+ this.verifyServer = verifyServer;
+ return this;
+ }
+
+ public Integer getArrowFlightPort()
+ {
+ return arrowFlightPort;
+ }
+
+ @Config("arrow-flight.server.port")
+ public ArrowFlightConfig setArrowFlightPort(Integer arrowFlightPort)
+ {
+ this.arrowFlightPort = arrowFlightPort;
+ return this;
+ }
+
+ public String getFlightServerSSLCertificate()
+ {
+ return flightServerSSLCertificate;
+ }
+
+ @Config("arrow-flight.server-ssl-certificate")
+ public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCertificate)
+ {
+ this.flightServerSSLCertificate = flightServerSSLCertificate;
+ return this;
+ }
+
+ public Boolean getArrowFlightServerSslEnabled()
+ {
+ return arrowFlightServerSslEnabled;
+ }
+
+ @Config("arrow-flight.server-ssl-enabled")
+ public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled)
+ {
+ this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled;
+ return this;
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java
new file mode 100644
index 0000000000000..8b231b98a6ee6
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ConnectorHandleResolver;
+import com.facebook.presto.spi.ConnectorSplit;
+import com.facebook.presto.spi.ConnectorTableHandle;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+
+public class ArrowHandleResolver
+ implements ConnectorHandleResolver
+{
+ @Override
+ public Class extends ConnectorTableHandle> getTableHandleClass()
+ {
+ return ArrowTableHandle.class;
+ }
+
+ @Override
+ public Class extends ConnectorTableLayoutHandle> getTableLayoutHandleClass()
+ {
+ return ArrowTableLayoutHandle.class;
+ }
+
+ @Override
+ public Class extends ColumnHandle> getColumnHandleClass()
+ {
+ return ArrowColumnHandle.class;
+ }
+
+ @Override
+ public Class extends ConnectorSplit> getSplitClass()
+ {
+ return ArrowSplit.class;
+ }
+
+ @Override
+ public Class extends ConnectorTransactionHandle> getTransactionHandleClass()
+ {
+ return ArrowTransactionHandle.class;
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java
new file mode 100644
index 0000000000000..977895966b291
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ColumnMetadata;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.ConnectorTableHandle;
+import com.facebook.presto.spi.ConnectorTableLayout;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.facebook.presto.spi.ConnectorTableLayoutResult;
+import com.facebook.presto.spi.ConnectorTableMetadata;
+import com.facebook.presto.spi.Constraint;
+import com.facebook.presto.spi.NotFoundException;
+import com.facebook.presto.spi.SchemaTableName;
+import com.facebook.presto.spi.SchemaTablePrefix;
+import com.facebook.presto.spi.connector.ConnectorMetadata;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import javax.inject.Inject;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Optional;
+import java.util.Set;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowMetadata
+ implements ConnectorMetadata
+{
+ private final BaseArrowFlightClientHandler clientHandler;
+ private final ArrowBlockBuilder arrowBlockBuilder;
+
+ @Inject
+ public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder)
+ {
+ this.clientHandler = requireNonNull(clientHandler, "clientHandler is null");
+ this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
+ }
+
+ @Override
+ public final List listSchemaNames(ConnectorSession session)
+ {
+ return clientHandler.listSchemaNames(session);
+ }
+
+ @Override
+ public final List listTables(ConnectorSession session, Optional schemaName)
+ {
+ return clientHandler.listTables(session, schemaName);
+ }
+
+ @Override
+ public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName)
+ {
+ if (!listSchemaNames(session).contains(tableName.getSchemaName())) {
+ return null;
+ }
+
+ if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) {
+ return null;
+ }
+ return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName());
+ }
+
+ public List getColumnsList(String schema, String table, ConnectorSession connectorSession)
+ {
+ try {
+ Schema flightSchema = clientHandler.getSchemaForTable(schema, table, connectorSession);
+ return flightSchema.getFields();
+ }
+ catch (Exception e) {
+ throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Table columns could not be listed for table: " + table, e);
+ }
+ }
+
+ @Override
+ public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle)
+ {
+ Map columnHandles = new HashMap<>();
+
+ String schemaValue = ((ArrowTableHandle) tableHandle).getSchema();
+ String tableValue = ((ArrowTableHandle) tableHandle).getTable();
+ List columnList = getColumnsList(schemaValue, tableValue, session);
+
+ for (Field field : columnList) {
+ String columnName = field.getName();
+ Type type = getPrestoTypeFromArrowField(field);
+ columnHandles.put(columnName, new ArrowColumnHandle(columnName, type));
+ }
+ return columnHandles;
+ }
+
+ @Override
+ public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns)
+ {
+ checkArgument(table instanceof ArrowTableHandle,
+ "Invalid table handle: Expected an instance of ArrowTableHandle but received %s",
+ table.getClass().getSimpleName());
+ checkArgument(desiredColumns.orElse(Collections.emptySet()).stream().allMatch(f -> f instanceof ArrowColumnHandle),
+ "Invalid column handles: Expected desired columns to be of type ArrowColumnHandle");
+
+ ArrowTableHandle tableHandle = (ArrowTableHandle) table;
+
+ List columns = new ArrayList<>();
+ if (desiredColumns.isPresent()) {
+ List arrowColumns = new ArrayList<>(desiredColumns.get());
+ columns = (List) (List>) arrowColumns;
+ }
+
+ ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary()));
+ return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary()));
+ }
+
+ @Override
+ public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle)
+ {
+ return new ConnectorTableLayout(handle);
+ }
+
+ @Override
+ public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table)
+ {
+ List meta = new ArrayList<>();
+ List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session);
+
+ for (Field field : columnList) {
+ String columnName = field.getName();
+ Type fieldType = getPrestoTypeFromArrowField(field);
+ meta.add(new ColumnMetadata(columnName, fieldType));
+ }
+ return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta);
+ }
+
+ @Override
+ public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle)
+ {
+ return ((ArrowColumnHandle) columnHandle).getColumnMetadata();
+ }
+
+ @Override
+ public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix)
+ {
+ requireNonNull(prefix, "prefix is null");
+ ImmutableMap.Builder> columns = ImmutableMap.builder();
+ List tables;
+ if (prefix.getSchemaName() != null && prefix.getTableName() != null) {
+ tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName()));
+ }
+ else {
+ tables = listTables(session, Optional.of(prefix.getSchemaName()));
+ }
+
+ for (SchemaTableName tableName : tables) {
+ try {
+ ConnectorTableHandle tableHandle = getTableHandle(session, tableName);
+ columns.put(tableName, getTableMetadata(session, tableHandle).getColumns());
+ }
+ catch (ClassCastException | NotFoundException e) {
+ throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Table columns could not be listed for table: " + tableName, e);
+ }
+ catch (Exception e) {
+ throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, e.getMessage(), e);
+ }
+ }
+ return columns.build();
+ }
+
+ private Type getPrestoTypeFromArrowField(Field field)
+ {
+ return arrowBlockBuilder.getPrestoTypeFromArrowField(field);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java
new file mode 100644
index 0000000000000..ed85f961b291e
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ConnectorHandleResolver;
+import com.facebook.presto.spi.connector.Connector;
+import com.facebook.presto.spi.connector.ConnectorMetadata;
+import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
+import com.facebook.presto.spi.connector.ConnectorSplitManager;
+import com.google.inject.Binder;
+import com.google.inject.Module;
+import com.google.inject.Scopes;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+
+import static com.facebook.airlift.configuration.ConfigBinder.configBinder;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowModule
+ implements Module
+{
+ protected final String connectorId;
+
+ public ArrowModule(String connectorId)
+ {
+ this.connectorId = requireNonNull(connectorId, "connector id is null");
+ }
+
+ public void configure(Binder binder)
+ {
+ configBinder(binder).bindConfig(ArrowFlightConfig.class);
+ binder.bind(BufferAllocator.class).to(RootAllocator.class).in(Scopes.SINGLETON);
+ binder.bind(ConnectorSplitManager.class).to(ArrowSplitManager.class).in(Scopes.SINGLETON);
+ binder.bind(ConnectorMetadata.class).to(ArrowMetadata.class).in(Scopes.SINGLETON);
+ binder.bind(ArrowConnector.class).in(Scopes.SINGLETON);
+ binder.bind(ArrowConnectorId.class).toInstance(new ArrowConnectorId(connectorId));
+ binder.bind(ConnectorHandleResolver.class).to(ArrowHandleResolver.class).in(Scopes.SINGLETON);
+ binder.bind(ArrowPageSourceProvider.class).in(Scopes.SINGLETON);
+ binder.bind(ConnectorPageSourceProvider.class).to(ArrowPageSourceProvider.class).in(Scopes.SINGLETON);
+ binder.bind(Connector.class).to(ArrowConnector.class).in(Scopes.SINGLETON);
+ binder.bind(ArrowBlockBuilder.class).to(ArrowBlockBuilder.class).in(Scopes.SINGLETON);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java
new file mode 100644
index 0000000000000..fcacb204da7d4
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java
@@ -0,0 +1,125 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.common.Page;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.spi.ConnectorPageSource;
+import com.facebook.presto.spi.ConnectorSession;
+import org.apache.arrow.vector.FieldVector;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowPageSource
+ implements ConnectorPageSource
+{
+ private static final Logger logger = Logger.get(ArrowPageSource.class);
+ private final List columnHandles;
+ private final ArrowBlockBuilder arrowBlockBuilder;
+ private final ClientClosingFlightStream flightStreamAndClient;
+ private boolean completed;
+ private int currentPosition;
+
+ public ArrowPageSource(
+ ArrowSplit split,
+ List columnHandles,
+ BaseArrowFlightClientHandler clientHandler,
+ ConnectorSession connectorSession,
+ ArrowBlockBuilder arrowBlockBuilder)
+ {
+ requireNonNull(split, "split is null");
+ this.columnHandles = requireNonNull(columnHandles, "columnHandles is null");
+ requireNonNull(clientHandler, "clientHandler is null");
+ this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
+ this.flightStreamAndClient = clientHandler.getFlightStream(split, connectorSession);
+ }
+
+ @Override
+ public long getCompletedBytes()
+ {
+ return 0;
+ }
+
+ @Override
+ public long getCompletedPositions()
+ {
+ return currentPosition;
+ }
+
+ @Override
+ public long getReadTimeNanos()
+ {
+ return 0;
+ }
+
+ @Override
+ public boolean isFinished()
+ {
+ return completed;
+ }
+
+ @Override
+ public long getSystemMemoryUsage()
+ {
+ return 0;
+ }
+
+ @Override
+ public Page getNextPage()
+ {
+ logger.debug("Reading next Arrow record batch");
+
+ if (!flightStreamAndClient.next()) {
+ // No more streams, end pages
+ completed = true;
+ logger.debug("Finished reading Arrow record batches");
+ return null;
+ }
+
+ currentPosition = currentPosition + 1;
+
+ // Create blocks from the loaded Arrow record batch
+ List blocks = new ArrayList<>();
+ List vectors = flightStreamAndClient.getRoot().getFieldVectors();
+ for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) {
+ FieldVector vector = vectors.get(columnIndex);
+ Type type = columnHandles.get(columnIndex).getColumnType();
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider());
+ blocks.add(block);
+ }
+
+ if (logger.isDebugEnabled()) {
+ logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectors.size());
+ }
+
+ return new Page(flightStreamAndClient.getRoot().getRowCount(), blocks.toArray(new Block[0]));
+ }
+
+ @Override
+ public void close()
+ {
+ try {
+ flightStreamAndClient.close();
+ }
+ catch (Exception e) {
+ throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e);
+ }
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java
new file mode 100644
index 0000000000000..216939831a2c8
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java
@@ -0,0 +1,54 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ConnectorPageSource;
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.ConnectorSplit;
+import com.facebook.presto.spi.SplitContext;
+import com.facebook.presto.spi.connector.ConnectorPageSourceProvider;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+import com.google.common.collect.ImmutableList;
+
+import javax.inject.Inject;
+
+import java.util.List;
+
+import static java.util.Objects.requireNonNull;
+
+public class ArrowPageSourceProvider
+ implements ConnectorPageSourceProvider
+{
+ private final BaseArrowFlightClientHandler clientHandler;
+ private final ArrowBlockBuilder arrowBlockBuilder;
+
+ @Inject
+ public ArrowPageSourceProvider(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder)
+ {
+ this.clientHandler = requireNonNull(clientHandler, "clientHandler is null");
+ this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null");
+ }
+
+ @Override
+ public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List columns, SplitContext splitContext)
+ {
+ ImmutableList.Builder columnHandles = ImmutableList.builder();
+ for (ColumnHandle handle : columns) {
+ columnHandles.add((ArrowColumnHandle) handle);
+ }
+ ArrowSplit arrowSplit = (ArrowSplit) split;
+ return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java
new file mode 100644
index 0000000000000..45197317cf81f
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java
@@ -0,0 +1,51 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.Plugin;
+import com.facebook.presto.spi.connector.ConnectorFactory;
+import com.google.common.collect.ImmutableList;
+import com.google.inject.Module;
+
+import static com.google.common.base.MoreObjects.firstNonNull;
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Strings.isNullOrEmpty;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowPlugin
+ implements Plugin
+{
+ private final String name;
+ private final Module module;
+ private final ImmutableList extraModules;
+
+ public ArrowPlugin(String name, Module module, Module... extraModules)
+ {
+ checkArgument(!isNullOrEmpty(name), "name is null or empty");
+ this.name = name;
+ this.module = requireNonNull(module, "module is null");
+ this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null"));
+ }
+
+ private static ClassLoader getClassLoader()
+ {
+ return firstNonNull(Thread.currentThread().getContextClassLoader(), ArrowPlugin.class.getClassLoader());
+ }
+
+ @Override
+ public Iterable getConnectorFactories()
+ {
+ return ImmutableList.of(new ArrowConnectorFactory(name, module, extraModules, getClassLoader()));
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java
new file mode 100644
index 0000000000000..9308bd60aa934
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ConnectorSplit;
+import com.facebook.presto.spi.HostAddress;
+import com.facebook.presto.spi.NodeProvider;
+import com.facebook.presto.spi.schedule.NodeSelectionStrategy;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import javax.annotation.Nullable;
+
+import java.util.Collections;
+import java.util.List;
+
+public class ArrowSplit
+ implements ConnectorSplit
+{
+ private final String schemaName;
+ private final String tableName;
+ private final byte[] flightEndpointBytes;
+
+ @JsonCreator
+ public ArrowSplit(
+ @JsonProperty("schemaName") @Nullable String schemaName,
+ @JsonProperty("tableName") String tableName,
+ @JsonProperty("flightEndpointBytes") byte[] flightEndpointBytes)
+ {
+ this.schemaName = schemaName;
+ this.tableName = tableName;
+ this.flightEndpointBytes = flightEndpointBytes;
+ }
+
+ @Override
+ public NodeSelectionStrategy getNodeSelectionStrategy()
+ {
+ return NodeSelectionStrategy.NO_PREFERENCE;
+ }
+
+ @Override
+ public List getPreferredNodes(NodeProvider nodeProvider)
+ {
+ return Collections.emptyList();
+ }
+
+ @Override
+ public Object getInfo()
+ {
+ return this.getInfoMap();
+ }
+
+ @JsonProperty
+ public String getSchemaName()
+ {
+ return schemaName;
+ }
+
+ @JsonProperty
+ public String getTableName()
+ {
+ return tableName;
+ }
+
+ @JsonProperty
+ public byte[] getFlightEndpointBytes()
+ {
+ return flightEndpointBytes;
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java
new file mode 100644
index 0000000000000..129a5e11f6355
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.ConnectorSplitSource;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.facebook.presto.spi.FixedSplitSource;
+import com.facebook.presto.spi.connector.ConnectorSplitManager;
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+import org.apache.arrow.flight.FlightInfo;
+
+import javax.inject.Inject;
+
+import java.util.List;
+
+import static com.google.common.collect.ImmutableList.toImmutableList;
+import static java.util.Objects.requireNonNull;
+
+public class ArrowSplitManager
+ implements ConnectorSplitManager
+{
+ private final BaseArrowFlightClientHandler clientHandler;
+
+ @Inject
+ public ArrowSplitManager(BaseArrowFlightClientHandler clientHandler)
+ {
+ this.clientHandler = requireNonNull(clientHandler, "clientHandler is null");
+ }
+
+ @Override
+ public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext)
+ {
+ ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout;
+ ArrowTableHandle tableHandle = tableLayoutHandle.getTable();
+ FlightInfo flightInfo = clientHandler.getFlightInfoForTableScan(tableLayoutHandle, session);
+ List splits = flightInfo.getEndpoints()
+ .stream()
+ .map(info -> new ArrowSplit(
+ tableHandle.getSchema(),
+ tableHandle.getTable(),
+ info.serialize().array()))
+ .collect(toImmutableList());
+ return new FixedSplitSource(splits);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java
new file mode 100644
index 0000000000000..cef04e4372a5a
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java
@@ -0,0 +1,73 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ConnectorTableHandle;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.util.Objects;
+
+public class ArrowTableHandle
+ implements ConnectorTableHandle
+{
+ private final String schema;
+ private final String table;
+
+ @JsonCreator
+ public ArrowTableHandle(
+ @JsonProperty("schema") String schema,
+ @JsonProperty("table") String table)
+ {
+ this.schema = schema;
+ this.table = table;
+ }
+
+ @JsonProperty("schema")
+ public String getSchema()
+ {
+ return schema;
+ }
+
+ @JsonProperty("table")
+ public String getTable()
+ {
+ return table;
+ }
+
+ @Override
+ public String toString()
+ {
+ return schema + ":" + table;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ArrowTableHandle that = (ArrowTableHandle) o;
+ return Objects.equals(schema, that.schema) && Objects.equals(table, that.table);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(schema, table);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java
new file mode 100644
index 0000000000000..dfb4d91985e91
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java
@@ -0,0 +1,87 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.common.predicate.TupleDomain;
+import com.facebook.presto.spi.ColumnHandle;
+import com.facebook.presto.spi.ConnectorTableLayoutHandle;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
+
+import java.util.List;
+import java.util.Objects;
+
+import static java.util.Objects.requireNonNull;
+
+public class ArrowTableLayoutHandle
+ implements ConnectorTableLayoutHandle
+{
+ private final ArrowTableHandle table;
+ private final List columnHandles;
+ private final TupleDomain tupleDomain;
+
+ @JsonCreator
+ public ArrowTableLayoutHandle(
+ @JsonProperty("table") ArrowTableHandle table,
+ @JsonProperty("columnHandles") List columnHandles,
+ @JsonProperty("tupleDomain") TupleDomain domain)
+ {
+ this.table = requireNonNull(table, "table is null");
+ this.columnHandles = requireNonNull(columnHandles, "columnHandles is null");
+ this.tupleDomain = requireNonNull(domain, "tupleDomain is null");
+ }
+
+ @JsonProperty("table")
+ public ArrowTableHandle getTable()
+ {
+ return table;
+ }
+
+ @JsonProperty("tupleDomain")
+ public TupleDomain getTupleDomain()
+ {
+ return tupleDomain;
+ }
+
+ @JsonProperty("columnHandles")
+ public List getColumnHandles()
+ {
+ return columnHandles;
+ }
+
+ @Override
+ public String toString()
+ {
+ return "table:" + table + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain;
+ }
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o;
+ return Objects.equals(table, arrowTableLayoutHandle.table) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain);
+ }
+
+ @Override
+ public int hashCode()
+ {
+ return Objects.hash(table, columnHandles, tupleDomain);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java
new file mode 100644
index 0000000000000..07eb7385cfbcf
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.connector.ConnectorTransactionHandle;
+
+public enum ArrowTransactionHandle
+ implements ConnectorTransactionHandle
+{
+ INSTANCE
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java
new file mode 100644
index 0000000000000..573265ef637d4
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.presto.spi.ConnectorSession;
+import com.facebook.presto.spi.SchemaTableName;
+import org.apache.arrow.flight.CallOption;
+import org.apache.arrow.flight.FlightClient;
+import org.apache.arrow.flight.FlightDescriptor;
+import org.apache.arrow.flight.FlightEndpoint;
+import org.apache.arrow.flight.FlightInfo;
+import org.apache.arrow.flight.FlightRuntimeException;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.vector.types.pojo.Schema;
+
+import java.io.IOException;
+import java.io.InputStream;
+import java.net.URISyntaxException;
+import java.nio.ByteBuffer;
+import java.nio.file.Paths;
+import java.util.List;
+import java.util.Optional;
+
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR;
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR;
+import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR;
+import static java.nio.file.Files.newInputStream;
+import static java.util.Objects.requireNonNull;
+
+public abstract class BaseArrowFlightClientHandler
+{
+ private final ArrowFlightConfig config;
+ private final BufferAllocator allocator;
+
+ public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig config)
+ {
+ this.allocator = requireNonNull(allocator, "allocator is null");
+ this.config = requireNonNull(config, "config is null");
+ }
+
+ protected FlightClient createFlightClient()
+ {
+ Location location;
+ if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) {
+ location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort());
+ }
+ else {
+ location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort());
+ }
+ return createFlightClient(location);
+ }
+
+ protected FlightClient createFlightClient(Location location)
+ {
+ try {
+ Optional trustedCertificate = Optional.empty();
+ FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location);
+ if (config.getVerifyServer() != null && !config.getVerifyServer()) {
+ flightClientBuilder.verifyServer(false);
+ }
+ else if (config.getFlightServerSSLCertificate() != null) {
+ trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate())));
+ flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls();
+ }
+
+ FlightClient flightClient = flightClientBuilder.build();
+ if (trustedCertificate.isPresent()) {
+ trustedCertificate.get().close();
+ }
+
+ return flightClient;
+ }
+ catch (Exception e) {
+ throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e);
+ }
+ }
+
+ public abstract CallOption[] getCallOptions(ConnectorSession connectorSession);
+
+ protected FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession)
+ {
+ try (FlightClient client = createFlightClient()) {
+ CallOption[] callOptions = getCallOptions(connectorSession);
+ return client.getInfo(flightDescriptor, callOptions);
+ }
+ catch (InterruptedException e) {
+ throw new ArrowException(ARROW_FLIGHT_INFO_ERROR, "Error getting flight information: " + e.getMessage(), e);
+ }
+ }
+
+ protected ClientClosingFlightStream getFlightStream(ArrowSplit split, ConnectorSession connectorSession)
+ {
+ ByteBuffer endpointBytes = ByteBuffer.wrap(split.getFlightEndpointBytes());
+ try {
+ FlightEndpoint endpoint = FlightEndpoint.deserialize(endpointBytes);
+ FlightClient client = endpoint.getLocations().stream()
+ .findAny()
+ .map(this::createFlightClient)
+ .orElseGet(this::createFlightClient);
+ return new ClientClosingFlightStream(
+ client.getStream(endpoint.getTicket(), getCallOptions(connectorSession)),
+ client);
+ }
+ catch (FlightRuntimeException | IOException | URISyntaxException e) {
+ throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e);
+ }
+ }
+
+ public Schema getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession)
+ {
+ try (FlightClient client = createFlightClient()) {
+ CallOption[] callOptions = this.getCallOptions(connectorSession);
+ return client.getSchema(flightDescriptor, callOptions).getSchema();
+ }
+ catch (InterruptedException e) {
+ throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Error getting schema for flight: " + e.getMessage(), e);
+ }
+ }
+
+ public abstract List listSchemaNames(ConnectorSession session);
+
+ public abstract List listTables(ConnectorSession session, Optional schemaName);
+
+ protected abstract FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName);
+
+ protected abstract FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle);
+
+ public Schema getSchemaForTable(String schemaName, String tableName, ConnectorSession connectorSession)
+ {
+ FlightDescriptor flightDescriptor = getFlightDescriptorForSchema(schemaName, tableName);
+ return getSchema(flightDescriptor, connectorSession);
+ }
+
+ public FlightInfo getFlightInfoForTableScan(ArrowTableLayoutHandle tableLayoutHandle, ConnectorSession session)
+ {
+ FlightDescriptor flightDescriptor = getFlightDescriptorForTableScan(tableLayoutHandle);
+ return getFlightInfo(flightDescriptor, session);
+ }
+}
diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java
new file mode 100644
index 0000000000000..30ca33d2276a9
--- /dev/null
+++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java
@@ -0,0 +1,60 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import org.apache.arrow.flight.FlightStream;
+import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+
+import static java.util.Objects.requireNonNull;
+
+public class ClientClosingFlightStream
+ implements AutoCloseable
+{
+ private final FlightStream flightStream;
+ private final AutoCloseable flightClient;
+
+ public ClientClosingFlightStream(FlightStream flightStream, AutoCloseable flightClient)
+ {
+ this.flightStream = requireNonNull(flightStream, "flightStream is null");
+ this.flightClient = requireNonNull(flightClient, "flightClient is null");
+ }
+
+ public VectorSchemaRoot getRoot()
+ {
+ return flightStream.getRoot();
+ }
+
+ public DictionaryProvider getDictionaryProvider()
+ {
+ return flightStream.getDictionaryProvider();
+ }
+
+ public boolean next()
+ {
+ return flightStream.next();
+ }
+
+ @Override
+ public void close()
+ throws Exception
+ {
+ try {
+ flightStream.close();
+ }
+ finally {
+ flightClient.close();
+ }
+ }
+}
diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java
new file mode 100644
index 0000000000000..f105359f0ae73
--- /dev/null
+++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java
@@ -0,0 +1,107 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.airlift.log.Logging;
+import com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin;
+import com.facebook.plugin.arrow.testingServer.TestingArrowProducer;
+import com.facebook.presto.Session;
+import com.facebook.presto.tests.DistributedQueryRunner;
+import com.google.common.collect.ImmutableMap;
+import org.apache.arrow.flight.FlightServer;
+import org.apache.arrow.flight.Location;
+import org.apache.arrow.memory.RootAllocator;
+
+import java.io.File;
+import java.util.Map;
+
+import static com.facebook.presto.testing.TestingSession.testSessionBuilder;
+
+public class ArrowFlightQueryRunner
+{
+ private ArrowFlightQueryRunner()
+ {
+ throw new UnsupportedOperationException("This is a utility class and cannot be instantiated");
+ }
+
+ public static DistributedQueryRunner createQueryRunner(int flightServerPort) throws Exception
+ {
+ return createQueryRunner(ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort)));
+ }
+
+ private static DistributedQueryRunner createQueryRunner(Map catalogProperties) throws Exception
+ {
+ return createQueryRunner(ImmutableMap.of(), catalogProperties);
+ }
+
+ private static DistributedQueryRunner createQueryRunner(
+ Map extraProperties,
+ Map catalogProperties)
+ throws Exception
+ {
+ Session session = testSessionBuilder()
+ .setCatalog("arrow")
+ .setSchema("tpch")
+ .build();
+
+ DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(extraProperties).build();
+
+ try {
+ queryRunner.installPlugin(new TestingArrowFlightPlugin());
+
+ ImmutableMap.Builder properties = ImmutableMap.builder()
+ .putAll(catalogProperties)
+ .put("arrow-flight.server", "localhost")
+ .put("arrow-flight.server-ssl-enabled", "true")
+ .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt")
+ .put("arrow-flight.server.verify", "true");
+
+ queryRunner.createCatalog("arrow", "arrow", properties.build());
+
+ return queryRunner;
+ }
+ catch (Exception e) {
+ throw new RuntimeException("Failed to create ArrowQueryRunner", e);
+ }
+ }
+
+ public static void main(String[] args)
+ throws Exception
+ {
+ Logging.initialize();
+
+ RootAllocator allocator = new RootAllocator(Long.MAX_VALUE);
+ Location serverLocation = Location.forGrpcTls("localhost", 9443);
+ File certChainFile = new File("src/test/resources/server.crt");
+ File privateKeyFile = new File("src/test/resources/server.key");
+ FlightServer server = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator))
+ .useTls(certChainFile, privateKeyFile)
+ .build();
+
+ server.start();
+
+ Logger log = Logger.get(ArrowFlightQueryRunner.class);
+
+ log.info("Server listening on port " + server.getPort());
+
+ DistributedQueryRunner queryRunner = createQueryRunner(
+ ImmutableMap.of("http-server.http.port", "8080"),
+ ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443)));
+ Thread.sleep(10);
+ log.info("======== SERVER STARTED ========");
+ log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl());
+ }
+}
diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java
new file mode 100644
index 0000000000000..c4ab656d41bb3
--- /dev/null
+++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java
@@ -0,0 +1,75 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.json.JsonCodec;
+import com.facebook.airlift.json.JsonCodecFactory;
+import com.facebook.airlift.json.JsonObjectMapperProvider;
+import com.facebook.presto.common.type.StandardTypes;
+import com.facebook.presto.common.type.Type;
+import com.fasterxml.jackson.databind.DeserializationContext;
+import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer;
+import com.google.common.collect.ImmutableMap;
+
+import java.util.Map;
+
+import static com.facebook.presto.common.type.BigintType.BIGINT;
+import static com.facebook.presto.common.type.VarcharType.VARCHAR;
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Locale.ENGLISH;
+import static org.testng.Assert.assertEquals;
+
+final class ArrowMetadataUtil
+{
+ private ArrowMetadataUtil() {}
+
+ public static final JsonCodec COLUMN_CODEC;
+ public static final JsonCodec TABLE_CODEC;
+
+ static {
+ JsonObjectMapperProvider provider = new JsonObjectMapperProvider();
+ provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer()));
+ JsonCodecFactory codecFactory = new JsonCodecFactory(provider);
+ COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class);
+ TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class);
+ }
+
+ public static final class TestingTypeDeserializer
+ extends FromStringDeserializer
+ {
+ private final Map types = ImmutableMap.of(
+ StandardTypes.BIGINT, BIGINT,
+ StandardTypes.VARCHAR, VARCHAR);
+
+ public TestingTypeDeserializer()
+ {
+ super(Type.class);
+ }
+
+ @Override
+ protected Type _deserialize(String value, DeserializationContext context)
+ {
+ Type type = types.get(value.toLowerCase(ENGLISH));
+ checkArgument(type != null, "Unknown type %s", value);
+ return type;
+ }
+ }
+
+ public static void assertJsonRoundTrip(JsonCodec codec, T object)
+ {
+ String json = codec.toJson(object);
+ T copy = codec.fromJson(json);
+ assertEquals(copy, object);
+ }
+}
diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java
new file mode 100644
index 0000000000000..f909e93c43774
--- /dev/null
+++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java
@@ -0,0 +1,826 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package com.facebook.plugin.arrow;
+
+import com.facebook.airlift.log.Logger;
+import com.facebook.presto.common.block.ArrayBlock;
+import com.facebook.presto.common.block.Block;
+import com.facebook.presto.common.block.BlockBuilder;
+import com.facebook.presto.common.block.DictionaryBlock;
+import com.facebook.presto.common.type.ArrayType;
+import com.facebook.presto.common.type.BigintType;
+import com.facebook.presto.common.type.BooleanType;
+import com.facebook.presto.common.type.DateType;
+import com.facebook.presto.common.type.DecimalType;
+import com.facebook.presto.common.type.Decimals;
+import com.facebook.presto.common.type.DoubleType;
+import com.facebook.presto.common.type.IntegerType;
+import com.facebook.presto.common.type.RowType;
+import com.facebook.presto.common.type.SmallintType;
+import com.facebook.presto.common.type.TimestampType;
+import com.facebook.presto.common.type.TinyintType;
+import com.facebook.presto.common.type.Type;
+import com.facebook.presto.common.type.VarcharType;
+import io.airlift.slice.Slice;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.apache.arrow.vector.BigIntVector;
+import org.apache.arrow.vector.BitVector;
+import org.apache.arrow.vector.DateDayVector;
+import org.apache.arrow.vector.DecimalVector;
+import org.apache.arrow.vector.FieldVector;
+import org.apache.arrow.vector.Float8Vector;
+import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.SmallIntVector;
+import org.apache.arrow.vector.TimeStampMicroVector;
+import org.apache.arrow.vector.TimeStampMilliVector;
+import org.apache.arrow.vector.TinyIntVector;
+import org.apache.arrow.vector.ValueVector;
+import org.apache.arrow.vector.VarCharVector;
+import org.apache.arrow.vector.complex.ListVector;
+import org.apache.arrow.vector.complex.impl.UnionListWriter;
+import org.apache.arrow.vector.dictionary.Dictionary;
+import org.apache.arrow.vector.dictionary.DictionaryEncoder;
+import org.apache.arrow.vector.dictionary.DictionaryProvider;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.DictionaryEncoding;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.FieldType;
+import org.apache.arrow.vector.util.Text;
+import org.testng.annotations.AfterClass;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Ignore;
+import org.testng.annotations.Test;
+
+import java.math.BigDecimal;
+import java.math.BigInteger;
+import java.nio.charset.StandardCharsets;
+import java.time.LocalDate;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Optional;
+
+import static com.facebook.presto.testing.TestingEnvironment.FUNCTION_AND_TYPE_MANAGER;
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+public class TestArrowBlockBuilder
+{
+ private static final Logger logger = Logger.get(TestArrowBlockBuilder.class);
+ private static final int DICTIONARY_LENGTH = 10;
+ private static final int VECTOR_LENGTH = 50;
+ private BufferAllocator allocator;
+ private ArrowBlockBuilder arrowBlockBuilder;
+
+ @BeforeClass
+ public void setUp()
+ {
+ // Initialize the Arrow allocator
+ allocator = new RootAllocator(Integer.MAX_VALUE);
+ logger.debug("Allocator initialized: %s", allocator.getName());
+ arrowBlockBuilder = new ArrowBlockBuilder(FUNCTION_AND_TYPE_MANAGER);
+ }
+
+ @AfterClass
+ public void tearDown()
+ {
+ allocator.close();
+ }
+
+ @Test
+ public void testBuildBlockFromBitVector()
+ {
+ // Create a BitVector and populate it with values
+ try (BitVector bitVector = new BitVector("bitVector", allocator)) {
+ bitVector.allocateNew(3); // Allocating space for 3 elements
+
+ bitVector.set(0, 1); // Set value to 1 (true)
+ bitVector.set(1, 0); // Set value to 0 (false)
+ bitVector.setNull(2); // Set null value
+
+ bitVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(bitVector, BooleanType.BOOLEAN, null);
+
+ // Now verify the result block
+ assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions
+ assertTrue(resultBlock.isNull(2)); // The 3rd element should be null
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromTinyIntVector()
+ {
+ // Create a TinyIntVector and populate it with values
+ try (TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator)) {
+ tinyIntVector.allocateNew(3); // Allocating space for 3 elements
+ tinyIntVector.set(0, 10);
+ tinyIntVector.set(1, 20);
+ tinyIntVector.setNull(2); // Set null value
+
+ tinyIntVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(tinyIntVector, TinyintType.TINYINT, null);
+
+ // Now verify the result block
+ assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions
+ assertTrue(resultBlock.isNull(2)); // The 3rd element should be null
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromSmallIntVector()
+ {
+ // Create a SmallIntVector and populate it with values
+ try (SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator)) {
+ smallIntVector.allocateNew(3); // Allocating space for 3 elements
+ smallIntVector.set(0, 10);
+ smallIntVector.set(1, 20);
+ smallIntVector.setNull(2); // Set null value
+
+ smallIntVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(smallIntVector, SmallintType.SMALLINT, null);
+
+ // Now verify the result block
+ assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions
+ assertTrue(resultBlock.isNull(2)); // The 3rd element should be null
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromIntVector()
+ {
+ // Create an IntVector and populate it with values
+ try (IntVector intVector = new IntVector("intVector", allocator)) {
+ intVector.allocateNew(3); // Allocating space for 3 elements
+ intVector.set(0, 10);
+ intVector.set(1, 20);
+ intVector.set(2, 30);
+
+ intVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(intVector, IntegerType.INTEGER, null);
+
+ // Now verify the result block
+ assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions
+ assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10
+ assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20
+ assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromBigIntVector()
+ throws InstantiationException, IllegalAccessException
+ {
+ // Create a BigIntVector and populate it with values
+ try (BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator)) {
+ bigIntVector.allocateNew(3); // Allocating space for 3 elements
+
+ bigIntVector.set(0, 10L);
+ bigIntVector.set(1, 20L);
+ bigIntVector.set(2, 30L);
+
+ bigIntVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(bigIntVector, BigintType.BIGINT, null);
+
+ // Now verify the result block
+ assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L
+ assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L
+ assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromDecimalVector()
+ {
+ // Create a DecimalVector and populate it with values
+ try (DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2)) { // Precision = 10, Scale = 2
+ decimalVector.allocateNew(2); // Allocating space for 2 elements
+ decimalVector.set(0, new BigDecimal("123.45"));
+
+ decimalVector.setValueCount(2);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(decimalVector, DecimalType.createDecimalType(10, 2), null);
+
+ // Now verify the result block
+ assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions
+ assertTrue(resultBlock.isNull(1)); // The 2nd element should be null
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromTimeStampMicroVector()
+ {
+ // Create a TimeStampMicroVector and populate it with values
+ try (TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator)) {
+ timestampMicroVector.allocateNew(3); // Allocating space for 3 elements
+ timestampMicroVector.set(0, 1000000L); // 1 second in microseconds
+ timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds
+ timestampMicroVector.setNull(2); // Set null value
+
+ timestampMicroVector.setValueCount(3);
+
+ // Build the block from the vector
+ Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(timestampMicroVector, TimestampType.TIMESTAMP, null);
+
+ // Now verify the result block
+ assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions
+ assertTrue(resultBlock.isNull(2)); // The 3rd element should be null
+ assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second)
+ assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds)
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromListVector()
+ {
+ // Create a root allocator for Arrow vectors
+ try (BufferAllocator allocator = new RootAllocator();
+ ListVector listVector = ListVector.empty("listVector", allocator)) {
+ // Allocate the vector and get the writer
+ listVector.allocateNew();
+ UnionListWriter listWriter = listVector.getWriter();
+
+ int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000};
+ int tmpIndex = 0;
+
+ for (int i = 0; i < 4; i++) { // 4 lists to be added
+ listWriter.startList();
+ for (int j = 0; j < 3; j++) { // Each list has 3 integers
+ listWriter.writeInt(data[tmpIndex]);
+ tmpIndex++;
+ }
+ listWriter.endList();
+ }
+
+ // Set the number of lists
+ listVector.setValueCount(4);
+
+ // Create Presto ArrayType for Integer
+ ArrayType arrayType = new ArrayType(IntegerType.INTEGER);
+
+ // Call the method to test
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(listVector, arrayType, null);
+ assertTrue(block instanceof ArrayBlock);
+
+ // Validate the result
+ assertEquals(block.getPositionCount(), 4); // 4 lists in the block
+ for (int j = 0; j < block.getPositionCount(); j++) {
+ Block subBlock = block.getBlock(j);
+ assertEquals(subBlock.getPositionCount(), 3); // each list should have 3 elements
+ }
+ }
+ }
+
+ @Test
+ public void testProcessDictionaryVector()
+ {
+ try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
+ VarCharVector rawVector = new VarCharVector("raw", allocator)) {
+ dictionaryVector.allocateNew(DICTIONARY_LENGTH);
+ for (int i = 0; i < DICTIONARY_LENGTH; i++) {
+ dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8));
+ }
+ dictionaryVector.setValueCount(DICTIONARY_LENGTH);
+
+ rawVector.allocateNew(VECTOR_LENGTH);
+ for (int i = 0; i < VECTOR_LENGTH; i++) {
+ int value = i % DICTIONARY_LENGTH;
+ rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8));
+ }
+ rawVector.setValueCount(VECTOR_LENGTH);
+
+ // Encode the vector with a dictionary
+ Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(16, true)));
+ try (ValueVector encodedVector = DictionaryEncoder.encode(rawVector, dictionary);
+ DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary)) {
+ // Process the dictionary vector
+ assertTrue(encodedVector instanceof FieldVector);
+ Block result = arrowBlockBuilder.buildBlockFromFieldVector((FieldVector) encodedVector, VarcharType.VARCHAR, dictionaryProvider);
+
+ // Verify the result
+ assertNotNull(result, "The BlockBuilder should not be null.");
+ assertEquals(result.getPositionCount(), 50);
+ }
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromDictionaryVector()
+ {
+ // Initialize a dictionary vector
+ // Example: dictionary contains 3 string values
+ VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
+ dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary
+
+ // Fill dictionaryVector with some values
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(32, true)));
+
+ FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding());
+ Field indexField = new Field("indices", indexFieldType, null);
+ try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary);
+ IntVector indicesVector = (IntVector) indexField.createVector(allocator)) {
+ indicesVector.allocateNew(4); // allocating space for values
+
+ // Set up index values (this would reference the dictionary)
+ indicesVector.set(0, 0); // First index points to "apple"
+ indicesVector.set(1, 1); // Second index points to "banana"
+ indicesVector.set(2, 2);
+ indicesVector.set(3, 2); // Third index points to "cherry"
+ indicesVector.setValueCount(4);
+
+ // Call the method under test
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider);
+
+ // Assertions to check the dictionary block's behavior
+ assertNotNull(block);
+ assertTrue(block instanceof DictionaryBlock);
+ DictionaryBlock dictionaryBlock = (DictionaryBlock) block;
+
+ // Verify the dictionary block contains the right dictionary
+ for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) {
+ // Get the slice (string value) at the given position
+ Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i));
+
+ // Assert based on the expected values
+ if (i == 0) {
+ assertEquals(slice.toStringUtf8(), "apple");
+ }
+ else if (i == 1) {
+ assertEquals(slice.toStringUtf8(), "banana");
+ }
+ else if (i == 2) {
+ assertEquals(slice.toStringUtf8(), "cherry");
+ }
+ else if (i == 3) {
+ assertEquals(slice.toStringUtf8(), "cherry");
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromDictionaryVectorSmallInt()
+ {
+ // Initialize a dictionary vector
+ // Example: dictionary contains 3 string values
+ VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
+ dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary
+
+ // Fill dictionaryVector with some values
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(16, true)));
+
+ FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding());
+ Field indexField = new Field("indices", indexFieldType, null);
+ try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary);
+ SmallIntVector indicesVector = (SmallIntVector) indexField.createVector(allocator)) {
+ indicesVector.allocateNew(3); // allocating space for 3 values
+ indicesVector.set(0, (short) 0);
+ indicesVector.set(1, (short) 1);
+ indicesVector.set(2, (short) 2);
+ indicesVector.setValueCount(3);
+
+ // Call the method under test
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider);
+
+ // Assertions to check the dictionary block's behavior
+ assertNotNull(block);
+ assertTrue(block instanceof DictionaryBlock);
+ DictionaryBlock dictionaryBlock = (DictionaryBlock) block;
+
+ // Verify the dictionary block contains the right dictionary
+ for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) {
+ // Get the slice (string value) at the given position
+ Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i));
+
+ // Assert based on the expected values
+ if (i == 0) {
+ assertEquals(slice.toStringUtf8(), "apple");
+ }
+ else if (i == 1) {
+ assertEquals(slice.toStringUtf8(), "banana");
+ }
+ else if (i == 2) {
+ assertEquals(slice.toStringUtf8(), "cherry");
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testBuildBlockFromDictionaryVectorTinyInt()
+ {
+ // Initialize a dictionary vector
+ // Example: dictionary contains 3 string values
+ VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator);
+ dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary
+
+ // Fill dictionaryVector with some values
+ dictionaryVector.set(0, "apple".getBytes());
+ dictionaryVector.set(1, "banana".getBytes());
+ dictionaryVector.set(2, "cherry".getBytes());
+ dictionaryVector.setValueCount(3);
+
+ Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(8, true)));
+
+ FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding());
+ Field indexField = new Field("indices", indexFieldType, null);
+ try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary);
+ TinyIntVector indicesVector = (TinyIntVector) indexField.createVector(allocator)) {
+ indicesVector.allocateNew(3); // allocating space for 3 values
+ indicesVector.set(0, (byte) 0);
+ indicesVector.set(1, (byte) 1);
+ indicesVector.set(2, (byte) 2);
+ indicesVector.setValueCount(3);
+
+ // Call the method under test
+ Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider);
+
+ // Assertions to check the dictionary block's behavior
+ assertNotNull(block);
+ assertTrue(block instanceof DictionaryBlock);
+ DictionaryBlock dictionaryBlock = (DictionaryBlock) block;
+
+ // Verify the dictionary block contains the right dictionary
+ for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) {
+ // Get the slice (string value) at the given position
+ Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i));
+
+ // Assert based on the expected values
+ if (i == 0) {
+ assertEquals(slice.toStringUtf8(), "apple");
+ }
+ else if (i == 1) {
+ assertEquals(slice.toStringUtf8(), "banana");
+ }
+ else if (i == 2) {
+ assertEquals(slice.toStringUtf8(), "cherry");
+ }
+ }
+ }
+ }
+
+ @Test
+ public void testAssignVarcharType()
+ {
+ try (VarCharVector vector = new VarCharVector("varCharVector", allocator)) {
+ vector.allocateNew(3);
+
+ String value = "test_string";
+ vector.set(0, new Text(value));
+ vector.setValueCount(1);
+
+ Type varcharType = VarcharType.createUnboundedVarcharType();
+ BlockBuilder builder = varcharType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromVarCharVector(vector, varcharType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ Slice result = varcharType.getSlice(block, 0);
+ assertEquals(result.toStringUtf8(), value);
+ }
+ }
+
+ @Test
+ public void testAssignSmallintType()
+ {
+ try (SmallIntVector vector = new SmallIntVector("smallIntVector", allocator)) {
+ vector.allocateNew(3);
+
+ short value = 42;
+ vector.set(0, value);
+ vector.setValueCount(1);
+
+ Type smallintType = SmallintType.SMALLINT;
+ BlockBuilder builder = smallintType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromSmallIntVector(vector, smallintType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ long result = smallintType.getLong(block, 0);
+ assertEquals(result, value);
+ }
+ }
+
+ @Test
+ public void testAssignTinyintType()
+ {
+ try (TinyIntVector vector = new TinyIntVector("tinyIntVector", allocator)) {
+ vector.allocateNew(3);
+
+ byte value = 7;
+ vector.set(0, value);
+ vector.setValueCount(1);
+
+ Type tinyintType = TinyintType.TINYINT;
+ BlockBuilder builder = tinyintType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromTinyIntVector(vector, tinyintType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ long result = tinyintType.getLong(block, 0);
+ assertEquals(result, value);
+ }
+ }
+
+ @Test
+ public void testAssignBigintType()
+ {
+ try (BigIntVector vector = new BigIntVector("bigIntVector", allocator)) {
+ vector.allocateNew(3);
+
+ long value = 123456789L;
+ vector.set(0, value);
+ vector.setValueCount(1);
+
+ Type bigintType = BigintType.BIGINT;
+ BlockBuilder builder = bigintType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromBigIntVector(vector, bigintType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ long result = bigintType.getLong(block, 0);
+ assertEquals(result, value);
+ }
+ }
+
+ @Test
+ public void testAssignIntegerType()
+ {
+ try (IntVector vector = new IntVector("IntVector", allocator)) {
+ vector.allocateNew(3);
+
+ int value = 42;
+ vector.set(0, value);
+ vector.setValueCount(1);
+
+ Type integerType = IntegerType.INTEGER;
+ BlockBuilder builder = integerType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromIntVector(vector, integerType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ long result = integerType.getLong(block, 0);
+ assertEquals(result, value);
+ }
+ }
+
+ @Test
+ public void testAssignDoubleType()
+ {
+ try (Float8Vector vector = new Float8Vector("Float8Vector", allocator)) {
+ vector.allocateNew(3);
+
+ double value = 42.42;
+ vector.set(0, value);
+ vector.setValueCount(1);
+
+ Type doubleType = DoubleType.DOUBLE;
+ BlockBuilder builder = doubleType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromFloat8Vector(vector, doubleType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ double result = doubleType.getDouble(block, 0);
+ assertEquals(result, value, 0.001);
+ }
+ }
+
+ @Test
+ public void testAssignBooleanType()
+ {
+ try (BitVector vector = new BitVector("BitVector", allocator)) {
+ vector.allocateNew(3);
+
+ boolean value = true;
+ vector.set(0, 1);
+ vector.setValueCount(1);
+
+ Type booleanType = BooleanType.BOOLEAN;
+ BlockBuilder builder = booleanType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromBitVector(vector, booleanType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ boolean result = booleanType.getBoolean(block, 0);
+ assertEquals(result, value);
+ }
+ }
+
+ @Test
+ public void testAssignArrayType()
+ {
+ try (ListVector vector = ListVector.empty("ListVector", allocator)) {
+ UnionListWriter writer = vector.getWriter();
+ writer.allocate();
+
+ writer.setPosition(0); // optional
+ writer.startList();
+ writer.integer().writeInt(1);
+ writer.integer().writeInt(2);
+ writer.integer().writeInt(3);
+ writer.endList();
+
+ writer.setValueCount(1);
+
+ Type elementType = IntegerType.INTEGER;
+ ArrayType arrayType = new ArrayType(elementType);
+ BlockBuilder builder = arrayType.createBlockBuilder(null, 1);
+
+ arrowBlockBuilder.assignBlockFromListVector(vector, arrayType, builder, 0, vector.getValueCount());
+
+ Block block = builder.build();
+ List values = Arrays.asList(1, 2, 3);
+ Block arrayBlock = arrayType.getObject(block, 0);
+ assertEquals(arrayBlock.getPositionCount(), values.size());
+ for (int i = 0; i < values.size(); i++) {
+ assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue());
+ }
+ }
+ }
+
+ @Ignore("RowType not implemented")
+ @Test
+ public void testAssignRowType()
+ {
+ RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER);
+ RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType());
+ RowType rowType = RowType.from(Arrays.asList(field1, field2));
+ BlockBuilder builder = rowType.createBlockBuilder(null, 1);
+
+ List