diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java index 28dcc44b28ca..b26088753465 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java @@ -17,13 +17,38 @@ package org.apache.spark.sql.execution.datasources.parquet; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.PrimitiveIterator; + /** * Helper class to store intermediate state while reading a Parquet column chunk. */ final class ParquetReadState { - /** Maximum definition level */ + /** A special row range used when there is no row indexes (hence all rows must be included) */ + private static final RowRange MAX_ROW_RANGE = new RowRange(Long.MIN_VALUE, Long.MAX_VALUE); + + /** + * A special row range used when the row indexes are present AND all the row ranges have been + * processed. This serves as a sentinel at the end indicating that all rows come after the last + * row range should be skipped. + */ + private static final RowRange END_ROW_RANGE = new RowRange(Long.MAX_VALUE, Long.MIN_VALUE); + + /** Iterator over all row ranges, only not-null if column index is present */ + private final Iterator rowRanges; + + /** The current row range */ + private RowRange currentRange; + + /** Maximum definition level for the Parquet column */ final int maxDefinitionLevel; + /** The current index over all rows within the column chunk. This is used to check if the + * current row should be skipped by comparing against the row ranges. */ + long rowId; + /** The offset in the current batch to put the next value */ int offset; @@ -33,31 +58,108 @@ final class ParquetReadState { /** The remaining number of values to read in the current batch */ int valuesToReadInBatch; - ParquetReadState(int maxDefinitionLevel) { + ParquetReadState(int maxDefinitionLevel, PrimitiveIterator.OfLong rowIndexes) { this.maxDefinitionLevel = maxDefinitionLevel; + this.rowRanges = constructRanges(rowIndexes); + nextRange(); } /** - * Called at the beginning of reading a new batch. + * Construct a list of row ranges from the given `rowIndexes`. For example, suppose the + * `rowIndexes` are `[0, 1, 2, 4, 5, 7, 8, 9]`, it will be converted into 3 row ranges: + * `[0-2], [4-5], [7-9]`. */ - void resetForBatch(int batchSize) { + private Iterator constructRanges(PrimitiveIterator.OfLong rowIndexes) { + if (rowIndexes == null) { + return null; + } + + List rowRanges = new ArrayList<>(); + long currentStart = Long.MIN_VALUE; + long previous = Long.MIN_VALUE; + + while (rowIndexes.hasNext()) { + long idx = rowIndexes.nextLong(); + if (currentStart == Long.MIN_VALUE) { + currentStart = idx; + } else if (previous + 1 != idx) { + RowRange range = new RowRange(currentStart, previous); + rowRanges.add(range); + currentStart = idx; + } + previous = idx; + } + + if (previous != Long.MIN_VALUE) { + rowRanges.add(new RowRange(currentStart, previous)); + } + + return rowRanges.iterator(); + } + + /** + * Must be called at the beginning of reading a new batch. + */ + void resetForNewBatch(int batchSize) { this.offset = 0; this.valuesToReadInBatch = batchSize; } /** - * Called at the beginning of reading a new page. + * Must be called at the beginning of reading a new page. */ - void resetForPage(int totalValuesInPage) { + void resetForNewPage(int totalValuesInPage, long pageFirstRowIndex) { this.valuesToReadInPage = totalValuesInPage; + this.rowId = pageFirstRowIndex; } /** - * Advance the current offset to the new values. + * Returns the start index of the current row range. */ - void advanceOffset(int newOffset) { + long currentRangeStart() { + return currentRange.start; + } + + /** + * Returns the end index of the current row range. + */ + long currentRangeEnd() { + return currentRange.end; + } + + /** + * Advance the current offset and rowId to the new values. + */ + void advanceOffsetAndRowId(int newOffset, long newRowId) { valuesToReadInBatch -= (newOffset - offset); - valuesToReadInPage -= (newOffset - offset); + valuesToReadInPage -= (newRowId - rowId); offset = newOffset; + rowId = newRowId; + } + + /** + * Advance to the next range. + */ + void nextRange() { + if (rowRanges == null) { + currentRange = MAX_ROW_RANGE; + } else if (!rowRanges.hasNext()) { + currentRange = END_ROW_RANGE; + } else { + currentRange = rowRanges.next(); + } + } + + /** + * Helper struct to represent a range of row indexes `[start, end]`. + */ + private static class RowRange { + final long start; + final long end; + + RowRange(long start, long end) { + this.start = start; + this.end = end; + } } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java index b91d507a3878..9bb852987e65 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdater.java @@ -30,12 +30,20 @@ public interface ParquetVectorUpdater { * @param values destination values vector * @param valuesReader reader to read values from */ - void updateBatch( + void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); + /** + * Skip a batch of `total` values from `valuesReader`. + * + * @param total total number of values to skip + * @param valuesReader reader to skip values from + */ + void skipValues(int total, VectorizedValuesReader valuesReader); + /** * Read a single value from `valuesReader` into `values`, at `offset`. * @@ -43,7 +51,7 @@ void updateBatch( * @param values destination value vector * @param valuesReader reader to read values from */ - void update(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); + void readValue(int offset, WritableColumnVector values, VectorizedValuesReader valuesReader); /** * Process a batch of `total` values starting from `offset` in `values`, whose null slots diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java index 62e34fe549f0..2282dc798463 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorUpdaterFactory.java @@ -185,7 +185,7 @@ boolean isUnsignedIntTypeMatched(int bitWidth) { private static class BooleanUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -194,7 +194,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBooleans(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -213,7 +218,7 @@ public void decodeSingleDictionaryId( private static class IntegerUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -222,7 +227,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -241,7 +251,7 @@ public void decodeSingleDictionaryId( private static class UnsignedIntegerUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -250,7 +260,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -270,7 +285,7 @@ public void decodeSingleDictionaryId( private static class ByteUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -279,7 +294,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBytes(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -298,7 +318,7 @@ public void decodeSingleDictionaryId( private static class ShortUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -307,7 +327,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipShorts(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -332,7 +357,7 @@ private static class IntegerWithRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -341,7 +366,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipIntegers(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -362,7 +392,7 @@ public void decodeSingleDictionaryId( private static class LongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -371,7 +401,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -390,7 +425,7 @@ public void decodeSingleDictionaryId( private static class DowncastLongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -401,7 +436,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -420,7 +460,7 @@ public void decodeSingleDictionaryId( private static class UnsignedLongUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -429,7 +469,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -457,7 +502,7 @@ private static class LongWithRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -466,7 +511,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -487,18 +537,23 @@ public void decodeSingleDictionaryId( private static class LongAsMicrosUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -524,18 +579,23 @@ private static class LongAsMicrosRebaseUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; ++i) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipLongs(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -557,7 +617,7 @@ public void decodeSingleDictionaryId( private static class FloatUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -566,7 +626,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFloats(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -585,7 +650,7 @@ public void decodeSingleDictionaryId( private static class DoubleUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -594,7 +659,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipDoubles(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -613,7 +683,7 @@ public void decodeSingleDictionaryId( private static class BinaryUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, @@ -622,7 +692,12 @@ public void updateBatch( } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipBinary(total); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -642,18 +717,23 @@ public void decodeSingleDictionaryId( private static class BinaryToSQLTimestampUpdater implements ParquetVectorUpdater { @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -681,18 +761,23 @@ private static class BinaryToSQLTimestampConvertTzUpdater implements ParquetVect } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -723,18 +808,23 @@ private static class BinaryToSQLTimestampRebaseUpdater implements ParquetVectorU } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -767,18 +857,23 @@ private static class BinaryToSQLTimestampConvertTzRebaseUpdater implements Parqu } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, 12); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -811,18 +906,23 @@ private static class FixedLenByteArrayUpdater implements ParquetVectorUpdater { } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -848,18 +948,23 @@ private static class FixedLenByteArrayAsIntUpdater implements ParquetVectorUpdat } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { @@ -886,18 +991,23 @@ private static class FixedLenByteArrayAsLongUpdater implements ParquetVectorUpda } @Override - public void updateBatch( + public void readValues( int total, int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { for (int i = 0; i < total; i++) { - update(offset + i, values, valuesReader); + readValue(offset + i, values, valuesReader); } } @Override - public void update( + public void skipValues(int total, VectorizedValuesReader valuesReader) { + valuesReader.skipFixedLenByteArray(total, arrayLen); + } + + @Override + public void readValue( int offset, WritableColumnVector values, VectorizedValuesReader valuesReader) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index c61ee460880a..92dea08102df 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -19,6 +19,7 @@ import java.io.IOException; import java.time.ZoneId; +import java.util.PrimitiveIterator; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; @@ -74,6 +75,12 @@ public class VectorizedColumnReader { */ private final ParquetReadState readState; + /** + * The index for the first row in the current page, among all rows across all pages in the + * column chunk for this reader. If there is no column index, the value is 0. + */ + private long pageFirstRowIndex; + private final PageReader pageReader; private final ColumnDescriptor descriptor; private final LogicalTypeAnnotation logicalTypeAnnotation; @@ -83,12 +90,13 @@ public VectorizedColumnReader( ColumnDescriptor descriptor, LogicalTypeAnnotation logicalTypeAnnotation, PageReader pageReader, + PrimitiveIterator.OfLong rowIndexes, ZoneId convertTz, String datetimeRebaseMode, String int96RebaseMode) throws IOException { this.descriptor = descriptor; this.pageReader = pageReader; - this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel()); + this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel(), rowIndexes); this.logicalTypeAnnotation = logicalTypeAnnotation; this.updaterFactory = new ParquetVectorUpdaterFactory( logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode); @@ -151,18 +159,19 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // page. dictionaryIds = column.reserveDictionaryIds(total); } - readState.resetForBatch(total); + readState.resetForNewBatch(total); while (readState.valuesToReadInBatch > 0) { - // Compute the number of values we want to read in this page. if (readState.valuesToReadInPage == 0) { int pageValueCount = readPage(); - readState.resetForPage(pageValueCount); + readState.resetForNewPage(pageValueCount, pageFirstRowIndex); } PrimitiveType.PrimitiveTypeName typeName = descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Save starting offset in case we need to decode dictionary IDs. int startOffset = readState.offset; + // Save starting row index so we can check if we need to eagerly decode dict ids later + long startRowId = readState.rowId; // Read and decode dictionary ids. defColumn.readIntegers(readState, dictionaryIds, column, @@ -170,10 +179,12 @@ void readBatch(int total, WritableColumnVector column) throws IOException { // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. - if (column.hasDictionary() || (startOffset == 0 && isLazyDecodingSupported(typeName))) { + if (column.hasDictionary() || (startRowId == pageFirstRowIndex && + isLazyDecodingSupported(typeName))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. - // We can't do this if rowId != 0 AND the column doesn't have a dictionary (i.e. some - // non-dictionary encoded values have already been added). + // We can't do this if startRowId is not the first row index in the page AND the column + // doesn't have a dictionary (i.e. some non-dictionary encoded values have already been + // added). PrimitiveType primitiveType = descriptor.getPrimitiveType(); // We need to make sure that we initialize the right type for the dictionary otherwise @@ -213,6 +224,8 @@ void readBatch(int total, WritableColumnVector column) throws IOException { private int readPage() { DataPage page = pageReader.readPage(); + this.pageFirstRowIndex = page.getFirstRowIndex().orElse(0L); + return page.accept(new DataPage.Visitor() { @Override public Integer visit(DataPageV1 dataPageV1) { @@ -268,7 +281,6 @@ private void initDataReader( } private int readPageV1(DataPageV1 page) throws IOException { - // Initialize the decoders. if (page.getDlEncoding() != Encoding.RLE && descriptor.getMaxDefinitionLevel() != 0) { throw new UnsupportedOperationException("Unsupported encoding: " + page.getDlEncoding()); } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 32455278c4fb..9f7836ae4818 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -334,6 +334,7 @@ private void checkEndOfRowGroup() throws IOException { columns.get(i), types.get(i).getLogicalTypeAnnotation(), pages.getPageReader(columns.get(i)), + pages.getRowIndexes().orElse(null), convertTz, datetimeRebaseMode, int96RebaseMode); diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java index 6a0038dbdc44..39591be3b4be 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedPlainValuesReader.java @@ -61,6 +61,14 @@ public final void readBooleans(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void skipBooleans(int total) { + // TODO: properly vectorize this + for (int i = 0; i < total; i++) { + readBoolean(); + } + } + private ByteBuffer getBuffer(int length) { try { return in.slice(length).order(ByteOrder.LITTLE_ENDIAN); @@ -84,6 +92,11 @@ public final void readIntegers(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipIntegers(int total) { + in.skip(total * 4L); + } + @Override public final void readUnsignedIntegers(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 4; @@ -140,6 +153,11 @@ public final void readLongs(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipLongs(int total) { + in.skip(total * 8L); + } + @Override public final void readUnsignedLongs(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 8; @@ -197,6 +215,11 @@ public final void readFloats(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipFloats(int total) { + in.skip(total * 4L); + } + @Override public final void readDoubles(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 8; @@ -212,6 +235,11 @@ public final void readDoubles(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipDoubles(int total) { + in.skip(total * 8L); + } + @Override public final void readBytes(int total, WritableColumnVector c, int rowId) { // Bytes are stored as a 4-byte little endian int. Just read the first byte. @@ -226,6 +254,11 @@ public final void readBytes(int total, WritableColumnVector c, int rowId) { } } + @Override + public final void skipBytes(int total) { + in.skip(total * 4L); + } + @Override public final void readShorts(int total, WritableColumnVector c, int rowId) { int requiredBytes = total * 4; @@ -236,6 +269,11 @@ public final void readShorts(int total, WritableColumnVector c, int rowId) { } } + @Override + public void skipShorts(int total) { + in.skip(total * 4L); + } + @Override public final boolean readBoolean() { // TODO: vectorize decoding and keep boolean[] instead of currentByte @@ -300,6 +338,14 @@ public final void readBinary(int total, WritableColumnVector v, int rowId) { } } + @Override + public void skipBinary(int total) { + for (int i = 0; i < total; i++) { + int len = readInteger(); + in.skip(len); + } + } + @Override public final Binary readBinary(int len) { ByteBuffer buffer = getBuffer(len); @@ -312,4 +358,9 @@ public final Binary readBinary(int len) { return Binary.fromConstantByteArray(bytes); } } + + @Override + public void skipFixedLenByteArray(int total, int len) { + in.skip(total * (long) len); + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index 538b69877e2d..03bda0fedbd2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -156,18 +156,12 @@ public int readInteger() { } /** - * Reads `total` ints into `c` filling them in starting at `c[rowId]`. This reader - * reads the definition levels and then will read from `data` for the non-null values. - * If the value is null, c will be populated with `nullValue`. Note that `nullValue` is only - * necessary for readIntegers because we also use it to decode dictionaryIds and want to make - * sure it always has a value in range. - * - * This is a batched version of this logic: - * if (this.readInt() == level) { - * c[rowId] = data.readInteger(); - * } else { - * c[rowId] = null; - * } + * Reads a batch of values into vector `values`, using `valueReader`. The related states such + * as row index, offset, number of values left in the batch and page, etc, are tracked by + * `state`. The type-specific `updater` is used to update or skip values. + *

+ * This reader reads the definition levels and then will read from `valueReader` for the + * non-null values. If the value is null, `values` will be populated with null value. */ public void readBatch( ParquetReadState state, @@ -175,36 +169,68 @@ public void readBatch( VectorizedValuesReader valueReader, ParquetVectorUpdater updater) throws IOException { int offset = state.offset; - int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage); + long rowId = state.rowId; + int leftInBatch = state.valuesToReadInBatch; + int leftInPage = state.valuesToReadInPage; - while (left > 0) { + while (leftInBatch > 0 && leftInPage > 0) { if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - - switch (mode) { - case RLE: - if (currentValue == state.maxDefinitionLevel) { - updater.updateBatch(n, offset, values, valueReader); - } else { - values.putNulls(offset, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { - updater.update(offset + i, values, valueReader); + int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount)); + + long rangeStart = state.currentRangeStart(); + long rangeEnd = state.currentRangeEnd(); + + if (rowId + n < rangeStart) { + updater.skipValues(n, valueReader); + advance(n); + rowId += n; + leftInPage -= n; + } else if (rowId > rangeEnd) { + state.nextRange(); + } else { + // the range [rowId, rowId + n) overlaps with the current row range in state + long start = Math.max(rangeStart, rowId); + long end = Math.min(rangeEnd, rowId + n - 1); + + // skip the part [rowId, start) + int toSkip = (int) (start - rowId); + if (toSkip > 0) { + updater.skipValues(toSkip, valueReader); + advance(toSkip); + rowId += toSkip; + leftInPage -= toSkip; + } + + // read the part [start, end] + n = (int) (end - start + 1); + + switch (mode) { + case RLE: + if (currentValue == state.maxDefinitionLevel) { + updater.readValues(n, offset, values, valueReader); } else { - values.putNull(offset + i); + values.putNulls(offset, n); } - } - break; + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { + updater.readValue(offset + i, values, valueReader); + } else { + values.putNull(offset + i); + } + } + break; + } + offset += n; + leftInBatch -= n; + rowId += n; + leftInPage -= n; + currentCount -= n; } - offset += n; - left -= n; - currentCount -= n; } - state.advanceOffset(offset); + state.advanceOffsetAndRowId(offset, rowId); } /** @@ -217,36 +243,68 @@ public void readIntegers( WritableColumnVector nulls, VectorizedValuesReader data) throws IOException { int offset = state.offset; - int left = Math.min(state.valuesToReadInBatch, state.valuesToReadInPage); + long rowId = state.rowId; + int leftInBatch = state.valuesToReadInBatch; + int leftInPage = state.valuesToReadInPage; - while (left > 0) { + while (leftInBatch > 0 && leftInPage > 0) { if (this.currentCount == 0) this.readNextGroup(); - int n = Math.min(left, this.currentCount); - - switch (mode) { - case RLE: - if (currentValue == state.maxDefinitionLevel) { - data.readIntegers(n, values, offset); - } else { - nulls.putNulls(offset, n); - } - break; - case PACKED: - for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { - values.putInt(offset + i, data.readInteger()); + int n = Math.min(leftInBatch, Math.min(leftInPage, this.currentCount)); + + long rangeStart = state.currentRangeStart(); + long rangeEnd = state.currentRangeEnd(); + + if (rowId + n < rangeStart) { + data.skipIntegers(n); + advance(n); + rowId += n; + leftInPage -= n; + } else if (rowId > rangeEnd) { + state.nextRange(); + } else { + // the range [rowId, rowId + n) overlaps with the current row range in state + long start = Math.max(rangeStart, rowId); + long end = Math.min(rangeEnd, rowId + n - 1); + + // skip the part [rowId, start) + int toSkip = (int) (start - rowId); + if (toSkip > 0) { + data.skipIntegers(toSkip); + advance(toSkip); + rowId += toSkip; + leftInPage -= toSkip; + } + + // read the part [start, end] + n = (int) (end - start + 1); + + switch (mode) { + case RLE: + if (currentValue == state.maxDefinitionLevel) { + data.readIntegers(n, values, offset); } else { - nulls.putNull(offset + i); + nulls.putNulls(offset, n); } - } - break; + break; + case PACKED: + for (int i = 0; i < n; ++i) { + if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { + values.putInt(offset + i, data.readInteger()); + } else { + nulls.putNull(offset + i); + } + } + break; + } + rowId += n; + leftInPage -= n; + offset += n; + leftInBatch -= n; + currentCount -= n; } - offset += n; - left -= n; - currentCount -= n; } - state.advanceOffset(offset); + state.advanceOffsetAndRowId(offset, rowId); } @@ -346,6 +404,71 @@ public Binary readBinary(int len) { throw new UnsupportedOperationException("only readInts is valid."); } + @Override + public void skipIntegers(int total) { + int left = total; + while (left > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int n = Math.min(left, this.currentCount); + advance(n); + left -= n; + } + } + + @Override + public void skipBooleans(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipBytes(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipShorts(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipLongs(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipFloats(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipDoubles(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipBinary(int total) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + @Override + public void skipFixedLenByteArray(int total, int len) { + throw new UnsupportedOperationException("only skipIntegers is valid"); + } + + /** + * Advance and skip the next `n` values in the current block. `n` MUST be <= `currentCount`. + */ + private void advance(int n) { + switch (mode) { + case RLE: + break; + case PACKED: + currentBufferIdx += n; + break; + } + currentCount -= n; + } + /** * Reads the next varint encoded int. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java index a2d663fd8c8b..fc4eac94d1c4 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedValuesReader.java @@ -50,4 +50,17 @@ public interface VectorizedValuesReader { void readFloats(int total, WritableColumnVector c, int rowId); void readDoubles(int total, WritableColumnVector c, int rowId); void readBinary(int total, WritableColumnVector c, int rowId); + + /* + * Skips `total` values + */ + void skipBooleans(int total); + void skipBytes(int total); + void skipShorts(int total); + void skipIntegers(int total); + void skipLongs(int total); + void skipFloats(int total); + void skipDoubles(int total); + void skipBinary(int total); + void skipFixedLenByteArray(int total, int len); } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala new file mode 100644 index 000000000000..f10b7013185b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -0,0 +1,126 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import org.apache.spark.sql.{DataFrame, QueryTest} +import org.apache.spark.sql.test.SharedSparkSession + +class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSparkSession { + import testImplicits._ + + /** + * create parquet file with two columns and unaligned pages + * pages will be of the following layout + * col_1 500 500 500 500 + * |---------|---------|---------|---------| + * |-------|-----|-----|---|---|---|---|---| + * col_2 400 300 200 200 200 200 200 200 + */ + def checkUnalignedPages(actions: (DataFrame => DataFrame)*): Unit = { + withTempPath(file => { + val ds = spark.range(0, 2000).map(i => (i, i + ":" + "o" * (i / 100).toInt)) + ds.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + + actions.foreach { action => + checkAnswer(action(parquetDf), action(ds.toDF())) + } + }) + } + + test("reading from unaligned pages - test filters") { + checkUnalignedPages( + // single value filter + df => df.filter("_1 = 500"), + df => df.filter("_1 = 500 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1500"), + df => df.filter("_1 = 500 or _1 = 501 or _1 = 1000 or _1 = 1500"), + // range filter + df => df.filter("_1 >= 500 and _1 < 1000"), + df => df.filter("(_1 >= 500 and _1 < 1000) or (_1 >= 1500 and _1 < 1600)") + ) + } + + test("test reading unaligned pages - test all types") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id as short) as _3", + "cast(id as int) as _4", + "cast(id as float) as _5", + "cast(id as double) as _6", + "cast(id as decimal(20,0)) as _7", + "cast(cast(1618161925000 + id * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + id as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500 " + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } + + test("test reading unaligned pages - test all types (dict encode)") { + withTempPath(file => { + val df = spark.range(0, 2000).selectExpr( + "id as _1", + "cast(id % 10 as byte) as _2", + "cast(id % 10 as short) as _3", + "cast(id % 10 as int) as _4", + "cast(id % 10 as float) as _5", + "cast(id % 10 as double) as _6", + "cast(id % 10 as decimal(20,0)) as _7", + "cast(id % 2 as boolean) as _8", + "cast(cast(1618161925000 + (id % 10) * 1000 * 60 * 60 * 24 as timestamp) as date) as _9", + "cast(1618161925000 + (id % 10) as timestamp) as _10" + ) + df.coalesce(1) + .write + .option("parquet.page.size", "4096") + .parquet(file.getCanonicalPath) + + val parquetDf = spark.read.parquet(file.getCanonicalPath) + val singleValueFilterExpr = "_1 = 500 or _1 = 1500" + checkAnswer( + parquetDf.filter(singleValueFilterExpr), + df.filter(singleValueFilterExpr) + ) + val rangeFilterExpr = "_1 > 500" + checkAnswer( + parquetDf.filter(rangeFilterExpr), + df.filter(rangeFilterExpr) + ) + }) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index bc4234f01b5f..a330b82de2d0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -368,7 +368,9 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession private def createParquetWriter( schema: MessageType, path: Path, - dictionaryEnabled: Boolean = false): ParquetWriter[Group] = { + dictionaryEnabled: Boolean = false, + pageSize: Int = 1024, + dictionaryPageSize: Int = 1024): ParquetWriter[Group] = { val hadoopConf = spark.sessionState.newHadoopConf() ExampleParquetWriter @@ -378,11 +380,77 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession .withWriterVersion(PARQUET_1_0) .withCompressionCodec(GZIP) .withRowGroupSize(1024 * 1024) - .withPageSize(1024) + .withPageSize(pageSize) + .withDictionaryPageSize(dictionaryPageSize) .withConf(hadoopConf) .build() } + test("SPARK-34859: test multiple pages with different sizes and nulls") { + def makeRawParquetFile( + path: Path, + dictionaryEnabled: Boolean, + n: Int, + pageSize: Int): Seq[Option[Int]] = { + val schemaStr = + """ + |message root { + | optional boolean _1; + | optional int32 _2; + | optional int64 _3; + | optional float _4; + | optional double _5; + |} + """.stripMargin + + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, + dictionaryEnabled = dictionaryEnabled, pageSize = pageSize, dictionaryPageSize = pageSize) + + val rand = scala.util.Random + val expected = (0 until n).map { i => + if (rand.nextBoolean()) { + None + } else { + Some(i) + } + } + expected.foreach { opt => + val record = new SimpleGroup(schema) + opt match { + case Some(i) => + record.add(0, i % 2 == 0) + record.add(1, i) + record.add(2, i.toLong) + record.add(3, i.toFloat) + record.add(4, i.toDouble) + case _ => + } + writer.write(record) + } + + writer.close() + expected + } + + Seq(true, false).foreach { dictionaryEnabled => + Seq(64, 128, 89).foreach { pageSize => + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = makeRawParquetFile(path, dictionaryEnabled, 1000, pageSize) + readParquetFile(path.toString) { df => + checkAnswer(df, expected.map { + case None => + Row(null, null, null, null, null) + case Some(i) => + Row(i % 2 == 0, i, i.toLong, i.toFloat, i.toDouble) + }) + } + } + } + } + } + test("read raw Parquet file") { def makeRawParquetFile(path: Path): Unit = { val schemaStr =