diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/ArrowAllocation.java b/arrow/src/main/java/org/apache/iceberg/arrow/ArrowAllocation.java new file mode 100644 index 000000000000..49882ce90690 --- /dev/null +++ b/arrow/src/main/java/org/apache/iceberg/arrow/ArrowAllocation.java @@ -0,0 +1,37 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.arrow; + +import org.apache.arrow.memory.RootAllocator; + +public class ArrowAllocation { + static { + ROOT_ALLOCATOR = new RootAllocator(Long.MAX_VALUE); + } + + private static final RootAllocator ROOT_ALLOCATOR; + + private ArrowAllocation() { + } + + public static RootAllocator rootAllocator() { + return ROOT_ALLOCATOR; + } +} diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/IcebergArrowVectors.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/IcebergArrowVectors.java index d6fa260a58f6..a82fa57e1e43 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/IcebergArrowVectors.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/IcebergArrowVectors.java @@ -21,8 +21,8 @@ import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.DecimalVector; -import org.apache.arrow.vector.VarBinaryVector; import org.apache.arrow.vector.VarCharVector; +import org.apache.arrow.vector.holders.NullableVarCharHolder; /** * The general way of getting a value at an index in the Arrow vector @@ -64,38 +64,9 @@ public void setNullabilityHolder(NullabilityHolder nullabilityHolder) { } } - /** - * Extension of Arrow's @{@link VarBinaryVector}. The whole reason of having this implementation is to override the - * expensive {@link VarBinaryVector#isSet(int)} method. - */ - public static class VarBinaryArrowVector extends VarBinaryVector { - private NullabilityHolder nullabilityHolder; - - public VarBinaryArrowVector( - String name, - BufferAllocator allocator) { - super(name, allocator); - } - - /** - * Same as {@link #isNull(int)}. - * - * @param index position of element - * @return 1 if element at given index is not null, 0 otherwise - */ - @Override - public int isSet(int index) { - return nullabilityHolder.isNullAt(index) ^ 1; - } - - public void setNullabilityHolder(NullabilityHolder nullabilityHolder) { - this.nullabilityHolder = nullabilityHolder; - } - } - /** * Extension of Arrow's @{@link VarCharVector}. The reason of having this implementation is to override the expensive - * {@link VarCharVector#isSet(int)} method. + * {@link VarCharVector#isSet(int)} method called by {@link VarCharVector#get(int, NullableVarCharHolder)} */ public static class VarcharArrowVector extends VarCharVector { diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorHolder.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorHolder.java index d59292f14101..b938d3845c19 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorHolder.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorHolder.java @@ -19,35 +19,48 @@ package org.apache.iceberg.arrow.vectorized; +import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.FieldVector; +import org.apache.iceberg.types.Type; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.Dictionary; /** - * Container class for holding the Arrow vector holding a batch of values along with other state needed for reading + * Container class for holding the Arrow vector storing a batch of values along with other state needed for reading * values out of it. */ public class VectorHolder { private final ColumnDescriptor columnDescriptor; private final FieldVector vector; private final boolean isDictionaryEncoded; - private final Dictionary dictionary; private final NullabilityHolder nullabilityHolder; - - public static final VectorHolder NULL_VECTOR_HOLDER = new VectorHolder(null, null, false, null, null); + private final Type icebergType; public VectorHolder( - ColumnDescriptor columnDescriptor, - FieldVector vector, - boolean isDictionaryEncoded, - Dictionary dictionary, - NullabilityHolder holder) { + ColumnDescriptor columnDescriptor, FieldVector vector, boolean isDictionaryEncoded, + Dictionary dictionary, NullabilityHolder holder, Type type) { + // All the fields except dictionary are not nullable unless it is a dummy holder + Preconditions.checkNotNull(columnDescriptor, "ColumnDescriptor cannot be null"); + Preconditions.checkNotNull(vector, "Vector cannot be null"); + Preconditions.checkNotNull(holder, "NullabilityHolder cannot be null"); + Preconditions.checkNotNull(type, "IcebergType cannot be null"); this.columnDescriptor = columnDescriptor; this.vector = vector; this.isDictionaryEncoded = isDictionaryEncoded; this.dictionary = dictionary; this.nullabilityHolder = holder; + this.icebergType = type; + } + + // Only used for returning dummy holder + private VectorHolder() { + columnDescriptor = null; + vector = null; + isDictionaryEncoded = false; + dictionary = null; + nullabilityHolder = null; + icebergType = null; } public ColumnDescriptor descriptor() { @@ -69,4 +82,26 @@ public Dictionary dictionary() { public NullabilityHolder nullabilityHolder() { return nullabilityHolder; } + + public Type icebergType() { + return icebergType; + } + + public int numValues() { + return vector.getValueCount(); + } + + public static VectorHolder dummyHolder(int numRows) { + return new VectorHolder() { + @Override + public int numValues() { + return numRows; + } + }; + } + + public boolean isDummy() { + return vector == null; + } + } diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java index cbe3eacc9139..dbde001b9764 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/VectorizedArrowReader.java @@ -48,9 +48,8 @@ import org.apache.parquet.schema.PrimitiveType; /** - * {@link VectorizedReader VectorReader(s)} that read in a batch of values into Arrow vectors. - * It also takes care of allocating the right kind of Arrow vectors depending on the corresponding - * Iceberg/Parquet data types. + * {@link VectorizedReader VectorReader(s)} that read in a batch of values into Arrow vectors. It also takes care of + * allocating the right kind of Arrow vectors depending on the corresponding Iceberg/Parquet data types. */ public class VectorizedArrowReader implements VectorizedReader { public static final int DEFAULT_BATCH_SIZE = 5000; @@ -58,33 +57,30 @@ public class VectorizedArrowReader implements VectorizedReader { private static final int AVERAGE_VARIABLE_WIDTH_RECORD_SIZE = 10; private final ColumnDescriptor columnDescriptor; - private final int batchSize; private final VectorizedColumnIterator vectorizedColumnIterator; private final Types.NestedField icebergField; private final BufferAllocator rootAlloc; + + private int batchSize; private FieldVector vec; private Integer typeWidth; private ReadType readType; - private boolean reuseContainers = true; private NullabilityHolder nullabilityHolder; // In cases when Parquet employs fall back to plain encoding, we eagerly decode the dictionary encoded pages // before storing the values in the Arrow vector. This means even if the dictionary is present, data // present in the vector may not necessarily be dictionary encoded. private Dictionary dictionary; - private boolean allPagesDictEncoded; public VectorizedArrowReader( ColumnDescriptor desc, Types.NestedField icebergField, BufferAllocator ra, - int batchSize, boolean setArrowValidityVector) { this.icebergField = icebergField; - this.batchSize = (batchSize == 0) ? DEFAULT_BATCH_SIZE : batchSize; this.columnDescriptor = desc; this.rootAlloc = ra; - this.vectorizedColumnIterator = new VectorizedColumnIterator(desc, "", batchSize, setArrowValidityVector); + this.vectorizedColumnIterator = new VectorizedColumnIterator(desc, "", setArrowValidityVector); } private VectorizedArrowReader() { @@ -96,21 +92,37 @@ private VectorizedArrowReader() { } private enum ReadType { - FIXED_LENGTH_DECIMAL, INT_LONG_BACKED_DECIMAL, VARCHAR, VARBINARY, FIXED_WIDTH_BINARY, - BOOLEAN, INT, LONG, FLOAT, DOUBLE, TIMESTAMP_MILLIS + FIXED_LENGTH_DECIMAL, + INT_LONG_BACKED_DECIMAL, + VARCHAR, + VARBINARY, + FIXED_WIDTH_BINARY, + BOOLEAN, + INT, + LONG, + FLOAT, + DOUBLE, + TIMESTAMP_MILLIS + } + + @Override + public void setBatchSize(int batchSize) { + this.batchSize = (batchSize == 0) ? DEFAULT_BATCH_SIZE : batchSize; + this.vectorizedColumnIterator.setBatchSize(batchSize); } @Override - public VectorHolder read(int numValsToRead) { - if (vec == null || !reuseContainers) { - allocateFieldVector(); + public VectorHolder read(VectorHolder reuse, int numValsToRead) { + if (reuse == null) { + allocateFieldVector(this.vectorizedColumnIterator.producesDictionaryEncodedVector()); nullabilityHolder = new NullabilityHolder(batchSize); } else { vec.setValueCount(0); nullabilityHolder.reset(); } + boolean dictEncoded = vectorizedColumnIterator.producesDictionaryEncodedVector(); if (vectorizedColumnIterator.hasNext()) { - if (allPagesDictEncoded) { + if (dictEncoded) { vectorizedColumnIterator.nextBatchDictionaryIds((IntVector) vec, nullabilityHolder); } else { switch (readType) { @@ -123,7 +135,6 @@ public VectorHolder read(int numValsToRead) { vectorizedColumnIterator.nextBatchIntLongBackedDecimal(vec, typeWidth, nullabilityHolder); break; case VARBINARY: - ((IcebergArrowVectors.VarBinaryArrowVector) vec).setNullabilityHolder(nullabilityHolder); vectorizedColumnIterator.nextBatchVarWidthType(vec, nullabilityHolder); break; case VARCHAR: @@ -131,7 +142,6 @@ public VectorHolder read(int numValsToRead) { vectorizedColumnIterator.nextBatchVarWidthType(vec, nullabilityHolder); break; case FIXED_WIDTH_BINARY: - ((IcebergArrowVectors.VarBinaryArrowVector) vec).setNullabilityHolder(nullabilityHolder); vectorizedColumnIterator.nextBatchFixedWidthBinary(vec, typeWidth, nullabilityHolder); break; case BOOLEAN: @@ -157,11 +167,12 @@ public VectorHolder read(int numValsToRead) { } Preconditions.checkState(vec.getValueCount() == numValsToRead, "Number of values read, %s, does not equal expected, %s", vec.getValueCount(), numValsToRead); - return new VectorHolder(columnDescriptor, vec, allPagesDictEncoded, dictionary, nullabilityHolder); + return new VectorHolder(columnDescriptor, vec, dictEncoded, dictionary, + nullabilityHolder, icebergField.type()); } - private void allocateFieldVector() { - if (allPagesDictEncoded) { + private void allocateFieldVector(boolean dictionaryEncodedVector) { + if (dictionaryEncodedVector) { Field field = new Field( icebergField.name(), new FieldType(icebergField.isOptional(), new ArrowType.Int(Integer.SIZE, true), null, null), @@ -182,7 +193,7 @@ private void allocateFieldVector() { //TODO: Possibly use the uncompressed page size info to set the initial capacity vec.setInitialCapacity(batchSize * AVERAGE_VARIABLE_WIDTH_RECORD_SIZE); vec.allocateNewSafe(); - this.readType = ReadType.VARCHAR; + this.readType = ReadType.VARCHAR; this.typeWidth = UNKNOWN_WIDTH; break; case INT_8: @@ -190,31 +201,31 @@ private void allocateFieldVector() { case INT_32: this.vec = arrowField.createVector(rootAlloc); ((IntVector) vec).allocateNew(batchSize); - this.readType = ReadType.INT; + this.readType = ReadType.INT; this.typeWidth = (int) IntVector.TYPE_WIDTH; break; case DATE: this.vec = arrowField.createVector(rootAlloc); ((DateDayVector) vec).allocateNew(batchSize); - this.readType = ReadType.INT; + this.readType = ReadType.INT; this.typeWidth = (int) IntVector.TYPE_WIDTH; break; case INT_64: this.vec = arrowField.createVector(rootAlloc); ((BigIntVector) vec).allocateNew(batchSize); - this.readType = ReadType.LONG; + this.readType = ReadType.LONG; this.typeWidth = (int) BigIntVector.TYPE_WIDTH; break; case TIMESTAMP_MILLIS: this.vec = arrowField.createVector(rootAlloc); ((BigIntVector) vec).allocateNew(batchSize); - this.readType = ReadType.TIMESTAMP_MILLIS; + this.readType = ReadType.TIMESTAMP_MILLIS; this.typeWidth = (int) BigIntVector.TYPE_WIDTH; break; case TIMESTAMP_MICROS: this.vec = arrowField.createVector(rootAlloc); ((TimeStampMicroTZVector) vec).allocateNew(batchSize); - this.readType = ReadType.LONG; + this.readType = ReadType.LONG; this.typeWidth = (int) BigIntVector.TYPE_WIDTH; break; case DECIMAL: @@ -225,15 +236,15 @@ private void allocateFieldVector() { switch (primitive.getPrimitiveTypeName()) { case BINARY: case FIXED_LEN_BYTE_ARRAY: - this.readType = ReadType.FIXED_LENGTH_DECIMAL; + this.readType = ReadType.FIXED_LENGTH_DECIMAL; this.typeWidth = primitive.getTypeLength(); break; case INT64: - this.readType = ReadType.INT_LONG_BACKED_DECIMAL; + this.readType = ReadType.INT_LONG_BACKED_DECIMAL; this.typeWidth = (int) BigIntVector.TYPE_WIDTH; break; case INT32: - this.readType = ReadType.INT_LONG_BACKED_DECIMAL; + this.readType = ReadType.INT_LONG_BACKED_DECIMAL; this.typeWidth = (int) IntVector.TYPE_WIDTH; break; default: @@ -249,48 +260,48 @@ private void allocateFieldVector() { switch (primitive.getPrimitiveTypeName()) { case FIXED_LEN_BYTE_ARRAY: int len = ((Types.FixedType) icebergField.type()).length(); - this.vec = new IcebergArrowVectors.VarBinaryArrowVector(icebergField.name(), rootAlloc); + this.vec = arrowField.createVector(rootAlloc); vec.setInitialCapacity(batchSize * len); vec.allocateNew(); - this.readType = ReadType.FIXED_WIDTH_BINARY; + this.readType = ReadType.FIXED_WIDTH_BINARY; this.typeWidth = len; break; case BINARY: - this.vec = new IcebergArrowVectors.VarBinaryArrowVector(icebergField.name(), rootAlloc); + this.vec = arrowField.createVector(rootAlloc); //TODO: Possibly use the uncompressed page size info to set the initial capacity vec.setInitialCapacity(batchSize * AVERAGE_VARIABLE_WIDTH_RECORD_SIZE); vec.allocateNewSafe(); - this.readType = ReadType.VARBINARY; + this.readType = ReadType.VARBINARY; this.typeWidth = UNKNOWN_WIDTH; break; case INT32: this.vec = arrowField.createVector(rootAlloc); ((IntVector) vec).allocateNew(batchSize); - this.readType = ReadType.INT; + this.readType = ReadType.INT; this.typeWidth = (int) IntVector.TYPE_WIDTH; break; case FLOAT: this.vec = arrowField.createVector(rootAlloc); ((Float4Vector) vec).allocateNew(batchSize); - this.readType = ReadType.FLOAT; + this.readType = ReadType.FLOAT; this.typeWidth = (int) Float4Vector.TYPE_WIDTH; break; case BOOLEAN: this.vec = arrowField.createVector(rootAlloc); ((BitVector) vec).allocateNew(batchSize); - this.readType = ReadType.BOOLEAN; + this.readType = ReadType.BOOLEAN; this.typeWidth = UNKNOWN_WIDTH; break; case INT64: this.vec = arrowField.createVector(rootAlloc); ((BigIntVector) vec).allocateNew(batchSize); - this.readType = ReadType.LONG; + this.readType = ReadType.LONG; this.typeWidth = (int) BigIntVector.TYPE_WIDTH; break; case DOUBLE: this.vec = arrowField.createVector(rootAlloc); ((Float8Vector) vec).allocateNew(batchSize); - this.readType = ReadType.DOUBLE; + this.readType = ReadType.DOUBLE; this.typeWidth = (int) Float8Vector.TYPE_WIDTH; break; default: @@ -303,13 +314,9 @@ private void allocateFieldVector() { @Override public void setRowGroupInfo(PageReadStore source, Map metadata) { ColumnChunkMetaData chunkMetaData = metadata.get(ColumnPath.get(columnDescriptor.getPath())); - allPagesDictEncoded = !ParquetUtil.hasNonDictionaryPages(chunkMetaData); - dictionary = vectorizedColumnIterator.setRowGroupInfo(source.getPageReader(columnDescriptor), allPagesDictEncoded); - } - - @Override - public void reuseContainers(boolean reuse) { - this.reuseContainers = reuse; + this.dictionary = vectorizedColumnIterator.setRowGroupInfo( + source.getPageReader(columnDescriptor), + !ParquetUtil.hasNonDictionaryPages(chunkMetaData)); } @Override @@ -324,16 +331,30 @@ public String toString() { return columnDescriptor.toString(); } - public static final VectorizedArrowReader NULL_VALUES_READER = - new VectorizedArrowReader() { - @Override - public VectorHolder read(int numValsToRead) { - return VectorHolder.NULL_VECTOR_HOLDER; - } + public static VectorizedArrowReader nulls() { + return NullVectorReader.INSTANCE; + } + + private static final class NullVectorReader extends VectorizedArrowReader { + private static final NullVectorReader INSTANCE = new NullVectorReader(); + + @Override + public VectorHolder read(VectorHolder reuse, int numValsToRead) { + return VectorHolder.dummyHolder(numValsToRead); + } + + @Override + public void setRowGroupInfo(PageReadStore source, Map metadata) { + } + + @Override + public String toString() { + return "NullReader"; + } + + @Override + public void setBatchSize(int batchSize) {} + } - @Override - public void setRowGroupInfo(PageReadStore source, Map metadata) { - } - }; } diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedColumnIterator.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedColumnIterator.java index 2692cfc59747..cb9d27890a4a 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedColumnIterator.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedColumnIterator.java @@ -36,20 +36,24 @@ public class VectorizedColumnIterator extends BaseColumnIterator { private final VectorizedPageIterator vectorizedPageIterator; - private final int batchSize; + private int batchSize; - public VectorizedColumnIterator(ColumnDescriptor desc, String writerVersion, int batchSize, - boolean setArrowValidityVector) { + public VectorizedColumnIterator(ColumnDescriptor desc, String writerVersion, boolean setArrowValidityVector) { super(desc); Preconditions.checkArgument(desc.getMaxRepetitionLevel() == 0, "Only non-nested columns are supported for vectorized reads"); - this.batchSize = batchSize; this.vectorizedPageIterator = new VectorizedPageIterator(desc, writerVersion, setArrowValidityVector); } + public void setBatchSize(int batchSize) { + this.batchSize = batchSize; + } + public Dictionary setRowGroupInfo(PageReader store, boolean allPagesDictEncoded) { - super.setPageSource(store); + // setPageSource can result in a data page read. If that happens, we need + // to know in advance whether all the pages in the row group are dictionary encoded or not this.vectorizedPageIterator.setAllPagesDictEncoded(allPagesDictEncoded); + super.setPageSource(store); return dictionary; } @@ -199,4 +203,8 @@ protected BasePageIterator pageIterator() { return vectorizedPageIterator; } + public boolean producesDictionaryEncodedVector() { + return vectorizedPageIterator.producesDictionaryEncodedVector(); + } + } diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedDictionaryEncodedParquetValuesReader.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedDictionaryEncodedParquetValuesReader.java index e71d61aa6f71..52e389ece40b 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedDictionaryEncodedParquetValuesReader.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedDictionaryEncodedParquetValuesReader.java @@ -19,7 +19,6 @@ package org.apache.iceberg.arrow.vectorized.parquet; -import io.netty.buffer.ArrowBuf; import java.nio.ByteBuffer; import org.apache.arrow.vector.BaseVariableWidthVector; import org.apache.arrow.vector.BitVectorHelper; @@ -53,15 +52,14 @@ void readBatchOfDictionaryIds(IntVector intVector, int startOffset, int numValue case RLE: for (int i = 0; i < numValues; i++) { intVector.set(idx, currentValue); - nullabilityHolder.setNotNull(idx); + setNotNull(intVector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < numValues; i++) { - intVector.set(idx, packedValuesBuffer[packedValuesBufferIdx]); - nullabilityHolder.setNotNull(idx); - packedValuesBufferIdx++; + intVector.set(idx, packedValuesBuffer[packedValuesBufferIdx++]); + setNotNull(intVector, nullabilityHolder, idx); idx++; } break; @@ -72,7 +70,7 @@ void readBatchOfDictionaryIds(IntVector intVector, int startOffset, int numValue } void readBatchOfDictionaryEncodedLongs(FieldVector vector, int startOffset, int numValuesToRead, Dictionary dict, - NullabilityHolder nullabilityHolder) { + NullabilityHolder nullabilityHolder, int typeWidth) { int left = numValuesToRead; int idx = startOffset; while (left > 0) { @@ -83,24 +81,16 @@ void readBatchOfDictionaryEncodedLongs(FieldVector vector, int startOffset, int switch (mode) { case RLE: for (int i = 0; i < numValues; i++) { - vector.getDataBuffer().setLong(idx, dict.decodeToLong(currentValue)); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer().setLong(idx * typeWidth, dict.decodeToLong(currentValue)); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < numValues; i++) { vector.getDataBuffer() - .setLong(idx, dict.decodeToLong(packedValuesBuffer[packedValuesBufferIdx++])); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + .setLong(idx * typeWidth, dict.decodeToLong(packedValuesBuffer[packedValuesBufferIdx++])); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; @@ -110,8 +100,9 @@ void readBatchOfDictionaryEncodedLongs(FieldVector vector, int startOffset, int } } - void readBatchOfDictionaryEncodedTimestampMillis(FieldVector vector, int startOffset, int numValuesToRead, - Dictionary dict, NullabilityHolder nullabilityHolder) { + void readBatchOfDictionaryEncodedTimestampMillis( + FieldVector vector, int startOffset, int numValuesToRead, + Dictionary dict, NullabilityHolder nullabilityHolder, int typeWidth) { int left = numValuesToRead; int idx = startOffset; while (left > 0) { @@ -122,24 +113,16 @@ void readBatchOfDictionaryEncodedTimestampMillis(FieldVector vector, int startOf switch (mode) { case RLE: for (int i = 0; i < numValues; i++) { - vector.getDataBuffer().setLong(idx, dict.decodeToLong(currentValue) * 1000); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer().setLong(idx * typeWidth, dict.decodeToLong(currentValue) * 1000); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < numValues; i++) { vector.getDataBuffer() - .setLong(idx, dict.decodeToLong(packedValuesBuffer[packedValuesBufferIdx++]) * 1000); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + .setLong(idx * typeWidth, dict.decodeToLong(packedValuesBuffer[packedValuesBufferIdx++]) * 1000); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; @@ -150,7 +133,7 @@ void readBatchOfDictionaryEncodedTimestampMillis(FieldVector vector, int startOf } void readBatchOfDictionaryEncodedIntegers(FieldVector vector, int startOffset, int numValuesToRead, Dictionary dict, - NullabilityHolder nullabilityHolder) { + NullabilityHolder nullabilityHolder, int typeWidth) { int left = numValuesToRead; int idx = startOffset; while (left > 0) { @@ -158,27 +141,19 @@ void readBatchOfDictionaryEncodedIntegers(FieldVector vector, int startOffset, i this.readNextGroup(); } int num = Math.min(left, this.currentCount); - ArrowBuf dataBuffer = vector.getDataBuffer(); switch (mode) { case RLE: for (int i = 0; i < num; i++) { - dataBuffer.setInt(idx, dict.decodeToInt(currentValue)); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer().setInt(idx * typeWidth, dict.decodeToInt(currentValue)); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < num; i++) { - dataBuffer.setInt(idx, dict.decodeToInt(packedValuesBuffer[packedValuesBufferIdx++])); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer() + .setInt(idx * typeWidth, dict.decodeToInt(packedValuesBuffer[packedValuesBufferIdx++])); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; @@ -189,7 +164,7 @@ void readBatchOfDictionaryEncodedIntegers(FieldVector vector, int startOffset, i } void readBatchOfDictionaryEncodedFloats(FieldVector vector, int startOffset, int numValuesToRead, Dictionary dict, - NullabilityHolder nullabilityHolder) { + NullabilityHolder nullabilityHolder, int typeWidth) { int left = numValuesToRead; int idx = startOffset; while (left > 0) { @@ -200,23 +175,16 @@ void readBatchOfDictionaryEncodedFloats(FieldVector vector, int startOffset, int switch (mode) { case RLE: for (int i = 0; i < num; i++) { - vector.getDataBuffer().setFloat(idx, dict.decodeToFloat(currentValue)); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer().setFloat(idx * typeWidth, dict.decodeToFloat(currentValue)); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < num; i++) { - vector.getDataBuffer().setFloat(idx, dict.decodeToFloat(packedValuesBuffer[packedValuesBufferIdx++])); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer() + .setFloat(idx * typeWidth, dict.decodeToFloat(packedValuesBuffer[packedValuesBufferIdx++])); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; @@ -227,7 +195,7 @@ void readBatchOfDictionaryEncodedFloats(FieldVector vector, int startOffset, int } void readBatchOfDictionaryEncodedDoubles(FieldVector vector, int startOffset, int numValuesToRead, Dictionary dict, - NullabilityHolder nullabilityHolder) { + NullabilityHolder nullabilityHolder, int typeWidth) { int left = numValuesToRead; int idx = startOffset; while (left > 0) { @@ -238,24 +206,16 @@ void readBatchOfDictionaryEncodedDoubles(FieldVector vector, int startOffset, in switch (mode) { case RLE: for (int i = 0; i < num; i++) { - vector.getDataBuffer().setDouble(idx, dict.decodeToDouble(currentValue)); - nullabilityHolder.setNotNull(idx); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer().setDouble(idx * typeWidth, dict.decodeToDouble(currentValue)); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; case PACKED: for (int i = 0; i < num; i++) { - vector.getDataBuffer().setDouble(idx, dict.decodeToDouble(packedValuesBuffer[packedValuesBufferIdx++])); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + vector.getDataBuffer() + .setDouble(idx * typeWidth, dict.decodeToDouble(packedValuesBuffer[packedValuesBufferIdx++])); + setNotNull(vector, nullabilityHolder, idx); idx++; } break; @@ -279,27 +239,14 @@ void readBatchOfDictionaryEncodedFixedWidthBinary(FieldVector vector, int typeWi case RLE: for (int i = 0; i < num; i++) { ByteBuffer buffer = dict.decodeToBinary(currentValue).toByteBuffer(); - vector.getDataBuffer().setBytes(idx * typeWidth, buffer.array(), - buffer.position() + buffer.arrayOffset(), buffer.limit() - buffer.position()); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + setFixedWidthBinary(vector, typeWidth, nullabilityHolder, idx, buffer); idx++; } break; case PACKED: for (int i = 0; i < num; i++) { ByteBuffer buffer = dict.decodeToBinary(packedValuesBuffer[packedValuesBufferIdx++]).toByteBuffer(); - vector.getDataBuffer() - .setBytes(idx * typeWidth, buffer.array(), - buffer.position() + buffer.arrayOffset(), buffer.limit() - buffer.position()); - if (setArrowValidityVector) { - BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); - } + setFixedWidthBinary(vector, typeWidth, nullabilityHolder, idx, buffer); idx++; } break; @@ -309,6 +256,22 @@ void readBatchOfDictionaryEncodedFixedWidthBinary(FieldVector vector, int typeWi } } + private void setFixedWidthBinary( + FieldVector vector, int typeWidth, NullabilityHolder nullabilityHolder, + int idx, ByteBuffer buffer) { + vector.getDataBuffer() + .setBytes(idx * typeWidth, buffer.array(), + buffer.position() + buffer.arrayOffset(), buffer.limit() - buffer.position()); + setNotNull(vector, nullabilityHolder, idx); + } + + private void setNotNull(FieldVector vector, NullabilityHolder nullabilityHolder, int idx) { + nullabilityHolder.setNotNull(idx); + if (setArrowValidityVector) { + BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); + } + } + void readBatchOfDictionaryEncodedFixedLengthDecimals(FieldVector vector, int typeWidth, int startOffset, int numValuesToRead, Dictionary dict, NullabilityHolder nullabilityHolder) { @@ -405,7 +368,7 @@ void readBatchOfDictionaryEncodedIntLongBackedDecimals(FieldVector vector, int t ((DecimalVector) vector).set( idx, typeWidth == Integer.BYTES ? - dict.decodeToInt(currentValue) + dict.decodeToInt(packedValuesBuffer[packedValuesBufferIdx++]) : dict.decodeToLong(packedValuesBuffer[packedValuesBufferIdx++])); nullabilityHolder.setNotNull(idx); idx++; diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedPageIterator.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedPageIterator.java index 7cc32e06aecf..2aa6f2c07324 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedPageIterator.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedPageIterator.java @@ -47,12 +47,19 @@ public VectorizedPageIterator(ColumnDescriptor desc, String writerVersion, boole this.setArrowValidityVector = setValidityVector; } - private boolean eagerDecodeDictionary; private ValuesAsBytesReader plainValuesReader = null; private VectorizedDictionaryEncodedParquetValuesReader dictionaryEncodedValuesReader = null; private boolean allPagesDictEncoded; private VectorizedParquetDefinitionLevelReader vectorizedDefinitionLevelReader; + private enum DictionaryDecodeMode { + NONE, // plain encoding + LAZY, + EAGER + } + + private DictionaryDecodeMode dictionaryDecodeMode; + public void setAllPagesDictEncoded(boolean allDictEncoded) { this.allPagesDictEncoded = allDictEncoded; } @@ -98,7 +105,7 @@ public int nextBatchIntegers( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedIntegers( vector, numValsInVector, @@ -132,7 +139,7 @@ public int nextBatchLongs( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedLongs( vector, numValsInVector, @@ -168,7 +175,7 @@ public int nextBatchTimestampMillis( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedTimestampMillis( vector, numValsInVector, @@ -202,7 +209,7 @@ public int nextBatchFloats( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedFloats( vector, numValsInVector, @@ -236,7 +243,7 @@ public int nextBatchDoubles( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedDoubles( vector, numValsInVector, @@ -274,7 +281,7 @@ public int nextBatchIntLongBackedDecimal( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader .readBatchOfDictionaryEncodedIntLongBackedDecimals( vector, @@ -312,7 +319,7 @@ public int nextBatchFixedLengthDecimal( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedFixedLengthDecimals( vector, numValsInVector, @@ -347,7 +354,7 @@ public int nextBatchVarWidthType( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedVarWidth( vector, numValsInVector, @@ -380,7 +387,7 @@ public int nextBatchFixedWidthBinary( if (actualBatchSize <= 0) { return 0; } - if (eagerDecodeDictionary) { + if (dictionaryDecodeMode == DictionaryDecodeMode.EAGER) { vectorizedDefinitionLevelReader.readBatchOfDictionaryEncodedFixedWidthBinary( vector, numValsInVector, @@ -403,6 +410,10 @@ public int nextBatchFixedWidthBinary( return actualBatchSize; } + public boolean producesDictionaryEncodedVector() { + return dictionaryDecodeMode == DictionaryDecodeMode.LAZY; + } + /** * Method for reading batches of booleans. */ @@ -426,8 +437,6 @@ public int nextBatchBoolean( @Override protected void initDataReader(Encoding dataEncoding, ByteBufferInputStream in, int valueCount) { ValuesReader previousReader = plainValuesReader; - this.eagerDecodeDictionary = dataEncoding.usesDictionary() && dictionary != null && - (ParquetUtil.isIntType(desc.getPrimitiveType()) || !allPagesDictEncoded); if (dataEncoding.usesDictionary()) { if (dictionary == null) { throw new ParquetDecodingException( @@ -437,12 +446,18 @@ protected void initDataReader(Encoding dataEncoding, ByteBufferInputStream in, i dictionaryEncodedValuesReader = new VectorizedDictionaryEncodedParquetValuesReader(desc.getMaxDefinitionLevel(), setArrowValidityVector); dictionaryEncodedValuesReader.initFromPage(valueCount, in); + if (ParquetUtil.isIntType(desc.getPrimitiveType()) || !allPagesDictEncoded) { + dictionaryDecodeMode = DictionaryDecodeMode.EAGER; + } else { + dictionaryDecodeMode = DictionaryDecodeMode.LAZY; + } } catch (IOException e) { throw new ParquetDecodingException("could not read page in col " + desc, e); } } else { plainValuesReader = new ValuesAsBytesReader(); plainValuesReader.initFromPage(valueCount, in); + dictionaryDecodeMode = DictionaryDecodeMode.NONE; } if (CorruptDeltaByteArrays.requiresSequentialReads(writerVersion, dataEncoding) && previousReader != null && previousReader instanceof RequiresPreviousReader) { diff --git a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedParquetDefinitionLevelReader.java b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedParquetDefinitionLevelReader.java index 86918f7de5d2..8a263483f89d 100644 --- a/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedParquetDefinitionLevelReader.java +++ b/arrow/src/main/java/org/apache/iceberg/arrow/vectorized/parquet/VectorizedParquetDefinitionLevelReader.java @@ -63,11 +63,10 @@ public void readBatchOfDictionaryIds( case PACKED: for (int i = 0; i < numValues; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.set(idx, dictionaryEncodedValuesReader.readInteger()); + vector.getDataBuffer().setInt(idx * IntVector.TYPE_WIDTH, dictionaryEncodedValuesReader.readInteger()); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, vector.getValidityBuffer()); @@ -106,10 +105,9 @@ public void readBatchOfLongs( for (int i = 0; i < numValues; ++i) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { vector.getDataBuffer().setLong(bufferIdx * typeWidth, valuesReader.readLong()); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } else { setNull(nullabilityHolder, bufferIdx, vector.getValidityBuffer()); @@ -140,12 +138,11 @@ public void readBatchOfTimestampMillis(final FieldVector vector, final int start for (int i = 0; i < numValues; i++) { vector.getDataBuffer().setLong(bufferIdx * typeWidth, valuesReader.readLong() * 1000); } + nullabilityHolder.setNotNulls(bufferIdx, numValues); if (setArrowValidityVector) { for (int i = 0; i < numValues; i++) { BitVectorHelper.setValidityBitToOne(validityBuffer, bufferIdx + i); } - } else { - nullabilityHolder.setNotNulls(bufferIdx, numValues); } } else { setNulls(nullabilityHolder, bufferIdx, numValues, validityBuffer); @@ -156,10 +153,9 @@ public void readBatchOfTimestampMillis(final FieldVector vector, final int start for (int i = 0; i < numValues; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { vector.getDataBuffer().setLong(bufferIdx * typeWidth, valuesReader.readLong() * 1000); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } else { setNull(nullabilityHolder, bufferIdx, vector.getValidityBuffer()); @@ -193,7 +189,7 @@ public void readBatchOfDictionaryEncodedLongs( case RLE: if (currentValue == maxDefLevel) { dictionaryEncodedValuesReader.readBatchOfDictionaryEncodedLongs(vector, - idx, numValues, dict, nullabilityHolder); + idx, numValues, dict, nullabilityHolder, typeWidth); } else { setNulls(nullabilityHolder, idx, numValues, validityBuffer); } @@ -202,11 +198,11 @@ public void readBatchOfDictionaryEncodedLongs( case PACKED: for (int i = 0; i < numValues; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.getDataBuffer().setLong(idx, dict.decodeToLong(dictionaryEncodedValuesReader.readInteger())); + vector.getDataBuffer().setLong(idx * typeWidth, + dict.decodeToLong(dictionaryEncodedValuesReader.readInteger())); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, validityBuffer); @@ -240,7 +236,7 @@ public void readBatchOfDictionaryEncodedTimestampMillis( case RLE: if (currentValue == maxDefLevel) { dictionaryEncodedValuesReader.readBatchOfDictionaryEncodedTimestampMillis(vector, - idx, numValues, dict, nullabilityHolder); + idx, numValues, dict, nullabilityHolder, typeWidth); } else { setNulls(nullabilityHolder, idx, numValues, validityBuffer); } @@ -249,12 +245,11 @@ public void readBatchOfDictionaryEncodedTimestampMillis( case PACKED: for (int i = 0; i < numValues; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.getDataBuffer().setLong(idx, + vector.getDataBuffer().setLong(idx * typeWidth, dict.decodeToLong(dictionaryEncodedValuesReader.readInteger()) * 1000); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, validityBuffer); @@ -293,10 +288,9 @@ public void readBatchOfIntegers(final FieldVector vector, final int startOffset, for (int i = 0; i < num; ++i) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { vector.getDataBuffer().setInt(bufferIdx * typeWidth, valuesReader.readInteger()); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } else { setNull(nullabilityHolder, bufferIdx, vector.getValidityBuffer()); @@ -329,7 +323,7 @@ public void readBatchOfDictionaryEncodedIntegers( case RLE: if (currentValue == maxDefLevel) { dictionaryEncodedValuesReader.readBatchOfDictionaryEncodedIntegers(vector, idx, - num, dict, nullabilityHolder); + num, dict, nullabilityHolder, typeWidth); } else { setNulls(nullabilityHolder, idx, num, vector.getValidityBuffer()); } @@ -338,11 +332,11 @@ public void readBatchOfDictionaryEncodedIntegers( case PACKED: for (int i = 0; i < num; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.getDataBuffer().setInt(idx, dict.decodeToInt(dictionaryEncodedValuesReader.readInteger())); + vector.getDataBuffer() + .setInt(idx * typeWidth, dict.decodeToInt(dictionaryEncodedValuesReader.readInteger())); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, vector.getValidityBuffer()); @@ -381,10 +375,9 @@ public void readBatchOfFloats(final FieldVector vector, final int startOffset, f for (int i = 0; i < num; ++i) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { vector.getDataBuffer().setFloat(bufferIdx * typeWidth, valuesReader.readFloat()); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } else { setNull(nullabilityHolder, bufferIdx, vector.getValidityBuffer()); @@ -418,7 +411,7 @@ public void readBatchOfDictionaryEncodedFloats( case RLE: if (currentValue == maxDefLevel) { dictionaryEncodedValuesReader.readBatchOfDictionaryEncodedFloats(vector, idx, - num, dict, nullabilityHolder); + num, dict, nullabilityHolder, typeWidth); } else { setNulls(nullabilityHolder, idx, num, validityBuffer); } @@ -427,11 +420,11 @@ public void readBatchOfDictionaryEncodedFloats( case PACKED: for (int i = 0; i < num; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.getDataBuffer().setFloat(idx, dict.decodeToFloat(dictionaryEncodedValuesReader.readInteger())); + vector.getDataBuffer() + .setFloat(idx * typeWidth, dict.decodeToFloat(dictionaryEncodedValuesReader.readInteger())); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, validityBuffer); @@ -471,10 +464,9 @@ public void readBatchOfDoubles( for (int i = 0; i < num; ++i) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { vector.getDataBuffer().setDouble(bufferIdx * typeWidth, valuesReader.readDouble()); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } else { setNull(nullabilityHolder, bufferIdx, vector.getValidityBuffer()); @@ -507,7 +499,7 @@ public void readBatchOfDictionaryEncodedDoubles( case RLE: if (currentValue == maxDefLevel) { dictionaryEncodedValuesReader.readBatchOfDictionaryEncodedDoubles(vector, idx, - num, dict, nullabilityHolder); + num, dict, nullabilityHolder, typeWidth); } else { setNulls(nullabilityHolder, idx, num, vector.getValidityBuffer()); } @@ -516,11 +508,11 @@ public void readBatchOfDictionaryEncodedDoubles( case PACKED: for (int i = 0; i < num; i++) { if (packedValuesBuffer[packedValuesBufferIdx++] == maxDefLevel) { - vector.getDataBuffer().setDouble(idx, dict.decodeToDouble(dictionaryEncodedValuesReader.readInteger())); + vector.getDataBuffer() + .setDouble(idx * typeWidth, dict.decodeToDouble(dictionaryEncodedValuesReader.readInteger())); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, vector.getValidityBuffer()); @@ -604,10 +596,9 @@ public void readBatchOfDictionaryEncodedFixedWidthBinary( ByteBuffer buffer = dict.decodeToBinary(dictionaryEncodedValuesReader.readInteger()).toByteBuffer(); vector.getDataBuffer().setBytes(idx * typeWidth, buffer.array(), buffer.position() + buffer.arrayOffset(), buffer.limit() - buffer.position()); + nullabilityHolder.setNotNull(idx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), idx); - } else { - nullabilityHolder.setNotNull(idx); } } else { setNull(nullabilityHolder, idx, vector.getValidityBuffer()); @@ -764,10 +755,9 @@ private void setVarWidthBinaryValue(FieldVector vector, ValuesAsBytesReader valu buffer.limit() - buffer.position()); // Similarly, we need to get the latest reference to the validity buffer as well // since reallocation changes reference of the validity buffers as well. + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } @@ -858,10 +848,9 @@ private void setIntLongBackedDecimal(FieldVector vector, int typeWidth, Nullabil ValuesAsBytesReader valuesReader, int bufferIdx, byte[] byteArray) { valuesReader.getBuffer(typeWidth).get(byteArray, 0, typeWidth); vector.getDataBuffer().setBytes(bufferIdx * DecimalVector.TYPE_WIDTH, byteArray); + nullabilityHolder.setNotNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBitToOne(vector.getValidityBuffer(), bufferIdx); - } else { - nullabilityHolder.setNotNull(bufferIdx); } } @@ -972,12 +961,11 @@ private void setNextNValuesInVector( if (currentValue == maxDefLevel) { ByteBuffer buffer = valuesReader.getBuffer(numValues * typeWidth); vector.getDataBuffer().setBytes(bufferIdx * typeWidth, buffer); + nullabilityHolder.setNotNulls(bufferIdx, numValues); if (setArrowValidityVector) { for (int i = 0; i < numValues; i++) { BitVectorHelper.setValidityBitToOne(validityBuffer, bufferIdx + i); } - } else { - nullabilityHolder.setNotNulls(bufferIdx, numValues); } } else { setNulls(nullabilityHolder, bufferIdx, numValues, validityBuffer); @@ -985,20 +973,18 @@ private void setNextNValuesInVector( } private void setNull(NullabilityHolder nullabilityHolder, int bufferIdx, ArrowBuf validityBuffer) { + nullabilityHolder.setNull(bufferIdx); if (setArrowValidityVector) { BitVectorHelper.setValidityBit(validityBuffer, bufferIdx, 0); - } else { - nullabilityHolder.setNull(bufferIdx); } } private void setNulls(NullabilityHolder nullabilityHolder, int idx, int numValues, ArrowBuf validityBuffer) { + nullabilityHolder.setNulls(idx, numValues); if (setArrowValidityVector) { for (int i = 0; i < numValues; i++) { BitVectorHelper.setValidityBit(validityBuffer, idx + i, 0); } - } else { - nullabilityHolder.setNulls(idx, numValues); } } diff --git a/build.gradle b/build.gradle index c5bed3291b45..f76c719fced3 100644 --- a/build.gradle +++ b/build.gradle @@ -410,6 +410,7 @@ project(':iceberg-spark') { compile project(':iceberg-parquet') compile project(':iceberg-arrow') compile project(':iceberg-hive') + compile project(':iceberg-arrow') compileOnly "org.apache.avro:avro" compileOnly("org.apache.spark:spark-hive_2.11") { @@ -428,6 +429,18 @@ project(':iceberg-spark') { exclude group: 'org.apache.avro', module: 'avro' } } + + test { + // For vectorized reads + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within bounds + systemProperty("arrow.enable_unsafe_memory_access", "true") + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + systemProperty("arrow.enable_null_check_for_get", "false") + + // Vectorized reads need more memory + maxHeapSize '2500m' + } } project(':iceberg-spark3') { @@ -455,6 +468,17 @@ project(':iceberg-spark3') { testCompile project(path: ':iceberg-hive', configuration: 'testArtifacts') testCompile project(path: ':iceberg-api', configuration: 'testArtifacts') } + test { + // For vectorized reads + // Allow unsafe memory access to avoid the costly check arrow does to check if index is within bounds + systemProperty("arrow.enable_unsafe_memory_access", "true") + // Disable expensive null check for every get(index) call. + // Iceberg manages nullability checks itself instead of relying on arrow. + systemProperty("arrow.enable_null_check_for_get", "false") + + // Vectorized reads need more memory + maxHeapSize '2500m' + } } project(':iceberg-pig') { diff --git a/core/src/main/java/org/apache/iceberg/TableProperties.java b/core/src/main/java/org/apache/iceberg/TableProperties.java index f2ad3e5451b4..39d066893fe1 100644 --- a/core/src/main/java/org/apache/iceberg/TableProperties.java +++ b/core/src/main/java/org/apache/iceberg/TableProperties.java @@ -77,6 +77,12 @@ private TableProperties() {} public static final String SPLIT_OPEN_FILE_COST = "read.split.open-file-cost"; public static final long SPLIT_OPEN_FILE_COST_DEFAULT = 4 * 1024 * 1024; // 4MB + public static final String PARQUET_VECTORIZATION_ENABLED = "read.parquet.vectorization.enabled"; + public static final boolean PARQUET_VECTORIZATION_ENABLED_DEFAULT = false; + + public static final String PARQUET_BATCH_SIZE = "read.parquet.vectorization.batch-size"; + public static final int PARQUET_BATCH_SIZE_DEFAULT = 5000; + public static final String OBJECT_STORE_ENABLED = "write.object-storage.enabled"; public static final boolean OBJECT_STORE_ENABLED_DEFAULT = false; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java index f92230c6eb1f..c4c8ebf30c12 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/ParquetUtil.java @@ -280,6 +280,7 @@ public static boolean isIntType(PrimitiveType primitiveType) { case INT_8: case INT_16: case INT_32: + case DATE: return true; default: return false; diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedParquetReader.java b/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedParquetReader.java index c3f87eee8296..6cb9da574caa 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedParquetReader.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedParquetReader.java @@ -90,6 +90,7 @@ private static class FileIterator implements CloseableIterator { private final long totalValues; private final int batchSize; private final List> columnChunkMetadata; + private final boolean reuseContainers; private int nextRowGroup = 0; private long nextRowGroupStart = 0; private long valuesRead = 0; @@ -98,13 +99,15 @@ private static class FileIterator implements CloseableIterator { FileIterator(ReadConf conf) { this.reader = conf.reader(); this.shouldSkip = conf.shouldSkip(); - this.model = conf.vectorizedModel(); this.totalValues = conf.totalValues(); - this.model.reuseContainers(conf.reuseContainers()); + this.reuseContainers = conf.reuseContainers(); + this.model = conf.vectorizedModel(); this.batchSize = conf.batchSize(); + this.model.setBatchSize(this.batchSize); this.columnChunkMetadata = conf.columnChunkMetadataForRowGroups(); } + @Override public boolean hasNext() { return valuesRead < totalValues; @@ -118,10 +121,16 @@ public T next() { if (valuesRead >= nextRowGroupStart) { advance(); } - long numValuesToRead = Math.min(nextRowGroupStart - valuesRead, batchSize); + // batchSize is an integer, so casting to integer is safe - this.last = model.read((int) numValuesToRead); + int numValuesToRead = (int) Math.min(nextRowGroupStart - valuesRead, batchSize); + if (reuseContainers) { + this.last = model.read(last, numValuesToRead); + } else { + this.last = model.read(null, numValuesToRead); + } valuesRead += numValuesToRead; + return last; } diff --git a/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedReader.java b/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedReader.java index 3eb3303d4eef..25c16f09bfb3 100644 --- a/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedReader.java +++ b/parquet/src/main/java/org/apache/iceberg/parquet/VectorizedReader.java @@ -31,25 +31,23 @@ public interface VectorizedReader { /** * Reads a batch of type @param <T> and of size numRows + * + * @param reuse container for the last batch to be reused for next batch * @param numRows number of rows to read * @return batch of records of type @param <T> */ - T read(int numRows); + T read(T reuse, int numRows); + + void setBatchSize(int batchSize); /** - * - * @param pages row group information for all the columns + * @param pages row group information for all the columns * @param metadata map of {@link ColumnPath} -> {@link ColumnChunkMetaData} for the row group */ void setRowGroupInfo(PageReadStore pages, Map metadata); /** - * Set up the reader to reuse the underlying containers used for storing batches - */ - void reuseContainers(boolean reuse); - - /** - * Release any resources allocated + * Release any resources allocated. */ void close(); } diff --git a/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java b/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java index 57863e0a0169..91568db0517c 100644 --- a/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java +++ b/spark/src/jmh/java/org/apache/iceberg/spark/source/IcebergSourceBenchmark.java @@ -27,6 +27,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.UpdateProperties; import org.apache.iceberg.spark.SparkSchemaUtil; import org.apache.spark.sql.Dataset; @@ -92,15 +93,24 @@ protected void cleanupFiles() throws IOException { } } - protected void setupSpark() { - spark = SparkSession.builder() - .config("spark.ui.enabled", false) - .master("local") - .getOrCreate(); + protected void setupSpark(boolean enableDictionaryEncoding) { + SparkSession.Builder builder = SparkSession.builder() + .config("spark.ui.enabled", false); + if (!enableDictionaryEncoding) { + builder.config("parquet.dictionary.page.size", "1") + .config("parquet.enable.dictionary", false) + .config(TableProperties.PARQUET_DICT_SIZE_BYTES, "1"); + } + builder.master("local"); + spark = builder.getOrCreate(); Configuration sparkHadoopConf = spark.sessionState().newHadoopConf(); hadoopConf.forEach(entry -> sparkHadoopConf.set(entry.getKey(), entry.getValue())); } + protected void setupSpark() { + setupSpark(false); + } + protected void tearDownSpark() { spark.stop(); } diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessor.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessor.java new file mode 100644 index 000000000000..c9c9959c9e95 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessor.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.arrow.vector.ValueVector; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; + +@SuppressWarnings("checkstyle:VisibilityModifier") +public abstract class ArrowVectorAccessor { + + private final ValueVector vector; + private final ArrowColumnVector[] childColumns; + + ArrowVectorAccessor(ValueVector vector) { + this.vector = vector; + this.childColumns = new ArrowColumnVector[0]; + } + + ArrowVectorAccessor(ValueVector vector, ArrowColumnVector[] children) { + this.vector = vector; + this.childColumns = children; + } + + final void close() { + for (ArrowColumnVector column : childColumns) { + // Closing an ArrowColumnVector is expected to not throw any exception + column.close(); + } + vector.close(); + } + + boolean getBoolean(int rowId) { + throw new UnsupportedOperationException("Unsupported type: boolean"); + } + + int getInt(int rowId) { + throw new UnsupportedOperationException("Unsupported type: int"); + } + + long getLong(int rowId) { + throw new UnsupportedOperationException("Unsupported type: long"); + } + + float getFloat(int rowId) { + throw new UnsupportedOperationException("Unsupported type: float"); + } + + double getDouble(int rowId) { + throw new UnsupportedOperationException("Unsupported type: double"); + } + + Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException("Unsupported type: decimal"); + } + + UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException("Unsupported type: UTF8String"); + } + + byte[] getBinary(int rowId) { + throw new UnsupportedOperationException("Unsupported type: binary"); + } + + ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException("Unsupported type: array"); + } + + ArrowColumnVector childColumn(int pos) { + return childColumns[pos]; + } + + public ValueVector getVector() { + return vector; + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java new file mode 100644 index 000000000000..74732a3e4192 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ArrowVectorAccessors.java @@ -0,0 +1,508 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import io.netty.buffer.ArrowBuf; +import java.math.BigInteger; +import java.util.stream.IntStream; +import org.apache.arrow.vector.BigIntVector; +import org.apache.arrow.vector.BitVector; +import org.apache.arrow.vector.DateDayVector; +import org.apache.arrow.vector.FieldVector; +import org.apache.arrow.vector.Float4Vector; +import org.apache.arrow.vector.Float8Vector; +import org.apache.arrow.vector.IntVector; +import org.apache.arrow.vector.TimeStampMicroTZVector; +import org.apache.arrow.vector.VarBinaryVector; +import org.apache.arrow.vector.complex.ListVector; +import org.apache.arrow.vector.complex.StructVector; +import org.apache.arrow.vector.holders.NullableVarCharHolder; +import org.apache.iceberg.arrow.vectorized.IcebergArrowVectors; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.column.Dictionary; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.unsafe.types.UTF8String; +import org.jetbrains.annotations.NotNull; + +public class ArrowVectorAccessors { + + private ArrowVectorAccessors() {} + + static ArrowVectorAccessor getVectorAccessor(VectorHolder holder) { + Dictionary dictionary = holder.dictionary(); + boolean isVectorDictEncoded = holder.isDictionaryEncoded(); + ColumnDescriptor desc = holder.descriptor(); + FieldVector vector = holder.vector(); + PrimitiveType primitive = desc.getPrimitiveType(); + if (isVectorDictEncoded) { + return getDictionaryVectorAccessor(dictionary, desc, vector, primitive); + } else { + return getPlainVectorAccessor(vector); + } + } + + @NotNull + private static ArrowVectorAccessor getDictionaryVectorAccessor( + Dictionary dictionary, + ColumnDescriptor desc, + FieldVector vector, PrimitiveType primitive) { + Preconditions.checkState(vector instanceof IntVector, "Dictionary ids should be stored in IntVectors only"); + if (primitive.getOriginalType() != null) { + switch (desc.getPrimitiveType().getOriginalType()) { + case ENUM: + case JSON: + case UTF8: + case BSON: + return new DictionaryStringAccessor((IntVector) vector, dictionary); + case INT_64: + case TIMESTAMP_MILLIS: + case TIMESTAMP_MICROS: + return new DictionaryLongAccessor((IntVector) vector, dictionary); + case DECIMAL: + switch (primitive.getPrimitiveTypeName()) { + case BINARY: + case FIXED_LEN_BYTE_ARRAY: + return new DictionaryDecimalBinaryAccessor( + (IntVector) vector, + dictionary); + case INT64: + return new DictionaryDecimalLongAccessor( + (IntVector) vector, + dictionary); + case INT32: + return new DictionaryDecimalIntAccessor( + (IntVector) vector, + dictionary); + default: + throw new UnsupportedOperationException( + "Unsupported base type for decimal: " + primitive.getPrimitiveTypeName()); + } + default: + throw new UnsupportedOperationException( + "Unsupported logical type: " + primitive.getOriginalType()); + } + } else { + switch (primitive.getPrimitiveTypeName()) { + case FIXED_LEN_BYTE_ARRAY: + case BINARY: + return new DictionaryBinaryAccessor((IntVector) vector, dictionary); + case FLOAT: + return new DictionaryFloatAccessor((IntVector) vector, dictionary); + case INT64: + return new DictionaryLongAccessor((IntVector) vector, dictionary); + case DOUBLE: + return new DictionaryDoubleAccessor((IntVector) vector, dictionary); + default: + throw new UnsupportedOperationException("Unsupported type: " + primitive); + } + } + } + + @NotNull + @SuppressWarnings("checkstyle:CyclomaticComplexity") + private static ArrowVectorAccessor getPlainVectorAccessor(FieldVector vector) { + if (vector instanceof BitVector) { + return new BooleanAccessor((BitVector) vector); + } else if (vector instanceof IntVector) { + return new IntAccessor((IntVector) vector); + } else if (vector instanceof BigIntVector) { + return new LongAccessor((BigIntVector) vector); + } else if (vector instanceof Float4Vector) { + return new FloatAccessor((Float4Vector) vector); + } else if (vector instanceof Float8Vector) { + return new DoubleAccessor((Float8Vector) vector); + } else if (vector instanceof IcebergArrowVectors.DecimalArrowVector) { + return new DecimalAccessor((IcebergArrowVectors.DecimalArrowVector) vector); + } else if (vector instanceof IcebergArrowVectors.VarcharArrowVector) { + return new StringAccessor((IcebergArrowVectors.VarcharArrowVector) vector); + } else if (vector instanceof VarBinaryVector) { + return new BinaryAccessor((VarBinaryVector) vector); + } else if (vector instanceof DateDayVector) { + return new DateAccessor((DateDayVector) vector); + } else if (vector instanceof TimeStampMicroTZVector) { + return new TimestampAccessor((TimeStampMicroTZVector) vector); + } else if (vector instanceof ListVector) { + ListVector listVector = (ListVector) vector; + return new ArrayAccessor(listVector); + } else if (vector instanceof StructVector) { + StructVector structVector = (StructVector) vector; + return new StructAccessor(structVector); + } + throw new UnsupportedOperationException("Unsupported vector: " + vector.getClass()); + } + + private static class BooleanAccessor extends ArrowVectorAccessor { + + private final BitVector vector; + + BooleanAccessor(BitVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final boolean getBoolean(int rowId) { + return vector.get(rowId) == 1; + } + } + + private static class IntAccessor extends ArrowVectorAccessor { + + private final IntVector vector; + + IntAccessor(IntVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final int getInt(int rowId) { + return vector.get(rowId); + } + } + + private static class LongAccessor extends ArrowVectorAccessor { + + private final BigIntVector vector; + + LongAccessor(BigIntVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final long getLong(int rowId) { + return vector.get(rowId); + } + } + + private static class DictionaryLongAccessor extends ArrowVectorAccessor { + private final IntVector offsetVector; + private final long[] decodedDictionary; + + DictionaryLongAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.decodedDictionary = IntStream.rangeClosed(0, dictionary.getMaxId()) + .mapToLong(dictionary::decodeToLong) + .toArray(); + } + + @Override + final long getLong(int rowId) { + return decodedDictionary[offsetVector.get(rowId)]; + } + } + + private static class FloatAccessor extends ArrowVectorAccessor { + + private final Float4Vector vector; + + FloatAccessor(Float4Vector vector) { + super(vector); + this.vector = vector; + } + + @Override + final float getFloat(int rowId) { + return vector.get(rowId); + } + } + + private static class DictionaryFloatAccessor extends ArrowVectorAccessor { + private final IntVector offsetVector; + private final float[] decodedDictionary; + + DictionaryFloatAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.decodedDictionary = new float[dictionary.getMaxId() + 1]; + for (int i = 0; i <= dictionary.getMaxId(); i++) { + decodedDictionary[i] = dictionary.decodeToFloat(i); + } + } + + @Override + final float getFloat(int rowId) { + return decodedDictionary[offsetVector.get(rowId)]; + } + } + + private static class DoubleAccessor extends ArrowVectorAccessor { + + private final Float8Vector vector; + + DoubleAccessor(Float8Vector vector) { + super(vector); + this.vector = vector; + } + + @Override + final double getDouble(int rowId) { + return vector.get(rowId); + } + } + + private static class DictionaryDoubleAccessor extends ArrowVectorAccessor { + private final IntVector offsetVector; + private final double[] decodedDictionary; + + DictionaryDoubleAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.decodedDictionary = IntStream.rangeClosed(0, dictionary.getMaxId()) + .mapToDouble(dictionary::decodeToDouble) + .toArray(); + } + + @Override + final double getDouble(int rowId) { + return decodedDictionary[offsetVector.get(rowId)]; + } + } + + private static class StringAccessor extends ArrowVectorAccessor { + + private final IcebergArrowVectors.VarcharArrowVector vector; + private final NullableVarCharHolder stringResult = new NullableVarCharHolder(); + + StringAccessor(IcebergArrowVectors.VarcharArrowVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final UTF8String getUTF8String(int rowId) { + vector.get(rowId, stringResult); + if (stringResult.isSet == 0) { + return null; + } else { + return UTF8String.fromAddress( + null, + stringResult.buffer.memoryAddress() + stringResult.start, + stringResult.end - stringResult.start); + } + } + } + + private static class DictionaryStringAccessor extends ArrowVectorAccessor { + private final UTF8String[] decodedDictionary; + private final IntVector offsetVector; + + DictionaryStringAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.decodedDictionary = IntStream.rangeClosed(0, dictionary.getMaxId()) + .mapToObj(dictionary::decodeToBinary) + .map(binary -> UTF8String.fromBytes(binary.getBytes())) + .toArray(UTF8String[]::new); + } + + @Override + final UTF8String getUTF8String(int rowId) { + int offset = offsetVector.get(rowId); + return decodedDictionary[offset]; + } + } + + private static class BinaryAccessor extends ArrowVectorAccessor { + + private final VarBinaryVector vector; + + BinaryAccessor(VarBinaryVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final byte[] getBinary(int rowId) { + return vector.get(rowId); + } + } + + private static class DictionaryBinaryAccessor extends ArrowVectorAccessor { + private final IntVector offsetVector; + private final byte[][] decodedDictionary; + + DictionaryBinaryAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.decodedDictionary = IntStream.rangeClosed(0, dictionary.getMaxId()) + .mapToObj(dictionary::decodeToBinary) + .map(binary -> binary.getBytes()) + .toArray(byte[][]::new); + } + + @Override + final byte[] getBinary(int rowId) { + int offset = offsetVector.get(rowId); + return decodedDictionary[offset]; + } + } + + private static class DateAccessor extends ArrowVectorAccessor { + + private final DateDayVector vector; + + DateAccessor(DateDayVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final int getInt(int rowId) { + return vector.get(rowId); + } + } + + private static class TimestampAccessor extends ArrowVectorAccessor { + + private final TimeStampMicroTZVector vector; + + TimestampAccessor(TimeStampMicroTZVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final long getLong(int rowId) { + return vector.get(rowId); + } + } + + private static class ArrayAccessor extends ArrowVectorAccessor { + + private final ListVector vector; + private final ArrowColumnVector arrayData; + + ArrayAccessor(ListVector vector) { + super(vector); + this.vector = vector; + this.arrayData = new ArrowColumnVector(vector.getDataVector()); + } + + @Override + final ColumnarArray getArray(int rowId) { + ArrowBuf offsets = vector.getOffsetBuffer(); + int index = rowId * ListVector.OFFSET_WIDTH; + int start = offsets.getInt(index); + int end = offsets.getInt(index + ListVector.OFFSET_WIDTH); + return new ColumnarArray(arrayData, start, end - start); + } + } + + /** + * Use {@link IcebergArrowColumnVector#getChild(int)} to get hold of the {@link ArrowColumnVector} vectors holding the + * struct values. + */ + private static class StructAccessor extends ArrowVectorAccessor { + StructAccessor(StructVector structVector) { + super(structVector, IntStream.range(0, structVector.size()) + .mapToObj(structVector::getVectorById) + .map(ArrowColumnVector::new) + .toArray(ArrowColumnVector[]::new)); + } + } + + private static class DecimalAccessor extends ArrowVectorAccessor { + + private final IcebergArrowVectors.DecimalArrowVector vector; + + DecimalAccessor(IcebergArrowVectors.DecimalArrowVector vector) { + super(vector); + this.vector = vector; + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + return Decimal.apply(vector.getObject(rowId), precision, scale); + } + } + + @SuppressWarnings("checkstyle:VisibilityModifier") + private abstract static class DictionaryDecimalAccessor extends ArrowVectorAccessor { + final Decimal[] cache; + Dictionary parquetDictionary; + final IntVector offsetVector; + + private DictionaryDecimalAccessor(IntVector vector, Dictionary dictionary) { + super(vector); + this.offsetVector = vector; + this.parquetDictionary = dictionary; + this.cache = new Decimal[dictionary.getMaxId() + 1]; + } + } + + private static class DictionaryDecimalBinaryAccessor extends DictionaryDecimalAccessor { + + DictionaryDecimalBinaryAccessor(IntVector vector, Dictionary dictionary) { + super(vector, dictionary); + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + int dictId = offsetVector.get(rowId); + if (cache[dictId] == null) { + cache[dictId] = Decimal.apply( + new BigInteger(parquetDictionary.decodeToBinary(dictId).getBytes()).longValue(), + precision, + scale); + } + return cache[dictId]; + } + } + + private static class DictionaryDecimalLongAccessor extends DictionaryDecimalAccessor { + + DictionaryDecimalLongAccessor(IntVector vector, Dictionary dictionary) { + super(vector, dictionary); + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + int dictId = offsetVector.get(rowId); + if (cache[dictId] == null) { + cache[dictId] = Decimal.apply(parquetDictionary.decodeToLong(dictId), precision, scale); + } + return cache[dictId]; + } + } + + private static class DictionaryDecimalIntAccessor extends DictionaryDecimalAccessor { + + DictionaryDecimalIntAccessor(IntVector vector, Dictionary dictionary) { + super(vector, dictionary); + } + + @Override + final Decimal getDecimal(int rowId, int precision, int scale) { + int dictId = offsetVector.get(rowId); + if (cache[dictId] == null) { + cache[dictId] = Decimal.apply(parquetDictionary.decodeToInt(dictId), precision, scale); + } + return cache[dictId]; + } + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java new file mode 100644 index 000000000000..c76321ecd61d --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/ColumnarBatchReader.java @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.parquet.Preconditions; +import org.apache.parquet.column.page.PageReadStore; +import org.apache.parquet.hadoop.metadata.ColumnChunkMetaData; +import org.apache.parquet.hadoop.metadata.ColumnPath; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +/** + * {@link VectorizedReader} that returns Spark's {@link ColumnarBatch} to support Spark's vectorized read path. The + * {@link ColumnarBatch} returned is created by passing in the Arrow vectors populated via delegated read calls to + * {@linkplain VectorizedArrowReader VectorReader(s)}. + */ +public class ColumnarBatchReader implements VectorizedReader { + private final VectorizedArrowReader[] readers; + private final VectorHolder[] vectorHolders; + + public ColumnarBatchReader(List> readers) { + this.readers = readers.stream() + .map(VectorizedArrowReader.class::cast) + .toArray(VectorizedArrowReader[]::new); + this.vectorHolders = new VectorHolder[readers.size()]; + } + + @Override + public final void setRowGroupInfo(PageReadStore pageStore, Map metaData) { + for (VectorizedArrowReader reader : readers) { + if (reader != null) { + reader.setRowGroupInfo(pageStore, metaData); + } + } + } + + @Override + public final ColumnarBatch read(ColumnarBatch reuse, int numRowsToRead) { + Preconditions.checkArgument(numRowsToRead > 0, "Invalid number of rows to read: %s", numRowsToRead); + ColumnVector[] arrowColumnVectors = new ColumnVector[readers.length]; + + if (reuse == null) { + closeVectors(); + } + + for (int i = 0; i < readers.length; i += 1) { + vectorHolders[i] = readers[i].read(vectorHolders[i], numRowsToRead); + int numRowsInVector = vectorHolders[i].numValues(); + Preconditions.checkState( + numRowsInVector == numRowsToRead, + "Number of rows in the vector %s didn't match expected %s ", numRowsInVector, + numRowsToRead); + arrowColumnVectors[i] = + IcebergArrowColumnVector.forHolder(vectorHolders[i], numRowsInVector); + } + ColumnarBatch batch = new ColumnarBatch(arrowColumnVectors); + batch.setNumRows(numRowsToRead); + return batch; + } + + private void closeVectors() { + for (int i = 0; i < vectorHolders.length; i++) { + if (vectorHolders[i] != null) { + // Release any resources used by the vector + if (vectorHolders[i].vector() != null) { + vectorHolders[i].vector().close(); + } + vectorHolders[i] = null; + } + } + } + + @Override + public void close() { + for (VectorizedReader reader : readers) { + reader.close(); + } + } + + @Override + public void setBatchSize(int batchSize) { + for (VectorizedArrowReader reader : readers) { + if (reader != null) { + reader.setBatchSize(batchSize); + } + } + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java new file mode 100644 index 000000000000..9d10cd935512 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/IcebergArrowColumnVector.java @@ -0,0 +1,153 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.arrow.vectorized.NullabilityHolder; +import org.apache.iceberg.arrow.vectorized.VectorHolder; +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ArrowColumnVector; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +/** + * Implementation of Spark's {@link ColumnVector} interface. The code for this class is heavily inspired from Spark's + * {@link ArrowColumnVector} The main difference is in how nullability checks are made in this class by relying on + * {@link NullabilityHolder} instead of the validity vector in the Arrow vector. + */ +public class IcebergArrowColumnVector extends ColumnVector { + + private final ArrowVectorAccessor accessor; + private final NullabilityHolder nullabilityHolder; + + public IcebergArrowColumnVector(VectorHolder holder) { + super(SparkSchemaUtil.convert(holder.icebergType())); + this.nullabilityHolder = holder.nullabilityHolder(); + this.accessor = ArrowVectorAccessors.getVectorAccessor(holder); + } + + @Override + public void close() { + accessor.close(); + } + + @Override + public boolean hasNull() { + return nullabilityHolder.hasNulls(); + } + + @Override + public int numNulls() { + return nullabilityHolder.numNulls(); + } + + @Override + public boolean isNullAt(int rowId) { + return nullabilityHolder.isNullAt(rowId) == 1; + } + + @Override + public boolean getBoolean(int rowId) { + return accessor.getBoolean(rowId); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException("Unsupported type - byte"); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException("Unsupported type - short"); + } + + @Override + public int getInt(int rowId) { + return accessor.getInt(rowId); + } + + @Override + public long getLong(int rowId) { + return accessor.getLong(rowId); + } + + @Override + public float getFloat(int rowId) { + return accessor.getFloat(rowId); + } + + @Override + public double getDouble(int rowId) { + return accessor.getDouble(rowId); + } + + @Override + public ColumnarArray getArray(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getArray(rowId); + } + + @Override + public ColumnarMap getMap(int rowId) { + throw new UnsupportedOperationException("Unsupported type - map"); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getDecimal(rowId, precision, scale); + } + + @Override + public UTF8String getUTF8String(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getUTF8String(rowId); + } + + @Override + public byte[] getBinary(int rowId) { + if (isNullAt(rowId)) { + return null; + } + return accessor.getBinary(rowId); + } + + @Override + public ArrowColumnVector getChild(int ordinal) { + return accessor.childColumn(ordinal); + } + + static ColumnVector forHolder(VectorHolder holder, int numRows) { + return holder.isDummy() ? new NullValuesColumnVector(numRows) : + new IcebergArrowColumnVector(holder); + } + + public ArrowVectorAccessor vectorAccessor() { + return accessor; + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java new file mode 100644 index 000000000000..8770d13ab883 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/NullValuesColumnVector.java @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.types.Type; +import org.apache.iceberg.types.Types; +import org.apache.spark.sql.types.Decimal; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarArray; +import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.unsafe.types.UTF8String; + +public class NullValuesColumnVector extends ColumnVector { + + private final int numNulls; + private static final Type NULL_TYPE = Types.IntegerType.get(); + + public NullValuesColumnVector(int nValues) { + super(SparkSchemaUtil.convert(NULL_TYPE)); + this.numNulls = nValues; + } + + @Override + public void close() { + + } + + @Override + public boolean hasNull() { + return true; + } + + @Override + public int numNulls() { + return numNulls; + } + + @Override + public boolean isNullAt(int rowId) { + return true; + } + + @Override + public boolean getBoolean(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public short getShort(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public int getInt(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public long getLong(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public float getFloat(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public double getDouble(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarArray getArray(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public ColumnarMap getMap(int ordinal) { + throw new UnsupportedOperationException(); + } + + @Override + public Decimal getDecimal(int rowId, int precision, int scale) { + throw new UnsupportedOperationException(); + } + + @Override + public UTF8String getUTF8String(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getBinary(int rowId) { + throw new UnsupportedOperationException(); + } + + @Override + protected ColumnVector getChild(int ordinal) { + throw new UnsupportedOperationException(); + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java new file mode 100644 index 000000000000..01cbe6f286ad --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/data/vectorized/VectorizedSparkParquetReaders.java @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.vectorized; + +import java.util.List; +import java.util.Map; +import java.util.stream.IntStream; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.iceberg.Schema; +import org.apache.iceberg.arrow.ArrowAllocation; +import org.apache.iceberg.arrow.vectorized.VectorizedArrowReader; +import org.apache.iceberg.parquet.TypeWithSchemaVisitor; +import org.apache.iceberg.parquet.VectorizedReader; +import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.relocated.com.google.common.collect.Maps; +import org.apache.iceberg.types.Types; +import org.apache.parquet.column.ColumnDescriptor; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.PrimitiveType; +import org.apache.parquet.schema.Type; + +public class VectorizedSparkParquetReaders { + + private VectorizedSparkParquetReaders() { + } + + public static ColumnarBatchReader buildReader( + Schema expectedSchema, + MessageType fileSchema, + boolean setArrowValidityVector) { + return (ColumnarBatchReader) + TypeWithSchemaVisitor.visit(expectedSchema.asStruct(), fileSchema, + new VectorizedReaderBuilder(expectedSchema, fileSchema, setArrowValidityVector)); + } + + private static class VectorizedReaderBuilder extends TypeWithSchemaVisitor> { + private final MessageType parquetSchema; + private final Schema icebergSchema; + private final BufferAllocator rootAllocator; + private final boolean setArrowValidityVector; + + VectorizedReaderBuilder( + Schema expectedSchema, + MessageType parquetSchema, + boolean setArrowValidityVector) { + this.parquetSchema = parquetSchema; + this.icebergSchema = expectedSchema; + this.rootAllocator = ArrowAllocation.rootAllocator() + .newChildAllocator("VectorizedReadBuilder", 0, Long.MAX_VALUE); + this.setArrowValidityVector = setArrowValidityVector; + } + + @Override + public VectorizedReader message( + Types.StructType expected, MessageType message, + List> fieldReaders) { + GroupType groupType = message.asGroupType(); + Map> readersById = Maps.newHashMap(); + List fields = groupType.getFields(); + + IntStream.range(0, fields.size()) + .forEach(pos -> readersById.put(fields.get(pos).getId().intValue(), fieldReaders.get(pos))); + + List icebergFields = expected != null ? + expected.fields() : ImmutableList.of(); + + List> reorderedFields = Lists.newArrayListWithExpectedSize( + icebergFields.size()); + + for (Types.NestedField field : icebergFields) { + int id = field.fieldId(); + VectorizedReader reader = readersById.get(id); + if (reader != null) { + reorderedFields.add(reader); + } else { + reorderedFields.add(VectorizedArrowReader.nulls()); + } + } + return new ColumnarBatchReader(reorderedFields); + } + + @Override + public VectorizedReader struct( + Types.StructType expected, GroupType groupType, + List> fieldReaders) { + if (expected != null) { + throw new UnsupportedOperationException("Vectorized reads are not supported yet for struct fields"); + } + return null; + } + + @Override + public VectorizedReader primitive( + org.apache.iceberg.types.Type.PrimitiveType expected, + PrimitiveType primitive) { + + // Create arrow vector for this field + int parquetFieldId = primitive.getId().intValue(); + ColumnDescriptor desc = parquetSchema.getColumnDescription(currentPath()); + // Nested types not yet supported for vectorized reads + if (desc.getMaxRepetitionLevel() > 0) { + return null; + } + Types.NestedField icebergField = icebergSchema.findField(parquetFieldId); + if (icebergField == null) { + return null; + } + // Set the validity buffer if null checking is enabled in arrow + return new VectorizedArrowReader(desc, icebergField, rootAllocator, setArrowValidityVector); + } + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java new file mode 100644 index 000000000000..eeb3ad559858 --- /dev/null +++ b/spark/src/main/java/org/apache/iceberg/spark/source/BatchDataReader.java @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.source; + +import org.apache.arrow.vector.NullCheckingForGet; +import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileFormat; +import org.apache.iceberg.FileScanTask; +import org.apache.iceberg.Schema; +import org.apache.iceberg.encryption.EncryptionManager; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.CloseableIterator; +import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.io.InputFile; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.spark.sql.vectorized.ColumnarBatch; + +class BatchDataReader extends BaseDataReader { + private final Schema expectedSchema; + private final boolean caseSensitive; + private final int batchSize; + + BatchDataReader( + CombinedScanTask task, Schema expectedSchema, FileIO fileIo, + EncryptionManager encryptionManager, boolean caseSensitive, int size) { + super(task, fileIo, encryptionManager); + this.expectedSchema = expectedSchema; + this.caseSensitive = caseSensitive; + this.batchSize = size; + } + + @Override + CloseableIterator open(FileScanTask task) { + CloseableIterable iter; + InputFile location = getInputFile(task); + Preconditions.checkNotNull(location, "Could not find InputFile associated with FileScanTask"); + if (task.file().format() == FileFormat.PARQUET) { + iter = Parquet.read(location) + .project(expectedSchema) + .split(task.start(), task.length()) + .createBatchedReaderFunc(fileSchema -> VectorizedSparkParquetReaders.buildReader(expectedSchema, + fileSchema, /* setArrowValidityVector */ NullCheckingForGet.NULL_CHECKING_ENABLED)) + .recordsPerBatch(batchSize) + .filter(task.residual()) + .caseSensitive(caseSensitive) + // Spark eagerly consumes the batches. So the underlying memory allocated could be reused + // without worrying about subsequent reads clobbering over each other. This improves + // read performance as every batch read doesn't have to pay the cost of allocating memory. + .reuseContainers() + .build(); + } else { + throw new UnsupportedOperationException( + "Format: " + task.file().format() + " not supported for batched reads"); + } + return iter.iterator(); + } +} diff --git a/spark/src/main/java/org/apache/iceberg/spark/source/Reader.java b/spark/src/main/java/org/apache/iceberg/spark/source/Reader.java index 1f3d26e4b185..d205c22a77f6 100644 --- a/spark/src/main/java/org/apache/iceberg/spark/source/Reader.java +++ b/spark/src/main/java/org/apache/iceberg/spark/source/Reader.java @@ -28,6 +28,7 @@ import org.apache.hadoop.fs.FileSystem; import org.apache.hadoop.fs.Path; import org.apache.iceberg.CombinedScanTask; +import org.apache.iceberg.FileFormat; import org.apache.iceberg.FileScanTask; import org.apache.iceberg.Schema; import org.apache.iceberg.SchemaParser; @@ -43,10 +44,12 @@ import org.apache.iceberg.hadoop.Util; import org.apache.iceberg.io.CloseableIterable; import org.apache.iceberg.io.FileIO; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.ImmutableSet; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.spark.SparkFilters; import org.apache.iceberg.spark.SparkSchemaUtil; +import org.apache.iceberg.util.PropertyUtil; import org.apache.spark.broadcast.Broadcast; import org.apache.spark.sql.SparkSession; import org.apache.spark.sql.catalyst.InternalRow; @@ -59,14 +62,16 @@ import org.apache.spark.sql.sources.v2.reader.SupportsPushDownFilters; import org.apache.spark.sql.sources.v2.reader.SupportsPushDownRequiredColumns; import org.apache.spark.sql.sources.v2.reader.SupportsReportStatistics; +import org.apache.spark.sql.sources.v2.reader.SupportsScanColumnarBatch; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.slf4j.Logger; import org.slf4j.LoggerFactory; -class Reader implements DataSourceReader, SupportsPushDownFilters, SupportsPushDownRequiredColumns, - SupportsReportStatistics { +class Reader implements DataSourceReader, SupportsScanColumnarBatch, SupportsPushDownFilters, + SupportsPushDownRequiredColumns, SupportsReportStatistics { private static final Logger LOG = LoggerFactory.getLogger(Reader.class); private static final Filter[] NO_FILTERS = new Filter[0]; @@ -87,14 +92,17 @@ class Reader implements DataSourceReader, SupportsPushDownFilters, SupportsPushD private List filterExpressions = null; private Filter[] pushedFilters = NO_FILTERS; private final boolean localityPreferred; + private final boolean batchReadsEnabled; + private final int batchSize; // lazy variables private Schema schema = null; private StructType type = null; // cached because Spark accesses it multiple times private List tasks = null; // lazy cache of tasks + private Boolean readUsingBatch = null; Reader(Table table, Broadcast io, Broadcast encryptionManager, - boolean caseSensitive, DataSourceOptions options) { + boolean caseSensitive, DataSourceOptions options) { this.table = table; this.snapshotId = options.get("snapshot-id").map(Long::parseLong).orElse(null); this.asOfTimestamp = options.get("as-of-timestamp").map(Long::parseLong).orElse(null); @@ -145,6 +153,13 @@ class Reader implements DataSourceReader, SupportsPushDownFilters, SupportsPushD this.io = io; this.encryptionManager = encryptionManager; this.caseSensitive = caseSensitive; + + this.batchReadsEnabled = options.get("vectorization-enabled").map(Boolean::parseBoolean).orElse( + PropertyUtil.propertyAsBoolean(table.properties(), + TableProperties.PARQUET_VECTORIZATION_ENABLED, TableProperties.PARQUET_VECTORIZATION_ENABLED_DEFAULT)); + this.batchSize = options.get("batch-size").map(Integer::parseInt).orElse( + PropertyUtil.propertyAsInt(table.properties(), + TableProperties.PARQUET_BATCH_SIZE, TableProperties.PARQUET_BATCH_SIZE_DEFAULT)); } private Schema lazySchema() { @@ -178,6 +193,30 @@ public StructType readSchema() { return lazyType(); } + /** + * This is called in the Spark Driver when data is to be materialized into {@link ColumnarBatch} + */ + @Override + public List> planBatchInputPartitions() { + Preconditions.checkState(enableBatchRead(), "Batched reads not enabled"); + Preconditions.checkState(batchSize > 0, "Invalid batch size"); + String tableSchemaString = SchemaParser.toJson(table.schema()); + String expectedSchemaString = SchemaParser.toJson(lazySchema()); + + List> readTasks = Lists.newArrayList(); + for (CombinedScanTask task : tasks()) { + readTasks.add(new ReadTask<>( + task, tableSchemaString, expectedSchemaString, io, encryptionManager, caseSensitive, localityPreferred, + new BatchReaderFactory(batchSize))); + } + LOG.info("Batching input partitions with {} tasks.", readTasks.size()); + + return readTasks; + } + + /** + * This is called in the Spark Driver when data is to be materialized into {@link InternalRow} + */ @Override public List> planInputPartitions() { String tableSchemaString = SchemaParser.toJson(table.schema()); @@ -185,9 +224,9 @@ public List> planInputPartitions() { List> readTasks = Lists.newArrayList(); for (CombinedScanTask task : tasks()) { - readTasks.add( - new ReadTask(task, tableSchemaString, expectedSchemaString, io, encryptionManager, - caseSensitive, localityPreferred)); + readTasks.add(new ReadTask<>( + task, tableSchemaString, expectedSchemaString, io, encryptionManager, caseSensitive, localityPreferred, + InternalRowReaderFactory.INSTANCE)); } return readTasks; @@ -249,6 +288,31 @@ public Statistics estimateStatistics() { return new Stats(sizeInBytes, numRows); } + @Override + public boolean enableBatchRead() { + if (readUsingBatch == null) { + boolean allParquetFileScanTasks = + tasks().stream() + .allMatch(combinedScanTask -> !combinedScanTask.isDataTask() && combinedScanTask.files() + .stream() + .allMatch(fileScanTask -> fileScanTask.file().format().equals( + FileFormat.PARQUET))); + + boolean atLeastOneColumn = lazySchema().columns().size() > 0; + + boolean hasNoIdentityProjections = tasks().stream() + .allMatch(combinedScanTask -> combinedScanTask.files() + .stream() + .allMatch(fileScanTask -> fileScanTask.spec().identitySourceIds().isEmpty())); + + boolean onlyPrimitives = lazySchema().columns().stream().allMatch(c -> c.type().isPrimitiveType()); + + this.readUsingBatch = batchReadsEnabled && allParquetFileScanTasks && atLeastOneColumn && + hasNoIdentityProjections && onlyPrimitives; + } + return readUsingBatch; + } + private static void mergeIcebergHadoopConfs( Configuration baseConf, Map options) { options.keySet().stream() @@ -299,7 +363,7 @@ private List tasks() { try (CloseableIterable tasksIterable = scan.planTasks()) { this.tasks = Lists.newArrayList(tasksIterable); - } catch (IOException e) { + } catch (IOException e) { throw new RuntimeIOException(e, "Failed to close table scan: %s", scan); } } @@ -310,11 +374,11 @@ private List tasks() { @Override public String toString() { return String.format( - "IcebergScan(table=%s, type=%s, filters=%s, caseSensitive=%s)", - table, lazySchema().asStruct(), filterExpressions, caseSensitive); + "IcebergScan(table=%s, type=%s, filters=%s, caseSensitive=%s, batchedReads=%s)", + table, lazySchema().asStruct(), filterExpressions, caseSensitive, enableBatchRead()); } - private static class ReadTask implements InputPartition, Serializable { + private static class ReadTask implements Serializable, InputPartition { private final CombinedScanTask task; private final String tableSchemaString; private final String expectedSchemaString; @@ -322,6 +386,7 @@ private static class ReadTask implements InputPartition, Serializab private final Broadcast encryptionManager; private final boolean caseSensitive; private final boolean localityPreferred; + private final ReaderFactory readerFactory; private transient Schema tableSchema = null; private transient Schema expectedSchema = null; @@ -329,7 +394,7 @@ private static class ReadTask implements InputPartition, Serializab private ReadTask(CombinedScanTask task, String tableSchemaString, String expectedSchemaString, Broadcast io, Broadcast encryptionManager, - boolean caseSensitive, boolean localityPreferred) { + boolean caseSensitive, boolean localityPreferred, ReaderFactory readerFactory) { this.task = task; this.tableSchemaString = tableSchemaString; this.expectedSchemaString = expectedSchemaString; @@ -338,12 +403,13 @@ private ReadTask(CombinedScanTask task, String tableSchemaString, String expecte this.caseSensitive = caseSensitive; this.localityPreferred = localityPreferred; this.preferredLocations = getPreferredLocations(); + this.readerFactory = readerFactory; } @Override - public InputPartitionReader createPartitionReader() { - return new RowDataReader(task, lazyTableSchema(), lazyExpectedSchema(), io.value(), - encryptionManager.value(), caseSensitive); + public InputPartitionReader createPartitionReader() { + return readerFactory.create(task, lazyTableSchema(), lazyExpectedSchema(), io.value(), + encryptionManager.value(), caseSensitive); } @Override @@ -375,6 +441,40 @@ private String[] getPreferredLocations() { } } + private interface ReaderFactory extends Serializable { + InputPartitionReader create(CombinedScanTask task, Schema tableSchema, Schema expectedSchema, FileIO io, + EncryptionManager encryptionManager, boolean caseSensitive); + } + + private static class InternalRowReaderFactory implements ReaderFactory { + private static final InternalRowReaderFactory INSTANCE = new InternalRowReaderFactory(); + + private InternalRowReaderFactory() { + } + + @Override + public InputPartitionReader create(CombinedScanTask task, Schema tableSchema, Schema expectedSchema, + FileIO io, EncryptionManager encryptionManager, + boolean caseSensitive) { + return new RowDataReader(task, tableSchema, expectedSchema, io, encryptionManager, caseSensitive); + } + } + + private static class BatchReaderFactory implements ReaderFactory { + private final int batchSize; + + BatchReaderFactory(int batchSize) { + this.batchSize = batchSize; + } + + @Override + public InputPartitionReader create(CombinedScanTask task, Schema tableSchema, Schema expectedSchema, + FileIO io, EncryptionManager encryptionManager, + boolean caseSensitive) { + return new BatchDataReader(task, expectedSchema, io, encryptionManager, caseSensitive, batchSize); + } + } + private static class StructLikeInternalRow implements StructLike { private final DataType[] types; private InternalRow row = null; diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java b/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java index 57e61efd6afa..966a0e656dd3 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/AvroDataTest.java @@ -22,6 +22,7 @@ import java.io.IOException; import java.util.concurrent.atomic.AtomicInteger; import org.apache.iceberg.Schema; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.types.TypeUtil; import org.apache.iceberg.types.Types; import org.apache.iceberg.types.Types.ListType; @@ -66,6 +67,23 @@ public void testSimpleStruct() throws IOException { writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields()))); } + @Test + public void testStructWithRequiredFields() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asRequired)))); + } + + @Test + public void testStructWithOptionalFields() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)))); + } + + @Test + public void testNestedStruct() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(required(1, "struct", SUPPORTED_PRIMITIVES)))); + } + @Test public void testArray() throws IOException { Schema schema = new Schema( diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java b/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java index b5f0b7153b7a..f99c0fccb89c 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/RandomData.java @@ -20,7 +20,9 @@ package org.apache.iceberg.spark.data; import java.math.BigDecimal; +import java.math.BigInteger; import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Iterator; import java.util.List; import java.util.Map; @@ -33,6 +35,7 @@ import org.apache.avro.generic.GenericData.Record; import org.apache.iceberg.Schema; import org.apache.iceberg.avro.AvroSchemaUtil; +import org.apache.iceberg.relocated.com.google.common.base.Preconditions; import org.apache.iceberg.relocated.com.google.common.collect.Lists; import org.apache.iceberg.relocated.com.google.common.collect.Maps; import org.apache.iceberg.relocated.com.google.common.collect.Sets; @@ -49,10 +52,13 @@ public class RandomData { + // Default percentage of number of values that are null for optional fields + public static final float DEFAULT_NULL_PERCENTAGE = 0.05f; + private RandomData() {} public static List generateList(Schema schema, int numRecords, long seed) { - RandomDataGenerator generator = new RandomDataGenerator(schema, seed); + RandomDataGenerator generator = new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE); List records = Lists.newArrayListWithExpectedSize(numRecords); for (int i = 0; i < numRecords; i += 1) { records.add((Record) TypeUtil.visit(schema, generator)); @@ -83,9 +89,27 @@ public InternalRow next() { } public static Iterable generate(Schema schema, int numRecords, long seed) { + return newIterable(() -> new RandomDataGenerator(schema, seed, DEFAULT_NULL_PERCENTAGE), schema, numRecords); + } + + public static Iterable generate(Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable(() -> new RandomDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + public static Iterable generateFallbackData(Schema schema, int numRecords, long seed, long numDictRecords) { + return newIterable(() -> new FallbackDataGenerator(schema, seed, numDictRecords), schema, numRecords); + } + + public static Iterable generateDictionaryEncodableData( + Schema schema, int numRecords, long seed, float nullPercentage) { + return newIterable(() -> new DictionaryEncodedDataGenerator(schema, seed, nullPercentage), schema, numRecords); + } + + private static Iterable newIterable(Supplier newGenerator, + Schema schema, int numRecords) { return () -> new Iterator() { - private RandomDataGenerator generator = new RandomDataGenerator(schema, seed); private int count = 0; + private RandomDataGenerator generator = newGenerator.get(); @Override public boolean hasNext() { @@ -106,8 +130,14 @@ public Record next() { private static class RandomDataGenerator extends TypeUtil.CustomOrderSchemaVisitor { private final Map typeToSchema; private final Random random; - - private RandomDataGenerator(Schema schema, long seed) { + // Percentage of number of values that are null for optional fields + private final float nullPercentage; + + private RandomDataGenerator(Schema schema, long seed, float nullPercentage) { + Preconditions.checkArgument( + 0.0f <= nullPercentage && nullPercentage <= 1.0f, + "Percentage needs to be in the range (0.0, 1.0)"); + this.nullPercentage = nullPercentage; this.typeToSchema = AvroSchemaUtil.convertTypes(schema.asStruct(), "test"); this.random = new Random(seed); } @@ -131,21 +161,23 @@ public Record struct(Types.StructType struct, Iterable fieldResults) { @Override public Object field(Types.NestedField field, Supplier fieldResult) { - // return null 5% of the time when the value is optional - if (field.isOptional() && random.nextInt(20) == 1) { + if (field.isOptional() && isNull()) { return null; } return fieldResult.get(); } + private boolean isNull() { + return random.nextFloat() < nullPercentage; + } + @Override public Object list(Types.ListType list, Supplier elementResult) { int numElements = random.nextInt(20); List result = Lists.newArrayListWithExpectedSize(numElements); for (int i = 0; i < numElements; i += 1) { - // return null 5% of the time when the value is optional - if (list.isElementOptional() && random.nextInt(20) == 1) { + if (list.isElementOptional() && isNull()) { result.add(null); } else { result.add(elementResult.get()); @@ -170,8 +202,7 @@ public Object map(Types.MapType map, Supplier keyResult, Supplier keyResult, Supplier { @@ -295,4 +330,71 @@ public Object primitive(Type.PrimitiveType primitive) { } } } + + private static Object generateDictionaryEncodablePrimitive(Type.PrimitiveType primitive, Random random) { + int value = random.nextInt(3); + switch (primitive.typeId()) { + case BOOLEAN: + return true; // doesn't really matter for booleans since they are not dictionary encoded + case INTEGER: + case DATE: + return value; + case FLOAT: + return (float) value; + case DOUBLE: + return (double) value; + case LONG: + case TIME: + case TIMESTAMP: + return (long) value; + case STRING: + return String.valueOf(value); + case FIXED: + byte[] fixed = new byte[((Types.FixedType) primitive).length()]; + Arrays.fill(fixed, (byte) value); + return fixed; + case BINARY: + byte[] binary = new byte[value + 1]; + Arrays.fill(binary, (byte) value); + return binary; + case DECIMAL: + Types.DecimalType type = (Types.DecimalType) primitive; + BigInteger unscaled = new BigInteger(String.valueOf(value + 1)); + return new BigDecimal(unscaled, type.scale()); + default: + throw new IllegalArgumentException( + "Cannot generate random value for unknown type: " + primitive); + } + } + + private static class DictionaryEncodedDataGenerator extends RandomDataGenerator { + private DictionaryEncodedDataGenerator(Schema schema, long seed, float nullPercentage) { + super(schema, seed, nullPercentage); + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random random) { + return generateDictionaryEncodablePrimitive(primitive, random); + } + } + + private static class FallbackDataGenerator extends RandomDataGenerator { + private final long dictionaryEncodedRows; + private long rowCount = 0; + + private FallbackDataGenerator(Schema schema, long seed, long numDictionaryEncoded) { + super(schema, seed, DEFAULT_NULL_PERCENTAGE); + this.dictionaryEncodedRows = numDictionaryEncoded; + } + + @Override + protected Object randomValue(Type.PrimitiveType primitive, Random rand) { + this.rowCount += 1; + if (rowCount > dictionaryEncodedRows) { + return RandomUtil.generatePrimitive(primitive, rand); + } else { + return generateDictionaryEncodablePrimitive(primitive, rand); + } + } + } } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java index 433f87c75582..f603757c2c44 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestHelpers.java @@ -29,13 +29,16 @@ import java.time.temporal.ChronoUnit; import java.util.Collection; import java.util.Date; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.UUID; +import org.apache.arrow.vector.ValueVector; import org.apache.avro.generic.GenericData; import org.apache.avro.generic.GenericData.Record; import org.apache.iceberg.Schema; import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.vectorized.IcebergArrowColumnVector; import org.apache.iceberg.types.Type; import org.apache.iceberg.types.Types; import org.apache.orc.storage.serde2.io.DateWritable; @@ -53,6 +56,8 @@ import org.apache.spark.sql.types.MapType; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnVector; +import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.unsafe.types.UTF8String; import org.junit.Assert; import scala.collection.Seq; @@ -78,6 +83,28 @@ public static void assertEqualsSafe(Types.StructType struct, Record rec, Row row } } + public static void assertEqualsBatch(Types.StructType struct, Iterator expected, ColumnarBatch batch, + boolean checkArrowValidityVector) { + for (int rowId = 0; rowId < batch.numRows(); rowId++) { + List fields = struct.fields(); + InternalRow row = batch.getRow(rowId); + Record rec = expected.next(); + for (int i = 0; i < fields.size(); i += 1) { + Type fieldType = fields.get(i).type(); + Object expectedValue = rec.get(i); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); + assertEqualsUnsafe(fieldType, expectedValue, actualValue); + + if (checkArrowValidityVector) { + ColumnVector columnVector = batch.column(i); + ValueVector arrowVector = ((IcebergArrowColumnVector) columnVector).vectorAccessor().getVector(); + Assert.assertEquals("Nullability doesn't match", expectedValue == null, arrowVector.isNull(rowId)); + } + } + } + } + + private static void assertEqualsSafe(Types.ListType list, Collection expected, List actual) { Type elementType = list.elementType(); List expectedElements = Lists.newArrayList(expected); @@ -199,7 +226,7 @@ public static void assertEqualsUnsafe(Types.StructType struct, Record rec, Inter Type fieldType = fields.get(i).type(); Object expectedValue = rec.get(i); - Object actualValue = row.get(i, convert(fieldType)); + Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType)); assertEqualsUnsafe(fieldType, expectedValue, actualValue); } diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java index 1466deab2af2..d7bd696ad608 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroReader.java @@ -186,7 +186,7 @@ public void testWithOldReadPath() throws IOException { @Test public void testCorrectness() throws IOException { - Iterable records = RandomData.generate(COMPLEX_SCHEMA, 250_000, 34139); + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java index 0e97c37ffe79..dcfc873a5a67 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestParquetAvroWriter.java @@ -74,7 +74,7 @@ public class TestParquetAvroWriter { @Test public void testCorrectness() throws IOException { - Iterable records = RandomData.generate(COMPLEX_SCHEMA, 250_000, 34139); + Iterable records = RandomData.generate(COMPLEX_SCHEMA, 50_000, 34139); File testFile = temp.newFile(); Assert.assertTrue("Delete should succeed", testFile.delete()); diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java index 4ff784448e80..c75a87abc45c 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java +++ b/spark/src/test/java/org/apache/iceberg/spark/data/TestSparkParquetWriter.java @@ -71,7 +71,7 @@ public class TestSparkParquetWriter { @Test public void testCorrectness() throws IOException { - int numRows = 250_000; + int numRows = 50_000; Iterable records = RandomData.generateSpark(COMPLEX_SCHEMA, numRows, 19981); File testFile = temp.newFile(); diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java new file mode 100644 index 000000000000..7f2d9c32cac8 --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryEncodedVectorizedReads.java @@ -0,0 +1,42 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Schema; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetDictionaryEncodedVectorizedReads extends TestParquetVectorizedReads { + + @Override + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage) { + return RandomData.generateDictionaryEncodableData(schema, numRecords, seed, nullPercentage); + } + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException { + + } +} diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java new file mode 100644 index 000000000000..ad9d020c74f4 --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetDictionaryFallbackToPlainEncodingVectorizedReads.java @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.TableProperties; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.spark.data.RandomData; +import org.junit.Ignore; +import org.junit.Test; + +public class TestParquetDictionaryFallbackToPlainEncodingVectorizedReads extends TestParquetVectorizedReads { + private static final int NUM_ROWS = 1_000_000; + + @Override + protected int getNumRows() { + return NUM_ROWS; + } + + @Override + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage) { + //TODO: take into account nullPercentage when generating fallback encoding data + return RandomData.generateFallbackData(schema, numRecords, seed, numRecords / 20); + } + + @Override + FileAppender getParquetWriter(Schema schema, File testFile) throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .set(TableProperties.PARQUET_DICT_SIZE_BYTES, "512000") + .build(); + } + + @Test + @Override + @Ignore // Fallback encoding not triggered when data is mostly null + public void testMostlyNullsForOptionalFields() { + + } + + @Test + @Override + @Ignore // Ignored since this code path is already tested in TestParquetVectorizedReads + public void testVectorizedReadsWithNewContainers() throws IOException { + + } +} diff --git a/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java new file mode 100644 index 000000000000..3e4f5f95c57e --- /dev/null +++ b/spark/src/test/java/org/apache/iceberg/spark/data/parquet/vectorized/TestParquetVectorizedReads.java @@ -0,0 +1,196 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iceberg.spark.data.parquet.vectorized; + +import java.io.File; +import java.io.IOException; +import java.util.Iterator; +import org.apache.avro.generic.GenericData; +import org.apache.iceberg.AssertHelpers; +import org.apache.iceberg.Files; +import org.apache.iceberg.Schema; +import org.apache.iceberg.io.CloseableIterable; +import org.apache.iceberg.io.FileAppender; +import org.apache.iceberg.parquet.Parquet; +import org.apache.iceberg.relocated.com.google.common.collect.Lists; +import org.apache.iceberg.spark.data.AvroDataTest; +import org.apache.iceberg.spark.data.RandomData; +import org.apache.iceberg.spark.data.TestHelpers; +import org.apache.iceberg.spark.data.vectorized.VectorizedSparkParquetReaders; +import org.apache.iceberg.types.TypeUtil; +import org.apache.iceberg.types.Types; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Ignore; +import org.junit.Test; + +import static org.apache.iceberg.types.Types.NestedField.required; + +public class TestParquetVectorizedReads extends AvroDataTest { + private static final int NUM_ROWS = 200_000; + + @Override + protected void writeAndValidate(Schema schema) throws IOException { + writeAndValidate(schema, getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, false, true); + } + + private void writeAndValidate( + Schema schema, int numRecords, long seed, float nullPercentage, + boolean setAndCheckArrowValidityVector, boolean reuseContainers) + throws IOException { + // Write test data + Assume.assumeTrue("Parquet Avro cannot write non-string map keys", null == TypeUtil.find( + schema, + type -> type.isMapType() && type.asMapType().keyType() != Types.StringType.get())); + + Iterable expected = generateData(schema, numRecords, seed, nullPercentage); + + // write a test parquet file using iceberg writer + File testFile = temp.newFile(); + Assert.assertTrue("Delete should succeed", testFile.delete()); + + try (FileAppender writer = getParquetWriter(schema, testFile)) { + writer.addAll(expected); + } + assertRecordsMatch(schema, numRecords, expected, testFile, setAndCheckArrowValidityVector, reuseContainers); + } + + protected int getNumRows() { + return NUM_ROWS; + } + + Iterable generateData(Schema schema, int numRecords, long seed, float nullPercentage) { + return RandomData.generate(schema, numRecords, seed, nullPercentage); + } + + FileAppender getParquetWriter(Schema schema, File testFile) throws IOException { + return Parquet.write(Files.localOutput(testFile)) + .schema(schema) + .named("test") + .build(); + } + + private void assertRecordsMatch( + Schema schema, int expectedSize, Iterable expected, File testFile, + boolean setAndCheckArrowValidityBuffer, boolean reuseContainers) + throws IOException { + Parquet.ReadBuilder readBuilder = Parquet.read(Files.localInput(testFile)) + .project(schema) + .recordsPerBatch(10000) + .createBatchedReaderFunc(type -> VectorizedSparkParquetReaders.buildReader( + schema, + type, + setAndCheckArrowValidityBuffer)); + if (reuseContainers) { + readBuilder.reuseContainers(); + } + try (CloseableIterable batchReader = + readBuilder.build()) { + Iterator expectedIter = expected.iterator(); + Iterator batches = batchReader.iterator(); + int numRowsRead = 0; + while (batches.hasNext()) { + ColumnarBatch batch = batches.next(); + numRowsRead += batch.numRows(); + TestHelpers.assertEqualsBatch(schema.asStruct(), expectedIter, batch, setAndCheckArrowValidityBuffer); + } + Assert.assertEquals(expectedSize, numRowsRead); + } + } + + @Test + @Ignore + public void testArray() { + } + + @Test + @Ignore + public void testArrayOfStructs() { + } + + @Test + @Ignore + public void testMap() { + } + + @Test + @Ignore + public void testNumericMapKey() { + } + + @Test + @Ignore + public void testComplexMapKey() { + } + + @Test + @Ignore + public void testMapOfStructs() { + } + + @Test + @Ignore + public void testMixedTypes() { + } + + @Test + @Override + public void testNestedStruct() { + AssertHelpers.assertThrows( + "Vectorized reads are not supported yet for struct fields", + UnsupportedOperationException.class, + "Vectorized reads are not supported yet for struct fields", + () -> VectorizedSparkParquetReaders.buildReader( + TypeUtil.assignIncreasingFreshIds(new Schema(required( + 1, + "struct", + SUPPORTED_PRIMITIVES))), + new MessageType("struct", new GroupType(Type.Repetition.OPTIONAL, "struct").withId(1)), + false)); + } + + @Test + public void testMostlyNullsForOptionalFields() throws IOException { + writeAndValidate( + TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), + 0L, + 0.99f, + false, + true); + } + + @Test + public void testSettingArrowValidityVector() throws IOException { + writeAndValidate(new Schema( + Lists.transform(SUPPORTED_PRIMITIVES.fields(), Types.NestedField::asOptional)), + getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true, true); + } + + @Test + public void testVectorizedReadsWithNewContainers() throws IOException { + writeAndValidate(TypeUtil.assignIncreasingFreshIds(new Schema(SUPPORTED_PRIMITIVES.fields())), + getNumRows(), 0L, RandomData.DEFAULT_NULL_PERCENTAGE, true, false); + } +} diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java index fcecb17c5b4d..f171fdb5766c 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestParquetScan.java @@ -31,6 +31,7 @@ import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.hadoop.HadoopTables; import org.apache.iceberg.io.FileAppender; import org.apache.iceberg.parquet.Parquet; @@ -48,15 +49,15 @@ import org.junit.BeforeClass; import org.junit.Rule; import org.junit.rules.TemporaryFolder; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; import static org.apache.iceberg.Files.localOutput; +@RunWith(Parameterized.class) public class TestParquetScan extends AvroDataTest { private static final Configuration CONF = new Configuration(); - @Rule - public TemporaryFolder temp = new TemporaryFolder(); - private static SparkSession spark = null; @BeforeClass @@ -71,6 +72,23 @@ public static void stopSpark() { currentSpark.stop(); } + @Rule + public TemporaryFolder temp = new TemporaryFolder(); + + @Parameterized.Parameters + public static Object[][] parameters() { + return new Object[][] { + new Object[] { false }, + new Object[] { true }, + }; + } + + private final boolean vectorized; + + public TestParquetScan(boolean vectorized) { + this.vectorized = vectorized; + } + @Override protected void writeAndValidate(Schema schema) throws IOException { Assume.assumeTrue("Cannot handle non-string map keys in parquet-avro", @@ -108,6 +126,7 @@ protected void writeAndValidate(Schema schema) throws IOException { .build(); table.newAppend().appendFile(file).commit(); + table.updateProperties().set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)).commit(); Dataset df = spark.read() .format("iceberg") diff --git a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java index 41b00918e18c..8bb951818258 100644 --- a/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java +++ b/spark/src/test/java/org/apache/iceberg/spark/source/TestSparkReadProjection.java @@ -31,6 +31,7 @@ import org.apache.iceberg.PartitionSpec; import org.apache.iceberg.Schema; import org.apache.iceberg.Table; +import org.apache.iceberg.TableProperties; import org.apache.iceberg.avro.Avro; import org.apache.iceberg.data.Record; import org.apache.iceberg.data.avro.DataWriter; @@ -66,14 +67,20 @@ public class TestSparkReadProjection extends TestReadProjection { @Parameterized.Parameters public static Object[][] parameters() { return new Object[][] { - new Object[] { "parquet" }, - new Object[] { "avro" }, - new Object[] { "orc" } + new Object[] { "parquet", false }, + new Object[] { "parquet", true }, + new Object[] { "avro", false }, + new Object[] { "orc", false } }; } - public TestSparkReadProjection(String format) { + private final FileFormat format; + private final boolean vectorized; + + public TestSparkReadProjection(String format, boolean vectorized) { super(format); + this.format = FileFormat.valueOf(format.toUpperCase(Locale.ROOT)); + this.vectorized = vectorized; } @BeforeClass @@ -96,9 +103,7 @@ protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema File dataFolder = new File(location, "data"); Assert.assertTrue("mkdirs should succeed", dataFolder.mkdirs()); - FileFormat fileFormat = FileFormat.valueOf(format.toUpperCase(Locale.ENGLISH)); - - File testFile = new File(dataFolder, fileFormat.addExtension(UUID.randomUUID().toString())); + File testFile = new File(dataFolder, format.addExtension(UUID.randomUUID().toString())); Table table = TestTables.create(location, desc, writeSchema, PartitionSpec.unpartitioned()); try { @@ -106,7 +111,7 @@ protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema // When tables are created, the column ids are reassigned. Schema tableSchema = table.schema(); - switch (fileFormat) { + switch (format) { case AVRO: try (FileAppender writer = Avro.write(localOutput(testFile)) .createWriterFunc(DataWriter::create) @@ -143,6 +148,8 @@ protected Record writeAndRead(String desc, Schema writeSchema, Schema readSchema table.newAppend().appendFile(file).commit(); + table.updateProperties().set(TableProperties.PARQUET_VECTORIZATION_ENABLED, String.valueOf(vectorized)).commit(); + // rewrite the read schema for the table's reassigned ids Map idMapping = Maps.newHashMap(); for (int id : allIds(writeSchema)) {