diff --git a/.github/workflows/arrow-flight-tests.yml b/.github/workflows/arrow-flight-tests.yml new file mode 100644 index 0000000000000..ee77c122536e1 --- /dev/null +++ b/.github/workflows/arrow-flight-tests.yml @@ -0,0 +1,82 @@ +name: arrow flight tests + +on: + pull_request: + +env: + CONTINUOUS_INTEGRATION: true + MAVEN_OPTS: "-Xmx1024M -XX:+ExitOnOutOfMemoryError" + MAVEN_INSTALL_OPTS: "-Xmx2G -XX:+ExitOnOutOfMemoryError" + MAVEN_FAST_INSTALL: "-B -V --quiet -T 1C -DskipTests -Dair.check.skip-all --no-transfer-progress -Dmaven.javadoc.skip=true" + MAVEN_TEST: "-B -Dair.check.skip-all -Dmaven.javadoc.skip=true -DLogTestDurationListener.enabled=true --no-transfer-progress --fail-at-end" + RETRY: .github/bin/retry + +jobs: + changes: + runs-on: ubuntu-latest + permissions: + pull-requests: read + outputs: + codechange: ${{ steps.filter.outputs.codechange }} + steps: + - uses: dorny/paths-filter@v2 + id: filter + with: + filters: | + codechange: + - '!presto-docs/**' + test: + runs-on: ubuntu-latest + needs: changes + strategy: + fail-fast: false + matrix: + modules: + - ":presto-base-arrow-flight" # Only run tests for the `presto-base-arrow-flight` module + + timeout-minutes: 80 + concurrency: + group: ${{ github.workflow }}-test-${{ matrix.modules }}-${{ github.event.pull_request.number }} + cancel-in-progress: true + + steps: + # Checkout the code only if there are changes in the relevant files + - uses: actions/checkout@v4 + if: needs.changes.outputs.codechange == 'true' + with: + show-progress: false + + # Set up Java for the build environment + - uses: actions/setup-java@v2 + if: needs.changes.outputs.codechange == 'true' + with: + distribution: 'temurin' + java-version: 8 + + # Cache Maven dependencies to speed up the build + - name: Cache local Maven repository + if: needs.changes.outputs.codechange == 'true' + id: cache-maven + uses: actions/cache@v2 + with: + path: ~/.m2/repository + key: ${{ runner.os }}-maven-2-${{ hashFiles('**/pom.xml') }} + restore-keys: | + ${{ runner.os }}-maven-2- + + # Resolve Maven dependencies (if cache is not found) + - name: Populate Maven cache + if: steps.cache-maven.outputs.cache-hit != 'true' && needs.changes.outputs.codechange == 'true' + run: ./mvnw de.qaware.maven:go-offline-maven-plugin:resolve-dependencies --no-transfer-progress && .github/bin/download_nodejs + + # Install dependencies for the target module + - name: Maven Install + if: needs.changes.outputs.codechange == 'true' + run: | + export MAVEN_OPTS="${MAVEN_INSTALL_OPTS}" + ./mvnw install ${MAVEN_FAST_INSTALL} -am -pl ${{ matrix.modules }} + + # Run Maven tests for the target module + - name: Maven Tests + if: needs.changes.outputs.codechange == 'true' + run: ./mvnw test ${MAVEN_TEST} -pl ${{ matrix.modules }} diff --git a/.github/workflows/test-other-modules.yml b/.github/workflows/test-other-modules.yml index 7b5ba858bdc4d..02cf887feecc6 100644 --- a/.github/workflows/test-other-modules.yml +++ b/.github/workflows/test-other-modules.yml @@ -87,4 +87,5 @@ jobs: !presto-test-coverage, !presto-iceberg, !presto-singlestore, - !presto-native-sidecar-plugin' + !presto-native-sidecar-plugin, + !presto-base-arrow-flight' diff --git a/pom.xml b/pom.xml index 4b2f740db855b..cf8e996d137f5 100644 --- a/pom.xml +++ b/pom.xml @@ -89,6 +89,7 @@ 2.11.0 3.14.0 5.1.0 + 17.0.0 + + com.fasterxml.jackson.datatype + jackson-datatype-jsr310 + + + + + + org.apache.arrow + flight-core + + + org.slf4j + slf4j-api + + + + + + com.facebook.airlift + bootstrap + + + + com.facebook.airlift + log + + + + com.google.guava + guava + + + + javax.inject + javax.inject + + + + com.facebook.presto + presto-spi + + + + io.airlift + slice + + + + com.facebook.airlift + log-manager + + + + com.fasterxml.jackson.core + jackson-annotations + + + + com.facebook.presto + presto-common + + + + com.fasterxml.jackson.core + jackson-databind + + + + com.google.code.findbugs + jsr305 + true + + + + com.google.inject + guice + + + + com.facebook.airlift + configuration + + + + joda-time + joda-time + + + + org.jdbi + jdbi3-core + + + + + org.testng + testng + test + + + + io.airlift.tpch + tpch + test + + + + com.facebook.presto + presto-tpch + test + + + + com.facebook.airlift + json + test + + + + com.facebook.presto + presto-testng-services + test + + + + com.facebook.airlift + testing + test + + + + com.facebook.presto + presto-main + test + + + + com.facebook.presto + presto-tests + test + + + + com.h2database + h2 + test + + + + org.apache.arrow + arrow-jdbc + ${dep.arrow.version} + test + + + org.slf4j + slf4j-api + + + + + + + + + org.apache.maven.plugins + maven-surefire-plugin + + -Xss10M + + + + org.apache.maven.plugins + maven-enforcer-plugin + + + org.apache.maven.plugins + maven-dependency-plugin + + + org.basepom.maven + duplicate-finder-maven-plugin + 1.2.1 + + + module-info + META-INF.versions.9.module-info + + + arrow-git.properties + about.html + + + + + + check + + + + + + + diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java new file mode 100644 index 0000000000000..3a5ce28cd76fd --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowBlockBuilder.java @@ -0,0 +1,722 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.VarbinaryType; +import com.facebook.presto.common.type.VarcharType; +import com.google.common.base.CharMatcher; +import io.airlift.slice.Slice; +import io.airlift.slice.Slices; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DateMilliVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.NullVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeMicroVector; +import org.apache.arrow.vector.TimeMilliVector; +import org.apache.arrow.vector.TimeSecVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliTZVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TimeStampSecVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; + +import javax.inject.Inject; + +import java.math.BigDecimal; +import java.nio.charset.StandardCharsets; +import java.time.Duration; +import java.time.LocalTime; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_TYPE_ERROR; +import static com.facebook.presto.common.Utils.checkArgument; +import static com.facebook.presto.common.type.TypeSignature.parseTypeSignature; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; + +public class ArrowBlockBuilder +{ + private final TypeManager typeManager; + + @Inject + public ArrowBlockBuilder(TypeManager typeManager) + { + this.typeManager = requireNonNull(typeManager, "typeManager is null"); + } + public Block buildBlockFromFieldVector(FieldVector vector, Type type, DictionaryProvider dictionaryProvider) + { + // Use Arrow dictionary to create a DictionaryBlock + if (dictionaryProvider != null && vector.getField().getDictionary() != null) { + Dictionary dictionary = dictionaryProvider.lookup(vector.getField().getDictionary().getId()); + if (dictionary != null) { + Type prestoType = getPrestoTypeFromArrowField(dictionary.getVector().getField()); + BlockBuilder dictionaryBuilder = prestoType.createBlockBuilder(null, vector.getValueCount()); + assignBlockFromValueVector(dictionary.getVector(), prestoType, dictionaryBuilder, 0, dictionary.getVector().getValueCount()); + return buildDictionaryBlock(vector, dictionaryBuilder.build()); + } + } + + BlockBuilder builder = type.createBlockBuilder(null, vector.getValueCount()); + assignBlockFromValueVector(vector, type, builder, 0, vector.getValueCount()); + return builder.build(); + } + + protected Type getPrestoTypeFromArrowField(Field field) + { + switch (field.getType().getTypeID()) { + case Int: + ArrowType.Int intType = (ArrowType.Int) field.getType(); + return getPrestoTypeForArrowIntType(intType); + case Binary: + case LargeBinary: + case FixedSizeBinary: + return VarbinaryType.VARBINARY; + case Date: + return DateType.DATE; + case Timestamp: + return TimestampType.TIMESTAMP; + case Utf8: + case LargeUtf8: + return VarcharType.VARCHAR; + case FloatingPoint: + ArrowType.FloatingPoint floatingPoint = (ArrowType.FloatingPoint) field.getType(); + return getPrestoTypeForArrowFloatingPointType(floatingPoint); + case Decimal: + ArrowType.Decimal decimalType = (ArrowType.Decimal) field.getType(); + return DecimalType.createDecimalType(decimalType.getPrecision(), decimalType.getScale()); + case Bool: + return BooleanType.BOOLEAN; + case Time: + return TimeType.TIME; + case List: { + List children = field.getChildren(); + checkArgument(children.size() == 1, "Arrow List expected to have 1 child Field, got: " + children.size()); + return new ArrayType(getPrestoTypeFromArrowField(field.getChildren().get(0))); + } + case Map: { + List children = field.getChildren(); + checkArgument(children.size() == 1, "Arrow Map expected to have 1 child Field for entries, got: " + children.size()); + Field entryField = children.get(0); + checkArgument(entryField.getChildren().size() == 2, "Arrow Map entries expected to have 2 child Fields, got: " + children.size()); + Type keyType = getPrestoTypeFromArrowField(entryField.getChildren().get(0)); + Type valueType = getPrestoTypeFromArrowField(entryField.getChildren().get(1)); + return typeManager.getType(parseTypeSignature(format("map(%s,%s)", keyType.getTypeSignature(), valueType.getTypeSignature()))); + } + case Struct: { + List children = field.getChildren().stream().map(child -> new RowType.Field(Optional.of(child.getName()), getPrestoTypeFromArrowField(child))).collect(toImmutableList()); + return RowType.from(children); + } + default: + throw new UnsupportedOperationException("The data type " + field.getType().getTypeID() + " is not supported."); + } + } + + private Type getPrestoTypeForArrowFloatingPointType(ArrowType.FloatingPoint floatingPoint) + { + switch (floatingPoint.getPrecision()) { + case SINGLE: + return RealType.REAL; + case DOUBLE: + return DoubleType.DOUBLE; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected floating point precision: " + floatingPoint.getPrecision()); + } + } + + private Type getPrestoTypeForArrowIntType(ArrowType.Int intType) + { + switch (intType.getBitWidth()) { + case 64: + return BigintType.BIGINT; + case 32: + return IntegerType.INTEGER; + case 16: + return SmallintType.SMALLINT; + case 8: + return TinyintType.TINYINT; + default: + throw new ArrowException(ARROW_FLIGHT_TYPE_ERROR, "Unexpected bit width: " + intType.getBitWidth()); + } + } + + private DictionaryBlock buildDictionaryBlock(FieldVector fieldVector, Block dictionaryblock) + { + if (fieldVector instanceof IntVector) { + // Get the Arrow indices vector + IntVector indicesVector = (IntVector) fieldVector; + int[] ids = new int[indicesVector.getValueCount()]; + for (int i = 0; i < indicesVector.getValueCount(); i++) { + ids[i] = indicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof SmallIntVector) { + // Get the SmallInt indices vector + SmallIntVector smallIntIndicesVector = (SmallIntVector) fieldVector; + int[] ids = new int[smallIntIndicesVector.getValueCount()]; + for (int i = 0; i < smallIntIndicesVector.getValueCount(); i++) { + ids[i] = smallIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else if (fieldVector instanceof TinyIntVector) { + // Get the TinyInt indices vector + TinyIntVector tinyIntIndicesVector = (TinyIntVector) fieldVector; + int[] ids = new int[tinyIntIndicesVector.getValueCount()]; + for (int i = 0; i < tinyIntIndicesVector.getValueCount(); i++) { + ids[i] = tinyIntIndicesVector.get(i); + } + return new DictionaryBlock(ids.length, dictionaryblock, ids); + } + else { + // Handle the case where the FieldVector is of an unsupported type + throw new IllegalArgumentException("Unsupported FieldVector type: " + fieldVector.getClass()); + } + } + + private void assignBlockFromValueVector(ValueVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (vector instanceof BitVector) { + assignBlockFromBitVector((BitVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TinyIntVector) { + assignBlockFromTinyIntVector((TinyIntVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof IntVector) { + assignBlockFromIntVector((IntVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof SmallIntVector) { + assignBlockFromSmallIntVector((SmallIntVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof BigIntVector) { + assignBlockFromBigIntVector((BigIntVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof DecimalVector) { + assignBlockFromDecimalVector((DecimalVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof NullVector) { + assignBlockFromNullVector((NullVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeStampMicroVector) { + assignBlockFromTimeStampMicroVector((TimeStampMicroVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeStampMilliVector) { + assignBlockFromTimeStampMilliVector((TimeStampMilliVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof Float4Vector) { + assignBlockFromFloat4Vector((Float4Vector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof Float8Vector) { + assignBlockFromFloat8Vector((Float8Vector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof VarCharVector) { + if (type instanceof CharType) { + assignCharTypeBlockFromVarcharVector((VarCharVector) vector, type, builder, startIndex, endIndex); + } + else if (type instanceof TimeType) { + assignTimeTypeBlockFromVarcharVector((VarCharVector) vector, type, builder, startIndex, endIndex); + } + else { + assignBlockFromVarCharVector((VarCharVector) vector, type, builder, startIndex, endIndex); + } + } + else if (vector instanceof VarBinaryVector) { + assignBlockFromVarBinaryVector((VarBinaryVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof DateDayVector) { + assignBlockFromDateDayVector((DateDayVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof DateMilliVector) { + assignBlockFromDateMilliVector((DateMilliVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeMilliVector) { + assignBlockFromTimeMilliVector((TimeMilliVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeSecVector) { + assignBlockFromTimeSecVector((TimeSecVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeStampSecVector) { + assignBlockFromTimeStampSecVector((TimeStampSecVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeMicroVector) { + assignBlockFromTimeMicroVector((TimeMicroVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof TimeStampMilliTZVector) { + assignBlockFromTimeMilliTZVector((TimeStampMilliTZVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof MapVector) { + // NOTE: MapVector is also instanceof ListVector, so check for Map first + assignBlockFromMapVector((MapVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof ListVector) { + assignBlockFromListVector((ListVector) vector, type, builder, startIndex, endIndex); + } + else if (vector instanceof StructVector) { + assignBlockFromStructVector((StructVector) vector, type, builder, startIndex, endIndex); + } + else { + throw new UnsupportedOperationException("Unsupported vector type: " + vector.getClass()); + } + } + + public void assignBlockFromBitVector(BitVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeBoolean(builder, vector.get(i) == 1); + } + } + } + + public void assignBlockFromIntVector(IntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + } + + public void assignBlockFromSmallIntVector(SmallIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + } + + public void assignBlockFromTinyIntVector(TinyIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + } + + public void assignBlockFromBigIntVector(BigIntVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + } + + public void assignBlockFromDecimalVector(DecimalVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof DecimalType)) { + throw new IllegalArgumentException("Type must be a DecimalType for DecimalVector"); + } + + DecimalType decimalType = (DecimalType) type; + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BigDecimal decimal = vector.getObject(i); // Get the BigDecimal value + if (decimalType.isShort()) { + builder.writeLong(decimal.unscaledValue().longValue()); + } + else { + Slice slice = Decimals.encodeScaledValue(decimal); + decimalType.writeSlice(builder, slice, 0, slice.length()); + } + } + } + } + + public void assignBlockFromNullVector(NullVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + builder.appendNull(); + } + } + + public void assignBlockFromTimeStampMicroVector(TimeStampMicroVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long micros = vector.get(i); + long millis = TimeUnit.MICROSECONDS.toMillis(micros); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromTimeStampMilliVector(TimeStampMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Expected TimestampType but got " + type.getClass().getName()); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromFloat8Vector(Float8Vector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeDouble(builder, vector.get(i)); + } + } + } + + public void assignBlockFromFloat4Vector(Float4Vector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int intBits = Float.floatToIntBits(vector.get(i)); + type.writeLong(builder, intBits); + } + } + } + + public void assignBlockFromVarBinaryVector(VarBinaryVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + byte[] value = vector.get(i); + type.writeSlice(builder, Slices.wrappedBuffer(value)); + } + } + } + + public void assignBlockFromVarCharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof VarcharType)) { + throw new IllegalArgumentException("Expected VarcharType but got " + type.getClass().getName()); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + // Directly create a Slice from the raw byte array + byte[] rawBytes = vector.get(i); + Slice slice = Slices.wrappedBuffer(rawBytes); + // Write the Slice directly to the builder + type.writeSlice(builder, slice); + } + } + } + + public void assignBlockFromDateDayVector(DateDayVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + type.writeLong(builder, vector.get(i)); + } + } + } + + public void assignBlockFromDateMilliVector(DateMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof DateType)) { + throw new IllegalArgumentException("Expected DateType but got " + type.getClass().getName()); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + DateType dateType = (DateType) type; + long days = TimeUnit.MILLISECONDS.toDays(vector.get(i)); + dateType.writeLong(builder, days); + } + } + } + + public void assignBlockFromTimeSecVector(TimeSecVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + int value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromTimeMilliVector(TimeMilliVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimeSecVector"); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromTimeMicroVector(TimeMicroVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimeType)) { + throw new IllegalArgumentException("Type must be a TimeType for TimemicroVector"); + } + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long micro = TimeUnit.MICROSECONDS.toMillis(value); + type.writeLong(builder, micro); + } + } + } + + public void assignBlockFromTimeStampSecVector(TimeStampSecVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampSecVector"); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long value = vector.get(i); + long millis = TimeUnit.SECONDS.toMillis(value); + type.writeLong(builder, millis); + } + } + } + + public void assignCharTypeBlockFromVarcharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String value = new String(vector.get(i), StandardCharsets.UTF_8); + type.writeSlice(builder, Slices.utf8Slice(CharMatcher.is(' ').trimTrailingFrom(value))); + } + } + } + + public void assignTimeTypeBlockFromVarcharVector(VarCharVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + String timeString = new String(vector.get(i), StandardCharsets.UTF_8); + LocalTime time = LocalTime.parse(timeString); + long millis = Duration.between(LocalTime.MIN, time).toMillis(); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromTimeMilliTZVector(TimeStampMilliTZVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof TimestampType)) { + throw new IllegalArgumentException("Type must be a TimestampType for TimeStampMilliTZVector"); + } + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + long millis = vector.get(i); + type.writeLong(builder, millis); + } + } + } + + public void assignBlockFromListVector(ListVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof ArrayType)) { + throw new IllegalArgumentException("Type must be an ArrayType for ListVector"); + } + + ArrayType arrayType = (ArrayType) type; + Type elementType = arrayType.getElementType(); + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BlockBuilder elementBuilder = builder.beginBlockEntry(); + assignBlockFromValueVector( + vector.getDataVector(), elementType, elementBuilder, vector.getElementStartIndex(i), vector.getElementEndIndex(i)); + builder.closeEntry(); + } + } + } + + public void assignBlockFromMapVector(MapVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof MapType)) { + throw new IllegalArgumentException("Type must be a MapType for MapVector"); + } + + MapType mapType = (MapType) type; + StructVector entryVector = (StructVector) vector.getDataVector(); + ValueVector keyVector = entryVector.getChildByOrdinal(0); + ValueVector valueVector = entryVector.getChildByOrdinal(1); + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BlockBuilder entryBuilder = builder.beginBlockEntry(); + int entryStart = vector.getElementStartIndex(i); + int entryEnd = vector.getElementEndIndex(i); + for (int entryIndex = entryStart; entryIndex < entryEnd; entryIndex++) { + assignBlockFromValueVector(keyVector, mapType.getKeyType(), entryBuilder, entryIndex, entryIndex + 1); + assignBlockFromValueVector(valueVector, mapType.getValueType(), entryBuilder, entryIndex, entryIndex + 1); + } + builder.closeEntry(); + } + } + } + + public void assignBlockFromStructVector(StructVector vector, Type type, BlockBuilder builder, int startIndex, int endIndex) + { + if (!(type instanceof RowType)) { + throw new IllegalArgumentException("Type must be a RowType for StructVector"); + } + + RowType rowType = (RowType) type; + + for (int i = startIndex; i < endIndex; i++) { + if (vector.isNull(i)) { + builder.appendNull(); + } + else { + BlockBuilder childBuilder = builder.beginBlockEntry(); + List childTypes = rowType.getTypeParameters(); + for (int childIndex = 0; childIndex < childTypes.size(); childIndex++) { + Type childType = childTypes.get(childIndex); + ValueVector childVector = vector.getChildByOrdinal(childIndex); + assignBlockFromValueVector(childVector, childType, childBuilder, i, i + 1); + } + builder.closeEntry(); + } + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java new file mode 100644 index 0000000000000..6a1430b007822 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowColumnHandle.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import static java.util.Objects.requireNonNull; + +public class ArrowColumnHandle + implements ColumnHandle +{ + private final String columnName; + private final Type columnType; + + @JsonCreator + public ArrowColumnHandle( + @JsonProperty("columnName") String columnName, + @JsonProperty("columnType") Type columnType) + { + this.columnName = requireNonNull(columnName, "columnName is null"); + this.columnType = requireNonNull(columnType, "columnType is null"); + } + + @JsonProperty + public String getColumnName() + { + return columnName; + } + + @JsonProperty + public Type getColumnType() + { + return columnType; + } + + public ColumnMetadata getColumnMetadata() + { + return new ColumnMetadata(columnName, columnType); + } + + @Override + public String toString() + { + return columnName + ":" + columnType; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java new file mode 100644 index 0000000000000..3bae2b16c484b --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnector.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.facebook.presto.spi.transaction.IsolationLevel; +import com.google.inject.Inject; +import org.apache.arrow.memory.BufferAllocator; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnector + implements Connector +{ + private final ConnectorMetadata metadata; + private final ConnectorSplitManager splitManager; + private final ConnectorPageSourceProvider pageSourceProvider; + private final BufferAllocator connectorAllocator; + + @Inject + public ArrowConnector( + ConnectorMetadata metadata, + ConnectorSplitManager splitManager, + ConnectorPageSourceProvider pageSourceProvider, + BufferAllocator connectorAllocator) + { + this.metadata = requireNonNull(metadata, "Metadata is null"); + this.splitManager = requireNonNull(splitManager, "splitManager is null"); + this.pageSourceProvider = requireNonNull(pageSourceProvider, "pageSourceProvider is null"); + this.connectorAllocator = requireNonNull(connectorAllocator, "connectorAllocator is null"); + } + + @Override + public ConnectorTransactionHandle beginTransaction(IsolationLevel isolationLevel, boolean readOnly) + { + return ArrowTransactionHandle.INSTANCE; + } + + @Override + public ConnectorMetadata getMetadata(ConnectorTransactionHandle transactionHandle) + { + return metadata; + } + + @Override + public ConnectorSplitManager getSplitManager() + { + return splitManager; + } + + @Override + public ConnectorPageSourceProvider getPageSourceProvider() + { + return pageSourceProvider; + } + + @Override + public void shutdown() + { + connectorAllocator.close(); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java new file mode 100644 index 0000000000000..ce95cccdd051e --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorFactory.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.bootstrap.Bootstrap; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.NodeManager; +import com.facebook.presto.spi.classloader.ThreadContextClassLoader; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorContext; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.facebook.presto.spi.function.FunctionMetadataManager; +import com.facebook.presto.spi.function.StandardFunctionResolution; +import com.facebook.presto.spi.relation.RowExpressionService; +import com.google.common.collect.ImmutableList; +import com.google.inject.ConfigurationException; +import com.google.inject.Injector; +import com.google.inject.Module; + +import java.util.List; +import java.util.Map; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_INTERNAL_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.base.Throwables.throwIfUnchecked; +import static com.google.inject.util.Modules.override; +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorFactory + implements ConnectorFactory +{ + private final String name; + private final Module module; + private final ImmutableList extraModules; + private final ClassLoader classLoader; + + public ArrowConnectorFactory(String name, Module module, List extraModules, ClassLoader classLoader) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null")); + this.classLoader = requireNonNull(classLoader, "classLoader is null"); + } + + @Override + public String getName() + { + return name; + } + + @Override + public ConnectorHandleResolver getHandleResolver() + { + return new ArrowHandleResolver(); + } + + @Override + public Connector create(String catalogName, Map requiredConfig, ConnectorContext context) + { + requireNonNull(requiredConfig, "requiredConfig is null"); + + try (ThreadContextClassLoader ignored = new ThreadContextClassLoader(classLoader)) { + Bootstrap app = new Bootstrap(ImmutableList.builder() + .add(binder -> { + binder.bind(TypeManager.class).toInstance(context.getTypeManager()); + binder.bind(FunctionMetadataManager.class).toInstance(context.getFunctionMetadataManager()); + binder.bind(StandardFunctionResolution.class).toInstance(context.getStandardFunctionResolution()); + binder.bind(RowExpressionService.class).toInstance(context.getRowExpressionService()); + binder.bind(NodeManager.class).toInstance(context.getNodeManager()); + }) + .add(override(new ArrowModule(catalogName)).with(module)) + .addAll(extraModules) + .build()); + + Injector injector = app + .doNotInitializeLogging() + .setRequiredConfigurationProperties(requiredConfig) + .initialize(); + + return injector.getInstance(ArrowConnector.class); + } + catch (ConfigurationException ex) { + throw new ArrowException(ARROW_INTERNAL_ERROR, "The connector instance could not be created.", ex); + } + catch (Exception e) { + throwIfUnchecked(e); + throw new RuntimeException(e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java new file mode 100644 index 0000000000000..dce08bac4ac24 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowConnectorId.java @@ -0,0 +1,53 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowConnectorId +{ + private final String id; + + public ArrowConnectorId(String id) + { + this.id = requireNonNull(id, "id is null"); + } + + @Override + public String toString() + { + return id; + } + + @Override + public int hashCode() + { + return Objects.hash(id); + } + + @Override + public boolean equals(Object obj) + { + if (this == obj) { + return true; + } + if ((obj == null) || (getClass() != obj.getClass())) { + return false; + } + ArrowConnectorId other = (ArrowConnectorId) obj; + return Objects.equals(this.id, other.id); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java new file mode 100644 index 0000000000000..6b217e4f56cdf --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowErrorCode.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.ErrorCode; +import com.facebook.presto.common.ErrorType; +import com.facebook.presto.spi.ErrorCodeSupplier; + +import static com.facebook.presto.common.ErrorType.EXTERNAL; +import static com.facebook.presto.common.ErrorType.INTERNAL_ERROR; + +public enum ArrowErrorCode + implements ErrorCodeSupplier +{ + ARROW_FLIGHT_INFO_ERROR(0, EXTERNAL), + ARROW_INTERNAL_ERROR(1, INTERNAL_ERROR), + ARROW_FLIGHT_CLIENT_ERROR(2, EXTERNAL), + ARROW_FLIGHT_METADATA_ERROR(3, EXTERNAL), + ARROW_FLIGHT_TYPE_ERROR(4, EXTERNAL); + + private final ErrorCode errorCode; + + ArrowErrorCode(int code, ErrorType type) + { + errorCode = new ErrorCode(code + 0x0510_0000, name(), type); + } + + @Override + public ErrorCode toErrorCode() + { + return errorCode; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java new file mode 100644 index 0000000000000..ba2c6edba589c --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowException.java @@ -0,0 +1,31 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ErrorCodeSupplier; +import com.facebook.presto.spi.PrestoException; + +public class ArrowException + extends PrestoException +{ + public ArrowException(ErrorCodeSupplier errorCode, String message) + { + super(errorCode, message); + } + + public ArrowException(ErrorCodeSupplier errorCode, String message, Throwable cause) + { + super(errorCode, message, cause); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java new file mode 100644 index 0000000000000..00698846a85a3 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowFlightConfig.java @@ -0,0 +1,85 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.configuration.Config; + +public class ArrowFlightConfig +{ + private String server; + private Boolean verifyServer; + private String flightServerSSLCertificate; + private Boolean arrowFlightServerSslEnabled; + private Integer arrowFlightPort; + + public String getFlightServerName() + { + return server; + } + + @Config("arrow-flight.server") + public ArrowFlightConfig setFlightServerName(String server) + { + this.server = server; + return this; + } + + public Boolean getVerifyServer() + { + return verifyServer; + } + + @Config("arrow-flight.server.verify") + public ArrowFlightConfig setVerifyServer(Boolean verifyServer) + { + this.verifyServer = verifyServer; + return this; + } + + public Integer getArrowFlightPort() + { + return arrowFlightPort; + } + + @Config("arrow-flight.server.port") + public ArrowFlightConfig setArrowFlightPort(Integer arrowFlightPort) + { + this.arrowFlightPort = arrowFlightPort; + return this; + } + + public String getFlightServerSSLCertificate() + { + return flightServerSSLCertificate; + } + + @Config("arrow-flight.server-ssl-certificate") + public ArrowFlightConfig setFlightServerSSLCertificate(String flightServerSSLCertificate) + { + this.flightServerSSLCertificate = flightServerSSLCertificate; + return this; + } + + public Boolean getArrowFlightServerSslEnabled() + { + return arrowFlightServerSslEnabled; + } + + @Config("arrow-flight.server-ssl-enabled") + public ArrowFlightConfig setArrowFlightServerSslEnabled(Boolean arrowFlightServerSslEnabled) + { + this.arrowFlightServerSslEnabled = arrowFlightServerSslEnabled; + return this; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java new file mode 100644 index 0000000000000..8b231b98a6ee6 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowHandleResolver.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public class ArrowHandleResolver + implements ConnectorHandleResolver +{ + @Override + public Class getTableHandleClass() + { + return ArrowTableHandle.class; + } + + @Override + public Class getTableLayoutHandleClass() + { + return ArrowTableLayoutHandle.class; + } + + @Override + public Class getColumnHandleClass() + { + return ArrowColumnHandle.class; + } + + @Override + public Class getSplitClass() + { + return ArrowSplit.class; + } + + @Override + public Class getTransactionHandleClass() + { + return ArrowTransactionHandle.class; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java new file mode 100644 index 0000000000000..977895966b291 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowMetadata.java @@ -0,0 +1,194 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorTableHandle; +import com.facebook.presto.spi.ConnectorTableLayout; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.ConnectorTableLayoutResult; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.Constraint; +import com.facebook.presto.spi.NotFoundException; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.spi.SchemaTablePrefix; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.Schema; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Objects.requireNonNull; + +public class ArrowMetadata + implements ConnectorMetadata +{ + private final BaseArrowFlightClientHandler clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; + + @Inject + public ArrowMetadata(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + } + + @Override + public final List listSchemaNames(ConnectorSession session) + { + return clientHandler.listSchemaNames(session); + } + + @Override + public final List listTables(ConnectorSession session, Optional schemaName) + { + return clientHandler.listTables(session, schemaName); + } + + @Override + public ConnectorTableHandle getTableHandle(ConnectorSession session, SchemaTableName tableName) + { + if (!listSchemaNames(session).contains(tableName.getSchemaName())) { + return null; + } + + if (!listTables(session, Optional.ofNullable(tableName.getSchemaName())).contains(tableName)) { + return null; + } + return new ArrowTableHandle(tableName.getSchemaName(), tableName.getTableName()); + } + + public List getColumnsList(String schema, String table, ConnectorSession connectorSession) + { + try { + Schema flightSchema = clientHandler.getSchemaForTable(schema, table, connectorSession); + return flightSchema.getFields(); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Table columns could not be listed for table: " + table, e); + } + } + + @Override + public Map getColumnHandles(ConnectorSession session, ConnectorTableHandle tableHandle) + { + Map columnHandles = new HashMap<>(); + + String schemaValue = ((ArrowTableHandle) tableHandle).getSchema(); + String tableValue = ((ArrowTableHandle) tableHandle).getTable(); + List columnList = getColumnsList(schemaValue, tableValue, session); + + for (Field field : columnList) { + String columnName = field.getName(); + Type type = getPrestoTypeFromArrowField(field); + columnHandles.put(columnName, new ArrowColumnHandle(columnName, type)); + } + return columnHandles; + } + + @Override + public List getTableLayouts(ConnectorSession session, ConnectorTableHandle table, Constraint constraint, Optional> desiredColumns) + { + checkArgument(table instanceof ArrowTableHandle, + "Invalid table handle: Expected an instance of ArrowTableHandle but received %s", + table.getClass().getSimpleName()); + checkArgument(desiredColumns.orElse(Collections.emptySet()).stream().allMatch(f -> f instanceof ArrowColumnHandle), + "Invalid column handles: Expected desired columns to be of type ArrowColumnHandle"); + + ArrowTableHandle tableHandle = (ArrowTableHandle) table; + + List columns = new ArrayList<>(); + if (desiredColumns.isPresent()) { + List arrowColumns = new ArrayList<>(desiredColumns.get()); + columns = (List) (List) arrowColumns; + } + + ConnectorTableLayout layout = new ConnectorTableLayout(new ArrowTableLayoutHandle(tableHandle, columns, constraint.getSummary())); + return ImmutableList.of(new ConnectorTableLayoutResult(layout, constraint.getSummary())); + } + + @Override + public ConnectorTableLayout getTableLayout(ConnectorSession session, ConnectorTableLayoutHandle handle) + { + return new ConnectorTableLayout(handle); + } + + @Override + public ConnectorTableMetadata getTableMetadata(ConnectorSession session, ConnectorTableHandle table) + { + List meta = new ArrayList<>(); + List columnList = getColumnsList(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable(), session); + + for (Field field : columnList) { + String columnName = field.getName(); + Type fieldType = getPrestoTypeFromArrowField(field); + meta.add(new ColumnMetadata(columnName, fieldType)); + } + return new ConnectorTableMetadata(new SchemaTableName(((ArrowTableHandle) table).getSchema(), ((ArrowTableHandle) table).getTable()), meta); + } + + @Override + public ColumnMetadata getColumnMetadata(ConnectorSession session, ConnectorTableHandle tableHandle, ColumnHandle columnHandle) + { + return ((ArrowColumnHandle) columnHandle).getColumnMetadata(); + } + + @Override + public Map> listTableColumns(ConnectorSession session, SchemaTablePrefix prefix) + { + requireNonNull(prefix, "prefix is null"); + ImmutableMap.Builder> columns = ImmutableMap.builder(); + List tables; + if (prefix.getSchemaName() != null && prefix.getTableName() != null) { + tables = ImmutableList.of(new SchemaTableName(prefix.getSchemaName(), prefix.getTableName())); + } + else { + tables = listTables(session, Optional.of(prefix.getSchemaName())); + } + + for (SchemaTableName tableName : tables) { + try { + ConnectorTableHandle tableHandle = getTableHandle(session, tableName); + columns.put(tableName, getTableMetadata(session, tableHandle).getColumns()); + } + catch (ClassCastException | NotFoundException e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Table columns could not be listed for table: " + tableName, e); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, e.getMessage(), e); + } + } + return columns.build(); + } + + private Type getPrestoTypeFromArrowField(Field field) + { + return arrowBlockBuilder.getPrestoTypeFromArrowField(field); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java new file mode 100644 index 0000000000000..ed85f961b291e --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowModule.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorHandleResolver; +import com.facebook.presto.spi.connector.Connector; +import com.facebook.presto.spi.connector.ConnectorMetadata; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; + +import static com.facebook.airlift.configuration.ConfigBinder.configBinder; +import static java.util.Objects.requireNonNull; + +public class ArrowModule + implements Module +{ + protected final String connectorId; + + public ArrowModule(String connectorId) + { + this.connectorId = requireNonNull(connectorId, "connector id is null"); + } + + public void configure(Binder binder) + { + configBinder(binder).bindConfig(ArrowFlightConfig.class); + binder.bind(BufferAllocator.class).to(RootAllocator.class).in(Scopes.SINGLETON); + binder.bind(ConnectorSplitManager.class).to(ArrowSplitManager.class).in(Scopes.SINGLETON); + binder.bind(ConnectorMetadata.class).to(ArrowMetadata.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowConnectorId.class).toInstance(new ArrowConnectorId(connectorId)); + binder.bind(ConnectorHandleResolver.class).to(ArrowHandleResolver.class).in(Scopes.SINGLETON); + binder.bind(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(ConnectorPageSourceProvider.class).to(ArrowPageSourceProvider.class).in(Scopes.SINGLETON); + binder.bind(Connector.class).to(ArrowConnector.class).in(Scopes.SINGLETON); + binder.bind(ArrowBlockBuilder.class).to(ArrowBlockBuilder.class).in(Scopes.SINGLETON); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java new file mode 100644 index 0000000000000..fcacb204da7d4 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSource.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.Page; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import org.apache.arrow.vector.FieldVector; + +import java.util.ArrayList; +import java.util.List; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR; +import static java.util.Objects.requireNonNull; + +public class ArrowPageSource + implements ConnectorPageSource +{ + private static final Logger logger = Logger.get(ArrowPageSource.class); + private final List columnHandles; + private final ArrowBlockBuilder arrowBlockBuilder; + private final ClientClosingFlightStream flightStreamAndClient; + private boolean completed; + private int currentPosition; + + public ArrowPageSource( + ArrowSplit split, + List columnHandles, + BaseArrowFlightClientHandler clientHandler, + ConnectorSession connectorSession, + ArrowBlockBuilder arrowBlockBuilder) + { + requireNonNull(split, "split is null"); + this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); + requireNonNull(clientHandler, "clientHandler is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + this.flightStreamAndClient = clientHandler.getFlightStream(split, connectorSession); + } + + @Override + public long getCompletedBytes() + { + return 0; + } + + @Override + public long getCompletedPositions() + { + return currentPosition; + } + + @Override + public long getReadTimeNanos() + { + return 0; + } + + @Override + public boolean isFinished() + { + return completed; + } + + @Override + public long getSystemMemoryUsage() + { + return 0; + } + + @Override + public Page getNextPage() + { + logger.debug("Reading next Arrow record batch"); + + if (!flightStreamAndClient.next()) { + // No more streams, end pages + completed = true; + logger.debug("Finished reading Arrow record batches"); + return null; + } + + currentPosition = currentPosition + 1; + + // Create blocks from the loaded Arrow record batch + List blocks = new ArrayList<>(); + List vectors = flightStreamAndClient.getRoot().getFieldVectors(); + for (int columnIndex = 0; columnIndex < columnHandles.size(); columnIndex++) { + FieldVector vector = vectors.get(columnIndex); + Type type = columnHandles.get(columnIndex).getColumnType(); + Block block = arrowBlockBuilder.buildBlockFromFieldVector(vector, type, flightStreamAndClient.getDictionaryProvider()); + blocks.add(block); + } + + if (logger.isDebugEnabled()) { + logger.debug("Read Arrow record batch with rows: %s, columns: %s", flightStreamAndClient.getRoot().getRowCount(), vectors.size()); + } + + return new Page(flightStreamAndClient.getRoot().getRowCount(), blocks.toArray(new Block[0])); + } + + @Override + public void close() + { + try { + flightStreamAndClient.close(); + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e); + } + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java new file mode 100644 index 0000000000000..216939831a2c8 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPageSourceProvider.java @@ -0,0 +1,54 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorPageSource; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.SplitContext; +import com.facebook.presto.spi.connector.ConnectorPageSourceProvider; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import com.google.common.collect.ImmutableList; + +import javax.inject.Inject; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class ArrowPageSourceProvider + implements ConnectorPageSourceProvider +{ + private final BaseArrowFlightClientHandler clientHandler; + private final ArrowBlockBuilder arrowBlockBuilder; + + @Inject + public ArrowPageSourceProvider(BaseArrowFlightClientHandler clientHandler, ArrowBlockBuilder arrowBlockBuilder) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + this.arrowBlockBuilder = requireNonNull(arrowBlockBuilder, "arrowBlockBuilder is null"); + } + + @Override + public ConnectorPageSource createPageSource(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorSplit split, List columns, SplitContext splitContext) + { + ImmutableList.Builder columnHandles = ImmutableList.builder(); + for (ColumnHandle handle : columns) { + columnHandles.add((ArrowColumnHandle) handle); + } + ArrowSplit arrowSplit = (ArrowSplit) split; + return new ArrowPageSource(arrowSplit, columnHandles.build(), clientHandler, session, arrowBlockBuilder); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java new file mode 100644 index 0000000000000..45197317cf81f --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowPlugin.java @@ -0,0 +1,51 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.Plugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import com.google.common.collect.ImmutableList; +import com.google.inject.Module; + +import static com.google.common.base.MoreObjects.firstNonNull; +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Strings.isNullOrEmpty; +import static java.util.Objects.requireNonNull; + +public class ArrowPlugin + implements Plugin +{ + private final String name; + private final Module module; + private final ImmutableList extraModules; + + public ArrowPlugin(String name, Module module, Module... extraModules) + { + checkArgument(!isNullOrEmpty(name), "name is null or empty"); + this.name = name; + this.module = requireNonNull(module, "module is null"); + this.extraModules = ImmutableList.copyOf(requireNonNull(extraModules, "extraModules is null")); + } + + private static ClassLoader getClassLoader() + { + return firstNonNull(Thread.currentThread().getContextClassLoader(), ArrowPlugin.class.getClassLoader()); + } + + @Override + public Iterable getConnectorFactories() + { + return ImmutableList.of(new ArrowConnectorFactory(name, module, extraModules, getClassLoader())); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java new file mode 100644 index 0000000000000..9308bd60aa934 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplit.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSplit; +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.NodeProvider; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import javax.annotation.Nullable; + +import java.util.Collections; +import java.util.List; + +public class ArrowSplit + implements ConnectorSplit +{ + private final String schemaName; + private final String tableName; + private final byte[] flightEndpointBytes; + + @JsonCreator + public ArrowSplit( + @JsonProperty("schemaName") @Nullable String schemaName, + @JsonProperty("tableName") String tableName, + @JsonProperty("flightEndpointBytes") byte[] flightEndpointBytes) + { + this.schemaName = schemaName; + this.tableName = tableName; + this.flightEndpointBytes = flightEndpointBytes; + } + + @Override + public NodeSelectionStrategy getNodeSelectionStrategy() + { + return NodeSelectionStrategy.NO_PREFERENCE; + } + + @Override + public List getPreferredNodes(NodeProvider nodeProvider) + { + return Collections.emptyList(); + } + + @Override + public Object getInfo() + { + return this.getInfoMap(); + } + + @JsonProperty + public String getSchemaName() + { + return schemaName; + } + + @JsonProperty + public String getTableName() + { + return tableName; + } + + @JsonProperty + public byte[] getFlightEndpointBytes() + { + return flightEndpointBytes; + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java new file mode 100644 index 0000000000000..129a5e11f6355 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowSplitManager.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.ConnectorSplitSource; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.facebook.presto.spi.FixedSplitSource; +import com.facebook.presto.spi.connector.ConnectorSplitManager; +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; +import org.apache.arrow.flight.FlightInfo; + +import javax.inject.Inject; + +import java.util.List; + +import static com.google.common.collect.ImmutableList.toImmutableList; +import static java.util.Objects.requireNonNull; + +public class ArrowSplitManager + implements ConnectorSplitManager +{ + private final BaseArrowFlightClientHandler clientHandler; + + @Inject + public ArrowSplitManager(BaseArrowFlightClientHandler clientHandler) + { + this.clientHandler = requireNonNull(clientHandler, "clientHandler is null"); + } + + @Override + public ConnectorSplitSource getSplits(ConnectorTransactionHandle transactionHandle, ConnectorSession session, ConnectorTableLayoutHandle layout, SplitSchedulingContext splitSchedulingContext) + { + ArrowTableLayoutHandle tableLayoutHandle = (ArrowTableLayoutHandle) layout; + ArrowTableHandle tableHandle = tableLayoutHandle.getTable(); + FlightInfo flightInfo = clientHandler.getFlightInfoForTableScan(tableLayoutHandle, session); + List splits = flightInfo.getEndpoints() + .stream() + .map(info -> new ArrowSplit( + tableHandle.getSchema(), + tableHandle.getTable(), + info.serialize().array())) + .collect(toImmutableList()); + return new FixedSplitSource(splits); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java new file mode 100644 index 0000000000000..cef04e4372a5a --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableHandle.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorTableHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Objects; + +public class ArrowTableHandle + implements ConnectorTableHandle +{ + private final String schema; + private final String table; + + @JsonCreator + public ArrowTableHandle( + @JsonProperty("schema") String schema, + @JsonProperty("table") String table) + { + this.schema = schema; + this.table = table; + } + + @JsonProperty("schema") + public String getSchema() + { + return schema; + } + + @JsonProperty("table") + public String getTable() + { + return table; + } + + @Override + public String toString() + { + return schema + ":" + table; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableHandle that = (ArrowTableHandle) o; + return Objects.equals(schema, that.schema) && Objects.equals(table, that.table); + } + + @Override + public int hashCode() + { + return Objects.hash(schema, table); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..dfb4d91985e91 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTableLayoutHandle.java @@ -0,0 +1,87 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.spi.ColumnHandle; +import com.facebook.presto.spi.ConnectorTableLayoutHandle; +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.List; +import java.util.Objects; + +import static java.util.Objects.requireNonNull; + +public class ArrowTableLayoutHandle + implements ConnectorTableLayoutHandle +{ + private final ArrowTableHandle table; + private final List columnHandles; + private final TupleDomain tupleDomain; + + @JsonCreator + public ArrowTableLayoutHandle( + @JsonProperty("table") ArrowTableHandle table, + @JsonProperty("columnHandles") List columnHandles, + @JsonProperty("tupleDomain") TupleDomain domain) + { + this.table = requireNonNull(table, "table is null"); + this.columnHandles = requireNonNull(columnHandles, "columnHandles is null"); + this.tupleDomain = requireNonNull(domain, "tupleDomain is null"); + } + + @JsonProperty("table") + public ArrowTableHandle getTable() + { + return table; + } + + @JsonProperty("tupleDomain") + public TupleDomain getTupleDomain() + { + return tupleDomain; + } + + @JsonProperty("columnHandles") + public List getColumnHandles() + { + return columnHandles; + } + + @Override + public String toString() + { + return "table:" + table + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ArrowTableLayoutHandle arrowTableLayoutHandle = (ArrowTableLayoutHandle) o; + return Objects.equals(table, arrowTableLayoutHandle.table) && Objects.equals(columnHandles, arrowTableLayoutHandle.columnHandles) && Objects.equals(tupleDomain, arrowTableLayoutHandle.tupleDomain); + } + + @Override + public int hashCode() + { + return Objects.hash(table, columnHandles, tupleDomain); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java new file mode 100644 index 0000000000000..07eb7385cfbcf --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ArrowTransactionHandle.java @@ -0,0 +1,22 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.connector.ConnectorTransactionHandle; + +public enum ArrowTransactionHandle + implements ConnectorTransactionHandle +{ + INSTANCE +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java new file mode 100644 index 0000000000000..573265ef637d4 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/BaseArrowFlightClientHandler.java @@ -0,0 +1,151 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightRuntimeException; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.io.InputStream; +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.nio.file.Paths; +import java.util.List; +import java.util.Optional; + +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_CLIENT_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_INFO_ERROR; +import static com.facebook.plugin.arrow.ArrowErrorCode.ARROW_FLIGHT_METADATA_ERROR; +import static java.nio.file.Files.newInputStream; +import static java.util.Objects.requireNonNull; + +public abstract class BaseArrowFlightClientHandler +{ + private final ArrowFlightConfig config; + private final BufferAllocator allocator; + + public BaseArrowFlightClientHandler(BufferAllocator allocator, ArrowFlightConfig config) + { + this.allocator = requireNonNull(allocator, "allocator is null"); + this.config = requireNonNull(config, "config is null"); + } + + protected FlightClient createFlightClient() + { + Location location; + if (config.getArrowFlightServerSslEnabled() != null && !config.getArrowFlightServerSslEnabled()) { + location = Location.forGrpcInsecure(config.getFlightServerName(), config.getArrowFlightPort()); + } + else { + location = Location.forGrpcTls(config.getFlightServerName(), config.getArrowFlightPort()); + } + return createFlightClient(location); + } + + protected FlightClient createFlightClient(Location location) + { + try { + Optional trustedCertificate = Optional.empty(); + FlightClient.Builder flightClientBuilder = FlightClient.builder(allocator, location); + if (config.getVerifyServer() != null && !config.getVerifyServer()) { + flightClientBuilder.verifyServer(false); + } + else if (config.getFlightServerSSLCertificate() != null) { + trustedCertificate = Optional.of(newInputStream(Paths.get(config.getFlightServerSSLCertificate()))); + flightClientBuilder.trustedCertificates(trustedCertificate.get()).useTls(); + } + + FlightClient flightClient = flightClientBuilder.build(); + if (trustedCertificate.isPresent()) { + trustedCertificate.get().close(); + } + + return flightClient; + } + catch (Exception e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, "Error creating flight client: " + e.getMessage(), e); + } + } + + public abstract CallOption[] getCallOptions(ConnectorSession connectorSession); + + protected FlightInfo getFlightInfo(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + try (FlightClient client = createFlightClient()) { + CallOption[] callOptions = getCallOptions(connectorSession); + return client.getInfo(flightDescriptor, callOptions); + } + catch (InterruptedException e) { + throw new ArrowException(ARROW_FLIGHT_INFO_ERROR, "Error getting flight information: " + e.getMessage(), e); + } + } + + protected ClientClosingFlightStream getFlightStream(ArrowSplit split, ConnectorSession connectorSession) + { + ByteBuffer endpointBytes = ByteBuffer.wrap(split.getFlightEndpointBytes()); + try { + FlightEndpoint endpoint = FlightEndpoint.deserialize(endpointBytes); + FlightClient client = endpoint.getLocations().stream() + .findAny() + .map(this::createFlightClient) + .orElseGet(this::createFlightClient); + return new ClientClosingFlightStream( + client.getStream(endpoint.getTicket(), getCallOptions(connectorSession)), + client); + } + catch (FlightRuntimeException | IOException | URISyntaxException e) { + throw new ArrowException(ARROW_FLIGHT_CLIENT_ERROR, e.getMessage(), e); + } + } + + public Schema getSchema(FlightDescriptor flightDescriptor, ConnectorSession connectorSession) + { + try (FlightClient client = createFlightClient()) { + CallOption[] callOptions = this.getCallOptions(connectorSession); + return client.getSchema(flightDescriptor, callOptions).getSchema(); + } + catch (InterruptedException e) { + throw new ArrowException(ARROW_FLIGHT_METADATA_ERROR, "Error getting schema for flight: " + e.getMessage(), e); + } + } + + public abstract List listSchemaNames(ConnectorSession session); + + public abstract List listTables(ConnectorSession session, Optional schemaName); + + protected abstract FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName); + + protected abstract FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle); + + public Schema getSchemaForTable(String schemaName, String tableName, ConnectorSession connectorSession) + { + FlightDescriptor flightDescriptor = getFlightDescriptorForSchema(schemaName, tableName); + return getSchema(flightDescriptor, connectorSession); + } + + public FlightInfo getFlightInfoForTableScan(ArrowTableLayoutHandle tableLayoutHandle, ConnectorSession session) + { + FlightDescriptor flightDescriptor = getFlightDescriptorForTableScan(tableLayoutHandle); + return getFlightInfo(flightDescriptor, session); + } +} diff --git a/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java new file mode 100644 index 0000000000000..30ca33d2276a9 --- /dev/null +++ b/presto-base-arrow-flight/src/main/java/com/facebook/plugin/arrow/ClientClosingFlightStream.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.dictionary.DictionaryProvider; + +import static java.util.Objects.requireNonNull; + +public class ClientClosingFlightStream + implements AutoCloseable +{ + private final FlightStream flightStream; + private final AutoCloseable flightClient; + + public ClientClosingFlightStream(FlightStream flightStream, AutoCloseable flightClient) + { + this.flightStream = requireNonNull(flightStream, "flightStream is null"); + this.flightClient = requireNonNull(flightClient, "flightClient is null"); + } + + public VectorSchemaRoot getRoot() + { + return flightStream.getRoot(); + } + + public DictionaryProvider getDictionaryProvider() + { + return flightStream.getDictionaryProvider(); + } + + public boolean next() + { + return flightStream.next(); + } + + @Override + public void close() + throws Exception + { + try { + flightStream.close(); + } + finally { + flightClient.close(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java new file mode 100644 index 0000000000000..f105359f0ae73 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowFlightQueryRunner.java @@ -0,0 +1,107 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.airlift.log.Logging; +import com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.Session; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; + +import java.io.File; +import java.util.Map; + +import static com.facebook.presto.testing.TestingSession.testSessionBuilder; + +public class ArrowFlightQueryRunner +{ + private ArrowFlightQueryRunner() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static DistributedQueryRunner createQueryRunner(int flightServerPort) throws Exception + { + return createQueryRunner(ImmutableMap.of("arrow-flight.server.port", String.valueOf(flightServerPort))); + } + + private static DistributedQueryRunner createQueryRunner(Map catalogProperties) throws Exception + { + return createQueryRunner(ImmutableMap.of(), catalogProperties); + } + + private static DistributedQueryRunner createQueryRunner( + Map extraProperties, + Map catalogProperties) + throws Exception + { + Session session = testSessionBuilder() + .setCatalog("arrow") + .setSchema("tpch") + .build(); + + DistributedQueryRunner queryRunner = DistributedQueryRunner.builder(session).setExtraProperties(extraProperties).build(); + + try { + queryRunner.installPlugin(new TestingArrowFlightPlugin()); + + ImmutableMap.Builder properties = ImmutableMap.builder() + .putAll(catalogProperties) + .put("arrow-flight.server", "localhost") + .put("arrow-flight.server-ssl-enabled", "true") + .put("arrow-flight.server-ssl-certificate", "src/test/resources/server.crt") + .put("arrow-flight.server.verify", "true"); + + queryRunner.createCatalog("arrow", "arrow", properties.build()); + + return queryRunner; + } + catch (Exception e) { + throw new RuntimeException("Failed to create ArrowQueryRunner", e); + } + } + + public static void main(String[] args) + throws Exception + { + Logging.initialize(); + + RootAllocator allocator = new RootAllocator(Long.MAX_VALUE); + Location serverLocation = Location.forGrpcTls("localhost", 9443); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + FlightServer server = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + + Logger log = Logger.get(ArrowFlightQueryRunner.class); + + log.info("Server listening on port " + server.getPort()); + + DistributedQueryRunner queryRunner = createQueryRunner( + ImmutableMap.of("http-server.http.port", "8080"), + ImmutableMap.of("arrow-flight.server.port", String.valueOf(9443))); + Thread.sleep(10); + log.info("======== SERVER STARTED ========"); + log.info("\n====\n%s\n====", queryRunner.getCoordinator().getBaseUrl()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java new file mode 100644 index 0000000000000..c4ab656d41bb3 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/ArrowMetadataUtil.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.json.JsonCodecFactory; +import com.facebook.airlift.json.JsonObjectMapperProvider; +import com.facebook.presto.common.type.StandardTypes; +import com.facebook.presto.common.type.Type; +import com.fasterxml.jackson.databind.DeserializationContext; +import com.fasterxml.jackson.databind.deser.std.FromStringDeserializer; +import com.google.common.collect.ImmutableMap; + +import java.util.Map; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.google.common.base.Preconditions.checkArgument; +import static java.util.Locale.ENGLISH; +import static org.testng.Assert.assertEquals; + +final class ArrowMetadataUtil +{ + private ArrowMetadataUtil() {} + + public static final JsonCodec COLUMN_CODEC; + public static final JsonCodec TABLE_CODEC; + + static { + JsonObjectMapperProvider provider = new JsonObjectMapperProvider(); + provider.setJsonDeserializers(ImmutableMap.of(Type.class, new TestingTypeDeserializer())); + JsonCodecFactory codecFactory = new JsonCodecFactory(provider); + COLUMN_CODEC = codecFactory.jsonCodec(ArrowColumnHandle.class); + TABLE_CODEC = codecFactory.jsonCodec(ArrowTableHandle.class); + } + + public static final class TestingTypeDeserializer + extends FromStringDeserializer + { + private final Map types = ImmutableMap.of( + StandardTypes.BIGINT, BIGINT, + StandardTypes.VARCHAR, VARCHAR); + + public TestingTypeDeserializer() + { + super(Type.class); + } + + @Override + protected Type _deserialize(String value, DeserializationContext context) + { + Type type = types.get(value.toLowerCase(ENGLISH)); + checkArgument(type != null, "Unknown type %s", value); + return type; + } + } + + public static void assertJsonRoundTrip(JsonCodec codec, T object) + { + String json = codec.toJson(object); + T copy = codec.fromJson(json); + assertEquals(copy, object); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java new file mode 100644 index 0000000000000..f909e93c43774 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowBlockBuilder.java @@ -0,0 +1,826 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.block.ArrayBlock; +import com.facebook.presto.common.block.Block; +import com.facebook.presto.common.block.BlockBuilder; +import com.facebook.presto.common.block.DictionaryBlock; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DecimalType; +import com.facebook.presto.common.type.Decimals; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import io.airlift.slice.Slice; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.DecimalVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.SmallIntVector; +import org.apache.arrow.vector.TimeStampMicroVector; +import org.apache.arrow.vector.TimeStampMilliVector; +import org.apache.arrow.vector.TinyIntVector; +import org.apache.arrow.vector.ValueVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.Text; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Ignore; +import org.testng.annotations.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; +import java.nio.charset.StandardCharsets; +import java.time.LocalDate; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +import static com.facebook.presto.testing.TestingEnvironment.FUNCTION_AND_TYPE_MANAGER; +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +public class TestArrowBlockBuilder +{ + private static final Logger logger = Logger.get(TestArrowBlockBuilder.class); + private static final int DICTIONARY_LENGTH = 10; + private static final int VECTOR_LENGTH = 50; + private BufferAllocator allocator; + private ArrowBlockBuilder arrowBlockBuilder; + + @BeforeClass + public void setUp() + { + // Initialize the Arrow allocator + allocator = new RootAllocator(Integer.MAX_VALUE); + logger.debug("Allocator initialized: %s", allocator.getName()); + arrowBlockBuilder = new ArrowBlockBuilder(FUNCTION_AND_TYPE_MANAGER); + } + + @AfterClass + public void tearDown() + { + allocator.close(); + } + + @Test + public void testBuildBlockFromBitVector() + { + // Create a BitVector and populate it with values + try (BitVector bitVector = new BitVector("bitVector", allocator)) { + bitVector.allocateNew(3); // Allocating space for 3 elements + + bitVector.set(0, 1); // Set value to 1 (true) + bitVector.set(1, 0); // Set value to 0 (false) + bitVector.setNull(2); // Set null value + + bitVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(bitVector, BooleanType.BOOLEAN, null); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + } + + @Test + public void testBuildBlockFromTinyIntVector() + { + // Create a TinyIntVector and populate it with values + try (TinyIntVector tinyIntVector = new TinyIntVector("tinyIntVector", allocator)) { + tinyIntVector.allocateNew(3); // Allocating space for 3 elements + tinyIntVector.set(0, 10); + tinyIntVector.set(1, 20); + tinyIntVector.setNull(2); // Set null value + + tinyIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(tinyIntVector, TinyintType.TINYINT, null); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + } + + @Test + public void testBuildBlockFromSmallIntVector() + { + // Create a SmallIntVector and populate it with values + try (SmallIntVector smallIntVector = new SmallIntVector("smallIntVector", allocator)) { + smallIntVector.allocateNew(3); // Allocating space for 3 elements + smallIntVector.set(0, 10); + smallIntVector.set(1, 20); + smallIntVector.setNull(2); // Set null value + + smallIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(smallIntVector, SmallintType.SMALLINT, null); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + } + } + + @Test + public void testBuildBlockFromIntVector() + { + // Create an IntVector and populate it with values + try (IntVector intVector = new IntVector("intVector", allocator)) { + intVector.allocateNew(3); // Allocating space for 3 elements + intVector.set(0, 10); + intVector.set(1, 20); + intVector.set(2, 30); + + intVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(intVector, IntegerType.INTEGER, null); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertEquals(10, resultBlock.getInt(0)); // The 1st element should be 10 + assertEquals(20, resultBlock.getInt(1)); // The 2nd element should be 20 + assertEquals(30, resultBlock.getInt(2)); // The 3rd element should be 30 + } + } + + @Test + public void testBuildBlockFromBigIntVector() + throws InstantiationException, IllegalAccessException + { + // Create a BigIntVector and populate it with values + try (BigIntVector bigIntVector = new BigIntVector("bigIntVector", allocator)) { + bigIntVector.allocateNew(3); // Allocating space for 3 elements + + bigIntVector.set(0, 10L); + bigIntVector.set(1, 20L); + bigIntVector.set(2, 30L); + + bigIntVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(bigIntVector, BigintType.BIGINT, null); + + // Now verify the result block + assertEquals(10L, resultBlock.getInt(0)); // The 1st element should be 10L + assertEquals(20L, resultBlock.getInt(1)); // The 2nd element should be 20L + assertEquals(30L, resultBlock.getInt(2)); // The 3rd element should be 30L + } + } + + @Test + public void testBuildBlockFromDecimalVector() + { + // Create a DecimalVector and populate it with values + try (DecimalVector decimalVector = new DecimalVector("decimalVector", allocator, 10, 2)) { // Precision = 10, Scale = 2 + decimalVector.allocateNew(2); // Allocating space for 2 elements + decimalVector.set(0, new BigDecimal("123.45")); + + decimalVector.setValueCount(2); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(decimalVector, DecimalType.createDecimalType(10, 2), null); + + // Now verify the result block + assertEquals(2, resultBlock.getPositionCount()); // Should have 2 positions + assertTrue(resultBlock.isNull(1)); // The 2nd element should be null + } + } + + @Test + public void testBuildBlockFromTimeStampMicroVector() + { + // Create a TimeStampMicroVector and populate it with values + try (TimeStampMicroVector timestampMicroVector = new TimeStampMicroVector("timestampMicroVector", allocator)) { + timestampMicroVector.allocateNew(3); // Allocating space for 3 elements + timestampMicroVector.set(0, 1000000L); // 1 second in microseconds + timestampMicroVector.set(1, 2000000L); // 2 seconds in microseconds + timestampMicroVector.setNull(2); // Set null value + + timestampMicroVector.setValueCount(3); + + // Build the block from the vector + Block resultBlock = arrowBlockBuilder.buildBlockFromFieldVector(timestampMicroVector, TimestampType.TIMESTAMP, null); + + // Now verify the result block + assertEquals(3, resultBlock.getPositionCount()); // Should have 3 positions + assertTrue(resultBlock.isNull(2)); // The 3rd element should be null + assertEquals(1000L, resultBlock.getLong(0)); // The 1st element should be 1000ms (1 second) + assertEquals(2000L, resultBlock.getLong(1)); // The 2nd element should be 2000ms (2 seconds) + } + } + + @Test + public void testBuildBlockFromListVector() + { + // Create a root allocator for Arrow vectors + try (BufferAllocator allocator = new RootAllocator(); + ListVector listVector = ListVector.empty("listVector", allocator)) { + // Allocate the vector and get the writer + listVector.allocateNew(); + UnionListWriter listWriter = listVector.getWriter(); + + int[] data = new int[] {1, 2, 3, 10, 20, 30, 100, 200, 300, 1000, 2000, 3000}; + int tmpIndex = 0; + + for (int i = 0; i < 4; i++) { // 4 lists to be added + listWriter.startList(); + for (int j = 0; j < 3; j++) { // Each list has 3 integers + listWriter.writeInt(data[tmpIndex]); + tmpIndex++; + } + listWriter.endList(); + } + + // Set the number of lists + listVector.setValueCount(4); + + // Create Presto ArrayType for Integer + ArrayType arrayType = new ArrayType(IntegerType.INTEGER); + + // Call the method to test + Block block = arrowBlockBuilder.buildBlockFromFieldVector(listVector, arrayType, null); + assertTrue(block instanceof ArrayBlock); + + // Validate the result + assertEquals(block.getPositionCount(), 4); // 4 lists in the block + for (int j = 0; j < block.getPositionCount(); j++) { + Block subBlock = block.getBlock(j); + assertEquals(subBlock.getPositionCount(), 3); // each list should have 3 elements + } + } + } + + @Test + public void testProcessDictionaryVector() + { + try (VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + VarCharVector rawVector = new VarCharVector("raw", allocator)) { + dictionaryVector.allocateNew(DICTIONARY_LENGTH); + for (int i = 0; i < DICTIONARY_LENGTH; i++) { + dictionaryVector.setSafe(i, String.valueOf(i).getBytes(StandardCharsets.UTF_8)); + } + dictionaryVector.setValueCount(DICTIONARY_LENGTH); + + rawVector.allocateNew(VECTOR_LENGTH); + for (int i = 0; i < VECTOR_LENGTH; i++) { + int value = i % DICTIONARY_LENGTH; + rawVector.setSafe(i, String.valueOf(value).getBytes(StandardCharsets.UTF_8)); + } + rawVector.setValueCount(VECTOR_LENGTH); + + // Encode the vector with a dictionary + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(16, true))); + try (ValueVector encodedVector = DictionaryEncoder.encode(rawVector, dictionary); + DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary)) { + // Process the dictionary vector + assertTrue(encodedVector instanceof FieldVector); + Block result = arrowBlockBuilder.buildBlockFromFieldVector((FieldVector) encodedVector, VarcharType.VARCHAR, dictionaryProvider); + + // Verify the result + assertNotNull(result, "The BlockBuilder should not be null."); + assertEquals(result.getPositionCount(), 50); + } + } + } + + @Test + public void testBuildBlockFromDictionaryVector() + { + // Initialize a dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(32, true))); + + FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding()); + Field indexField = new Field("indices", indexFieldType, null); + try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary); + IntVector indicesVector = (IntVector) indexField.createVector(allocator)) { + indicesVector.allocateNew(4); // allocating space for values + + // Set up index values (this would reference the dictionary) + indicesVector.set(0, 0); // First index points to "apple" + indicesVector.set(1, 1); // Second index points to "banana" + indicesVector.set(2, 2); + indicesVector.set(3, 2); // Third index points to "cherry" + indicesVector.setValueCount(4); + + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + else if (i == 3) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorSmallInt() + { + // Initialize a dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(16, true))); + + FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding()); + Field indexField = new Field("indices", indexFieldType, null); + try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary); + SmallIntVector indicesVector = (SmallIntVector) indexField.createVector(allocator)) { + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (short) 0); + indicesVector.set(1, (short) 1); + indicesVector.set(2, (short) 2); + indicesVector.setValueCount(3); + + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + } + + @Test + public void testBuildBlockFromDictionaryVectorTinyInt() + { + // Initialize a dictionary vector + // Example: dictionary contains 3 string values + VarCharVector dictionaryVector = new VarCharVector("dictionary", allocator); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, new ArrowType.Int(8, true))); + + FieldType indexFieldType = new FieldType(false, dictionary.getEncoding().getIndexType(), dictionary.getEncoding()); + Field indexField = new Field("indices", indexFieldType, null); + try (DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary); + TinyIntVector indicesVector = (TinyIntVector) indexField.createVector(allocator)) { + indicesVector.allocateNew(3); // allocating space for 3 values + indicesVector.set(0, (byte) 0); + indicesVector.set(1, (byte) 1); + indicesVector.set(2, (byte) 2); + indicesVector.setValueCount(3); + + // Call the method under test + Block block = arrowBlockBuilder.buildBlockFromFieldVector(indicesVector, VarcharType.VARCHAR, dictionaryProvider); + + // Assertions to check the dictionary block's behavior + assertNotNull(block); + assertTrue(block instanceof DictionaryBlock); + DictionaryBlock dictionaryBlock = (DictionaryBlock) block; + + // Verify the dictionary block contains the right dictionary + for (int i = 0; i < dictionaryBlock.getPositionCount(); i++) { + // Get the slice (string value) at the given position + Slice slice = dictionaryBlock.getSlice(i, 0, dictionaryBlock.getSliceLength(i)); + + // Assert based on the expected values + if (i == 0) { + assertEquals(slice.toStringUtf8(), "apple"); + } + else if (i == 1) { + assertEquals(slice.toStringUtf8(), "banana"); + } + else if (i == 2) { + assertEquals(slice.toStringUtf8(), "cherry"); + } + } + } + } + + @Test + public void testAssignVarcharType() + { + try (VarCharVector vector = new VarCharVector("varCharVector", allocator)) { + vector.allocateNew(3); + + String value = "test_string"; + vector.set(0, new Text(value)); + vector.setValueCount(1); + + Type varcharType = VarcharType.createUnboundedVarcharType(); + BlockBuilder builder = varcharType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromVarCharVector(vector, varcharType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + Slice result = varcharType.getSlice(block, 0); + assertEquals(result.toStringUtf8(), value); + } + } + + @Test + public void testAssignSmallintType() + { + try (SmallIntVector vector = new SmallIntVector("smallIntVector", allocator)) { + vector.allocateNew(3); + + short value = 42; + vector.set(0, value); + vector.setValueCount(1); + + Type smallintType = SmallintType.SMALLINT; + BlockBuilder builder = smallintType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromSmallIntVector(vector, smallintType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = smallintType.getLong(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignTinyintType() + { + try (TinyIntVector vector = new TinyIntVector("tinyIntVector", allocator)) { + vector.allocateNew(3); + + byte value = 7; + vector.set(0, value); + vector.setValueCount(1); + + Type tinyintType = TinyintType.TINYINT; + BlockBuilder builder = tinyintType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromTinyIntVector(vector, tinyintType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = tinyintType.getLong(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignBigintType() + { + try (BigIntVector vector = new BigIntVector("bigIntVector", allocator)) { + vector.allocateNew(3); + + long value = 123456789L; + vector.set(0, value); + vector.setValueCount(1); + + Type bigintType = BigintType.BIGINT; + BlockBuilder builder = bigintType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromBigIntVector(vector, bigintType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = bigintType.getLong(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignIntegerType() + { + try (IntVector vector = new IntVector("IntVector", allocator)) { + vector.allocateNew(3); + + int value = 42; + vector.set(0, value); + vector.setValueCount(1); + + Type integerType = IntegerType.INTEGER; + BlockBuilder builder = integerType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromIntVector(vector, integerType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = integerType.getLong(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignDoubleType() + { + try (Float8Vector vector = new Float8Vector("Float8Vector", allocator)) { + vector.allocateNew(3); + + double value = 42.42; + vector.set(0, value); + vector.setValueCount(1); + + Type doubleType = DoubleType.DOUBLE; + BlockBuilder builder = doubleType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromFloat8Vector(vector, doubleType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + double result = doubleType.getDouble(block, 0); + assertEquals(result, value, 0.001); + } + } + + @Test + public void testAssignBooleanType() + { + try (BitVector vector = new BitVector("BitVector", allocator)) { + vector.allocateNew(3); + + boolean value = true; + vector.set(0, 1); + vector.setValueCount(1); + + Type booleanType = BooleanType.BOOLEAN; + BlockBuilder builder = booleanType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromBitVector(vector, booleanType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + boolean result = booleanType.getBoolean(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignArrayType() + { + try (ListVector vector = ListVector.empty("ListVector", allocator)) { + UnionListWriter writer = vector.getWriter(); + writer.allocate(); + + writer.setPosition(0); // optional + writer.startList(); + writer.integer().writeInt(1); + writer.integer().writeInt(2); + writer.integer().writeInt(3); + writer.endList(); + + writer.setValueCount(1); + + Type elementType = IntegerType.INTEGER; + ArrayType arrayType = new ArrayType(elementType); + BlockBuilder builder = arrayType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromListVector(vector, arrayType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + List values = Arrays.asList(1, 2, 3); + Block arrayBlock = arrayType.getObject(block, 0); + assertEquals(arrayBlock.getPositionCount(), values.size()); + for (int i = 0; i < values.size(); i++) { + assertEquals(elementType.getLong(arrayBlock, i), values.get(i).longValue()); + } + } + } + + @Ignore("RowType not implemented") + @Test + public void testAssignRowType() + { + RowType.Field field1 = new RowType.Field(Optional.of("field1"), IntegerType.INTEGER); + RowType.Field field2 = new RowType.Field(Optional.of("field2"), VarcharType.createUnboundedVarcharType()); + RowType rowType = RowType.from(Arrays.asList(field1, field2)); + BlockBuilder builder = rowType.createBlockBuilder(null, 1); + + List rowValues = Arrays.asList(42, "test"); + // TODO: arrowBlockBuilder.(rowType, builder, rowValues); + + Block block = builder.build(); + Block rowBlock = rowType.getObject(block, 0); + assertEquals(IntegerType.INTEGER.getLong(rowBlock, 0), 42); + assertEquals(VarcharType.createUnboundedVarcharType().getSlice(rowBlock, 1).toStringUtf8(), "test"); + } + + @Test + public void testAssignDateType() + { + try (DateDayVector vector = new DateDayVector("DateDayVector", allocator)) { + vector.allocateNew(3); + + LocalDate value = LocalDate.of(2020, 1, 1); + vector.set(0, (int) value.toEpochDay()); + vector.setValueCount(1); + + Type dateType = DateType.DATE; + BlockBuilder builder = dateType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromDateDayVector(vector, dateType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = dateType.getLong(block, 0); + assertEquals(result, value.toEpochDay()); + } + } + + @Test + public void testAssignTimestampType() + { + try (TimeStampMilliVector vector = new TimeStampMilliVector("TimeStampMilliVector", allocator)) { + vector.allocateNew(3); + + long value = 1609459200000L; // Jan 1, 2021, 00:00:00 UTC + vector.set(0, value); + vector.setValueCount(1); + + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromTimeStampMilliVector(vector, timestampType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, value); + } + } + + @Test + public void testAssignTimestampTypeWithSqlTimestamp() + { + try (TimeStampMilliVector vector = new TimeStampMilliVector("TimeStampMilliVector", allocator)) { + vector.allocateNew(3); + + java.sql.Timestamp timestamp = java.sql.Timestamp.valueOf("2021-01-01 00:00:00"); + long expectedMillis = timestamp.getTime(); + vector.set(0, expectedMillis); + vector.setValueCount(1); + + Type timestampType = TimestampType.TIMESTAMP; + BlockBuilder builder = timestampType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromTimeStampMilliVector(vector, timestampType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long result = timestampType.getLong(block, 0); + assertEquals(result, expectedMillis); + } + } + + @Test + public void testAssignShortDecimal() + { + try (DecimalVector vector = new DecimalVector("DecimalVector", allocator, 10, 2)) { + vector.allocateNew(3); + + BigDecimal decimalValue = new BigDecimal("12345.67"); + vector.set(0, decimalValue); + vector.setValueCount(1); + + DecimalType shortDecimalType = DecimalType.createDecimalType(10, 2); // Precision: 10, Scale: 2 + BlockBuilder builder = shortDecimalType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromDecimalVector(vector, shortDecimalType, builder, 0, vector.getValueCount()); + + Block block = builder.build(); + long unscaledValue = shortDecimalType.getLong(block, 0); // Unscaled value: 1234567 + BigDecimal result = BigDecimal.valueOf(unscaledValue).movePointLeft(shortDecimalType.getScale()); + assertEquals(result, decimalValue); + } + } + + @Test + public void testAssignLongDecimal() + { + try (DecimalVector vector = new DecimalVector("DecimalVector", allocator, 38, 10)) { + vector.allocateNew(3); + + BigDecimal decimalValue = new BigDecimal("1234567890.1234567890"); + vector.set(0, decimalValue); + vector.setValueCount(1); + + // Create a DecimalType with precision 38 and scale 10 + DecimalType longDecimalType = DecimalType.createDecimalType(38, 10); + BlockBuilder builder = longDecimalType.createBlockBuilder(null, 1); + + arrowBlockBuilder.assignBlockFromDecimalVector(vector, longDecimalType, builder, 0, vector.getValueCount()); + + // Build the block after inserting the decimal value + Block block = builder.build(); + Slice unscaledSlice = longDecimalType.getSlice(block, 0); + BigInteger unscaledValue = Decimals.decodeUnscaledValue(unscaledSlice); + BigDecimal result = new BigDecimal(unscaledValue).movePointLeft(longDecimalType.getScale()); + // Assert the decoded result is equal to the original decimal value + assertEquals(result, decimalValue); + } + } + + @Test + public void testVarcharVector() + { + try (VarCharVector vector = new VarCharVector("VarCharVector", allocator)) { + vector.allocateNew(3); + + vector.set(0, new Text("apple").getBytes()); + vector.set(1, new Text("fig").getBytes()); + vector.setValueCount(2); + + Block resultblock = arrowBlockBuilder.buildBlockFromFieldVector(vector, VarcharType.VARCHAR, null); + + assertEquals(2, resultblock.getPositionCount()); + + // Extract values from the Block and compare with the values in the vector + for (int i = 0; i < vector.getValueCount(); i++) { + // Retrieve the value as a Slice for the ith position in the Block + Slice slice = resultblock.getSlice(i, 0, resultblock.getSliceLength(i)); + // Assert based on the expected values + assertEquals(slice.toStringUtf8(), new String(vector.get(i))); + } + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java new file mode 100644 index 0000000000000..1d9c490180abc --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowColumnHandle.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.spi.ColumnMetadata; +import org.testng.annotations.Test; + +import java.util.Locale; + +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; + +public class TestArrowColumnHandle +{ + @Test + public void testConstructorAndGetters() + { + // Given + String columnName = "testColumn"; + // When + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // Then + assertEquals(columnHandle.getColumnName(), columnName, "Column name should match the input"); + assertEquals(columnHandle.getColumnType(), IntegerType.INTEGER, "Column type should match the input"); + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnName() + { + // Given + // When + new ArrowColumnHandle(null, IntegerType.INTEGER); // Should throw NullPointerException + } + + @Test(expectedExceptions = NullPointerException.class) + public void testConstructorWithNullColumnType() + { + // Given + String columnName = "testColumn"; + + // When + new ArrowColumnHandle(columnName, null); // Should throw NullPointerException + } + + @Test + public void testGetColumnMetadata() + { + // Given + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + + // When + ColumnMetadata columnMetadata = columnHandle.getColumnMetadata(); + + // Then + assertNotNull(columnMetadata, "ColumnMetadata should not be null"); + assertEquals(columnMetadata.getName(), columnName.toLowerCase(Locale.ENGLISH), "ColumnMetadata name should match the column name"); + assertEquals(columnMetadata.getType(), IntegerType.INTEGER, "ColumnMetadata type should match the column type"); + } + + @Test + public void testToString() + { + String columnName = "testColumn"; + ArrowColumnHandle columnHandle = new ArrowColumnHandle(columnName, IntegerType.INTEGER); + String result = columnHandle.toString(); + String expected = columnName + ":" + IntegerType.INTEGER; + assertEquals(result, expected, "toString() should return the correct string representation"); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java new file mode 100644 index 0000000000000..8263defd59585 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightEchoQueries.java @@ -0,0 +1,600 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse; +import com.facebook.presto.common.function.OperatorType; +import com.facebook.presto.common.type.ArrayType; +import com.facebook.presto.common.type.MapType; +import com.facebook.presto.common.type.RowType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueryFramework; +import com.facebook.presto.tests.DistributedQueryRunner; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.AsyncPutListener; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CallOptions; +import org.apache.arrow.flight.CallStatus; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.NoOpFlightProducer; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.MapVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.complex.impl.UnionListWriter; +import org.apache.arrow.vector.complex.impl.UnionMapWriter; +import org.apache.arrow.vector.dictionary.Dictionary; +import org.apache.arrow.vector.dictionary.DictionaryEncoder; +import org.apache.arrow.vector.dictionary.DictionaryProvider; +import org.apache.arrow.vector.ipc.ArrowStreamReader; +import org.apache.arrow.vector.ipc.ArrowStreamWriter; +import org.apache.arrow.vector.types.Types; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.DictionaryEncoding; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; +import org.apache.arrow.vector.util.ByteArrayReadableSeekableByteChannel; +import org.apache.arrow.vector.util.DictionaryUtility; +import org.apache.arrow.vector.util.Text; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.lang.invoke.MethodHandle; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.TreeSet; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.common.block.MethodHandleUtil.compose; +import static com.facebook.presto.common.block.MethodHandleUtil.nativeValueGetter; +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static com.facebook.presto.testing.TestingEnvironment.getOperatorMethodHandle; +import static com.facebook.presto.testing.assertions.Assert.assertEquals; +import static java.lang.String.format; +import static java.nio.channels.Channels.newChannel; + +public class TestArrowFlightEchoQueries + extends AbstractTestQueryFramework +{ + private static final Logger logger = Logger.get(TestArrowFlightEchoQueries.class); + private static final CallOption CALL_OPTIONS = CallOptions.timeout(300, TimeUnit.SECONDS); + private RootAllocator allocator; + private FlightServer server; + private DistributedQueryRunner arrowFlightQueryRunner; + private JsonCodec requestCodec; + private JsonCodec responseCodec; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + + requestCodec = jsonCodec(TestingArrowFlightRequest.class); + responseCodec = jsonCodec(TestingArrowFlightResponse.class); + + server = FlightServer.builder(allocator, getServerLocation(), new TestingEchoFlightProducer(allocator, requestCodec, responseCodec)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port %s", server.getPort()); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + arrowFlightQueryRunner.close(); + server.close(); + allocator.close(); + } + + public static int getServerPort() + { + return 9444; + } + + public static Location getServerLocation() + { + return Location.forGrpcTls("localhost", getServerPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9444); + } + + @Test + public void testVarCharVector() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + VarCharVector stringVector = new VarCharVector("c", bufferAllocator); + VectorSchemaRoot root = new VectorSchemaRoot(Arrays.asList(intVector, stringVector)); + FlightClient client = createFlightClient(bufferAllocator)) { + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, VARCHAR); + + final int numValues = 10; + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + String value = stringData.substring(0, i % stringData.length()); + stringVector.setSafe(i, new Text(value)); + expectedBuilder.row(i, value); + } + root.setRowCount(numValues); + + String tableName = "varchar"; + addTableToServer(client, root, tableName); + + MaterializedResult actual = computeActual(format("SELECT * FROM %s", tableName)); + + assertEquals(actual.getRowCount(), numValues); + assertEquals(actual, expectedBuilder.build()); + + removeTableFromServer(client, tableName); + } + } + + @Test + public void testListVector() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + ListVector listVectorInt = ListVector.empty("array-int", bufferAllocator); + ListVector listVectorVarchar = ListVector.empty("array-varchar", bufferAllocator)) { + // Add the element vectors + listVectorInt.addOrGetVector(FieldType.nullable(Types.MinorType.INT.getType())); + listVectorVarchar.addOrGetVector(FieldType.nullable(Types.MinorType.VARCHAR.getType())); + listVectorInt.allocateNew(); + listVectorVarchar.allocateNew(); + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, listVectorInt, listVectorVarchar)); + FlightClient client = createFlightClient(bufferAllocator)) { + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, new ArrayType(INTEGER), new ArrayType(VARCHAR)); + + final int numValues = 10; + final String stringData = "abcdefghijklmnopqrstuvwxyz"; + final UnionListWriter writerInt = listVectorInt.getWriter(); + final UnionListWriter writerVarchar = listVectorVarchar.getWriter(); + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + + List intArray = new ArrayList<>(); + List stringArray = new ArrayList<>(); + writerInt.setPosition(i); + writerInt.startList(); + writerVarchar.startList(); + for (int j = 0; j < i % 4; j++) { + writerInt.integer().writeInt(i * j); + String stringValue = stringData.substring(0, i % stringData.length()); + writerVarchar.writeVarChar(new Text(stringValue)); + intArray.add(i * j); + stringArray.add(stringValue); + } + writerInt.endList(); + writerVarchar.endList(); + + expectedBuilder.row(i, intArray, stringArray); + } + expectedRoot.setRowCount(numValues); + + String tableName = "arrays"; + addTableToServer(client, expectedRoot, tableName); + + MaterializedResult actual = computeActual(format("SELECT * FROM %s", tableName)); + + assertEquals(actual.getRowCount(), numValues); + assertEquals(actual, expectedBuilder.build()); + + removeTableFromServer(client, tableName); + } + } + } + + @Test + public void testMapVector() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + MapVector mapVector = MapVector.empty("map-int-long", bufferAllocator, false)) { + UnionMapWriter mapWriter = mapVector.getWriter(); + mapWriter.allocate(); + + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, createMapType(INTEGER, BIGINT)); + + final int numValues = 10; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + mapWriter.setPosition(i); + mapWriter.startMap(); + + Map expectedMap = new HashMap<>(); + for (int j = 0; j < i; j++) { + mapWriter.startEntry(); + mapWriter.key().integer().writeInt(j); + mapWriter.value().bigInt().writeBigInt(i * j); + mapWriter.endEntry(); + expectedMap.put(j, (long) i * j); + } + mapWriter.endMap(); + expectedBuilder.row(i, expectedMap); + } + mapWriter.setValueCount(numValues); + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, mapVector)); + FlightClient client = createFlightClient(bufferAllocator)) { + expectedRoot.setRowCount(numValues); + + String tableName = "map"; + addTableToServer(client, expectedRoot, tableName); + + MaterializedResult actual = computeActual(format("SELECT * FROM %s", tableName)); + assertEquals(actual.getRowCount(), numValues); + assertEquals(actual, expectedBuilder.build()); + + removeTableFromServer(client, tableName); + } + } + } + + @Test + public void testStructVector() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + StructVector structVector = StructVector.empty("struct", bufferAllocator)) { + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, + RowType.from(ImmutableList.of( + new RowType.Field(Optional.of("int"), INTEGER), + new RowType.Field(Optional.of("long"), BIGINT)))); + + final IntVector childIntVector + = structVector.addOrGet("int", FieldType.nullable(new ArrowType.Int(32, true)), IntVector.class); + final BigIntVector childLongVector + = structVector.addOrGet("long", FieldType.nullable(new ArrowType.Int(64, true)), BigIntVector.class); + childIntVector.allocateNew(); + childLongVector.allocateNew(); + + final int numValues = 10; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + childIntVector.setSafe(i, i + i); + childLongVector.setSafe(i, i * i); + structVector.setIndexDefined(i); + expectedBuilder.row(i, ImmutableList.of(i + i, (long) i * i)); + } + + try (VectorSchemaRoot expectedRoot = new VectorSchemaRoot(Arrays.asList(intVector, structVector)); + FlightClient client = createFlightClient(bufferAllocator)) { + expectedRoot.setRowCount(numValues); + + String tableName = "structs"; + addTableToServer(client, expectedRoot, tableName); + + MaterializedResult actual = computeActual(format("SELECT * FROM %s", tableName)); + + assertEquals(actual.getRowCount(), numValues); + assertEquals(actual, expectedBuilder.build()); + + removeTableFromServer(client, tableName); + } + } + } + + @Test + public void testDictionaryVector() throws Exception + { + try (BufferAllocator bufferAllocator = allocator.newChildAllocator("echo-test-client", 0, Long.MAX_VALUE); + IntVector intVector = new IntVector("id", bufferAllocator); + VarCharVector rawVector = new VarCharVector("varchar", bufferAllocator); + VarCharVector dictionaryVector = new VarCharVector("dictionary", bufferAllocator)) { + intVector.allocateNew(); + rawVector.allocateNew(); + dictionaryVector.allocateNew(3); // allocating 3 elements in dictionary + + // Fill dictionaryVector with some values + dictionaryVector.set(0, "apple".getBytes()); + dictionaryVector.set(1, "banana".getBytes()); + dictionaryVector.set(2, "cherry".getBytes()); + dictionaryVector.setValueCount(3); + + MaterializedResult.Builder expectedBuilder = resultBuilder(getSession(), INTEGER, VARCHAR); + + final int numValues = 10; + for (int i = 0; i < numValues; i++) { + intVector.setSafe(i, i); + Text rawValue = dictionaryVector.getObject((numValues - i) % dictionaryVector.getValueCount()); + rawVector.setSafe(i, rawValue); + expectedBuilder.row(i, rawValue.toString()); + } + rawVector.setValueCount(numValues); + + Dictionary dictionary = new Dictionary(dictionaryVector, new DictionaryEncoding(1L, false, null)); + + try (FieldVector encodedVector = (FieldVector) DictionaryEncoder.encode(rawVector, dictionary); + VectorSchemaRoot root = new VectorSchemaRoot(Arrays.asList(intVector, encodedVector)); + DictionaryProvider.MapDictionaryProvider dictionaryProvider = new DictionaryProvider.MapDictionaryProvider(dictionary); + FlightClient client = createFlightClient(bufferAllocator)) { + root.setRowCount(numValues); + + String tableName = "dictionary"; + addTableToServer(client, root, tableName, dictionaryProvider); + + MaterializedResult actual = computeActual(format("SELECT * FROM %s", tableName)); + + assertEquals(actual.getRowCount(), numValues); + assertEquals(actual, expectedBuilder.build()); + + removeTableFromServer(client, tableName); + } + } + } + + private static MapType createMapType(Type keyType, Type valueType) + { + MethodHandle keyNativeEquals = getOperatorMethodHandle(OperatorType.EQUAL, keyType, keyType); + MethodHandle keyBlockEquals = compose(keyNativeEquals, nativeValueGetter(keyType), nativeValueGetter(keyType)); + MethodHandle keyNativeHashCode = getOperatorMethodHandle(OperatorType.HASH_CODE, keyType); + MethodHandle keyBlockHashCode = compose(keyNativeHashCode, nativeValueGetter(keyType)); + + return new MapType( + keyType, + valueType, + keyBlockEquals, + keyBlockHashCode); + } + + private static FlightClient createFlightClient(BufferAllocator allocator) throws IOException + { + InputStream trustedCertificate = new ByteArrayInputStream(Files.readAllBytes(Paths.get("src/test/resources/server.crt"))); + return FlightClient.builder(allocator, getServerLocation()).verifyServer(true).useTls().trustedCertificates(trustedCertificate).build(); + } + + private void addTableToServer(FlightClient client, VectorSchemaRoot root, String tableName) + { + addTableToServer(client, root, tableName, null); + } + + private void addTableToServer(FlightClient client, VectorSchemaRoot root, String tableName, DictionaryProvider dictionaryProvider) + { + TestingArrowFlightRequest putRequest = new TestingArrowFlightRequest(Optional.empty(), Optional.of(tableName), Optional.empty()); + final FlightClient.ClientStreamListener stream; + + if (dictionaryProvider == null) { + stream = client.startPut(FlightDescriptor.command(requestCodec.toJsonBytes(putRequest)), + root, new AsyncPutListener(), CALL_OPTIONS); + } + else { + stream = client.startPut(FlightDescriptor.command(requestCodec.toJsonBytes(putRequest)), + root, dictionaryProvider, new AsyncPutListener(), CALL_OPTIONS); + } + stream.putNext(); + stream.completed(); + stream.getResult(); + } + + private void removeTableFromServer(FlightClient client, String tableName) + { + TestingArrowFlightRequest dropRequest = new TestingArrowFlightRequest(Optional.empty(), Optional.of(tableName), Optional.empty()); + Iterator iterator = client.doAction(new Action("drop", requestCodec.toJsonBytes(dropRequest)), CALL_OPTIONS); + iterator.hasNext(); + } + + private static class TestingEchoFlightProducer + extends NoOpFlightProducer + { + private final BufferAllocator allocator; + private final Map tableMap = new ConcurrentHashMap<>(); + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + public TestingEchoFlightProducer(BufferAllocator allocator, JsonCodec requestCodec, JsonCodec responseCodec) + { + this.allocator = allocator; + this.requestCodec = requestCodec; + this.responseCodec = responseCodec; + } + + public Runnable acceptPut(FlightProducer.CallContext context, FlightStream flightStream, FlightProducer.StreamListener ackStream) + { + return () -> { + TestingArrowFlightRequest request = requestCodec.fromJson(flightStream.getDescriptor().getCommand()); + if (!request.getTable().isPresent()) { + throw new IllegalArgumentException("Table name must be specified"); + } + + final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try (ArrowStreamWriter writer = new ArrowStreamWriter(flightStream.getRoot(), flightStream.getDictionaryProvider(), newChannel(outputStream))) { + while (flightStream.next()) { + writer.writeBatch(); + } + } + catch (IOException e) { + throw new RuntimeException("Error receiving table batches", e); + } + + tableMap.put(request.getTable().get(), outputStream.toByteArray()); + }; + } + + @Override + public void doAction(CallContext context, Action action, StreamListener listener) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(action.getBody()); + + if ("discovery".equals(action.getType())) { + TestingArrowFlightResponse response; + if (!request.getSchema().isPresent()) { + // Return the list of schemas + response = new TestingArrowFlightResponse(ImmutableList.of("tpch"), ImmutableList.of()); + } + else { + // Return the list of tables + response = new TestingArrowFlightResponse(ImmutableList.of(), new ArrayList<>(tableMap.keySet())); + } + + listener.onNext(new Result(responseCodec.toJsonBytes(response))); + listener.onCompleted(); + } + else if ("drop".equals(action.getType())) { + if (!request.getTable().isPresent() || null == tableMap.remove(request.getTable().get())) { + listener.onError(CallStatus.INVALID_ARGUMENT.withDescription("Table not found: " + request.getTable()).toRuntimeException()); + } + listener.onCompleted(); + } + else { + listener.onError(CallStatus.INVALID_ARGUMENT.withDescription("Invalid action: " + action.getType() + ", request: " + request.toString()).toRuntimeException()); + } + } + catch (Exception e) { + listener.onError(e); + } + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + TestingArrowFlightRequest request = requestCodec.fromJson(flightDescriptor.getCommand()); + + if (!request.getTable().isPresent()) { + throw new IllegalArgumentException("Table name must be specified"); + } + + if (!tableMap.containsKey(request.getTable().get())) { + throw new IllegalArgumentException("Unknown table requested"); + } + + byte[] arrowFileBytes = tableMap.get(request.getTable().get()); + + Schema schema; + try (ArrowStreamReader reader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(arrowFileBytes), allocator)) { + schema = generateSchema(reader.getVectorSchemaRoot().getSchema(), reader, new TreeSet<>()); + } + catch (IOException e) { + throw new RuntimeException("Error deserializing Arrow file", e); + } + + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(request.getTable().get().getBytes(StandardCharsets.UTF_8))); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + String tableName = new String(ticket.getBytes(), StandardCharsets.UTF_8); + + if (!tableMap.containsKey(tableName)) { + throw new IllegalArgumentException("Unknown table requested"); + } + + byte[] arrowFileBytes = tableMap.get(tableName); + + try (ArrowStreamReader reader = new ArrowStreamReader(new ByteArrayReadableSeekableByteChannel(arrowFileBytes), allocator)) { + boolean started = false; + // NOTE: need to read first batch to initialize dictionaries + while (reader.loadNextBatch()) { + if (!started) { + serverStreamListener.start(reader.getVectorSchemaRoot(), reader); + started = true; + } + serverStreamListener.putNext(); + } + serverStreamListener.completed(); + } + catch (IOException e) { + throw new RuntimeException("Error deserializing Arrow file", e); + } + } + + /** + * From org.apache.arrow.flight.DictionaryUtils which is package private + */ + static Schema generateSchema( + final Schema originalSchema, final DictionaryProvider provider, Set dictionaryIds) + { + // first determine if a new schema needs to be created. + boolean createSchema = false; + for (Field field : originalSchema.getFields()) { + if (DictionaryUtility.needConvertToMessageFormat(field)) { + createSchema = true; + break; + } + } + + if (!createSchema) { + return originalSchema; + } + else { + final List fields = new ArrayList<>(originalSchema.getFields().size()); + for (final Field field : originalSchema.getFields()) { + fields.add(DictionaryUtility.toMessageFormat(field, provider, dictionaryIds)); + } + return new Schema(fields, originalSchema.getCustomMetadata()); + } + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java new file mode 100644 index 0000000000000..d0983c2f5213d --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightIntegrationSmokeTest.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestIntegrationSmokeTest; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; + +import java.io.File; + +public class TestArrowFlightIntegrationSmokeTest + extends AbstractTestIntegrationSmokeTest +{ + private static final Logger logger = Logger.get(TestArrowFlightIntegrationSmokeTest.class); + private RootAllocator allocator; + private FlightServer server; + private Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("127.0.0.1", 9442); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port %s", server.getPort()); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9442); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + server.close(); + allocator.close(); + arrowFlightQueryRunner.close(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java new file mode 100644 index 0000000000000..7bc8e9f050f05 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowFlightQueries.java @@ -0,0 +1,176 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.log.Logger; +import com.facebook.plugin.arrow.testingServer.TestingArrowProducer; +import com.facebook.presto.Session; +import com.facebook.presto.common.type.TimeZoneKey; +import com.facebook.presto.testing.MaterializedResult; +import com.facebook.presto.testing.QueryRunner; +import com.facebook.presto.tests.AbstractTestQueries; +import com.facebook.presto.tests.DistributedQueryRunner; +import org.apache.arrow.flight.FlightServer; +import org.apache.arrow.flight.Location; +import org.apache.arrow.memory.RootAllocator; +import org.testng.annotations.AfterClass; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.Test; + +import java.io.File; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.LocalTime; +import java.time.ZoneId; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; + +import static com.facebook.presto.common.type.CharType.createCharType; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.common.type.TimeType.TIME; +import static com.facebook.presto.common.type.TimestampType.TIMESTAMP; +import static com.facebook.presto.common.type.VarcharType.VARCHAR; +import static com.facebook.presto.common.type.VarcharType.createVarcharType; +import static com.facebook.presto.testing.MaterializedResult.resultBuilder; +import static java.lang.String.format; +import static org.testng.Assert.assertTrue; + +public class TestArrowFlightQueries + extends AbstractTestQueries +{ + private static final Logger logger = Logger.get(TestArrowFlightQueries.class); + private RootAllocator allocator; + private FlightServer server; + private Location serverLocation; + private DistributedQueryRunner arrowFlightQueryRunner; + + @BeforeClass + public void setup() + throws Exception + { + arrowFlightQueryRunner = getDistributedQueryRunner(); + File certChainFile = new File("src/test/resources/server.crt"); + File privateKeyFile = new File("src/test/resources/server.key"); + + allocator = new RootAllocator(Long.MAX_VALUE); + serverLocation = Location.forGrpcTls("localhost", 9443); + server = FlightServer.builder(allocator, serverLocation, new TestingArrowProducer(allocator)) + .useTls(certChainFile, privateKeyFile) + .build(); + + server.start(); + logger.info("Server listening on port %s", server.getPort()); + } + + @AfterClass(alwaysRun = true) + public void close() + throws InterruptedException + { + arrowFlightQueryRunner.close(); + server.close(); + allocator.close(); + } + + @Override + protected QueryRunner createQueryRunner() + throws Exception + { + return ArrowFlightQueryRunner.createQueryRunner(9443); + } + + @Test + public void testShowCharColumns() + { + MaterializedResult actual = computeActual("SHOW COLUMNS FROM member"); + + MaterializedResult expectedUnparametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar", "", "") + .row("sex", "char", "", "") + .row("state", "char", "", "") + .build(); + + MaterializedResult expectedParametrizedVarchar = resultBuilder(getSession(), VARCHAR, VARCHAR, VARCHAR, VARCHAR) + .row("id", "integer", "", "") + .row("name", "varchar(50)", "", "") + .row("sex", "char(1)", "", "") + .row("state", "char(5)", "", "") + .build(); + + assertTrue(actual.equals(expectedParametrizedVarchar) || actual.equals(expectedUnparametrizedVarchar), + format("%s matches neither %s nor %s", actual, expectedParametrizedVarchar, expectedUnparametrizedVarchar)); + } + + @Test + public void testPredicateOnCharColumn() + { + MaterializedResult actualRow = computeActual("SELECT * from member WHERE state = 'CD'"); + MaterializedResult expectedRow = resultBuilder(getSession(), INTEGER, createVarcharType(50), createCharType(1), createCharType(5)) + .row(2, "MARY", "F", "CD ") + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + @Test + public void testSelectTime() + { + MaterializedResult actualRow = computeActual("SELECT * from event WHERE id = 1"); + Session session = getSession(); + MaterializedResult expectedRow = resultBuilder(session, INTEGER, DATE, TIME, TIMESTAMP) + .row(1, + getDate("2004-12-31"), + getTimeAtZone("23:59:59", session.getTimeZoneKey()), + getDateTimeAtZone("2005-12-31 23:59:59", session.getTimeZoneKey())) + .build(); + assertTrue(actualRow.equals(expectedRow)); + } + + private LocalDate getDate(String dateString) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd"); + LocalDate localDate = LocalDate.parse(dateString, formatter); + + return localDate; + } + + private LocalTime getTimeAtZone(String timeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("HH:mm:ss"); + LocalTime localTime = LocalTime.parse(timeString, formatter); + + LocalDateTime localDateTime = LocalDateTime.of(LocalDate.of(1970, 1, 1), localTime); + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalTime localTimeAtZone = zonedDateTime.toLocalTime(); + return localTimeAtZone; + } + + private LocalDateTime getDateTimeAtZone(String dateTimeString, TimeZoneKey timeZoneKey) + { + DateTimeFormatter formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss"); + LocalDateTime localDateTime = LocalDateTime.parse(dateTimeString, formatter); + + ZonedDateTime localZonedDateTime = localDateTime.atZone(ZoneId.systemDefault()); + + ZoneId zoneId = ZoneId.of(timeZoneKey.getId()); + ZonedDateTime zonedDateTime = localZonedDateTime.withZoneSameInstant(zoneId); + + LocalDateTime localDateTimeAtZone = zonedDateTime.toLocalDateTime(); + return localDateTimeAtZone; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java new file mode 100644 index 0000000000000..ea95e9fec01b0 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowHandleResolver.java @@ -0,0 +1,67 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + +public class TestArrowHandleResolver +{ + @Test + public void testGetTableHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableHandleClass(), + ArrowTableHandle.class, + "getTableHandleClass should return ArrowTableHandle class."); + } + @Test + public void testGetTableLayoutHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTableLayoutHandleClass(), + ArrowTableLayoutHandle.class, + "getTableLayoutHandleClass should return ArrowTableLayoutHandle class."); + } + @Test + public void testGetColumnHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getColumnHandleClass(), + ArrowColumnHandle.class, + "getColumnHandleClass should return ArrowColumnHandle class."); + } + @Test + public void testGetSplitClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getSplitClass(), + ArrowSplit.class, + "getSplitClass should return ArrowSplit class."); + } + @Test + public void testGetTransactionHandleClass() + { + ArrowHandleResolver resolver = new ArrowHandleResolver(); + assertEquals( + resolver.getTransactionHandleClass(), + ArrowTransactionHandle.class, + "getTransactionHandleClass should return ArrowTransactionHandle class."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowPlugin.java new file mode 100644 index 0000000000000..3ea8fb68d9236 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowPlugin.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.plugin.arrow.testingConnector.TestingArrowFlightPlugin; +import com.facebook.presto.spi.connector.ConnectorFactory; +import org.testng.annotations.Test; + +import static com.facebook.airlift.testing.Assertions.assertInstanceOf; +import static com.google.common.collect.Iterables.getOnlyElement; + +public class TestArrowPlugin +{ + @Test + public void testStartup() + { + ArrowPlugin plugin = new TestingArrowFlightPlugin(); + ConnectorFactory factory = getOnlyElement(plugin.getConnectorFactories()); + assertInstanceOf(factory, ArrowConnectorFactory.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java new file mode 100644 index 0000000000000..1bc77989d444c --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowSplit.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.presto.spi.HostAddress; +import com.facebook.presto.spi.schedule.NodeSelectionStrategy; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.Location; +import org.apache.arrow.flight.Ticket; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import java.net.URISyntaxException; +import java.nio.ByteBuffer; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotNull; +import static org.testng.Assert.assertTrue; + +@Test(singleThreaded = true) +public class TestArrowSplit +{ + private ArrowSplit arrowSplit; + private String schemaName; + private String tableName; + private FlightEndpoint flightEndpoint; + + @BeforeMethod + public void setUp() + throws URISyntaxException + { + schemaName = "testSchema"; + tableName = "testTable"; + byte[] ticketArray = new byte[] {1, 2, 3, 4}; + Ticket ticket = new Ticket(ByteBuffer.wrap(ticketArray).array()); // Wrap the byte array in a Ticket + Location location = new Location("http://localhost:8080"); + flightEndpoint = new FlightEndpoint(ticket, location); + // Instantiate ArrowSplit with mock data + arrowSplit = new ArrowSplit(schemaName, tableName, flightEndpoint.serialize().array()); + } + + @Test + public void testConstructorAndGetters() + { + // Test that the constructor correctly initializes fields + assertEquals(arrowSplit.getSchemaName(), schemaName, "Schema name should match."); + assertEquals(arrowSplit.getTableName(), tableName, "Table name should match."); + assertEquals(arrowSplit.getFlightEndpointBytes(), flightEndpoint.serialize().array(), "Byte array should match"); + } + + @Test + public void testNodeSelectionStrategy() + { + // Test that the node selection strategy is NO_PREFERENCE + assertEquals(arrowSplit.getNodeSelectionStrategy(), NodeSelectionStrategy.NO_PREFERENCE, "Node selection strategy should be NO_PREFERENCE."); + } + + @Test + public void testGetPreferredNodes() + { + // Test that the preferred nodes list is empty + List preferredNodes = arrowSplit.getPreferredNodes(null); + assertNotNull(preferredNodes, "Preferred nodes list should not be null."); + assertTrue(preferredNodes.isEmpty(), "Preferred nodes list should be empty."); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java new file mode 100644 index 0000000000000..2061fe5036534 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableHandle.java @@ -0,0 +1,37 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow; + +import com.facebook.airlift.testing.EquivalenceTester; +import org.testng.annotations.Test; + +import static com.facebook.plugin.arrow.ArrowMetadataUtil.TABLE_CODEC; +import static com.facebook.plugin.arrow.ArrowMetadataUtil.assertJsonRoundTrip; + +public class TestArrowTableHandle +{ + @Test + public void testJsonRoundTrip() + { + assertJsonRoundTrip(TABLE_CODEC, new ArrowTableHandle("schema", "table")); + } + + @Test + public void testEquivalence() + { + EquivalenceTester.equivalenceTester() + .addEquivalentGroup( + new ArrowTableHandle("tm_engine", "employees")).check(); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java new file mode 100644 index 0000000000000..0cdf7eb14e67c --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/TestArrowTableLayoutHandle.java @@ -0,0 +1,116 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow; + +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import org.testng.annotations.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertNotEquals; + +public class TestArrowTableLayoutHandle +{ + @Test + public void testConstructorAndGetters() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + assertEquals(layoutHandle.getTable(), tableHandle, "Table handle mismatch."); + assertEquals(layoutHandle.getColumnHandles(), columnHandles, "Column handles mismatch."); + assertEquals(layoutHandle.getTupleDomain(), tupleDomain, "Tuple domain mismatch."); + } + + @Test + public void testToString() + { + ArrowTableHandle tableHandle = new ArrowTableHandle("schema", "table"); + List columnHandles = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", BigintType.BIGINT)); + TupleDomain tupleDomain = TupleDomain.all(); + + ArrowTableLayoutHandle layoutHandle = new ArrowTableLayoutHandle(tableHandle, columnHandles, tupleDomain); + + String expectedString = "table:" + tableHandle + ", columnHandles:" + columnHandles + ", tupleDomain:" + tupleDomain; + assertEquals(layoutHandle.toString(), expectedString, "toString output mismatch."); + } + + @Test + public void testEqualsAndHashCode() + { + ArrowTableHandle tableHandle1 = new ArrowTableHandle("schema", "table"); + ArrowTableHandle tableHandle2 = new ArrowTableHandle("schema", "different_table"); + + List columnHandles1 = Arrays.asList( + new ArrowColumnHandle("column1", IntegerType.INTEGER), + new ArrowColumnHandle("column2", VarcharType.VARCHAR)); + List columnHandles2 = Collections.singletonList( + new ArrowColumnHandle("column1", IntegerType.INTEGER)); + + TupleDomain tupleDomain1 = TupleDomain.all(); + TupleDomain tupleDomain2 = TupleDomain.none(); + + ArrowTableLayoutHandle layoutHandle1 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle2 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle3 = new ArrowTableLayoutHandle(tableHandle2, columnHandles1, tupleDomain1); + ArrowTableLayoutHandle layoutHandle4 = new ArrowTableLayoutHandle(tableHandle1, columnHandles2, tupleDomain1); + ArrowTableLayoutHandle layoutHandle5 = new ArrowTableLayoutHandle(tableHandle1, columnHandles1, tupleDomain2); + + // Test equality + assertEquals(layoutHandle1, layoutHandle2, "Handles with same attributes should be equal."); + assertNotEquals(layoutHandle1, layoutHandle3, "Handles with different tableHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle4, "Handles with different columnHandles should not be equal."); + assertNotEquals(layoutHandle1, layoutHandle5, "Handles with different tupleDomains should not be equal."); + assertNotEquals(layoutHandle1, null, "Handle should not be equal to null."); + assertNotEquals(layoutHandle1, new Object(), "Handle should not be equal to an object of another class."); + + // Test hash codes + assertEquals(layoutHandle1.hashCode(), layoutHandle2.hashCode(), "Equal handles should have same hash code."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle3.hashCode(), "Handles with different tableHandles should have different hash codes."); + assertNotEquals(layoutHandle1.hashCode(), layoutHandle4.hashCode(), "Handles with different columnHandles should have different hash codes."); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "table is null") + public void testConstructorNullTableHandle() + { + new ArrowTableLayoutHandle(null, Collections.emptyList(), TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "columnHandles is null") + public void testConstructorNullColumnHandles() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), null, TupleDomain.all()); + } + + @Test(expectedExceptions = NullPointerException.class, expectedExceptionsMessageRegExp = "tupleDomain is null") + public void testConstructorNullTupleDomain() + { + new ArrowTableLayoutHandle(new ArrowTableHandle("schema", "table"), Collections.emptyList(), null); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java new file mode 100644 index 0000000000000..a8752e0690f5e --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowBlockBuilder.java @@ -0,0 +1,57 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.TypeManager; +import com.facebook.presto.common.type.VarcharType; +import org.apache.arrow.vector.types.pojo.Field; + +import javax.inject.Inject; + +import java.util.Optional; + +public class TestingArrowBlockBuilder + extends ArrowBlockBuilder +{ + @Inject + public TestingArrowBlockBuilder(TypeManager typeManager) + { + super(typeManager); + } + + @Override + protected Type getPrestoTypeFromArrowField(Field field) + { + String columnLength = field.getMetadata().get("columnLength"); + int length = columnLength != null ? Integer.parseInt(columnLength) : 0; + + String nativeType = Optional.ofNullable(field.getMetadata().get("columnNativeType")).orElse(""); + + switch (nativeType) { + case "CHAR": + case "CHARACTER": + return CharType.createCharType(length); + case "VARCHAR": + return VarcharType.createVarcharType(length); + case "TIME": + return TimeType.TIME; + default: + return super.getPrestoTypeFromArrowField(field); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java new file mode 100644 index 0000000000000..4d35cab2cc1bb --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightClientHandler.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.plugin.arrow.ArrowFlightConfig; +import com.facebook.plugin.arrow.ArrowTableHandle; +import com.facebook.plugin.arrow.ArrowTableLayoutHandle; +import com.facebook.plugin.arrow.BaseArrowFlightClientHandler; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse; +import com.facebook.presto.spi.ConnectorSession; +import com.facebook.presto.spi.SchemaTableName; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.CallOption; +import org.apache.arrow.flight.CallOptions; +import org.apache.arrow.flight.FlightClient; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.auth2.BearerCredentialWriter; +import org.apache.arrow.flight.grpc.CredentialCallOption; +import org.apache.arrow.memory.BufferAllocator; + +import javax.inject.Inject; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.TimeUnit; + +import static com.facebook.presto.common.Utils.checkArgument; +import static java.util.Locale.ENGLISH; +import static java.util.Objects.requireNonNull; + +public class TestingArrowFlightClientHandler + extends BaseArrowFlightClientHandler +{ + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + @Inject + public TestingArrowFlightClientHandler( + BufferAllocator allocator, + ArrowFlightConfig config, + JsonCodec requestCodec, + JsonCodec responseCodec) + { + super(allocator, config); + this.requestCodec = requireNonNull(requestCodec, "requestCodec is null"); + this.responseCodec = requireNonNull(responseCodec, "responseCodec is null"); + } + + @Override + public CallOption[] getCallOptions(ConnectorSession connectorSession) + { + return new CallOption[] { + new CredentialCallOption(new BearerCredentialWriter(null)), + CallOptions.timeout(300, TimeUnit.SECONDS) + }; + } + + @Override + public FlightDescriptor getFlightDescriptorForSchema(String schemaName, String tableName) + { + TestingArrowFlightRequest request = TestingArrowFlightRequest.createDescribeTableRequest(schemaName, tableName); + return FlightDescriptor.command(requestCodec.toBytes(request)); + } + + @Override + public List listSchemaNames(ConnectorSession session) + { + List res; + try (FlightClient client = createFlightClient()) { + List names1 = new ArrayList<>(); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createListSchemaRequest(); + Iterator iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session)); + while (iterator.hasNext()) { + Result result = iterator.next(); + TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody()); + checkArgument(response != null, "response is null"); + checkArgument(response.getSchemaNames() != null, "response.getSchemaNames() is null"); + names1.addAll(response.getSchemaNames()); + } + res = names1; + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + List listSchemas = res; + List names = new ArrayList<>(); + for (String value : listSchemas) { + names.add(value.toLowerCase(ENGLISH)); + } + return ImmutableList.copyOf(names); + } + + @Override + public List listTables(ConnectorSession session, Optional schemaName) + { + String schemaValue = schemaName.orElse(""); + List res; + try (FlightClient client = createFlightClient()) { + List names = new ArrayList<>(); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createListTablesRequest(schemaName.orElse("")); + Iterator iterator = client.doAction(new Action("discovery", requestCodec.toJsonBytes(request)), getCallOptions(session)); + while (iterator.hasNext()) { + Result result = iterator.next(); + TestingArrowFlightResponse response = responseCodec.fromJson(result.getBody()); + checkArgument(response != null, "response is null"); + checkArgument(response.getTableNames() != null, "response.getTableNames() is null"); + names.addAll(response.getTableNames()); + } + res = names; + } + catch (InterruptedException e) { + throw new RuntimeException(e); + } + List listTables = res; + List tables = new ArrayList<>(); + for (String value : listTables) { + tables.add(new SchemaTableName(schemaValue.toLowerCase(ENGLISH), value.toLowerCase(ENGLISH))); + } + + return tables; + } + + @Override + public FlightDescriptor getFlightDescriptorForTableScan(ArrowTableLayoutHandle tableLayoutHandle) + { + ArrowTableHandle tableHandle = tableLayoutHandle.getTable(); + String query = new TestingArrowQueryBuilder().buildSql( + tableHandle.getSchema(), + tableHandle.getTable(), + tableLayoutHandle.getColumnHandles(), ImmutableMap.of(), + tableLayoutHandle.getTupleDomain()); + TestingArrowFlightRequest request = TestingArrowFlightRequest.createQueryRequest(tableHandle.getSchema(), tableHandle.getTable(), query); + return FlightDescriptor.command(requestCodec.toBytes(request)); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java new file mode 100644 index 0000000000000..196098c3151c7 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowFlightPlugin.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.airlift.json.JsonModule; +import com.facebook.plugin.arrow.ArrowPlugin; + +public class TestingArrowFlightPlugin + extends ArrowPlugin +{ + public TestingArrowFlightPlugin() + { + super("arrow", new TestingArrowModule(), new JsonModule()); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java new file mode 100644 index 0000000000000..de9bff4cb83a9 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowModule.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowBlockBuilder; +import com.facebook.plugin.arrow.BaseArrowFlightClientHandler; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightRequest; +import com.facebook.plugin.arrow.testingServer.TestingArrowFlightResponse; +import com.google.inject.Binder; +import com.google.inject.Module; +import com.google.inject.Scopes; + +import static com.facebook.airlift.json.JsonCodecBinder.jsonCodecBinder; + +public class TestingArrowModule + implements Module +{ + @Override + public void configure(Binder binder) + { + // Concrete implementation of the BaseFlightClientHandler + binder.bind(BaseArrowFlightClientHandler.class).to(TestingArrowFlightClientHandler.class).in(Scopes.SINGLETON); + // Override the ArrowBlockBuilder with an implementation that handles h2 types + binder.bind(ArrowBlockBuilder.class).to(TestingArrowBlockBuilder.class).in(Scopes.SINGLETON); + // Request/response objects + jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightRequest.class); + jsonCodecBinder(binder).bindJsonCodec(TestingArrowFlightResponse.class); + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowQueryBuilder.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowQueryBuilder.java new file mode 100644 index 0000000000000..6229dadab1921 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingConnector/TestingArrowQueryBuilder.java @@ -0,0 +1,306 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingConnector; + +import com.facebook.plugin.arrow.ArrowColumnHandle; +import com.facebook.presto.common.predicate.Domain; +import com.facebook.presto.common.predicate.Range; +import com.facebook.presto.common.predicate.TupleDomain; +import com.facebook.presto.common.type.BigintType; +import com.facebook.presto.common.type.BooleanType; +import com.facebook.presto.common.type.CharType; +import com.facebook.presto.common.type.DateType; +import com.facebook.presto.common.type.DoubleType; +import com.facebook.presto.common.type.IntegerType; +import com.facebook.presto.common.type.RealType; +import com.facebook.presto.common.type.SmallintType; +import com.facebook.presto.common.type.TimeType; +import com.facebook.presto.common.type.TimestampType; +import com.facebook.presto.common.type.TinyintType; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnHandle; +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableList; +import io.airlift.slice.Slice; + +import java.sql.Time; +import java.sql.Timestamp; +import java.text.SimpleDateFormat; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Map; +import java.util.TimeZone; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.google.common.base.Preconditions.checkArgument; +import static com.google.common.base.Preconditions.checkState; +import static com.google.common.base.Strings.isNullOrEmpty; +import static com.google.common.collect.Iterables.getOnlyElement; +import static java.lang.Float.intBitsToFloat; +import static java.lang.Math.toIntExact; +import static java.lang.String.format; +import static java.util.Objects.requireNonNull; +import static java.util.stream.Collectors.joining; + +public class TestingArrowQueryBuilder +{ + // not all databases support booleans, so use 1=1 and 1=0 instead + private static final String ALWAYS_TRUE = "1=1"; + private static final String ALWAYS_FALSE = "1=0"; + public static final String DATE_FORMAT = "yyyy-MM-dd"; + public static final String TIMESTAMP_FORMAT = "yyyy-MM-dd HH:mm:ss.SSS"; + public static final String TIME_FORMAT = "HH:mm:ss"; + public static final TimeZone UTC_TIME_ZONE = TimeZone.getTimeZone(ZoneId.of("UTC")); + + public String buildSql( + String schema, + String table, + List columns, + Map columnExpressions, + TupleDomain tupleDomain) + { + StringBuilder sql = new StringBuilder(); + + sql.append("SELECT "); + sql.append(addColumnExpression(columns, columnExpressions)); + + sql.append(" FROM "); + if (!isNullOrEmpty(schema)) { + sql.append(quote(schema)).append('.'); + } + sql.append(quote(table)); + + List accumulator = new ArrayList<>(); + + if (tupleDomain != null && !tupleDomain.isAll()) { + List clauses = toConjuncts(columns, tupleDomain, accumulator); + if (!clauses.isEmpty()) { + sql.append(" WHERE ") + .append(Joiner.on(" AND ").join(clauses)); + } + } + + return sql.toString(); + } + + public static String convertEpochToString(long epochValue, Type type) + { + if (type instanceof DateType) { + long millis = TimeUnit.DAYS.toMillis(epochValue); + Date date = new Date(millis); + SimpleDateFormat dateFormat = new SimpleDateFormat(DATE_FORMAT); + dateFormat.setTimeZone(UTC_TIME_ZONE); + return dateFormat.format(date); + } + else if (type instanceof TimestampType) { + Timestamp timestamp = new Timestamp(epochValue); + SimpleDateFormat timestampFormat = new SimpleDateFormat(TIMESTAMP_FORMAT); + timestampFormat.setTimeZone(UTC_TIME_ZONE); + return timestampFormat.format(timestamp); + } + else if (type instanceof TimeType) { + long millis = TimeUnit.SECONDS.toMillis(epochValue / 1000); + Time time = new Time(millis); + SimpleDateFormat timeFormat = new SimpleDateFormat(TIME_FORMAT); + timeFormat.setTimeZone(UTC_TIME_ZONE); + return timeFormat.format(time); + } + else { + throw new UnsupportedOperationException(type + " is not supported."); + } + } + + protected static class TypeAndValue + { + private final Type type; + private final Object value; + + public TypeAndValue(Type type, Object value) + { + this.type = requireNonNull(type, "type is null"); + this.value = requireNonNull(value, "value is null"); + } + + public Type getType() + { + return type; + } + + public Object getValue() + { + return value; + } + } + + private String addColumnExpression(List columns, Map columnExpressions) + { + if (columns.isEmpty()) { + return "null"; + } + + return columns.stream() + .map(arrowColumnHandle -> { + String columnAlias = quote(arrowColumnHandle.getColumnName()); + String expression = columnExpressions.get(arrowColumnHandle.getColumnName()); + if (expression == null) { + return columnAlias; + } + return format("%s AS %s", expression, columnAlias); + }) + .collect(joining(", ")); + } + + private static boolean isAcceptedType(Type type) + { + Type validType = requireNonNull(type, "type is null"); + return validType.equals(BigintType.BIGINT) || + validType.equals(TinyintType.TINYINT) || + validType.equals(SmallintType.SMALLINT) || + validType.equals(IntegerType.INTEGER) || + validType.equals(DoubleType.DOUBLE) || + validType.equals(RealType.REAL) || + validType.equals(BooleanType.BOOLEAN) || + validType.equals(DateType.DATE) || + validType.equals(TimeType.TIME) || + validType.equals(TimestampType.TIMESTAMP) || + validType instanceof VarcharType || + validType instanceof CharType; + } + private List toConjuncts(List columns, TupleDomain tupleDomain, List accumulator) + { + ImmutableList.Builder builder = ImmutableList.builder(); + for (ArrowColumnHandle column : columns) { + Type type = column.getColumnType(); + if (isAcceptedType(type)) { + Domain domain = tupleDomain.getDomains().get().get(column); + if (domain != null) { + builder.add(toPredicate(column.getColumnName(), domain, column, accumulator)); + } + } + } + return builder.build(); + } + + private String toPredicate(String columnName, Domain domain, ArrowColumnHandle columnHandle, List accumulator) + { + checkArgument(domain.getType().isOrderable(), "Domain type must be orderable"); + + if (domain.getValues().isNone()) { + return domain.isNullAllowed() ? quote(columnName) + " IS NULL" : ALWAYS_FALSE; + } + + if (domain.getValues().isAll()) { + return domain.isNullAllowed() ? ALWAYS_TRUE : quote(columnName) + " IS NOT NULL"; + } + + List disjuncts = new ArrayList<>(); + List singleValues = new ArrayList<>(); + for (Range range : domain.getValues().getRanges().getOrderedRanges()) { + checkState(!range.isAll()); // Already checked + if (range.isSingleValue()) { + singleValues.add(range.getSingleValue()); + } + else { + List rangeConjuncts = new ArrayList<>(); + if (!range.isLowUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isLowInclusive() ? ">=" : ">", range.getLowBoundedValue(), columnHandle, accumulator)); + } + if (!range.isHighUnbounded()) { + rangeConjuncts.add(toPredicate(columnName, range.isHighInclusive() ? "<=" : "<", range.getHighBoundedValue(), columnHandle, accumulator)); + } + // If rangeConjuncts is null, then the range was ALL, which should already have been checked for + checkState(!rangeConjuncts.isEmpty()); + disjuncts.add("(" + Joiner.on(" AND ").join(rangeConjuncts) + ")"); + } + } + + // Add back all of the possible single values either as an equality or an IN predicate + if (singleValues.size() == 1) { + disjuncts.add(toPredicate(columnName, "=", getOnlyElement(singleValues), columnHandle, accumulator)); + } + else if (singleValues.size() > 1) { + for (Object value : singleValues) { + bindValue(value, columnHandle, accumulator); + } + String values = Joiner.on(",").join(singleValues.stream().map(v -> + parameterValueToString(columnHandle.getColumnType(), v)).collect(Collectors.toList())); + disjuncts.add(quote(columnName) + " IN (" + values + ")"); + } + + // Add nullability disjuncts + checkState(!disjuncts.isEmpty()); + if (domain.isNullAllowed()) { + disjuncts.add(quote(columnName) + " IS NULL"); + } + + return "(" + Joiner.on(" OR ").join(disjuncts) + ")"; + } + + private String toPredicate(String columnName, String operator, Object value, ArrowColumnHandle columnHandle, List accumulator) + { + bindValue(value, columnHandle, accumulator); + return quote(columnName) + " " + operator + " " + parameterValueToString(columnHandle.getColumnType(), value); + } + private String quote(String name) + { + return "\"" + name + "\""; + } + + private String quoteValue(String name) + { + return "'" + name + "'"; + } + + private void bindValue(Object value, ArrowColumnHandle columnHandle, List accumulator) + { + Type type = columnHandle.getColumnType(); + accumulator.add(new TypeAndValue(type, value)); + } + + public static String convertLongToFloatString(Long value) + { + float floatFromIntBits = intBitsToFloat(toIntExact(value)); + return String.valueOf(floatFromIntBits); + } + + private String parameterValueToString(Type type, Object value) + { + Class javaType = type.getJavaType(); + if (type instanceof DateType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimeType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof TimestampType && javaType == long.class) { + return quoteValue(convertEpochToString((Long) value, type)); + } + else if (type instanceof RealType && javaType == long.class) { + return convertLongToFloatString((Long) value); + } + else if (javaType == boolean.class || javaType == double.class || javaType == long.class) { + return value.toString(); + } + else if (javaType == Slice.class) { + return quoteValue(((Slice) value).toStringUtf8()); + } + else { + return quoteValue(value.toString()); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java new file mode 100644 index 0000000000000..31d6bf0d4e20d --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightRequest.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingServer; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +import java.util.Optional; + +public class TestingArrowFlightRequest +{ + private final Optional schema; + private final Optional table; + private final Optional query; + + @JsonCreator + public TestingArrowFlightRequest( + @JsonProperty("schema") Optional schema, + @JsonProperty("table") Optional table, + @JsonProperty("query") Optional query) + { + this.schema = schema; + this.table = table; + this.query = query; + } + + public static TestingArrowFlightRequest createListSchemaRequest() + { + return new TestingArrowFlightRequest(Optional.empty(), Optional.empty(), Optional.empty()); + } + + public static TestingArrowFlightRequest createListTablesRequest(String schema) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.empty(), Optional.empty()); + } + + public static TestingArrowFlightRequest createDescribeTableRequest(String schema, String table) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.of(table), Optional.empty()); + } + + public static TestingArrowFlightRequest createQueryRequest(String schema, String table, String query) + { + return new TestingArrowFlightRequest(Optional.of(schema), Optional.of(table), Optional.of(query)); + } + + @JsonProperty + public Optional getSchema() + { + return schema; + } + + @JsonProperty + public Optional getTable() + { + return table; + } + + @JsonProperty + public Optional getQuery() + { + return query; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightResponse.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightResponse.java new file mode 100644 index 0000000000000..9eacc6632c7bd --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowFlightResponse.java @@ -0,0 +1,47 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingServer; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; +import com.google.common.collect.ImmutableList; + +import java.util.List; + +import static java.util.Objects.requireNonNull; + +public class TestingArrowFlightResponse +{ + private final List schemaNames; + private final List tableNames; + + @JsonCreator + public TestingArrowFlightResponse(@JsonProperty("schemaNames") List schemaNames, @JsonProperty("tableNames") List tableNames) + { + this.schemaNames = ImmutableList.copyOf(requireNonNull(schemaNames, "schemaNames is null")); + this.tableNames = ImmutableList.copyOf(requireNonNull(tableNames, "tableNames is null")); + } + + @JsonProperty + public List getSchemaNames() + { + return schemaNames; + } + + @JsonProperty + public List getTableNames() + { + return tableNames; + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java new file mode 100644 index 0000000000000..a29ba4aaf8724 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingArrowProducer.java @@ -0,0 +1,309 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.facebook.plugin.arrow.testingServer; + +import com.facebook.airlift.json.JsonCodec; +import com.facebook.airlift.log.Logger; +import com.google.common.collect.ImmutableList; +import org.apache.arrow.adapter.jdbc.ArrowVectorIterator; +import org.apache.arrow.adapter.jdbc.JdbcToArrow; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfig; +import org.apache.arrow.adapter.jdbc.JdbcToArrowConfigBuilder; +import org.apache.arrow.flight.Action; +import org.apache.arrow.flight.ActionType; +import org.apache.arrow.flight.Criteria; +import org.apache.arrow.flight.FlightDescriptor; +import org.apache.arrow.flight.FlightEndpoint; +import org.apache.arrow.flight.FlightInfo; +import org.apache.arrow.flight.FlightProducer; +import org.apache.arrow.flight.FlightStream; +import org.apache.arrow.flight.PutResult; +import org.apache.arrow.flight.Result; +import org.apache.arrow.flight.Ticket; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.VectorLoader; +import org.apache.arrow.vector.VectorSchemaRoot; +import org.apache.arrow.vector.VectorUnloader; +import org.apache.arrow.vector.ipc.message.ArrowRecordBatch; +import org.apache.arrow.vector.types.DateUnit; +import org.apache.arrow.vector.types.FloatingPointPrecision; +import org.apache.arrow.vector.types.TimeUnit; +import org.apache.arrow.vector.types.pojo.ArrowType; +import org.apache.arrow.vector.types.pojo.Field; +import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.types.pojo.Schema; + +import java.io.IOException; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TimeZone; +import java.util.concurrent.ThreadLocalRandom; + +import static com.facebook.airlift.json.JsonCodec.jsonCodec; +import static com.facebook.presto.common.Utils.checkArgument; +import static org.apache.arrow.adapter.jdbc.JdbcToArrowUtils.jdbcToArrowSchema; + +public class TestingArrowProducer + implements FlightProducer +{ + private final BufferAllocator allocator; + private final Connection connection; + private static final Logger logger = Logger.get(TestingArrowProducer.class); + private final JsonCodec requestCodec; + private final JsonCodec responseCodec; + + public TestingArrowProducer(BufferAllocator allocator) throws Exception + { + this.allocator = allocator; + String h2JdbcUrl = "jdbc:h2:mem:testdb" + System.nanoTime() + "_" + ThreadLocalRandom.current().nextInt() + ";DB_CLOSE_DELAY=-1"; + TestingH2DatabaseSetup.setup(h2JdbcUrl); + this.connection = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + this.requestCodec = jsonCodec(TestingArrowFlightRequest.class); + this.responseCodec = jsonCodec(TestingArrowFlightResponse.class); + } + + @Override + public void getStream(CallContext callContext, Ticket ticket, ServerStreamListener serverStreamListener) + { + try (Statement stmt = connection.createStatement()) { + TestingArrowFlightRequest request = requestCodec.fromJson(ticket.getBytes()); + checkArgument(request != null, "Request is null"); + checkArgument(request.getQuery().isPresent(), "Query is missing"); + + // Extract and validate the SQL query + String query = request.getQuery().get(); + if (query.trim().isEmpty()) { + throw new IllegalArgumentException("Query cannot be empty."); + } + + logger.debug("Executing query: %s", query); + + try (ResultSet resultSet = stmt.executeQuery(query.toUpperCase())) { + JdbcToArrowConfig config = new JdbcToArrowConfigBuilder().setAllocator(allocator).setTargetBatchSize(2048) + .setCalendar(Calendar.getInstance(TimeZone.getDefault())).build(); + Schema schema = jdbcToArrowSchema(resultSet.getMetaData(), config); + try (VectorSchemaRoot streamRoot = VectorSchemaRoot.create(schema, allocator)) { + VectorLoader loader = new VectorLoader(streamRoot); + serverStreamListener.start(streamRoot); + ArrowVectorIterator iterator = JdbcToArrow.sqlToArrowVectorIterator(resultSet, config); + + while (iterator.hasNext()) { + try (VectorSchemaRoot iteratorRoot = iterator.next()) { + VectorUnloader vectorUnloader = new VectorUnloader(iteratorRoot); + try (ArrowRecordBatch batch = vectorUnloader.getRecordBatch()) { + loader.load(batch); + serverStreamListener.putNext(); + } + streamRoot.clear(); + } + } + } + serverStreamListener.completed(); + } + } + // Handle Arrow processing errors + catch (IOException e) { + logger.error("Arrow data processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process Arrow data", e); + } + // Handle all other exceptions, including parsing errors + catch (Exception e) { + logger.error("Ticket processing failed", e); + serverStreamListener.error(e); + throw new RuntimeException("Failed to process the ticket", e); + } + } + + @Override + public void listFlights(CallContext callContext, Criteria criteria, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public FlightInfo getFlightInfo(CallContext callContext, FlightDescriptor flightDescriptor) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(flightDescriptor.getCommand()); + checkArgument(request != null, "Request is null"); + + checkArgument(request.getSchema().isPresent(), "Schema is missing"); + String schemaName = request.getSchema().get(); + Optional tableName = request.getTable(); + String selectStatement = request.getQuery().orElse(null); + + List fields = new ArrayList<>(); + if (tableName.isPresent()) { + String query = "SELECT * FROM INFORMATION_SCHEMA.COLUMNS " + + "WHERE TABLE_SCHEMA='" + schemaName.toUpperCase() + "' " + + "AND TABLE_NAME='" + tableName.get().toUpperCase() + "'"; + + try (ResultSet rs = connection.createStatement().executeQuery(query)) { + while (rs.next()) { + String columnName = rs.getString("COLUMN_NAME"); + String dataType = rs.getString("TYPE_NAME"); + String charMaxLength = rs.getString("CHARACTER_MAXIMUM_LENGTH"); + int precision = rs.getInt("NUMERIC_PRECISION"); + int scale = rs.getInt("NUMERIC_SCALE"); + + ArrowType arrowType = convertSqlTypeToArrowType(dataType, precision, scale); + Map metaDataMap = new HashMap<>(); + metaDataMap.put("columnNativeType", dataType); + if (charMaxLength != null) { + metaDataMap.put("columnLength", charMaxLength); + } + FieldType fieldType = new FieldType(true, arrowType, null, metaDataMap); + Field field = new Field(columnName, fieldType, null); + fields.add(field); + } + } + } + else if (selectStatement != null) { + selectStatement = selectStatement.toUpperCase(); + logger.debug("Executing SELECT query: %s", selectStatement); + try (ResultSet rs = connection.createStatement().executeQuery(selectStatement)) { + ResultSetMetaData metaData = rs.getMetaData(); + int columnCount = metaData.getColumnCount(); + + for (int i = 1; i <= columnCount; i++) { + String columnName = metaData.getColumnName(i); + String columnType = metaData.getColumnTypeName(i); + int precision = metaData.getPrecision(i); + int scale = metaData.getScale(i); + + ArrowType arrowType = convertSqlTypeToArrowType(columnType, precision, scale); + Field field = new Field(columnName, FieldType.nullable(arrowType), null); + fields.add(field); + } + } + } + else { + throw new IllegalArgumentException("Either schema_name/table_name or select_statement must be provided."); + } + + Schema schema = new Schema(fields); + FlightEndpoint endpoint = new FlightEndpoint(new Ticket(flightDescriptor.getCommand())); + return new FlightInfo(schema, flightDescriptor, Collections.singletonList(endpoint), -1, -1); + } + catch (Exception e) { + logger.error(e); + throw new RuntimeException("Failed to retrieve FlightInfo", e); + } + } + + @Override + public Runnable acceptPut(CallContext callContext, FlightStream flightStream, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + @Override + public void doAction(CallContext callContext, Action action, StreamListener streamListener) + { + try { + TestingArrowFlightRequest request = requestCodec.fromJson(action.getBody()); + Optional schemaName = request.getSchema(); + + String query; + if (!schemaName.isPresent()) { + query = "SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"; + } + else { + query = "SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA='" + schemaName.get().toUpperCase() + "'"; + } + ResultSet rs = connection.createStatement().executeQuery(query); + List names = new ArrayList<>(); + while (rs.next()) { + names.add(rs.getString(1)); + } + + TestingArrowFlightResponse response; + if (!schemaName.isPresent()) { + response = new TestingArrowFlightResponse(names, ImmutableList.of()); + } + else { + response = new TestingArrowFlightResponse(ImmutableList.of(), names); + } + + streamListener.onNext(new Result(responseCodec.toJsonBytes(response))); + streamListener.onCompleted(); + } + catch (Exception e) { + streamListener.onError(e); + } + } + + @Override + public void listActions(CallContext callContext, StreamListener streamListener) + { + throw new UnsupportedOperationException("This operation is not supported"); + } + + private ArrowType convertSqlTypeToArrowType(String sqlType, int precision, int scale) + { + switch (sqlType.toUpperCase()) { + case "VARCHAR": + case "CHAR": + case "CHARACTER VARYING": + case "CHARACTER": + case "CLOB": + return new ArrowType.Utf8(); + case "INTEGER": + case "INT": + return new ArrowType.Int(32, true); + case "BIGINT": + return new ArrowType.Int(64, true); + case "SMALLINT": + return new ArrowType.Int(16, true); + case "TINYINT": + return new ArrowType.Int(8, true); + case "DOUBLE": + case "DOUBLE PRECISION": + case "FLOAT": + return new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE); + case "REAL": + return new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE); + case "BOOLEAN": + return new ArrowType.Bool(); + case "DATE": + return new ArrowType.Date(DateUnit.DAY); + case "TIMESTAMP": + return new ArrowType.Timestamp(TimeUnit.MILLISECOND, null); + case "TIME": + return new ArrowType.Time(TimeUnit.MILLISECOND, 32); + case "DECIMAL": + case "NUMERIC": + return new ArrowType.Decimal(precision, scale); + case "BINARY": + case "VARBINARY": + return new ArrowType.Binary(); + case "NULL": + return new ArrowType.Null(); + default: + throw new IllegalArgumentException("Unsupported SQL type: " + sqlType); + } + } +} diff --git a/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java new file mode 100644 index 0000000000000..0b8f8b5db9ef5 --- /dev/null +++ b/presto-base-arrow-flight/src/test/java/com/facebook/plugin/arrow/testingServer/TestingH2DatabaseSetup.java @@ -0,0 +1,273 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.facebook.plugin.arrow.testingServer; + +import com.facebook.airlift.log.Logger; +import com.facebook.presto.common.type.Type; +import com.facebook.presto.common.type.VarcharType; +import com.facebook.presto.spi.ColumnMetadata; +import com.facebook.presto.spi.ConnectorTableMetadata; +import com.facebook.presto.spi.RecordCursor; +import com.facebook.presto.spi.RecordSet; +import com.facebook.presto.spi.SchemaTableName; +import com.facebook.presto.tpch.TpchMetadata; +import com.facebook.presto.tpch.TpchTableHandle; +import com.google.common.base.Joiner; +import io.airlift.tpch.TpchTable; +import org.jdbi.v3.core.Handle; +import org.jdbi.v3.core.Jdbi; +import org.jdbi.v3.core.statement.PreparedBatch; +import org.joda.time.DateTimeZone; + +import java.sql.Connection; +import java.sql.Date; +import java.sql.DriverManager; +import java.sql.ResultSet; +import java.sql.Statement; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; + +import static com.facebook.presto.common.type.BigintType.BIGINT; +import static com.facebook.presto.common.type.BooleanType.BOOLEAN; +import static com.facebook.presto.common.type.DateType.DATE; +import static com.facebook.presto.common.type.DoubleType.DOUBLE; +import static com.facebook.presto.common.type.IntegerType.INTEGER; +import static com.facebook.presto.tpch.TpchMetadata.TINY_SCHEMA_NAME; +import static com.facebook.presto.tpch.TpchRecordSet.createTpchRecordSet; +import static com.google.common.collect.ImmutableList.toImmutableList; +import static io.airlift.tpch.TpchTable.CUSTOMER; +import static io.airlift.tpch.TpchTable.LINE_ITEM; +import static io.airlift.tpch.TpchTable.NATION; +import static io.airlift.tpch.TpchTable.ORDERS; +import static io.airlift.tpch.TpchTable.PART; +import static io.airlift.tpch.TpchTable.PART_SUPPLIER; +import static io.airlift.tpch.TpchTable.REGION; +import static io.airlift.tpch.TpchTable.SUPPLIER; +import static java.lang.String.format; +import static java.util.Collections.nCopies; + +public class TestingH2DatabaseSetup +{ + private static final Logger logger = Logger.get(TestingH2DatabaseSetup.class); + private TestingH2DatabaseSetup() + { + throw new UnsupportedOperationException("This is a utility class and cannot be instantiated"); + } + + public static void setup(String h2JdbcUrl) throws Exception + { + Class.forName("org.h2.Driver"); + + Connection conn = DriverManager.getConnection(h2JdbcUrl, "sa", ""); + + Jdbi jdbi = Jdbi.create(h2JdbcUrl, "sa", ""); + Handle handle = jdbi.open(); // Get a handle for the database connection + + TpchMetadata tpchMetadata = new TpchMetadata(""); + + Statement stmt = conn.createStatement(); + + // Create schema + stmt.execute("CREATE SCHEMA IF NOT EXISTS tpch"); + + stmt.execute("CREATE TABLE tpch.member (" + + " id INTEGER PRIMARY KEY," + + " name VARCHAR(50)," + + " sex CHAR(1)," + + " state CHAR(5)" + + ")"); + stmt.execute("INSERT INTO tpch.member VALUES(1, 'TOM', 'M', 'TMX '),(2, 'MARY', 'F', 'CD ')"); + + stmt.execute("CREATE TABLE tpch.event (" + + " id INTEGER PRIMARY KEY," + + " startDate DATE," + + " startTime TIME," + + " startTimestamp TIMESTAMP" + + ")"); + stmt.execute("INSERT INTO tpch.event VALUES(1, DATE '2004-12-31', TIME '23:59:59'," + + " TIMESTAMP '2005-12-31 23:59:59')"); + + stmt.execute("CREATE TABLE tpch.orders (\n" + + " orderkey BIGINT PRIMARY KEY,\n" + + " custkey BIGINT NOT NULL,\n" + + " orderstatus VARCHAR(1) NOT NULL,\n" + + " totalprice DOUBLE NOT NULL,\n" + + " orderdate DATE NOT NULL,\n" + + " orderpriority VARCHAR(15) NOT NULL,\n" + + " clerk VARCHAR(15) NOT NULL,\n" + + " shippriority INTEGER NOT NULL,\n" + + " comment VARCHAR(79) NOT NULL\n" + + ")"); + stmt.execute("CREATE INDEX custkey_index ON tpch.orders (custkey)"); + insertRows(tpchMetadata, ORDERS, handle); + + handle.execute("CREATE TABLE tpch.lineitem (\n" + + " orderkey BIGINT,\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " linenumber INTEGER,\n" + + " quantity DOUBLE NOT NULL,\n" + + " extendedprice DOUBLE NOT NULL,\n" + + " discount DOUBLE NOT NULL,\n" + + " tax DOUBLE NOT NULL,\n" + + " returnflag CHAR(1) NOT NULL,\n" + + " linestatus CHAR(1) NOT NULL,\n" + + " shipdate DATE NOT NULL,\n" + + " commitdate DATE NOT NULL,\n" + + " receiptdate DATE NOT NULL,\n" + + " shipinstruct VARCHAR(25) NOT NULL,\n" + + " shipmode VARCHAR(10) NOT NULL,\n" + + " comment VARCHAR(44) NOT NULL,\n" + + " PRIMARY KEY (orderkey, linenumber)" + + ")"); + insertRows(tpchMetadata, LINE_ITEM, handle); + + handle.execute(" CREATE TABLE tpch.partsupp (\n" + + " partkey BIGINT NOT NULL,\n" + + " suppkey BIGINT NOT NULL,\n" + + " availqty INTEGER NOT NULL,\n" + + " supplycost DOUBLE NOT NULL,\n" + + " comment VARCHAR(199) NOT NULL,\n" + + " PRIMARY KEY(partkey, suppkey)" + + ")"); + insertRows(tpchMetadata, PART_SUPPLIER, handle); + + handle.execute("CREATE TABLE tpch.nation (\n" + + " nationkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " regionkey BIGINT NOT NULL,\n" + + " comment VARCHAR(152) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, NATION, handle); + + handle.execute("CREATE TABLE tpch.region(\n" + + " regionkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(25) NOT NULL,\n" + + " comment VARCHAR(115) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, REGION, handle); + handle.execute("CREATE TABLE tpch.part(\n" + + " partkey BIGINT PRIMARY KEY,\n" + + " name VARCHAR(55) NOT NULL,\n" + + " mfgr VARCHAR(25) NOT NULL,\n" + + " brand VARCHAR(10) NOT NULL,\n" + + " type VARCHAR(25) NOT NULL,\n" + + " size INTEGER NOT NULL,\n" + + " container VARCHAR(10) NOT NULL,\n" + + " retailprice DOUBLE NOT NULL,\n" + + " comment VARCHAR(23) NOT NULL\n" + + ")"); + insertRows(tpchMetadata, PART, handle); + handle.execute(" CREATE TABLE tpch.customer ( \n" + + " custkey BIGINT NOT NULL, \n" + + " name VARCHAR(25) NOT NULL, \n" + + " address VARCHAR(40) NOT NULL, \n" + + " nationkey BIGINT NOT NULL, \n" + + " phone VARCHAR(15) NOT NULL, \n" + + " acctbal DOUBLE NOT NULL, \n" + + " mktsegment VARCHAR(10) NOT NULL, \n" + + " comment VARCHAR(117) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, CUSTOMER, handle); + handle.execute(" CREATE TABLE tpch.supplier ( \n" + + " suppkey bigint NOT NULL, \n" + + " name varchar(25) NOT NULL, \n" + + " address varchar(40) NOT NULL, \n" + + " nationkey bigint NOT NULL, \n" + + " phone varchar(15) NOT NULL, \n" + + " acctbal double NOT NULL, \n" + + " comment varchar(101) NOT NULL \n" + + " ) "); + insertRows(tpchMetadata, SUPPLIER, handle); + + ResultSet resultSet1 = stmt.executeQuery("SELECT TABLE_NAME FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_SCHEMA = 'TPCH'"); + List tables = new ArrayList<>(); + while (resultSet1.next()) { + String tableName = resultSet1.getString("TABLE_NAME"); + tables.add(tableName); + } + logger.info("Tables in 'tpch' schema: %s", tables.stream().collect(Collectors.joining(", "))); + + ResultSet resultSet = stmt.executeQuery("SELECT SCHEMA_NAME FROM INFORMATION_SCHEMA.SCHEMATA"); + List schemas = new ArrayList<>(); + while (resultSet.next()) { + String schemaName = resultSet.getString("SCHEMA_NAME"); + schemas.add(schemaName); + } + logger.info("Schemas: %s", schemas.stream().collect(Collectors.joining(", "))); + } + + private static void insertRows(TpchMetadata tpchMetadata, TpchTable tpchTable, Handle handle) + { + TpchTableHandle tableHandle = tpchMetadata.getTableHandle(null, new SchemaTableName(TINY_SCHEMA_NAME, tpchTable.getTableName())); + insertRows(tpchMetadata.getTableMetadata(null, tableHandle), handle, createTpchRecordSet(tpchTable, tableHandle.getScaleFactor())); + } + + private static void insertRows(ConnectorTableMetadata tableMetadata, Handle handle, RecordSet data) + { + List columns = tableMetadata.getColumns().stream() + .filter(columnMetadata -> !columnMetadata.isHidden()) + .collect(toImmutableList()); + + String schemaName = "tpch"; + String tableNameWithSchema = schemaName + "." + tableMetadata.getTable().getTableName(); + String vars = Joiner.on(',').join(nCopies(columns.size(), "?")); + String sql = format("INSERT INTO %s VALUES (%s)", tableNameWithSchema, vars); + + RecordCursor cursor = data.cursor(); + while (true) { + // insert 1000 rows at a time + PreparedBatch batch = handle.prepareBatch(sql); + for (int row = 0; row < 1000; row++) { + if (!cursor.advanceNextPosition()) { + if (batch.size() > 0) { + batch.execute(); + } + return; + } + for (int column = 0; column < columns.size(); column++) { + Type type = columns.get(column).getType(); + if (BOOLEAN.equals(type)) { + batch.bind(column, cursor.getBoolean(column)); + } + else if (BIGINT.equals(type)) { + batch.bind(column, cursor.getLong(column)); + } + else if (INTEGER.equals(type)) { + batch.bind(column, (int) cursor.getLong(column)); + } + else if (DOUBLE.equals(type)) { + batch.bind(column, cursor.getDouble(column)); + } + else if (type instanceof VarcharType) { + batch.bind(column, cursor.getSlice(column).toStringUtf8()); + } + else if (DATE.equals(type)) { + long millisUtc = TimeUnit.DAYS.toMillis(cursor.getLong(column)); + // H2 expects dates in to be millis at midnight in the JVM timezone + long localMillis = DateTimeZone.UTC.getMillisKeepLocal(DateTimeZone.getDefault(), millisUtc); + batch.bind(column, new Date(localMillis)); + } + else { + throw new IllegalArgumentException("Unsupported type " + type); + } + } + batch.add(); + } + batch.execute(); + } + } +} diff --git a/presto-base-arrow-flight/src/test/resources/server.crt b/presto-base-arrow-flight/src/test/resources/server.crt new file mode 100644 index 0000000000000..f070253c9e119 --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.crt @@ -0,0 +1,32 @@ +-----BEGIN CERTIFICATE----- +MIIFlTCCA32gAwIBAgIUf2p3qdxzpsOofYDpXNXDJA1fjbgwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCQVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MCAX +DTI0MTIyMTA3MzUwMFoYDzIxMjQxMTI3MDczNTAwWjBZMQswCQYDVQQGEwJBVTET +MBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50ZXJuZXQgV2lkZ2l0cyBQ +dHkgTHRkMRIwEAYDVQQDDAlsb2NhbGhvc3QwggIiMA0GCSqGSIb3DQEBAQUAA4IC +DwAwggIKAoICAQCmsVs+OHGQ/GlR9+KpgcKkBvXZQAUtCJzFG27GXBYoltJDAxsi +o47KbGkUWxf5E+4oFbbz7s8OHSuQnjP2XVHbKvmgicf3uP98fKkQtjvl+xn8vT18 +s6a8Kv8hj+f6MRWGwpHa/sQKU9uZmRYTRh+32vDZRtAvmpzkf6+2K8B1fFbwVmuC +j3ULb+0iYHnomC3aMWBFXkxjEmsamx4YK74NtQU98+EjQZwhWgWXhW5shS1kSs4r +3N6++tonBz+tDKAhCMueRRJAQXGjKqL7qDZn7wpk53L/fZT9mgRYyA+PN2ND9L0H +nMGMjJl71p42kkgIGOpllsmK6+g0Bj6aC/uCEnX2AtM0g57Th7U2aLwgOmRshC0s +uuBHxMWUgzJsB1dscXrFPPB+XVcUlgcwRbGsQG5VzK/rYRV4y1FmF5vSLY60mF43 +hFDAcnh4mnnBkobca5Dl6PSUpFmDkF56IUgYQCTrE6hPYnrIJKhPcfpmr/tm/Acr +ra1sPp/QPSFIxI9j7Nzm/QsOBF3Zy4AbbbOmhJOjNtwEi59r9za/FhxV5kSN/YM9 +HyyYYebxW/jXF/7hzQzWYfJBz1SgdD9prl4ml8VMVJdhZmBTzhVciPXizRUKkbD+ +LvQKw8q4a24/VruUvW15J39qalhdyWf3vqGuORWowpG7oYnGXik3kYVT5QIDAQAB +o1MwUTAdBgNVHQ4EFgQUkkKtG5568IDoFtn3AK0Yes0CqGgwHwYDVR0jBBgwFoAU +kkKtG5568IDoFtn3AK0Yes0CqGgwDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0B +AQsFAAOCAgEAVjwxyO5N4EGQHf9euKiTjksUMeJrJkPMFYS/1UmQB5/yc0TkOxv3 +M8yvTzkzxWAQV2X+CPbOLgqO2eCoSev2GHVUyHLLUEFATzAgTpF79Ovzq6JkSgFo +GFxTL0TuA+d++430wYyts7iYY9uT5Xt8xp4+SwXrCvLWxHosP3xGUmNvY4sP0hdS +beIvdGE/J8izgy5DRt2fWZ03mmwfKqiz/qhKGj9DDsHkA/1jyKIivP/nufr9dzDr +41lhk1N7qFWkOjbMd06NYySIe0MaapIkenjT1IOgqGw2f98RfSEomoaXuxN0VoSM +6dZ4rN97cER25X7/zE0zCZurjCLHzPTyuTspYGEK+9U4plOeWh0keQEcdpcHTAr+ +NqU3VlhXVxz91nVREpRJmKk+r4c+xdrfY3YkDcSJ1dazbd3eS05Ggx51KOqes8Zb +hFQfhIDqvaqXDNlkBezLpr4v/MU69+cp7SOn5uPnOccS6sd7fzl/PUZhEjd5ZIws +8SX79OwhQjbYZYRHJSPPasb8B1amULtoo0pJ5izSkXxileiGuXhRO5stiBux7SL+ +oJAztjuRf0IvP6LWOMHgqquzc2JiEDCz0DPnTCqoXGZlT2HPGzXOoDSTOmRdFx+L +qi/DeY+MpIMVov/rplqjydXw6AuQDxcV1GvyjMvaHxJG5MEBC/mVeqQ= +-----END CERTIFICATE----- \ No newline at end of file diff --git a/presto-base-arrow-flight/src/test/resources/server.key b/presto-base-arrow-flight/src/test/resources/server.key new file mode 100644 index 0000000000000..4524c26943f2d --- /dev/null +++ b/presto-base-arrow-flight/src/test/resources/server.key @@ -0,0 +1,52 @@ +-----BEGIN PRIVATE KEY----- +MIIJQwIBADANBgkqhkiG9w0BAQEFAASCCS0wggkpAgEAAoICAQCmsVs+OHGQ/GlR +9+KpgcKkBvXZQAUtCJzFG27GXBYoltJDAxsio47KbGkUWxf5E+4oFbbz7s8OHSuQ +njP2XVHbKvmgicf3uP98fKkQtjvl+xn8vT18s6a8Kv8hj+f6MRWGwpHa/sQKU9uZ +mRYTRh+32vDZRtAvmpzkf6+2K8B1fFbwVmuCj3ULb+0iYHnomC3aMWBFXkxjEmsa +mx4YK74NtQU98+EjQZwhWgWXhW5shS1kSs4r3N6++tonBz+tDKAhCMueRRJAQXGj +KqL7qDZn7wpk53L/fZT9mgRYyA+PN2ND9L0HnMGMjJl71p42kkgIGOpllsmK6+g0 +Bj6aC/uCEnX2AtM0g57Th7U2aLwgOmRshC0suuBHxMWUgzJsB1dscXrFPPB+XVcU +lgcwRbGsQG5VzK/rYRV4y1FmF5vSLY60mF43hFDAcnh4mnnBkobca5Dl6PSUpFmD +kF56IUgYQCTrE6hPYnrIJKhPcfpmr/tm/Acrra1sPp/QPSFIxI9j7Nzm/QsOBF3Z +y4AbbbOmhJOjNtwEi59r9za/FhxV5kSN/YM9HyyYYebxW/jXF/7hzQzWYfJBz1Sg +dD9prl4ml8VMVJdhZmBTzhVciPXizRUKkbD+LvQKw8q4a24/VruUvW15J39qalhd +yWf3vqGuORWowpG7oYnGXik3kYVT5QIDAQABAoICAE0MeIzTgSbPjQz+w82u9V1k +/DlNfrb4nqH7EqJsSS+8uvaPlnjV2fgV0SJAEt4mCLSNiPHKpfkzoYHopkMPknj4 +Lcc3OG94GtubMXhQi3I7tSDeBfBAh+a9Bw2n20WJb5ZJFCsCDHJrnXsrSAljpeCR +OjdsJGmEkVWK8ZiGM6D6dqMDhxEjpynAs/7qUh8hTDxpC0M1GaDHkCsNnQT2HxRt +4jznH97wgi7mUeReIBLYIgmUDCU5I9ppz/EvSA8AYXmze46uBYge19xgJlKlR3SW +CJtoYf7XOMlZ6f1xh8OeiesM0l0U51/EU2Nq6dl2lwXrIlkPsBve/AckBcalmDwh +tJgAOHM8VEMp0imJ8+SPGZupiIt+sTVSt4aq2uB5dtTBDcXGx4gVUYJoekI4yQtv +OWKAQAobq0Cutogx8hyXr2tSizp5pGpJDKx8/G/3YcNmRqCq+HQgwshlHQy/hPqs +QC/jcuOkr5GOZASF3wLCuLmAwFCELL96iVEJvPLe5Yb1u4HbCrsy9/N8rHV1s1qa +xMcGiIS4m/fwLbD+TjgM1foESeQciHH7GWVZ3Osn6Inpp6vLJKDhLhGBgpaYF72S +0fft6CkIy/CqbwXPe/3Wm0PL0SLTsSk0lKhw6Zeuvg0Z0jUrmDmlTrGo8j4zU7cA +Gc8uiZuFAU2ZoS9R7SZXAoIBAQDSfF8H4C8jn9IvTw/S3APkP7VyLKFSTitEfagt +qcAZr7sYWYGdYDL1hX+jBopAWdzbJ+N3MZBccH67f3XDZWVH4GDQWF6tyOpyseBe +42aU/yCr6ZSoxi91W9V+oyiTFh+t2cEkNyGey/pbXVBw45jncHpzIAaERm5gD8DD +MxOadIvDxEGquGs9MUgZg6ABFmqvQIfiZ7I3PvfoHJGrHsPyM9i6AqAkJe1rKiN3 +tt2wiDk6zDor4vFKn+5LaRQ5INbpgBuCJrvDsRh3MeKLb0ddCYcAStHeAjHJZbH4 +OiC5y9hpaIcSQN3Z4gEhdMd+Lhi1+z36uN5k7KCvbGmO3sG/AoIBAQDKvMiVqTgm +a36FFswwZJd5o+Hx6G8Nt5y7ZYLJgoq4wyTPYA5ZIHBp2Jtb//IGJMKpeLcRigxx +nKu4zrLZPHesgYzHbCH9i5fOBFl8yF9pZnbCU/zGOYhiVdn7MHZgJw8nlstWr1Jr +cMnxBBXEjgL65Lrse1noiPATNrvBsirFaJky6HhxRPbkijKKd0Xqc0FeHMxt6IHU +y6yZRzguI0M1A8RGR4CrqsOGrdv0vkMSiumkUdz8JW4w9R69n6ax/qNAZSMql0ss +PIPOqGtYRPhibOhXKPl0p30X32YXx/SKm8+L9Sr1ny76dab/bSnxStOdihGJLCs4 +l7vFkuJgTMtbAoIBAQCYA3CSfIsu3EbtGdlgvLsmxggh7C+aBJBlB6dFSzpMksi5 +njLo2MgU35Q9xgRk00GZGWbC94295RTyDuya8IjD7z2cWqYONnNz4BkeDndQli0f +WzOc7Hzr8iXvLqCoEatRYFmH8TUbvU8TWwI0dXtBcs9Mg82RDFi8kcPyddnri85A +1WVjiYsRh5z9qD0PbAQii6VXkvJ3ycc64B8oCbEUI/Oa6ziCws2DvswcsnnK+6bx +WvuMJHuFHJn55mrPk3MC8h1r0tN6UlVMCEAH2ZcdjzrrsB1/i/Au9n4gusJVzO1/ +uxkJysUujXWplvBYpav9CfVKNOeQ1gB6kP5vS1t7AoIBAB1DquCPkJ9bHOQxKkBC +BOt2EINOvdkJDAKw4HQd99A7uvCEOQ38dL2Smrpo85KXc9HqruJFPw6XQuJmU8Kv +y8aG3L9ciHuEzuDaF+C/O6aHN9VNMkuaumkXY2Oy1yOB/9oDFk7o98iyezPjFxFM +Pnng0mqYU54RRjY/zFJlWW8tbg+/JsOS5OCQYkNCfEEfaewf1BJ5YWRKEhv9/8oJ +JQZeCNLsN1KQT7D9H6bwX9YpXxhtCK0M6h7/AvT0OqeuzfnZn33iYON9yLjn7rbL +Hd93QQJz065XDuOHR8FfB5mKbCcTuKPD2pAks3pjU46U8n7nEyjtyz9cB6q5TRwB +eckCggEBAMd/3riFnoc2VF6LsA0VQzZiJleN0G9g3AF/5acqHHjMv6aJs/AjV7ZF +hFdiqTcQFEg87HSOzrtV2DVNWJc7DMcMIuwGhVec+D4wedJ2wrQv+MEgySltSBZq +wcPVn5IQiml38NnG/tPIrHETb0PIoa8+iu/Jg80o7j8M3+DKVKfCyfh334PjFK6W +B/mkgC9PfcfeA/Doby9pJsRnqmAJTeWjxbefksckI4PcRCLMEwggB2ReiIWef8Q+ +IooNIxypWtBWtpNEl5lhO7Y2f65Whp34TjooXmGBl1lj4szO8PNnf9QA865J86OS +kOoFda4Sn7LajkbSX0wTGMuXDpmx34M= +-----END PRIVATE KEY----- \ No newline at end of file diff --git a/presto-docs/src/main/sphinx/connector.rst b/presto-docs/src/main/sphinx/connector.rst index f07506b460cab..d337fe4ed12d1 100644 --- a/presto-docs/src/main/sphinx/connector.rst +++ b/presto-docs/src/main/sphinx/connector.rst @@ -9,6 +9,7 @@ from different data sources. :maxdepth: 1 connector/accumulo + connector/base-arrow-flight connector/bigquery connector/blackhole connector/cassandra diff --git a/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst new file mode 100644 index 0000000000000..5459e3cc3e43a --- /dev/null +++ b/presto-docs/src/main/sphinx/connector/base-arrow-flight.rst @@ -0,0 +1,95 @@ + +====================== +Arrow Flight Connector +====================== +This connector allows querying multiple data sources that are supported by an Arrow Flight server. Flight supports parallel transfers, allowing data to be streamed to or from a cluster of servers simultaneously. Official documentation for Arrow Flight can be found at `Arrow Flight RPC `_. + +Getting Started with presto-base-arrow-flight module: Essential Abstract Methods for Developers +--------------------------------------------------------------------------------- +To create a plugin extending the presto-base-arrow-flight module, you need to implement certain abstract methods that are specific to your use case. Below are the required classes and their purposes: + +- ``BaseArrowFlightClientHandler.java`` + This class contains the core functionality for the Arrow Flight module. A plugin extending base-arrow-module should extend this abstract class and implement the following methods. + + - ``getCallOptions`` This method should return an array of credential call options to authenticate to the Flight server. + - ``getFlightDescriptorForSchema`` This method should return the flight descriptor to fetch Arrow Schema for the table. + - ``listSchemaNames`` This method should return the list of schemas in the catalog. + - ``listTables`` This method should return the list of tables in the catalog. + - ``getFlightDescriptorForTableScan`` This method should return the flight descriptor for fetching data from the table. + +- ``ArrowPlugin.java`` + Register your connector name by extending the ArrowPlugin class. +- ``ArrowBlockBuilder.java`` (optional override to customize data types) + This class builds Presto blocks from Arrow vectors. Extend this class if needed and override ``getPrestoTypeFromArrowField`` method, if any customizations are needed for the conversion of Arrow vector to Presto type. A binding for this class should be created in the ``Module`` for the plugin. + +A reference implementation of the presto-base-arrow-flight module is provided in the test folder, containing a Flight server and a connector implementation. +The testing Flight server in ``com.facebook.plugin.arrow.testingServer``, starts a local server and initializes an H2 database to fetch data from. The server defines ``TestingArrowFlightRequest`` and ``TestingArrowFlightResponse`` used for commands in the Flight calls, and the ``TestingArrowProducer`` handles the calls including actions for ``listSchemaNames`` and ``listTables``. +The testing Flight connector in ``com.facebook.plugin.arrow.testingConnector``, implements the above classes to connect with the testing Flight server to use as a data source for test queries. + + +Configuration +------------- +Create a catalog file +in ``etc/catalog`` named, for example, ``arrowmariadb.properties``, to +mount the Flight connector as the ``arrowmariadb`` catalog. +Create the file with the following contents, replacing the +connection properties as appropriate for your setup: + + +.. code-block:: none + + + connector.name= + arrow-flight.server= + arrow-flight.server.port= + + + +Add other properties that are required for your Flight server to connect. + +========================================== ============================================================== +Property Name Description +========================================== ============================================================== +``arrow-flight.server`` Endpoint of the Flight server +``arrow-flight.server.port`` Flight server port +``arrow-flight.server-ssl-certificate`` Pass ssl certificate +``arrow-flight.server.verify`` To verify server +``arrow-flight.server-ssl-enabled`` Port is ssl enabled +========================================== ============================================================== + +Querying Arrow-Flight +--------------------- + +The Flight connector provides schema for each supported *database*. +Example for MariaDB is shown below. +To see the available schemas, run ``SHOW SCHEMAS``:: + + SHOW SCHEMAS FROM arrowmariadb; + +To view the tables in the MariaDB database named ``user``, +run ``SHOW TABLES``:: + + SHOW TABLES FROM arrowmariadb.user; + +To see a list of the columns in the ``admin`` table in the ``user`` database, +use either of the following commands:: + + DESCRIBE arrowmariadb.user.admin; + SHOW COLUMNS FROM arrowmariadb.user.admin; + +Finally, you can access the ``admin`` table in the ``user`` database:: + + SELECT * FROM arrowmariadb.user.admin; + +If you used a different name for your catalog properties file, use +that catalog name instead of ``arrowmariadb`` in the above examples. + + +Flight Connector Limitations +---------------------------- + +* SELECT and DESCRIBE queries are supported. Implementing modules can add support for additional features. + +* The Flight connector can query against only those datasources which are supported by the Flight server. + +* The Flight server must be running for the Flight connector to work.