diff --git a/pom.xml b/pom.xml index 81a0126539b1..e8164315606a 100644 --- a/pom.xml +++ b/pom.xml @@ -2333,6 +2333,27 @@ ${parquet.version} ${parquet.deps.scope} + + org.apache.parquet + parquet-encoding + ${parquet.version} + ${parquet.test.deps.scope} + tests + + + org.apache.parquet + parquet-common + ${parquet.version} + ${parquet.test.deps.scope} + tests + + + org.apache.parquet + parquet-column + ${parquet.version} + ${parquet.test.deps.scope} + tests + org.apache.parquet parquet-hadoop diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java index 5ef4fba193e0..0033680199a3 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java @@ -216,7 +216,7 @@ public double[] getDoubles(int rowId, int count) { * the struct type, and each child vector is responsible to store the data for its corresponding * struct field. */ - public final ColumnarRow getStruct(int rowId) { + public ColumnarRow getStruct(int rowId) { if (isNullAt(rowId)) return null; return new ColumnarRow(this, rowId); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 8ba2b9f8fd2a..9620d1397f30 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -886,6 +886,14 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED = + buildConf("spark.sql.parquet.enableNestedColumnVectorizedReader") + .doc("Enables vectorized Parquet decoding for nested columns (e.g., struct, list, map). " + + s"Note to enable this ${PARQUET_VECTORIZED_READER_ENABLED} also needs to be enabled.") + .version("3.3.0") + .booleanConf + .createWithDefault(true) + val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled") .doc("If true, enables Parquet's native record-level filtering using the pushed down " + "filters. " + @@ -3612,6 +3620,9 @@ class SQLConf extends Serializable with Logging { def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED) + def parquetVectorizedReaderNestedColumnEnabled: Boolean = + getConf(PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED) + def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE) def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE) diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 27260ce67ae7..752dcd902b4d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -106,6 +106,24 @@ org.apache.parquet parquet-column + + org.apache.parquet + parquet-encoding + test + tests + + + org.apache.parquet + parquet-common + test + tests + + + org.apache.parquet + parquet-column + test + tests + org.apache.parquet parquet-hadoop diff --git a/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java b/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java new file mode 100644 index 000000000000..c1732cc206de --- /dev/null +++ b/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.parquet.io; + +/** + * This is a workaround since both methods below are not public in {@link ColumnIO}. + * TODO(SPARK-36511): we should remove this once PARQUET-2050 is released with Parquet 1.13. + */ +public class ColumnIOUtil { + private ColumnIOUtil() {} + + public static int getDefinitionLevel(ColumnIO column) { + return column.getDefinitionLevel(); + } + + public static int getRepetitionLevel(ColumnIO column) { + return column.getRepetitionLevel(); + } + + public static String[] getFieldPath(ColumnIO column) { + return column.getFieldPath(); + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java index 40ed0b2454c1..9df95892f28b 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java @@ -176,8 +176,7 @@ public void initBatch( // Initialize the missing columns once. if (colId == -1) { OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt); - missingCol.putNulls(0, capacity); - missingCol.setIsConstant(); + missingCol.setAllNull(); orcVectorWrappers[i] = missingCol; } else { orcVectorWrappers[i] = OrcColumnVectorUtils.toOrcColumnVector( diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java new file mode 100644 index 000000000000..c4f6f8778226 --- /dev/null +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java @@ -0,0 +1,318 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +import com.google.common.base.Preconditions; +import org.apache.spark.memory.MemoryMode; +import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; +import org.apache.spark.sql.execution.vectorized.WritableColumnVector; +import org.apache.spark.sql.types.ArrayType; +import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.MapType; +import org.apache.spark.sql.types.StructType; + +/** + * Contains necessary information representing a Parquet column, either of primitive or nested type. + */ +final class ParquetColumn { + private final ParquetType type; + private final List children; + private final WritableColumnVector vector; + + /** + * Repetition & Definition levels + * These are allocated only for leaf columns; for non-leaf columns, they simply maintain + * references to that of the former. + */ + private WritableColumnVector repetitionLevels; + private WritableColumnVector definitionLevels; + + /** Whether this column is primitive (i.e., leaf column) */ + private final boolean isPrimitive; + + /** Reader for this column - only set if 'isPrimitive' is true */ + private VectorizedColumnReader columnReader; + + ParquetColumn( + ParquetType type, + WritableColumnVector vector, + int capacity, + MemoryMode memoryMode, + Set missingColumns) { + DataType sparkType = type.sparkType(); + if (!sparkType.sameType(vector.dataType())) { + throw new IllegalArgumentException("Spark type: " + type.sparkType() + + " doesn't match the type: " + vector.dataType() + " in column vector"); + } + + this.type = type; + this.vector = vector; + this.children = new ArrayList<>(); + this.isPrimitive = type.isPrimitive(); + + if (missingColumns.contains(type)) { + vector.setAllNull(); + return; + } + + if (isPrimitive) { + // TODO: avoid allocating these if not necessary, for instance, the node is of top-level + // and is not repeated, or the node is not top-level but its max repetition level is 0. + repetitionLevels = allocateLevelsVector(capacity, memoryMode); + definitionLevels = allocateLevelsVector(capacity, memoryMode); + } else { + Preconditions.checkArgument(type.children().size() == vector.getNumChildren()); + for (int i = 0; i < type.children().size(); i++) { + ParquetColumn childColumn = new ParquetColumn(type.children().apply(i), + vector.getChild(i), capacity, memoryMode, missingColumns); + children.add(childColumn); + + // only use levels from non-missing child + if (!childColumn.vector.isAllNull()) { + this.repetitionLevels = childColumn.repetitionLevels; + this.definitionLevels = childColumn.definitionLevels; + } + } + + // this can happen if all the fields of a struct are missing, in which case we should mark + // the struct itself as a missing column + if (repetitionLevels == null) { + vector.setAllNull(); + } + } + } + + /** + * Returns all the children of this column. + */ + List getChildren() { + return children; + } + + /** + * Returns all the leaf columns in depth-first order. + */ + List getLeaves() { + List result = new ArrayList<>(); + getLeavesHelper(this, result); + return result; + } + + /** + * Assembles this column and calculate collection offsets recursively. + * This is a no-op for primitive columns. + */ + void assemble() { + // nothing to do if the column itself is missing + if (vector.isAllNull()) return; + + DataType type = this.type.sparkType(); + if (type instanceof ArrayType || type instanceof MapType) { + for (ParquetColumn child : children) { + child.assemble(); + } + calculateCollectionOffsets(); + } else if (type instanceof StructType) { + for (ParquetColumn child : children) { + child.assemble(); + } + calculateStructOffsets(); + } + } + + void reset() { + // nothing to do if the column itself is missing + if (vector.isAllNull()) return; + + vector.reset(); + repetitionLevels.reset(); + definitionLevels.reset(); + for (ParquetColumn childColumn : children) { + childColumn.reset(); + } + } + + ParquetType getType() { + return this.type; + } + + WritableColumnVector getValueVector() { + return this.vector; + } + + WritableColumnVector getRepetitionLevelVector() { + return this.repetitionLevels; + } + + WritableColumnVector getDefinitionLevelVector() { + return this.definitionLevels; + } + + VectorizedColumnReader getColumnReader() { + return this.columnReader; + } + + void setColumnReader(VectorizedColumnReader reader) { + if (!isPrimitive) { + throw new IllegalStateException("can't set reader for non-primitive column"); + } + this.columnReader = reader; + } + + private static void getLeavesHelper(ParquetColumn column, List coll) { + if (column.isPrimitive) { + coll.add(column); + } else { + for (ParquetColumn child : column.children) { + getLeavesHelper(child, coll); + } + } + } + + private void calculateCollectionOffsets() { + int maxDefinitionLevel = type.definitionLevel(); + int maxElementRepetitionLevel = type.repetitionLevel(); + + // There are 4 cases when calculating definition levels: + // 1. definitionLevel == maxDefinitionLevel + // ==> value is defined and not null + // 2. definitionLevel == maxDefinitionLevel - 1 + // ==> value is null + // 3. definitionLevel < maxDefinitionLevel - 1 + // ==> value doesn't exist since one of its optional parent is null + // 4. definitionLevel > maxDefinitionLevel + // ==> value is a nested element within an array or map + // + // `i` is the index over all leaf elements of this array, while `offset` is the index over + // all top-level elements of this array. + int rowId = 0; + for (int i = 0, offset = 0; i < definitionLevels.getElementsAppended(); + i = getNextCollectionStart(maxElementRepetitionLevel, i)) { + vector.reserve(rowId + 1); + int definitionLevel = definitionLevels.getInt(i); + if (definitionLevel == maxDefinitionLevel - 1) { + // the collection is null + vector.putNull(rowId++); + } else if (definitionLevel == maxDefinitionLevel) { + // collection is defined but empty + vector.putNotNull(rowId); + vector.putArray(rowId, offset, 0); + rowId++; + } else if (definitionLevel > maxDefinitionLevel) { + // collection is defined and non-empty: find out how many top element there is till the + // start of the next array. + vector.putNotNull(rowId); + int length = getCollectionSize(maxElementRepetitionLevel, i + 1); + vector.putArray(rowId, offset, length); + offset += length; + rowId++; + } + } + vector.addElementsAppended(rowId); + } + + private void calculateStructOffsets() { + int maxRepetitionLevel = type.repetitionLevel(); + int maxDefinitionLevel = type.definitionLevel(); + + vector.reserve(definitionLevels.getElementsAppended()); + int rowId = 0; + int nonnullRowId = 0; + boolean hasRepetitionLevels = repetitionLevels.getElementsAppended() > 0; + for (int i = 0; i < definitionLevels.getElementsAppended(); i++) { + // if repetition level > maxRepetitionLevel, the value is a nested element (e.g., an array + // element in struct>), and we should skip the definition level since it doesn't + // represent with the struct. + if (!hasRepetitionLevels || repetitionLevels.getInt(i) <= maxRepetitionLevel) { + if (definitionLevels.getInt(i) == maxDefinitionLevel - 1) { + // the struct is null + vector.putNull(rowId); + rowId++; + } else if (definitionLevels.getInt(i) >= maxDefinitionLevel) { + vector.putNotNull(rowId); + vector.putStruct(rowId, nonnullRowId); + rowId++; + nonnullRowId++; + } + } + } + vector.addElementsAppended(rowId); + } + + private static WritableColumnVector allocateLevelsVector(int capacity, MemoryMode memoryMode) { + switch (memoryMode) { + case ON_HEAP: + return new OnHeapColumnVector(capacity, DataTypes.IntegerType); + case OFF_HEAP: + return new OffHeapColumnVector(capacity, DataTypes.IntegerType); + default: + throw new IllegalArgumentException("Unknown memory mode: " + memoryMode); + } + } + + private int getNextCollectionStart(int maxRepetitionLevel, int elementIndex) { + int idx = elementIndex + 1; + for (; idx < repetitionLevels.getElementsAppended(); idx++) { + if (repetitionLevels.getInt(idx) <= maxRepetitionLevel) { + break; + } + } + return idx; + } + + private int getCollectionSize(int maxRepetitionLevel, int idx) { + int size = 1; + for (; idx < repetitionLevels.getElementsAppended(); idx++) { + if (repetitionLevels.getInt(idx) <= maxRepetitionLevel) { + break; + } else if (repetitionLevels.getInt(idx) <= maxRepetitionLevel + 1) { + // only count elements which belong to the current collection + // For instance, suppose we have the following Parquet schema: + // + // message schema { max rl max dl + // optional group col (LIST) { 0 1 + // repeated group list { 1 2 + // optional group element (LIST) { 1 3 + // repeated group list { 2 4 + // required int32 element; 2 4 + // } + // } + // } + // } + // } + // + // For a list such as: [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], the repetition & definition + // levels would be: + // + // repetition levels: [0, 2, 1, 2, 0, 2, 1, 2] + // definition levels: [2, 2, 2, 2, 2, 2, 2, 2] + // + // when calculating collection size for the outer array, we should only count repetition + // levels whose value is <= 1 (which is the max repetition level for the inner array) + size++; + } + } + return size; + } +} diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java index b26088753465..e4f75e5cfa05 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet; +import org.apache.parquet.column.ColumnDescriptor; + import java.util.ArrayList; import java.util.Iterator; import java.util.List; @@ -45,21 +47,48 @@ final class ParquetReadState { /** Maximum definition level for the Parquet column */ final int maxDefinitionLevel; + /** Maximum repetition level for the Parquet column */ + final int maxRepetitionLevel; + + /** Whether this column is required */ + final boolean isRequired; + /** The current index over all rows within the column chunk. This is used to check if the * current row should be skipped by comparing against the row ranges. */ long rowId; - /** The offset in the current batch to put the next value */ - int offset; + /** The offset to put new values into definition & repetition level vector */ + int levelOffset; + + /** The offset to put new values into value vector */ + int valueOffset; /** The remaining number of values to read in the current page */ int valuesToReadInPage; - /** The remaining number of values to read in the current batch */ - int valuesToReadInBatch; + /** The remaining number of rows to read in the current batch */ + int rowsToReadInBatch; + + // The following are only used when reading repeated values + + /** When processing repeated values, whether we've found the beginning of the first list after the + * current batch. */ + boolean lastListCompleted; - ParquetReadState(int maxDefinitionLevel, PrimitiveIterator.OfLong rowIndexes) { - this.maxDefinitionLevel = maxDefinitionLevel; + /** When processing repeated types, the number of accumulated definition levels to process */ + int numBatchedDefLevels; + + /** When processing repeated types, whether we should skip the current batch of definition + * levels. */ + boolean shouldSkip; + + ParquetReadState( + ColumnDescriptor descriptor, + boolean isRequired, + PrimitiveIterator.OfLong rowIndexes) { + this.maxDefinitionLevel = descriptor.getMaxDefinitionLevel(); + this.maxRepetitionLevel = descriptor.getMaxRepetitionLevel(); + this.isRequired = isRequired; this.rowRanges = constructRanges(rowIndexes); nextRange(); } @@ -101,8 +130,12 @@ private Iterator constructRanges(PrimitiveIterator.OfLong rowIndexes) * Must be called at the beginning of reading a new batch. */ void resetForNewBatch(int batchSize) { - this.offset = 0; - this.valuesToReadInBatch = batchSize; + this.levelOffset = 0; + this.valueOffset = 0; + this.rowsToReadInBatch = batchSize; + this.lastListCompleted = this.maxRepetitionLevel == 0; + this.numBatchedDefLevels = 0; + this.shouldSkip = false; } /** @@ -127,16 +160,6 @@ long currentRangeEnd() { return currentRange.end; } - /** - * Advance the current offset and rowId to the new values. - */ - void advanceOffsetAndRowId(int newOffset, long newRowId) { - valuesToReadInBatch -= (newOffset - offset); - valuesToReadInPage -= (newRowId - rowId); - offset = newOffset; - rowId = newRowId; - } - /** * Advance to the next range. */ diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java index 6264d6341c65..59ead1b0f096 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.parquet; +import java.io.Closeable; import java.io.File; import java.io.IOException; import java.lang.reflect.InvocationTargetException; @@ -29,6 +30,7 @@ import java.util.Map; import java.util.Set; +import org.apache.parquet.column.page.PageReadStore; import scala.Option; import org.apache.hadoop.conf.Configuration; @@ -64,9 +66,9 @@ * this way, albeit at a higher cost to implement. This base class is reusable. */ public abstract class SpecificParquetRecordReaderBase extends RecordReader { - protected Path file; protected MessageType fileSchema; - protected MessageType requestedSchema; + protected MessageType requestedParquetSchema; + protected ParquetType requestedSchema; protected StructType sparkSchema; /** @@ -75,31 +77,42 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader fileMetadata = reader.getFileMetaData().getKeyValueMetaData(); + ParquetFileReader fileReader = new ParquetFileReader( + HadoopInputFile.fromPath(file, configuration), options); + this.reader = new ParquetRowGroupReaderImpl(fileReader); + this.fileSchema = fileReader.getFileMetaData().getSchema(); + Map fileMetadata = fileReader.getFileMetaData().getKeyValueMetaData(); ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration)); ReadSupport.ReadContext readContext = readSupport.init(new InitContext( taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema)); - this.requestedSchema = readContext.getRequestedSchema(); - reader.setRequestedSchema(requestedSchema); + this.requestedParquetSchema = readContext.getRequestedSchema(); + fileReader.setRequestedSchema(requestedParquetSchema); String sparkRequestedSchemaString = configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA()); - this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); - this.totalRowCount = reader.getFilteredRecordCount(); + StructType sparkRequestedSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString); + boolean caseSensitive = configuration.getBoolean(SQLConf.CASE_SENSITIVE().key(), + Boolean.getBoolean(SQLConf.CASE_SENSITIVE().defaultValueString())); + ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(configuration); + this.requestedSchema = converter.convertTypeInfo(requestedParquetSchema, + Option.apply(sparkRequestedSchema), caseSensitive); + this.sparkSchema = (StructType) requestedSchema.sparkType(); + this.totalRowCount = fileReader.getFilteredRecordCount(); // For test purpose. // If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read @@ -111,7 +124,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) { @SuppressWarnings("unchecked") AccumulatorV2 intAccum = (AccumulatorV2) accu.get(); - intAccum.add(reader.getRowGroups().size()); + intAccum.add(fileReader.getRowGroups().size()); } } } @@ -148,18 +161,19 @@ protected void initialize(String path, List columns) throws IOException config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false); config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false); - this.file = new Path(path); - long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen(); + Path file = new Path(path); + long length = file.getFileSystem(config).getFileStatus(file).getLen(); ParquetReadOptions options = HadoopReadOptions .builder(config) .withRange(0, length) .build(); - this.reader = ParquetFileReader.open(HadoopInputFile.fromPath(file, config), options); - this.fileSchema = reader.getFooter().getFileMetaData().getSchema(); + ParquetFileReader fileReader = ParquetFileReader.open( + HadoopInputFile.fromPath(file, config), options); + this.fileSchema = fileReader.getFooter().getFileMetaData().getSchema(); if (columns == null) { - this.requestedSchema = fileSchema; + this.requestedParquetSchema = fileSchema; } else { if (columns.size() > 0) { Types.MessageTypeBuilder builder = Types.buildMessage(); @@ -170,14 +184,35 @@ protected void initialize(String path, List columns) throws IOException } builder.addFields(fileSchema.getType(s)); } - this.requestedSchema = builder.named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME()); + this.requestedParquetSchema = + builder.named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME()); } else { - this.requestedSchema = ParquetSchemaConverter.EMPTY_MESSAGE(); + this.requestedParquetSchema = ParquetSchemaConverter.EMPTY_MESSAGE(); } } - reader.setRequestedSchema(requestedSchema); - this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema); - this.totalRowCount = reader.getFilteredRecordCount(); + fileReader.setRequestedSchema(requestedParquetSchema); + this.requestedSchema = new ParquetToSparkSchemaConverter(config) + .convertTypeInfo(requestedParquetSchema, Option.empty(), true); + this.sparkSchema = (StructType) requestedSchema.sparkType(); + this.totalRowCount = fileReader.getFilteredRecordCount(); + this.reader = new ParquetRowGroupReaderImpl(fileReader); + } + + protected void initialize( + MessageType fileSchema, + MessageType requestedSchema, + ParquetRowGroupReader rowGroupReader, + int totalRowCount) throws IOException { + this.reader = rowGroupReader; + this.fileSchema = fileSchema; + this.requestedParquetSchema = requestedSchema; + Configuration config = new Configuration(); + config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false); + config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false); + this.requestedSchema = new ParquetToSparkSchemaConverter(config) + .convertTypeInfo(requestedSchema, Option.empty(), true); + this.sparkSchema = (StructType) this.requestedSchema.sparkType(); + this.totalRowCount = totalRowCount; } @Override @@ -222,4 +257,31 @@ private static ReadSupport getReadSupportInstance( throw new BadConfigurationException("could not instantiate read support class", e); } } + + interface ParquetRowGroupReader extends Closeable { + /** + * Read the next row group from this reader. Returns null if there is no more row group. + */ + PageReadStore readNextRowGroup() throws IOException; + } + + private static class ParquetRowGroupReaderImpl implements ParquetRowGroupReader { + private final ParquetFileReader reader; + + ParquetRowGroupReaderImpl(ParquetFileReader reader) { + this.reader = reader; + } + + @Override + public PageReadStore readNextRowGroup() throws IOException { + return reader.readNextFilteredRowGroup(); + } + + @Override + public void close() throws IOException { + if (reader != null) { + reader.close(); + } + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java index 92dea08102df..f4fc8e616c66 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java @@ -19,8 +19,8 @@ import java.io.IOException; import java.time.ZoneId; -import java.util.PrimitiveIterator; +import com.google.common.base.Preconditions; import org.apache.parquet.bytes.ByteBufferInputStream; import org.apache.parquet.bytes.BytesInput; import org.apache.parquet.bytes.BytesUtils; @@ -38,7 +38,6 @@ import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.types.Decimal; -import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL; import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64; /** @@ -65,6 +64,11 @@ public class VectorizedColumnReader { */ private VectorizedRleValuesReader defColumn; + /** + * Vectorized RLE decoder for repetition levels + */ + private VectorizedRleValuesReader repColumn; + /** * Factory to get type-specific vector updater. */ @@ -88,16 +92,16 @@ public class VectorizedColumnReader { public VectorizedColumnReader( ColumnDescriptor descriptor, - LogicalTypeAnnotation logicalTypeAnnotation, - PageReader pageReader, - PrimitiveIterator.OfLong rowIndexes, + boolean isRequiredColumn, + PageReadStore pageReadStore, ZoneId convertTz, String datetimeRebaseMode, String int96RebaseMode) throws IOException { this.descriptor = descriptor; - this.pageReader = pageReader; - this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel(), rowIndexes); - this.logicalTypeAnnotation = logicalTypeAnnotation; + this.pageReader = pageReadStore.getPageReader(descriptor); + this.readState = new ParquetReadState(descriptor, isRequiredColumn, + pageReadStore.getRowIndexes().orElse(null)); + this.logicalTypeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation(); this.updaterFactory = new ParquetVectorUpdaterFactory( logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode); @@ -149,37 +153,54 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName /** * Reads `total` values from this columnReader into column. */ - void readBatch(int total, WritableColumnVector column) throws IOException { + void readBatch( + int total, + WritableColumnVector values, + WritableColumnVector repetitionLevels, + WritableColumnVector definitionLevels) throws IOException { WritableColumnVector dictionaryIds = null; - ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, column.dataType()); + ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, values.dataType()); if (dictionary != null) { // SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to // decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded // page. - dictionaryIds = column.reserveDictionaryIds(total); + dictionaryIds = values.reserveDictionaryIds(total); } readState.resetForNewBatch(total); - while (readState.valuesToReadInBatch > 0) { + while (readState.rowsToReadInBatch > 0 || !readState.lastListCompleted) { if (readState.valuesToReadInPage == 0) { int pageValueCount = readPage(); + if (pageValueCount < 0) { + // we've read all the pages + break; + } readState.resetForNewPage(pageValueCount, pageFirstRowIndex); } + PrimitiveType.PrimitiveTypeName typeName = descriptor.getPrimitiveType().getPrimitiveTypeName(); if (isCurrentPageDictionaryEncoded) { // Save starting offset in case we need to decode dictionary IDs. - int startOffset = readState.offset; + int startOffset = readState.valueOffset; // Save starting row index so we can check if we need to eagerly decode dict ids later long startRowId = readState.rowId; + Preconditions.checkNotNull(dictionaryIds, "dictionaryIds == null when " + + "isCurrentPageDictionaryEncoded is true"); + // Read and decode dictionary ids. - defColumn.readIntegers(readState, dictionaryIds, column, - (VectorizedValuesReader) dataColumn); + if (readState.maxRepetitionLevel == 0) { + defColumn.readIntegers(readState, dictionaryIds, values, definitionLevels, + (VectorizedValuesReader) dataColumn); + } else { + repColumn.readIntegersNested(readState, repetitionLevels, defColumn, definitionLevels, + dictionaryIds, values, (VectorizedValuesReader) dataColumn); + } // TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process // the values to add microseconds precision. - if (column.hasDictionary() || (startRowId == pageFirstRowIndex && + if (values.hasDictionary() || (startRowId == pageFirstRowIndex && isLazyDecodingSupported(typeName))) { // Column vector supports lazy decoding of dictionary values so just set the dictionary. // We can't do this if startRowId is not the first row index in the page AND the column @@ -204,26 +225,34 @@ void readBatch(int total, WritableColumnVector column) throws IOException { boolean isUnsignedInt64 = updaterFactory.isUnsignedIntTypeMatched(64); boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64; - column.setDictionary(new ParquetDictionary(dictionary, needTransform)); + values.setDictionary(new ParquetDictionary(dictionary, needTransform)); } else { - updater.decodeDictionaryIds(readState.offset - startOffset, startOffset, column, + updater.decodeDictionaryIds(readState.valueOffset - startOffset, startOffset, values, dictionaryIds, dictionary); } } else { - if (column.hasDictionary() && readState.offset != 0) { + if (values.hasDictionary() && readState.valueOffset != 0) { // This batch already has dictionary encoded values but this new page is not. The batch // does not support a mix of dictionary and not so we will decode the dictionary. - updater.decodeDictionaryIds(readState.offset, 0, column, dictionaryIds, dictionary); + updater.decodeDictionaryIds(readState.valueOffset, 0, values, dictionaryIds, dictionary); } - column.setDictionary(null); + values.setDictionary(null); VectorizedValuesReader valuesReader = (VectorizedValuesReader) dataColumn; - defColumn.readBatch(readState, column, valuesReader, updater); + if (readState.maxRepetitionLevel == 0) { + defColumn.readBatch(readState, values, definitionLevels, valuesReader, updater); + } else { + repColumn.readBatchNested(readState, repetitionLevels, defColumn, definitionLevels, + values, valuesReader, updater); + } } } } private int readPage() { DataPage page = pageReader.readPage(); + if (page == null) { + return -1; + } this.pageFirstRowIndex = page.getFirstRowIndex().orElse(0L); return page.accept(new DataPage.Visitor() { @@ -286,18 +315,15 @@ private int readPageV1(DataPageV1 page) throws IOException { } int pageValueCount = page.getValueCount(); - int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); - - this.defColumn = new VectorizedRleValuesReader(bitWidth); + int dlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + this.defColumn = new VectorizedRleValuesReader(dlBitWidth); + int rlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxRepetitionLevel()); + this.repColumn = new VectorizedRleValuesReader(rlBitWidth); try { BytesInput bytes = page.getBytes(); ByteBufferInputStream in = bytes.toInputStream(); - // only used now to consume the repetition level data - page.getRlEncoding() - .getValuesReader(descriptor, REPETITION_LEVEL) - .initFromPage(pageValueCount, in); - + repColumn.initFromPage(pageValueCount, in); defColumn.initFromPage(pageValueCount, in); initDataReader(pageValueCount, page.getValueEncoding(), in); return pageValueCount; @@ -308,11 +334,16 @@ private int readPageV1(DataPageV1 page) throws IOException { private int readPageV2(DataPageV2 page) throws IOException { int pageValueCount = page.getValueCount(); - int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); // do not read the length from the stream. v2 pages handle dividing the page bytes. - defColumn = new VectorizedRleValuesReader(bitWidth, false); + int dlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel()); + defColumn = new VectorizedRleValuesReader(dlBitWidth, false); defColumn.initFromPage(pageValueCount, page.getDefinitionLevels().toInputStream()); + + int rlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxRepetitionLevel()); + repColumn = new VectorizedRleValuesReader(rlBitWidth, false); + repColumn.initFromPage(pageValueCount, page.getRepetitionLevels().toInputStream()); + try { initDataReader(pageValueCount, page.getDataEncoding(), page.getData().toInputStream()); return pageValueCount; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java index 9f7836ae4818..ca05498eb428 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java @@ -20,29 +20,35 @@ import java.io.IOException; import java.time.ZoneId; import java.util.Arrays; +import java.util.HashSet; import java.util.List; +import java.util.Set; +import com.google.common.base.Preconditions; import org.apache.hadoop.mapreduce.InputSplit; import org.apache.hadoop.mapreduce.TaskAttemptContext; import org.apache.parquet.column.ColumnDescriptor; import org.apache.parquet.column.page.PageReadStore; -import org.apache.parquet.schema.Type; +import org.apache.parquet.schema.GroupType; +import org.apache.parquet.schema.MessageType; +import org.apache.parquet.schema.Type; import org.apache.spark.memory.MemoryMode; import org.apache.spark.sql.catalyst.InternalRow; import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils; import org.apache.spark.sql.execution.vectorized.WritableColumnVector; import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector; import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector; -import org.apache.spark.sql.vectorized.ColumnarBatch; import org.apache.spark.sql.types.StructField; import org.apache.spark.sql.types.StructType; +import org.apache.spark.sql.vectorized.ColumnarBatch; +import scala.collection.JavaConverters; /** * A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the * Parquet column APIs. This is somewhat based on parquet-mr's ColumnReader. * - * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch. + * TODO: decimal requiring more than 8 bytes, INT96. Schema mismatch. * All of these can be handled efficiently and easily with codegen. * * This class can either return InternalRows or ColumnarBatches. With whole stage codegen @@ -62,10 +68,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private int numBatched = 0; /** - * For each request column, the reader to read this column. This is NULL if this column - * is missing from the file, in which case we populate the attribute with NULL. + * Encapsulate writable column vectors with other Parquet related info such as + * repetition / definition levels. */ - private VectorizedColumnReader[] columnReaders; + private ParquetColumn[] columns; /** * The number of rows that have been returned. @@ -78,9 +84,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa private long totalCountLoadedSoFar = 0; /** - * For each column, true if the column is missing in the file and we'll instead return NULLs. + * For each leaf column, if it is in the set, it means the column is missing in the file and + * we'll instead return NULLs. */ - private boolean[] missingColumns; + private Set missingColumns; /** * The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to @@ -114,8 +121,6 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa */ private ColumnarBatch columnarBatch; - private WritableColumnVector[] columnVectors; - /** * If true, this class returns batches instead of rows. */ @@ -165,6 +170,16 @@ public void initialize(String path, List columns) throws IOException, initializeInternal(); } + @Override + public void initialize( + MessageType fileSchema, + MessageType requestedSchema, + ParquetRowGroupReader rowGroupReader, + int totalRowCount) throws IOException { + super.initialize(fileSchema, requestedSchema, rowGroupReader, totalRowCount); + initializeInternal(); + } + @Override public void close() throws IOException { if (columnarBatch != null) { @@ -218,12 +233,20 @@ private void initBatch( } } + WritableColumnVector[] columnVectors; if (memMode == MemoryMode.OFF_HEAP) { columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema); } else { columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema); } columnarBatch = new ColumnarBatch(columnVectors); + + columns = new ParquetColumn[sparkSchema.fields().length]; + for (int i = 0; i < columns.length; i++) { + columns[i] = new ParquetColumn(requestedSchema.children().apply(i), + columnVectors[i], capacity, memMode, missingColumns); + } + if (partitionColumns != null) { int partitionIdx = sparkSchema.fields().length; for (int i = 0; i < partitionColumns.fields().length; i++) { @@ -231,14 +254,6 @@ private void initBatch( columnVectors[i + partitionIdx].setIsConstant(); } } - - // Initialize missing columns with nulls. - for (int i = 0; i < missingColumns.length; i++) { - if (missingColumns[i]) { - columnVectors[i].putNulls(0, capacity); - columnVectors[i].setIsConstant(); - } - } } private void initBatch() { @@ -270,18 +285,26 @@ public void enableReturningBatches() { * Advances to the next batch of rows. Returns false if there are no more. */ public boolean nextBatch() throws IOException { - for (WritableColumnVector vector : columnVectors) { - vector.reset(); + for (ParquetColumn column : columns) { + column.reset(); } + columnarBatch.setNumRows(0); if (rowsReturned >= totalRowCount) return false; checkEndOfRowGroup(); - int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned); - for (int i = 0; i < columnReaders.length; ++i) { - if (columnReaders[i] == null) continue; - columnReaders[i].readBatch(num, columnVectors[i]); + int num = (int) Math.min(capacity, totalCountLoadedSoFar - rowsReturned); + for (ParquetColumn col : columns) { + for (ParquetColumn leafCol : col.getLeaves()) { + VectorizedColumnReader columnReader = leafCol.getColumnReader(); + if (columnReader != null) { + columnReader.readBatch(num, leafCol.getValueVector(), + leafCol.getRepetitionLevelVector(), leafCol.getDefinitionLevelVector()); + } + } + col.assemble(); } + rowsReturned += num; columnarBatch.setNumRows(num); numBatched = num; @@ -290,55 +313,89 @@ public boolean nextBatch() throws IOException { } private void initializeInternal() throws IOException, UnsupportedOperationException { - // Check that the requested schema is supported. - missingColumns = new boolean[requestedSchema.getFieldCount()]; - List columns = requestedSchema.getColumns(); - List paths = requestedSchema.getPaths(); - for (int i = 0; i < requestedSchema.getFieldCount(); ++i) { - Type t = requestedSchema.getFields().get(i); - if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) { - throw new UnsupportedOperationException("Complex types not supported."); - } + missingColumns = new HashSet<>(); + for (ParquetType columnType : JavaConverters.seqAsJavaList(requestedSchema.children())) { + checkColumn(columnType); + } + } - String[] colPath = paths.get(i); - if (fileSchema.containsPath(colPath)) { - ColumnDescriptor fd = fileSchema.getColumnDescription(colPath); - if (!fd.equals(columns.get(i))) { + /** + * Check whether a column from requested schema is missing from the file schema, or whether it + * conforms to the type of the file schema. + */ + private void checkColumn(ParquetType columnType) throws IOException { + String[] path = JavaConverters.seqAsJavaList(columnType.path()).toArray(new String[0]); + if (containsPath(fileSchema, path)) { + if (columnType.isPrimitive()) { + ColumnDescriptor desc = columnType.descriptor().get(); + ColumnDescriptor fd = fileSchema.getColumnDescription(desc.getPath()); + if (!fd.equals(desc)) { throw new UnsupportedOperationException("Schema evolution not supported."); } - missingColumns[i] = false; } else { - if (columns.get(i).getMaxDefinitionLevel() == 0) { - // Column is missing in data but the required data is non-nullable. This file is invalid. - throw new IOException("Required column is missing in data file. Col: " + - Arrays.toString(colPath)); + for (ParquetType childType : JavaConverters.seqAsJavaList(columnType.children())) { + checkColumn(childType); } - missingColumns[i] = true; } + } else { // a missing column which is either primitive or complex + if (columnType.required()) { + // Column is missing in data but the required data is non-nullable. This file is invalid. + throw new IOException("Required column is missing in data file. Col: " + + Arrays.toString(path)); + } + missingColumns.add(columnType); } } + /** + * Checks whether the given 'path' exists in 'parquetType'. The difference between this and + * {@link MessageType#containsPath(String[])} is that the latter only support paths to leaf + * nodes, while this support paths both to leaf and non-leaf nodes. + */ + private boolean containsPath(Type parquetType, String[] path) { + return containsPath(parquetType, path, 0); + } + + private boolean containsPath(Type parquetType, String[] path, int depth) { + if (path.length == depth) return true; + if (parquetType instanceof GroupType) { + String fieldName = path[depth]; + GroupType parquetGroupType = (GroupType) parquetType; + if (parquetGroupType.containsField(fieldName)) { + return containsPath(parquetGroupType.getType(fieldName), path, depth + 1); + } + } + return false; + } + private void checkEndOfRowGroup() throws IOException { if (rowsReturned != totalCountLoadedSoFar) return; - PageReadStore pages = reader.readNextFilteredRowGroup(); + PageReadStore pages = reader.readNextRowGroup(); if (pages == null) { throw new IOException("expecting more rows but reached last block. Read " + rowsReturned + " out of " + totalRowCount); } - List columns = requestedSchema.getColumns(); - List types = requestedSchema.asGroupType().getFields(); - columnReaders = new VectorizedColumnReader[columns.size()]; - for (int i = 0; i < columns.size(); ++i) { - if (missingColumns[i]) continue; - columnReaders[i] = new VectorizedColumnReader( - columns.get(i), - types.get(i).getLogicalTypeAnnotation(), - pages.getPageReader(columns.get(i)), - pages.getRowIndexes().orElse(null), - convertTz, - datetimeRebaseMode, - int96RebaseMode); + for (ParquetColumn column : columns) { + initColumnReader(pages, column); } totalCountLoadedSoFar += pages.getRowCount(); } + + private void initColumnReader(PageReadStore pages, ParquetColumn column) throws IOException { + if (!missingColumns.contains(column.getType())) { + if (column.getType().isPrimitive()) { + ParquetType colType = column.getType(); + Preconditions.checkArgument(colType.isPrimitive()); + VectorizedColumnReader reader = new VectorizedColumnReader( + colType.descriptor().get(), colType.required(), pages, convertTz, datetimeRebaseMode, + int96RebaseMode); + column.setColumnReader(reader); + } else { + // not in missing columns and is a complex type: this must be a struct + for (ParquetColumn childCol : column.getChildren()) { + initColumnReader(pages, childCol); + } + } + } + } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java index af739a52d8ed..c24048f34700 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java @@ -166,9 +166,10 @@ public int readInteger() { public void readBatch( ParquetReadState state, WritableColumnVector values, + WritableColumnVector defLevels, VectorizedValuesReader valueReader, ParquetVectorUpdater updater) { - readBatchInternal(state, values, values, valueReader, updater); + readBatchInternal(state, values, values, defLevels, valueReader, updater); } /** @@ -179,21 +180,23 @@ public void readIntegers( ParquetReadState state, WritableColumnVector values, WritableColumnVector nulls, - VectorizedValuesReader data) { - readBatchInternal(state, values, nulls, data, new ParquetVectorUpdaterFactory.IntegerUpdater()); + WritableColumnVector defLevels, + VectorizedValuesReader valueReader) throws IOException { + readBatchInternal(state, values, nulls, defLevels, valueReader, + new ParquetVectorUpdaterFactory.IntegerUpdater()); } private void readBatchInternal( ParquetReadState state, WritableColumnVector values, WritableColumnVector nulls, + WritableColumnVector defLevels, VectorizedValuesReader valueReader, ParquetVectorUpdater updater) { - int offset = state.offset; - long rowId = state.rowId; - int leftInBatch = state.valuesToReadInBatch; + int leftInBatch = state.rowsToReadInBatch; int leftInPage = state.valuesToReadInPage; + long rowId = state.rowId; while (leftInBatch > 0 && leftInPage > 0) { if (this.currentCount == 0) this.readNextGroup(); @@ -227,44 +230,352 @@ private void readBatchInternal( switch (mode) { case RLE: if (currentValue == state.maxDefinitionLevel) { - updater.readValues(n, offset, values, valueReader); - } else { - nulls.putNulls(offset, n); + updater.readValues(n, state.valueOffset, values, valueReader); + state.valueOffset += n; + } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) { + // only add null if this represents a null element, but not the case when a + // struct is null + nulls.putNulls(state.valueOffset, n); + state.valueOffset += n; } + defLevels.putInts(state.levelOffset, n, currentValue); break; case PACKED: for (int i = 0; i < n; ++i) { - if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) { - updater.readValue(offset + i, values, valueReader); - } else { - nulls.putNull(offset + i); + int currentValue = currentBuffer[currentBufferIdx++]; + if (currentValue == state.maxDefinitionLevel) { + updater.readValue(state.valueOffset++, values, valueReader); + } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) { + // only add null if this represents a null element, but not the case when a + // struct is null + nulls.putNull(state.valueOffset++); } + defLevels.putInt(state.levelOffset + i, currentValue); } break; } - offset += n; + state.levelOffset += n; leftInBatch -= n; rowId += n; leftInPage -= n; currentCount -= n; + defLevels.addElementsAppended(n); } } - state.advanceOffsetAndRowId(offset, rowId); + state.rowsToReadInBatch = leftInBatch; + state.valuesToReadInPage = leftInPage; + state.rowId = rowId; + } + + public void readBatchNested( + ParquetReadState state, + WritableColumnVector repLevels, + VectorizedRleValuesReader defLevelsReader, + WritableColumnVector defLevels, + WritableColumnVector values, + VectorizedValuesReader valueReader, + ParquetVectorUpdater updater) { + readBatchNestedInternal(state, repLevels, defLevelsReader, defLevels, values, values, true, + valueReader, updater); + } + + public void readIntegersNested( + ParquetReadState state, + WritableColumnVector repLevels, + VectorizedRleValuesReader defLevelsReader, + WritableColumnVector defLevels, + WritableColumnVector values, + WritableColumnVector nulls, + VectorizedValuesReader valueReader) { + readBatchNestedInternal(state, repLevels, defLevelsReader, defLevels, values, nulls, false, + valueReader, new ParquetVectorUpdaterFactory.IntegerUpdater()); } /** - * Skip the next `n` values (either null or non-null) from this definition level reader and - * `valueReader`. + * Keep reading repetition level values from the page until either: 1) we've read enough + * top-level rows to fill the current batch, or 2) we've drained the data page completely. + * + * @param valuesReused whether `values` vector is reused for `nulls` */ - private void skipValues( - int n, + public void readBatchNestedInternal( ParquetReadState state, - VectorizedValuesReader valuesReader, + WritableColumnVector repLevels, + VectorizedRleValuesReader defLevelsReader, + WritableColumnVector defLevels, + WritableColumnVector values, + WritableColumnVector nulls, + boolean valuesReused, + VectorizedValuesReader valueReader, ParquetVectorUpdater updater) { + + int leftInBatch = state.rowsToReadInBatch; + int leftInPage = state.valuesToReadInPage; + long rowId = state.rowId; + + DefLevelProcessor defLevelProcessor = new DefLevelProcessor(defLevelsReader, state, defLevels, + values, nulls, valuesReused, valueReader, updater); + + while ((leftInBatch > 0 || !state.lastListCompleted) && leftInPage > 0) { + if (currentCount == 0 && !readNextGroup()) break; + + // values to read in the current RLE/PACKED block, must be <= what's left in the page + int valuesLeftInBlock = Math.min(leftInPage, currentCount); + + // the current row range start and end + long rangeStart = state.currentRangeStart(); + long rangeEnd = state.currentRangeEnd(); + + switch (mode) { + case RLE: + // this RLE block is consist of top-level rows, so we'll need to check + // if the rows should be skipped according to row indexes. + if (currentValue == 0) { + if (leftInBatch == 0) { + state.lastListCompleted = true; + } else { + // # of rows to read in the block, must be <= what's left in the current batch + int n = Math.min(leftInBatch, valuesLeftInBlock); + + if (rowId + n < rangeStart) { + // need to skip all rows in [rowId, rowId + n) + defLevelProcessor.skipValues(n); + rowId += n; + currentCount -= n; + leftInPage -= n; + } else if (rowId > rangeEnd) { + // the current row index already beyond the current range: move to the next range + // and repeat + state.nextRange(); + } else { + // the range [rowId, rowId + n) overlaps with the current row range + long start = Math.max(rangeStart, rowId); + long end = Math.min(rangeEnd, rowId + n - 1); + + // skip the rows in [rowId, start) + int toSkip = (int) (start - rowId); + if (toSkip > 0) { + defLevelProcessor.skipValues(toSkip); + rowId += toSkip; + currentCount -= toSkip; + leftInPage -= toSkip; + } + + // read the rows in [start, end] + n = (int) (end - start + 1); + + leftInBatch -= n; + if (n > 0) { + repLevels.appendInts(n, 0); + defLevelProcessor.readValues(n); + } + + rowId += n; + currentCount -= n; + leftInPage -= n; + } + } + } else { + // not a top-level row: just read all the repetition levels in the block if the row + // should be included according to row indexes, else skip the rows. + if (!state.shouldSkip) { + repLevels.appendInts(valuesLeftInBlock, currentValue); + } + state.numBatchedDefLevels += valuesLeftInBlock; + leftInPage -= valuesLeftInBlock; + currentCount -= valuesLeftInBlock; + } + break; + case PACKED: + int i = 0; + + for (; i < valuesLeftInBlock; i++) { + int currentValue = currentBuffer[currentBufferIdx + i]; + if (currentValue == 0) { + if (leftInBatch == 0) { + state.lastListCompleted = true; + break; + } else if (rowId < rangeStart) { + // this is a top-level row, therefore check if we should skip it with row indexes + // the row is before the current range, skip it + defLevelProcessor.skipValues(1); + } else if (rowId > rangeEnd) { + // the row is after the current range, move to the next range and compare again + state.nextRange(); + break; + } else { + // the row is in the current range, decrement the row counter and read it + leftInBatch--; + repLevels.appendInt(0); + defLevelProcessor.readValues(1); + } + rowId++; + } else { + if (!state.shouldSkip) { + repLevels.appendInt(currentValue); + } + state.numBatchedDefLevels += 1; + } + } + + leftInPage -= i; + currentCount -= i; + currentBufferIdx += i; + break; + } + } + + // process all the batched def levels + defLevelProcessor.finish(); + + state.rowsToReadInBatch = leftInBatch; + state.valuesToReadInPage = leftInPage; + state.rowId = rowId; + } + + private static class DefLevelProcessor { + private final VectorizedRleValuesReader reader; + private final ParquetReadState state; + private final WritableColumnVector defLevels; + private final WritableColumnVector values; + private final WritableColumnVector nulls; + private final boolean valuesReused; + private final VectorizedValuesReader valueReader; + private final ParquetVectorUpdater updater; + + DefLevelProcessor( + VectorizedRleValuesReader reader, + ParquetReadState state, + WritableColumnVector defLevels, + WritableColumnVector values, + WritableColumnVector nulls, + boolean valuesReused, + VectorizedValuesReader valueReader, + ParquetVectorUpdater updater) { + this.reader = reader; + this.state = state; + this.defLevels = defLevels; + this.values = values; + this.nulls = nulls; + this.valuesReused = valuesReused; + this.valueReader = valueReader; + this.updater = updater; + } + + void readValues(int n) { + if (!state.shouldSkip) { + state.numBatchedDefLevels += n; + } else { + reader.skipValues(state.numBatchedDefLevels, state, valueReader, updater); + state.numBatchedDefLevels = n; + state.shouldSkip = false; + } + } + + void skipValues(int n) { + if (state.shouldSkip) { + state.numBatchedDefLevels += n; + } else { + reader.readValues(state.numBatchedDefLevels, state, defLevels, values, nulls, valuesReused, + valueReader, updater); + state.numBatchedDefLevels = n; + state.shouldSkip = true; + } + } + + void finish() { + if (state.numBatchedDefLevels > 0) { + if (state.shouldSkip) { + reader.skipValues(state.numBatchedDefLevels, state, valueReader, updater); + } else { + reader.readValues(state.numBatchedDefLevels, state, defLevels, values, nulls, + valuesReused, valueReader, updater); + } + state.numBatchedDefLevels = 0; + } + } + } + + /** + * Read the next 'total' values (either null or non-null) from this definition level reader and + * 'valueReader'. The definition levels are read into 'defLevels'. If a value is not + * null, it is appended to 'values'. Otherwise, a null bit will be set in 'nulls'. + * + * This is only used when reading repeated values. + */ + private void readValues( + int total, + ParquetReadState state, + WritableColumnVector defLevels, + WritableColumnVector values, + WritableColumnVector nulls, + boolean valuesReused, + VectorizedValuesReader valueReader, + ParquetVectorUpdater updater) { + defLevels.reserveAdditional(total); + values.reserveAdditional(total); + if (!valuesReused) { + nulls.reserveAdditional(total); + } + int n = total; + int initialValueOffset = state.valueOffset; while (n > 0) { if (this.currentCount == 0) this.readNextGroup(); int num = Math.min(n, this.currentCount); + switch (mode) { + case RLE: + if (currentValue == state.maxDefinitionLevel) { + updater.readValues(num, state.valueOffset, values, valueReader); + state.valueOffset += num; + } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) { + // only add null if this represents a null element, but not the case when a + // collection is null or empty. + nulls.putNulls(state.valueOffset, num); + state.valueOffset += num; + } + defLevels.putInts(state.levelOffset, num, currentValue); + break; + case PACKED: + for (int i = 0; i < num; ++i) { + int currentValue = currentBuffer[currentBufferIdx++]; + if (currentValue == state.maxDefinitionLevel) { + updater.readValue(state.valueOffset++, values, valueReader); + } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) { + // only add null if this represents a null element, but not the case when a + // collection is null or empty. + nulls.putNull(state.valueOffset++); + } + defLevels.putInt(state.levelOffset + i, currentValue); + } + break; + } + state.levelOffset += num; + currentCount -= num; + n -= num; + } + defLevels.addElementsAppended(total); + + int valuesRead = state.valueOffset - initialValueOffset; + values.addElementsAppended(valuesRead); + if (!valuesReused) { + nulls.addElementsAppended(valuesRead); + } + } + + /** + * Skip the next 'total' values (either null or non-null) from this definition level reader and + * 'valuesReader'. + * + * This is used in reading both non-repeated and repeated values. + */ + private void skipValues( + int total, + ParquetReadState state, + VectorizedValuesReader valuesReader, + ParquetVectorUpdater updater) { + while (total > 0) { + if (this.currentCount == 0) this.readNextGroup(); + int num = Math.min(total, this.currentCount); switch (mode) { case RLE: // we only need to skip non-null values from `valuesReader` since nulls are represented @@ -283,7 +594,7 @@ private void skipValues( break; } currentCount -= num; - n -= num; + total -= num; } } @@ -497,7 +808,12 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException { /** * Reads the next group. */ - private void readNextGroup() { + private boolean readNextGroup() { + if (in.available() <= 0) { + currentCount = 0; + return false; + } + try { int header = readUnsignedVarInt(); this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED; @@ -505,7 +821,7 @@ private void readNextGroup() { case RLE: this.currentCount = header >>> 1; this.currentValue = readIntLittleEndianPaddedOnBitWidth(); - return; + break; case PACKED: int numGroups = header >>> 1; this.currentCount = numGroups * 8; @@ -521,12 +837,13 @@ private void readNextGroup() { this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex); valueIndex += 8; } - return; + break; default: throw new ParquetDecodingException("not a valid mode " + this.mode); } } catch (IOException e) { throw new ParquetDecodingException("Failed to read from input stream", e); } + return true; } } diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java index b4b6903cda24..e5dc033acaf2 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java @@ -64,6 +64,9 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[] private long lengthData; private long offsetData; + // Only set if type is Struct + private long structOffsetData; + public OffHeapColumnVector(int capacity, DataType type) { super(capacity, type); @@ -71,6 +74,7 @@ public OffHeapColumnVector(int capacity, DataType type) { data = 0; lengthData = 0; offsetData = 0; + structOffsetData = 0; reserveInternal(capacity); reset(); @@ -91,10 +95,12 @@ public void close() { Platform.freeMemory(data); Platform.freeMemory(lengthData); Platform.freeMemory(offsetData); + Platform.freeMemory(structOffsetData); nulls = 0; data = 0; lengthData = 0; offsetData = 0; + structOffsetData = 0; } // @@ -132,7 +138,7 @@ public void putNotNulls(int rowId, int count) { @Override public boolean isNullAt(int rowId) { - return Platform.getByte(null, nulls + rowId) == 1; + return isAllNull || Platform.getByte(null, nulls + rowId) == 1; } // @@ -527,6 +533,11 @@ public int getArrayOffset(int rowId) { return Platform.getInt(null, offsetData + 4L * rowId); } + @Override + public int getStructOffset(int rowId) { + return Platform.getInt(null, structOffsetData + 4L * rowId); + } + // APIs dealing with ByteArrays @Override public int putByteArray(int rowId, byte[] value, int offset, int length) { @@ -536,6 +547,11 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { return result; } + @Override + public void putStruct(int rowId, int offset) { + Platform.putInt(null, structOffsetData + 4L * rowId, offset); + } + // Split out the slow path. @Override protected void reserveInternal(int newCapacity) { @@ -545,6 +561,9 @@ protected void reserveInternal(int newCapacity) { Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L); this.offsetData = Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L); + } else if (isStruct()) { + this.structOffsetData = + Platform.reallocateMemory(structOffsetData, oldCapacity * 4L, newCapacity * 4L); } else if (type instanceof ByteType || type instanceof BooleanType) { this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity); } else if (type instanceof ShortType) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java index 3fb96d872cd8..9bed6bd0e728 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java @@ -73,6 +73,9 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f private int[] arrayLengths; private int[] arrayOffsets; + // Only set if type is Struct + private int[] structOffsets; + public OnHeapColumnVector(int capacity, DataType type) { super(capacity, type); @@ -92,6 +95,7 @@ public void close() { doubleData = null; arrayLengths = null; arrayOffsets = null; + structOffsets = null; } // @@ -127,7 +131,7 @@ public void putNotNulls(int rowId, int count) { @Override public boolean isNullAt(int rowId) { - return nulls[rowId] == 1; + return isAllNull || nulls[rowId] == 1; } // @@ -492,6 +496,11 @@ public int getArrayOffset(int rowId) { return arrayOffsets[rowId]; } + @Override + public int getStructOffset(int rowId) { + return structOffsets[rowId]; + } + @Override public void putArray(int rowId, int offset, int length) { arrayOffsets[rowId] = offset; @@ -510,6 +519,11 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) { return result; } + @Override + public void putStruct(int rowId, int offset) { + structOffsets[rowId] = offset; + } + // Spilt this function out since it is the slow path. @Override protected void reserveInternal(int newCapacity) { @@ -522,6 +536,12 @@ protected void reserveInternal(int newCapacity) { } arrayLengths = newLengths; arrayOffsets = newOffsets; + } else if (isStruct()) { + int[] newOffsets = new int[newCapacity]; + if (this.structOffsets != null) { + System.arraycopy(this.structOffsets, 0, newOffsets, 0, capacity); + } + structOffsets = newOffsets; } else if (type instanceof BooleanType) { if (byteData == null || byteData.length < newCapacity) { byte[] newData = new byte[newCapacity]; diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java index 8f7dcf237440..150327a31ac9 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java @@ -26,6 +26,7 @@ import org.apache.spark.sql.vectorized.ColumnVector; import org.apache.spark.sql.vectorized.ColumnarArray; import org.apache.spark.sql.vectorized.ColumnarMap; +import org.apache.spark.sql.vectorized.ColumnarRow; import org.apache.spark.unsafe.array.ByteArrayMethods; import org.apache.spark.unsafe.types.CalendarInterval; import org.apache.spark.unsafe.types.UTF8String; @@ -51,7 +52,7 @@ public abstract class WritableColumnVector extends ColumnVector { * Resets this column for writing. The currently stored values are no longer accessible. */ public void reset() { - if (isConstant) return; + if (isConstant || isAllNull) return; if (childColumns != null) { for (WritableColumnVector c: childColumns) { @@ -81,6 +82,10 @@ public void close() { dictionary = null; } + public void reserveAdditional(int additionalCapacity) { + reserve(elementsAppended + additionalCapacity); + } + public void reserve(int requiredCapacity) { if (requiredCapacity < 0) { throwUnsupportedException(requiredCapacity, null); @@ -115,7 +120,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) { @Override public boolean hasNull() { - return numNulls > 0; + return isAllNull || numNulls > 0; } @Override @@ -350,6 +355,11 @@ public WritableColumnVector reserveDictionaryIds(int capacity) { */ public abstract void putArray(int rowId, int offset, int length); + /** + * Puts a struct with the given offset which indicates its position across all non-null entries. + */ + public abstract void putStruct(int rowId, int offset); + /** * Sets values from [value + offset, value + offset + count) to the values at rowId. */ @@ -433,6 +443,7 @@ public final int appendNull() { } public final int appendNotNull() { + assert (!(dataType() instanceof StructType)); // Use appendStruct() reserve(elementsAppended + 1); putNotNull(elementsAppended); return elementsAppended++; @@ -625,21 +636,23 @@ public final int appendArray(int length) { * common non-struct case. */ public final int appendStruct(boolean isNull) { + reserve(elementsAppended + 1); if (isNull) { // This is the same as appendNull but without the assertion for struct types - reserve(elementsAppended + 1); putNull(elementsAppended); - elementsAppended++; for (WritableColumnVector c: childColumns) { - if (c.type instanceof StructType) { + if (c.isStruct()) { c.appendStruct(true); } else { c.appendNull(); } } } else { - appendNotNull(); + reserve(elementsAppended + 1); + putNotNull(elementsAppended); } + putStruct(elementsAppended, elementsAppended); + elementsAppended++; return elementsAppended; } @@ -659,6 +672,12 @@ public final ColumnarMap getMap(int rowId) { return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId)); } + @Override + public final ColumnarRow getStruct(int rowId) { + if (isNullAt(rowId)) return null; + return new ColumnarRow(this, getStructOffset(rowId)); + } + public WritableColumnVector arrayData() { return childColumns[0]; } @@ -667,19 +686,49 @@ public WritableColumnVector arrayData() { public abstract int getArrayOffset(int rowId); + public abstract int getStructOffset(int rowId); + @Override public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; } + public int getNumChildren() { + return childColumns.length; + } + /** * Returns the elements appended. */ public final int getElementsAppended() { return elementsAppended; } + /** + * Add `num` to the elements appended. This is useful when calling the `putXXX` APIs, for + * keeping track of how many values have been added in the vector. + */ + public final void addElementsAppended(int num) { + this.elementsAppended += num; + } + /** * Marks this column as being constant. */ public final void setIsConstant() { isConstant = true; } + /** + * Marks this column only contain null values. + */ + public final void setAllNull() { + isAllNull = true; + } + + /** + * Returns true iff 'setAllNull' is called, which means this is a constant column vector with + * all values being null. If this returns false, it doesn't necessarily mean the vector + * contains non-null values. Rather it means the null-ness is unknown. + */ + public final boolean isAllNull() { + return isAllNull; + } + /** * Maximum number of rows that can be stored in this column. */ @@ -702,6 +751,11 @@ public WritableColumnVector arrayData() { */ protected boolean isConstant; + /** + * True if this column's values are all null. + */ + protected boolean isAllNull; + /** * Default size of each array length value. This grows as necessary. */ @@ -727,6 +781,10 @@ protected boolean isArray() { DecimalType.isByteArrayDecimalType(type); } + protected boolean isStruct() { + return type instanceof StructType || type instanceof CalendarIntervalType; + } + /** * Sets up the common state and also handles creating the child columns if this is a nested * type. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index d3ac077ccf4a..e575a418bc7b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -45,6 +45,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.errors.QueryExecutionErrors +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector} import org.apache.spark.sql.internal.SQLConf @@ -169,8 +170,24 @@ class ParquetFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled && - schema.length <= conf.wholeStageMaxNumFields && - schema.forall(_.dataType.isInstanceOf[AtomicType]) + isBatchReadSupported(conf, schema) && !WholeStageCodegenExec.isTooManyFields(conf, schema) + } + + private def isBatchReadSupported(sqlConf: SQLConf, dt: DataType): Boolean = dt match { + case _: AtomicType => + true + case at: ArrayType => + sqlConf.parquetVectorizedReaderNestedColumnEnabled && + isBatchReadSupported(sqlConf, at.elementType) + case mt: MapType => + sqlConf.parquetVectorizedReaderNestedColumnEnabled && + isBatchReadSupported(sqlConf, mt.keyType) && + isBatchReadSupported(sqlConf, mt.valueType) + case st: StructType => + sqlConf.parquetVectorizedReaderNestedColumnEnabled && + st.fields.forall(f => isBatchReadSupported(sqlConf, f.dataType)) + case _ => + false } override def vectorTypes( @@ -239,7 +256,7 @@ class ParquetFileFormat val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled val enableVectorizedReader: Boolean = sqlConf.parquetVectorizedReaderEnabled && - resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + resultSchema.map(_.dataType).forall(isBatchReadSupported(sqlConf, _)) val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion val capacity = sqlConf.parquetVectorizedReaderBatchSize diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index e0af5d8dd869..f982402f65a2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -25,8 +25,9 @@ import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import org.apache.parquet.column.Dictionary +import org.apache.parquet.io.ColumnIOFactory import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter} -import org.apache.parquet.schema.{GroupType, Type} +import org.apache.parquet.schema.{GroupType, Type, Types} import org.apache.parquet.schema.LogicalTypeAnnotation._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96} @@ -601,7 +602,9 @@ private[parquet] class ParquetRowConverter( // Here we try to convert field `list` into a Catalyst type to see whether the converted type // matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise, // it's case 2. - val guessedElementType = schemaConverter.convertField(repeatedType) + val messageType = Types.buildMessage().addField(repeatedType).named("foo") + val column = new ColumnIOFactory().getColumnIO(messageType) + val guessedElementType = schemaConverter.convertField(column.getChild(0)).sparkType if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) { // If the repeated field corresponds to the element type, creates a new converter using the diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala index f3bfd99368de..fedcfc0fb9a6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala @@ -17,9 +17,8 @@ package org.apache.spark.sql.execution.datasources.parquet -import scala.collection.JavaConverters._ - import org.apache.hadoop.conf.Configuration +import org.apache.parquet.io.{ColumnIO, ColumnIOFactory, GroupColumnIO, PrimitiveColumnIO} import org.apache.parquet.schema._ import org.apache.parquet.schema.LogicalTypeAnnotation._ import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._ @@ -30,7 +29,6 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types._ - /** * This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]]. * @@ -60,40 +58,106 @@ class ParquetToSparkSchemaConverter( /** * Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]]. */ - def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType()) + def convert(parquetSchema: MessageType): StructType = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + val converted = convertInternal(column) + converted.sparkType.asInstanceOf[StructType] + } - private def convert(parquetSchema: GroupType): StructType = { - val fields = parquetSchema.getFields.asScala.map { field => - field.getRepetition match { - case OPTIONAL => - StructField(field.getName, convertField(field), nullable = true) + /** + * Convert `parquetSchema` into a [[ParquetType]] which contains its corresponding Spark + * SQL [[StructType]] along with other information such as the maximum repetition and definition + * level of each node, column descriptor for the leave nodes, etc. + * + * If `sparkReadSchema` is not empty, when deriving Spark SQL type from a Parquet field this will + * check if the same field also exists in the schema. If so, it will use the Spark SQL type. + * This is necessary since conversion from Parquet to Spark could cause precision loss. For + * instance, Spark read schema is smallint/tinyint but Parquet only support int. + */ + def convertTypeInfo( + parquetSchema: MessageType, + sparkReadSchema: Option[StructType] = None, + caseSensitive: Boolean = true): ParquetType = { + val column = new ColumnIOFactory().getColumnIO(parquetSchema) + convertInternal(column, sparkReadSchema, caseSensitive) + } - case REQUIRED => - StructField(field.getName, convertField(field), nullable = false) + private def convertInternal( + groupColumn: GroupColumnIO, + sparkReadSchema: Option[StructType] = None, + caseSensitive: Boolean = true): ParquetType = { + val converted = (0 until groupColumn.getChildrenCount).map { i => + val field = groupColumn.getChild(i) + var fieldReadType = sparkReadSchema.flatMap { schema => + schema.find(f => isSameFieldName(f.name, field.getName, caseSensitive)).map(_.dataType) + } + + // if a field is repeated here then it is neither contained by a `LIST` nor `MAP` + // annotated group (these should've been handled in `convertGroupField`), e.g.: + // + // message schema { + // repeated int32 int_array; + // } + // or + // message schema { + // repeated group struct_array { + // optional int32 field; + // } + // } + // + // the corresponding Spark read type should be an array and we should pass the element type + // to the group or primitive type conversion method. + if (field.getType.getRepetition == REPEATED) { + fieldReadType = fieldReadType.flatMap { + case at: ArrayType => Some(at.elementType) + case _ => + throw QueryCompilationErrors.illegalParquetTypeError(groupColumn.toString) + } + } + + var convertedField = convertField(field, fieldReadType) + + field.getType.getRepetition match { + case OPTIONAL | REQUIRED => + val nullable = field.getType.getRepetition == OPTIONAL + (StructField(field.getType.getName, convertedField.sparkType, nullable = nullable), + convertedField) case REPEATED => // A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor // annotated by `LIST` or `MAP` should be interpreted as a required list of required // elements where the element type is the type of the field. - val arrayType = ArrayType(convertField(field), containsNull = false) - StructField(field.getName, arrayType, nullable = false) + val arrayType = ArrayType(convertedField.sparkType, containsNull = false) + (StructField(field.getType.getName, arrayType, nullable = false), + ParquetType(arrayType, None, convertedField.repetitionLevel - 1, + convertedField.definitionLevel - 1, required = true, convertedField.path, + Seq(convertedField.copy(required = true)))) } } - StructType(fields.toSeq) + ParquetType(StructType(converted.map(_._1)), groupColumn, converted.map(_._2)) } + private def isSameFieldName(left: String, right: String, caseSensitive: Boolean): Boolean = + if (caseSensitive) left.equalsIgnoreCase(right) + else left == right + /** * Converts a Parquet [[Type]] to a Spark SQL [[DataType]]. */ - def convertField(parquetType: Type): DataType = parquetType match { - case t: PrimitiveType => convertPrimitiveField(t) - case t: GroupType => convertGroupField(t.asGroupType()) + def convertField( + field: ColumnIO, + sparkReadType: Option[DataType] = None): ParquetType = field match { + case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, sparkReadType) + case groupColumn: GroupColumnIO => convertGroupField(groupColumn, sparkReadType) } - private def convertPrimitiveField(field: PrimitiveType): DataType = { - val typeName = field.getPrimitiveTypeName - val typeAnnotation = field.getLogicalTypeAnnotation + private def convertPrimitiveField( + primitiveColumn: PrimitiveColumnIO, + sparkReadType: Option[DataType] = None): ParquetType = { + val parquetType = primitiveColumn.getType.asPrimitiveType() + val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation + val typeName = primitiveColumn.getPrimitive def typeString = if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)" @@ -108,7 +172,7 @@ class ParquetToSparkSchemaConverter( // specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored // as binaries with variable lengths. def makeDecimalType(maxPrecision: Int = -1): DecimalType = { - val decimalLogicalTypeAnnotation = field.getLogicalTypeAnnotation + val decimalLogicalTypeAnnotation = typeAnnotation .asInstanceOf[DecimalLogicalTypeAnnotation] val precision = decimalLogicalTypeAnnotation.getPrecision val scale = decimalLogicalTypeAnnotation.getScale @@ -120,7 +184,7 @@ class ParquetToSparkSchemaConverter( DecimalType(precision, scale) } - typeName match { + val sparkType = sparkReadType.getOrElse(typeName match { case BOOLEAN => BooleanType case FLOAT => FloatType @@ -195,17 +259,23 @@ class ParquetToSparkSchemaConverter( case FIXED_LEN_BYTE_ARRAY => typeAnnotation match { case _: DecimalLogicalTypeAnnotation => - makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength)) + makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength)) case _: IntervalLogicalTypeAnnotation => typeNotImplemented() case _ => illegalType() } case _ => illegalType() - } + }) + + ParquetType(sparkType, primitiveColumn) } - private def convertGroupField(field: GroupType): DataType = { - Option(field.getLogicalTypeAnnotation).fold(convert(field): DataType) { + private def convertGroupField( + groupColumn: GroupColumnIO, + sparkReadType: Option[DataType] = None): ParquetType = { + val field = groupColumn.getType.asGroupType() + Option(field.getLogicalTypeAnnotation).fold( + convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) { // A Parquet list is represented as a 3-level structure: // // group (LIST) { @@ -222,17 +292,37 @@ class ParquetToSparkSchemaConverter( case _: ListLogicalTypeAnnotation => ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1, s"Invalid list type $field") + ParquetSchemaConverter.checkConversionRequirement( + sparkReadType.forall(_.isInstanceOf[ArrayType]), + s"Invalid Spark read type: expected $field to be list type but found $sparkReadType") - val repeatedType = field.getType(0) + val repeated = groupColumn.getChild(0) + val repeatedType = repeated.getType ParquetSchemaConverter.checkConversionRequirement( repeatedType.isRepetition(REPEATED), s"Invalid list type $field") + val sparkReadElementType = sparkReadType.map(_.asInstanceOf[ArrayType].elementType) if (isElementType(repeatedType, field.getName)) { - ArrayType(convertField(repeatedType), containsNull = false) + var converted = convertField(repeated, sparkReadElementType) + val convertedType = sparkReadElementType.getOrElse(converted.sparkType) + + // legacy format such as: + // optional group my_list (LIST) { + // repeated int32 element; + // } + // we should mark the primitive field as required + if (repeatedType.isPrimitive) converted = converted.copy(required = true) + + ParquetType(ArrayType(convertedType, containsNull = false), + groupColumn, Seq(converted)) } else { - val elementType = repeatedType.asGroupType().getType(0) + val element = repeated.asInstanceOf[GroupColumnIO].getChild(0) + val elementType = element.getType val optional = elementType.isRepetition(OPTIONAL) - ArrayType(convertField(elementType), containsNull = optional) + val converted = convertField(element, sparkReadElementType) + val convertedType = sparkReadElementType.getOrElse(converted.sparkType) + ParquetType(ArrayType(convertedType, containsNull = optional), + groupColumn, Seq(converted)) } // scalastyle:off @@ -243,20 +333,30 @@ class ParquetToSparkSchemaConverter( ParquetSchemaConverter.checkConversionRequirement( field.getFieldCount == 1 && !field.getType(0).isPrimitive, s"Invalid map type: $field") + ParquetSchemaConverter.checkConversionRequirement( + sparkReadType.forall(_.isInstanceOf[MapType]), + s"Invalid Spark read type: expected $field to be map type but found $sparkReadType") - val keyValueType = field.getType(0).asGroupType() + val keyValue = groupColumn.getChild(0).asInstanceOf[GroupColumnIO] + val keyValueType = keyValue.getType.asGroupType() ParquetSchemaConverter.checkConversionRequirement( keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2, s"Invalid map type: $field") - val keyType = keyValueType.getType(0) - val valueType = keyValueType.getType(1) + val key = keyValue.getChild(0) + val value = keyValue.getChild(1) + val sparkReadKeyType = sparkReadType.map(_.asInstanceOf[MapType].keyType) + val sparkReadValueType = sparkReadType.map(_.asInstanceOf[MapType].valueType) + val valueType = value.getType val valueOptional = valueType.isRepetition(OPTIONAL) - MapType( - convertField(keyType), - convertField(valueType), - valueContainsNull = valueOptional) - + val convertedKey = convertField(key, sparkReadKeyType) + val convertedValue = convertField(value, sparkReadValueType) + val convertedKeyType = sparkReadKeyType.getOrElse(convertedKey.sparkType) + val convertedValueType = sparkReadValueType.getOrElse(convertedValue.sparkType) + ParquetType( + MapType(convertedKeyType, convertedValueType, + valueContainsNull = valueOptional), + groupColumn, Seq(convertedKey, convertedValue)) case _ => throw QueryCompilationErrors.unrecognizedParquetTypeError(field.toString) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala new file mode 100644 index 000000000000..ddf4a95a337c --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import scala.collection.mutable + +import org.apache.parquet.column.ColumnDescriptor +import org.apache.parquet.io.{ColumnIOUtil, GroupColumnIO, PrimitiveColumnIO} +import org.apache.parquet.schema.Type.Repetition + +import org.apache.spark.sql.types.DataType + +/** + * Rich type information for a Parquet type together with its SparkSQL type. + */ +case class ParquetType( + sparkType: DataType, + descriptor: Option[ColumnDescriptor], // only set when this is a primitive type + repetitionLevel: Int, + definitionLevel: Int, + required: Boolean, + path: Seq[String], + children: Seq[ParquetType]) { + + def isPrimitive: Boolean = descriptor.nonEmpty + + /** + * Returns all the leaves (i.e., primitive columns) of this, in depth-first order. + */ + def leaves: Seq[ParquetType] = { + val buffer = mutable.ArrayBuffer[ParquetType]() + leaves0(buffer) + buffer.toSeq + } + + private def leaves0(buffer: mutable.ArrayBuffer[ParquetType]): Unit = { + children.foreach(_.leaves0(buffer)) + } +} + +object ParquetType { + def apply(sparkType: DataType, io: PrimitiveColumnIO): ParquetType = { + this(sparkType, Some(io.getColumnDescriptor), ColumnIOUtil.getRepetitionLevel(io), + ColumnIOUtil.getDefinitionLevel(io), io.getType.isRepetition(Repetition.REQUIRED), + ColumnIOUtil.getFieldPath(io), Seq.empty) + } + + def apply(sparkType: DataType, io: GroupColumnIO, children: Seq[ParquetType]): ParquetType = { + this(sparkType, None, ColumnIOUtil.getRepetitionLevel(io), + ColumnIOUtil.getDefinitionLevel(io), io.getType.isRepetition(Repetition.REQUIRED), + ColumnIOUtil.getFieldPath(io), children) + } +} diff --git a/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java b/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java new file mode 100644 index 000000000000..da3081919589 --- /dev/null +++ b/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.parquet.column.page; + +import java.util.Optional; + +/** + * A hack to create Parquet data pages with customized first row index. We have to put it under + * 'org.apache.parquet.column.page' since the constructor of `DataPage` is package-private. + */ +public class TestDataPage extends DataPage { + private final DataPage wrapped; + + public TestDataPage(DataPage wrapped, long firstRowIndex) { + super(wrapped.getCompressedSize(), wrapped.getUncompressedSize(), wrapped.getValueCount(), + firstRowIndex); + this.wrapped = wrapped; + } + + @Override + public Optional getIndexRowCount() { + return Optional.empty(); + } + + @Override + public T accept(Visitor visitor) { + return wrapped.accept(visitor); + } +} diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out index f5e5b46d29ce..f98fb1eb2a57 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out @@ -1125,7 +1125,8 @@ struct -- !query output == Physical Plan == *Filter v#x IN ([a],null) -+- FileScan parquet default.t[v#x] Batched: false, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct> ++- *ColumnarToRow + +- FileScan parquet default.t[v#x] Batched: true, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct> -- !query diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out index 4e552d51a395..a563eda1e7b0 100644 --- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out +++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out @@ -1067,7 +1067,8 @@ struct -- !query output == Physical Plan == *Filter v#x IN ([a],null) -+- FileScan parquet default.t[v#x] Batched: false, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct> ++- *ColumnarToRow + +- FileScan parquet default.t[v#x] Batched: true, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct> -- !query diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala index 0fc43c7052d0..b6623d6f9d8e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala @@ -66,15 +66,23 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { try f finally tableNames.foreach(spark.catalog.dropTempView) } - private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = { + private def prepareTable( + dir: File, + df: DataFrame, + partition: Option[String] = None, + isComplexType: Boolean = false): Unit = { val testDf = if (partition.isDefined) { df.write.partitionBy(partition.get) } else { df.write } - saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") - saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + // don't create CSV & JSON tables when benchmarking complex types as they don't support them + if (!isComplexType) { + saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv") + saveAsJsonTable(testDf, dir.getCanonicalPath + "/json") + } + saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet") saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc") } @@ -540,6 +548,106 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { } } + /** + * Similar to [[numericScanBenchmark]] but accessed column is a struct field. + */ + def nestedNumericScanBenchmark(values: Int, dataType: DataType): Unit = { + val sqlBenchmark = new Benchmark( + s"SQL Single ${dataType.sql} Column Scan in Struct", + values, + output = output) + + withTempPath { dir => + withTempTable("t1", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1") + + prepareTable(dir, + spark.sql(s"SELECT named_struct('f', CAST(value as ${dataType.sql})) as col FROM t1"), + isComplexType = true) + + sqlBenchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("select sum(col.f) from parquetTable").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet Vectorized (Disabled Nested Column)") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { + spark.sql("select sum(col.f) from parquetTable").noop() + } + } + + sqlBenchmark.addCase("SQL Parquet Vectorized (Enabled Nested Column)") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + spark.sql("select sum(col.f) from parquetTable").noop() + } + } + + sqlBenchmark.run() + } + } + } + + def nestedColumnScanBenchmark(values: Int): Unit = { + val benchmark = new Benchmark(s"SQL Nested Column Scan", values, minNumIters = 10, + output = output) + + withTempPath { dir => + withTempTable("t1", "parquetTable", "orcTable") { + import spark.implicits._ + spark.range(values).map(_ => Random.nextLong).map { x => + val arrayOfStructColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) + val mapOfStructColumn = Map( + s"$x" -> (x * 0.1, (x, s"$x" * 100)), + (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), + (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) + (arrayOfStructColumn, mapOfStructColumn) + }.toDF("col1", "col2").createOrReplaceTempView("t1") + + prepareTable(dir, spark.sql(s"SELECT * FROM t1"), isComplexType = true) + + benchmark.addCase("SQL ORC MR") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() + } + } + + benchmark.addCase("SQL ORC Vectorized (Disabled Nested Column)") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() + } + } + + benchmark.addCase("SQL ORC Vectorized (Enabled Nested Column)") { _ => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop() + } + } + + benchmark.addCase("SQL Parquet MR") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop() + } + } + + benchmark.addCase("SQL Parquet Vectorized (Disabled Nested Column)") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop() + } + } + + benchmark.addCase("SQL Parquet Vectorized (Enabled Nested Column)") { _ => + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop() + } + } + + benchmark.run() + } + } + } + override def runBenchmarkSuite(mainArgs: Array[String]): Unit = { runBenchmark("SQL Single Numeric Column Scan") { Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { @@ -565,5 +673,13 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark { columnsBenchmark(1024 * 1024 * 1, columnWidth) } } + runBenchmark("SQL Single Numeric Column Scan in Struct") { + Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach { + dataType => nestedNumericScanBenchmark(1024 * 1024 * 15, dataType) + } + } + runBenchmark("SQL Nested Column Scan") { + nestedColumnScanBenchmark(1024 * 1024) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala index c2dc20b0099a..01fac6336155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala @@ -38,6 +38,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils { protected val dataSourceName: String // The SQL config key for enabling vectorized reader. protected val vectorizedReaderEnabledKey: String + // The SQL config key for enabling vectorized reader for nested types. + protected val vectorizedReaderNestedEnabledKey: String /** * Reads data source file from given `path` as `DataFrame` and passes it to given function. @@ -52,7 +54,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils { f(spark.read.format(dataSourceName).load(path.toString)) } if (testVectorized) { - withSQLConf(vectorizedReaderEnabledKey -> "true") { + withSQLConf(vectorizedReaderEnabledKey -> "true", + vectorizedReaderNestedEnabledKey -> "true") { f(spark.read.format(dataSourceName).load(path.toString)) } } @@ -66,7 +69,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils { (data: Seq[T]) (f: String => Unit): Unit = { withTempPath { file => - spark.createDataFrame(data).write.format(dataSourceName).save(file.getCanonicalPath) + spark.createDataFrame(data).coalesce(1) + .write.format(dataSourceName).save(file.getCanonicalPath) f(file.getCanonicalPath) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala index 4243318ac1dd..0e98e2a2803d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala @@ -56,6 +56,8 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor override protected val dataSourceName: String = "orc" override protected val vectorizedReaderEnabledKey: String = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key protected override def beforeAll(): Unit = { super.beforeAll() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala index 2ce38dae47db..4d33eacecc13 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala @@ -25,6 +25,8 @@ class OrcV1SchemaPruningSuite extends SchemaPruningSuite { override protected val dataSourceName: String = "orc" override protected val vectorizedReaderEnabledKey: String = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key override protected def sparkConf: SparkConf = super diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala index 47254f4231d5..107a2b791202 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala @@ -29,6 +29,8 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH override protected val dataSourceName: String = "orc" override protected val vectorizedReaderEnabledKey: String = SQLConf.ORC_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key override protected def sparkConf: SparkConf = super diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala index 5f1c5b5cdb4e..4b1e0445fe6f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala @@ -96,10 +96,23 @@ class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSpar test("SPARK-36123: reading from unaligned pages - test filters with nulls") { // insert 50 null values in [400, 450) to verify that they are skipped during processing row // range [500, 1000) against the second page of col_2 [400, 800) - var df = spark.range(0, 2000).map { i => + val df = spark.range(0, 2000).map { i => val strVal = if (i >= 400 && i < 450) null else i + ":" + "o" * (i / 100).toInt (i, strVal) }.toDF() checkUnalignedPages(df)(actions: _*) } + + test("SPARK-34861: reading unaligned pages - struct type") { + val df = (0 until 2000).map(i => Tuple1((i.toLong, i + ":" + "o" * (i / 100)))).toDF("s") + checkUnalignedPages(df)( + df => df.filter("s._1 = 500"), + df => df.filter("s._1 = 500 or s._1 = 1500"), + df => df.filter("s._1 = 500 or s._1 = 501 or s._1 = 1500"), + df => df.filter("s._1 = 500 or s._1 = 501 or s._1 = 1000 or s._1 = 1500"), + // range filter + df => df.filter("s._1 >= 500 and s._1 < 1000"), + df => df.filter("(s._1 >= 500 and s._1 < 1000) or (s._1 >= 1500 and s._1 < 1600)") + ) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala index a330b82de2d0..8943b1680cfc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala @@ -290,6 +290,332 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession } } + test("vectorized reader: array") { + val data = Seq( + Tuple1(null), + Tuple1(Seq()), + Tuple1(Seq("a", "b", "c")), + Tuple1(Seq(null)) + ) + + withParquetFile(data) { file => + readParquetFile(file) { df => + checkAnswer(df.sort("_1"), + Row(null) :: Row(Seq()) :: Row(Seq(null)) :: Row(Seq("a", "b", "c")) :: Nil + ) + } + } + } + + test("vectorized reader: missing array") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val data = Seq( + Tuple1(null), + Tuple1(Seq()), + Tuple1(Seq("a", "b", "c")), + Tuple1(Seq(null)) + ) + + val readSchema = new StructType().add("_2", new ArrayType( + new StructType().add("a", LongType, nullable = true), + containsNull = true) + ) + + withParquetFile(data) { file => + checkAnswer(spark.read.schema(readSchema).parquet(file), + Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil + ) + } + } + } + + test("vectorized reader: array of array") { + val data = Seq( + Tuple1(Seq(Seq(0, 1), Seq(2, 3))), + Tuple1(Seq(Seq(4, 5), Seq(6, 7))) + ) + + withParquetFile(data) { file => + readParquetFile(file) { df => + checkAnswer(df.sort("_1"), + Row(Seq(Seq(0, 1), Seq(2, 3))) :: Row(Seq(Seq(4, 5), Seq(6, 7))) :: Nil + ) + } + } + } + + test("vectorized reader: array of struct") { + val data = Seq( + Tuple1(Tuple2("a", null)), + Tuple1(null), + Tuple1(Tuple2(null, null)), + Tuple1(Tuple2(null, Seq("b", "c"))), + Tuple1(Tuple2("d", Seq("e", "f"))), + Tuple1(null) + ) + + withParquetFile(data) { file => + readParquetFile(file) { df => + checkAnswer(df, + Row(Row("a", null)) :: Row(null) :: Row(Row(null, null)) :: + Row(Row(null, Seq("b", "c"))) :: Row(Row("d", Seq("e", "f"))) :: Row(null) :: Nil + ) + } + } + } + + test("vectorized reader: array of nested struct") { + val data = Seq( + Tuple1(Tuple2("a", null)), + Tuple1(Tuple2("b", Seq(Tuple2("c", "d")))), + Tuple1(null), + Tuple1(Tuple2("e", Seq(Tuple2("f", null), Tuple2(null, "g")))), + Tuple1(Tuple2(null, null)), + Tuple1(Tuple2(null, Seq(null))), + Tuple1(Tuple2(null, Seq(Tuple2(null, null), Tuple2("h", null), null))), + Tuple1(Tuple2("i", Seq())), + Tuple1(null) + ) + + withParquetFile(data) { file => + readParquetFile(file) { df => + checkAnswer(df, + Row(Row("a", null)) :: + Row(Row("b", Seq(Row("c", "d")))) :: + Row(null) :: + Row(Row("e", Seq(Row("f", null), Row(null, "g")))) :: + Row(Row(null, null)) :: + Row(Row(null, Seq(null))) :: + Row(Row(null, Seq(Row(null, null), Row("h", null), null))) :: + Row(Row("i", Seq())) :: + Row(null) :: + Nil) + } + } + } + + test("vectorized reader: required array with required elements") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = { + val schemaStr = + """message spark_schema { + | required group _1 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + expected.foreach { values => + val group = factory.newGroup() + val list = group.addGroup(0) + values.foreach { value => + list.addGroup(0).append("element", value) + } + writer.write(group) + } + writer.close() + } + + // write the following into the Parquet file: + // 0: [ "a", "b" ] + // 1: [ ] + // 2: [ "c", "d" ] + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = Seq(Seq("a", "b"), Seq(), Seq("c", "d")) + makeRawParquetFile(path, expected) + readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) } + } + } + } + + test("vectorized reader: optional array with required elements") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = { + val schemaStr = + """message spark_schema { + | optional group _1 (LIST) { + | repeated group list { + | required binary element (UTF8); + | } + | } + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + expected.foreach { values => + val group = factory.newGroup() + if (values != null) { + val list = group.addGroup(0) + values.foreach { value => + list.addGroup(0).append("element", value) + } + } + writer.write(group) + } + writer.close() + } + + // write the following into the Parquet file: + // 0: [ "a", "b" ] + // 1: null + // 2: [ "c", "d" ] + // 3: [ ] + // 4: [ "e", "f" ] + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = Seq(Seq("a", "b"), null, Seq("c", "d"), Seq(), Seq("e", "f")) + makeRawParquetFile(path, expected) + readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) } + } + } + } + + test("vectorized reader: required array with optional elements") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = { + val schemaStr = + """message spark_schema { + | required group _1 (LIST) { + | repeated group list { + | optional binary element (UTF8); + | } + | } + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + expected.foreach { values => + val group = factory.newGroup() + if (values != null) { + val list = group.addGroup(0) + values.foreach { value => + val group = list.addGroup(0) + if (value != null) group.append("element", value) + } + } + writer.write(group) + } + writer.close() + } + + // write the following into the Parquet file: + // 0: [ "a", null ] + // 3: [ ] + // 4: [ null, "b" ] + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = Seq(Seq("a", null), Seq(), Seq(null, "b")) + makeRawParquetFile(path, expected) + readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) } + } + } + } + + test("vectorized reader: required array with legacy format") { + Seq(true, false).foreach { dictionaryEnabled => + def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = { + val schemaStr = + """message spark_schema { + | repeated binary element (UTF8); + |} + """.stripMargin + val schema = MessageTypeParser.parseMessageType(schemaStr) + val writer = createParquetWriter(schema, path, dictionaryEnabled) + + val factory = new SimpleGroupFactory(schema) + expected.foreach { values => + val group = factory.newGroup() + values.foreach(group.append("element", _)) + writer.write(group) + } + writer.close() + } + + // write the following into the Parquet file: + // 0: [ "a", "b" ] + // 3: [ ] + // 4: [ "c", "d" ] + withTempDir { dir => + val path = new Path(dir.toURI.toString, "part-r-0.parquet") + val expected = Seq(Seq("a", "b"), Seq(), Seq("c", "d")) + makeRawParquetFile(path, expected) + readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) } + } + } + } + + test("vectorized reader: struct") { + val data = Seq( + Tuple1(null), + Tuple1((1, "a")), + Tuple1((2, null)), + Tuple1((3, "b")), + Tuple1(null) + ) + + withParquetFile(data) { file => + readParquetFile(file) { df => + checkAnswer(df.sort("_1"), + Row(null) :: Row(null) :: Row(Row(1, "a")) :: Row(Row(2, null)) :: Row(Row(3, "b")) :: Nil + ) + } + } + } + + test("vectorized reader: missing all struct fields") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val data = Seq( + Tuple1((1, "a")), + Tuple1((2, null)), + Tuple1(null) + ) + + val readSchema = new StructType().add("_1", + new StructType() + .add("_3", IntegerType, nullable = true) + .add("_4", LongType, nullable = true), + nullable = true) + + withParquetFile(data) { file => + checkAnswer(spark.read.schema(readSchema).parquet(file), + Row(null) :: Row(null) :: Row(null) :: Nil + ) + } + } + } + + test("vectorized reader: missing some struct fields") { + withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val data = Seq( + Tuple1((1, "a")), + Tuple1((2, null)), + Tuple1(null) + ) + + val readSchema = new StructType().add("_1", + new StructType() + .add("_1", IntegerType, nullable = true) + .add("_3", LongType, nullable = true), + nullable = true) + + withParquetFile(data) { file => + checkAnswer(spark.read.schema(readSchema).parquet(file), + Row(null) :: Row(Row(1, null)) :: Row(Row(2, null)) :: Nil + ) + } + } + } + test("SPARK-34817: Support for unsigned Parquet logical types") { val parquetSchema = MessageTypeParser.parseMessageType( """message root { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala index cab93bd96fff..6a93b72472c7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala @@ -31,6 +31,8 @@ abstract class ParquetSchemaPruningSuite extends SchemaPruningSuite with Adaptiv override protected val dataSourceName: String = "parquet" override protected val vectorizedReaderEnabledKey: String = SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala index 7a4a382f7f5c..16eb854cf250 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala @@ -49,6 +49,8 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest { override protected val dataSourceName: String = "parquet" override protected val vectorizedReaderEnabledKey: String = SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key + override protected val vectorizedReaderNestedEnabledKey: String = + SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key /** * Reads the parquet file at `path` diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala new file mode 100644 index 000000000000..832df35b5b1e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala @@ -0,0 +1,751 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.parquet + +import java.util.{Optional, PrimitiveIterator} + +import scala.collection.mutable.ArrayBuffer +import scala.language.implicitConversions + +import org.apache.parquet.column.{ColumnDescriptor, ParquetProperties} +import org.apache.parquet.column.impl.ColumnWriteStoreV1 +import org.apache.parquet.column.page._ +import org.apache.parquet.column.page.mem.MemPageStore +import org.apache.parquet.io.ParquetDecodingException +import org.apache.parquet.io.api.Binary +import org.apache.parquet.schema.{MessageType, MessageTypeParser} +import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName + +import org.apache.spark.memory.MemoryMode +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.RowOrdering +import org.apache.spark.sql.catalyst.util.DateTimeUtils +import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ParquetRowGroupReader +import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ + +class ParquetVectorizedSuite extends QueryTest with ParquetTest with SharedSparkSession { + private val VALUES: Seq[String] = ('a' to 'z').map(_.toString) + private val NUM_VALUES: Int = VALUES.length + private val BATCH_SIZE_CONFIGS: Seq[Int] = Seq(1, 3, 5, 7, 10, 20, 40) + private val PAGE_SIZE_CONFIGS: Seq[Seq[Int]] = Seq(Seq(6, 6, 7, 7), Seq(4, 9, 4, 9)) + + implicit def toStrings(ints: Seq[Int]): Seq[String] = ints.map(i => ('a' + i).toChar.toString) + + test("primitive type - no column index") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + PAGE_SIZE_CONFIGS.foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + testPrimitiveString(None, None, pageSizes, VALUES, batchSize, + dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("primitive type - column index with ranges") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + PAGE_SIZE_CONFIGS.foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + var ranges = Seq((0L, 9L)) + testPrimitiveString(None, Some(ranges), pageSizes, 0 to 9, batchSize, + dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((30, 50)) + testPrimitiveString(None, Some(ranges), pageSizes, Seq.empty, batchSize, + dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 25)) + testPrimitiveString(None, Some(ranges), pageSizes, 15 to 19, batchSize, + dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((19, 20)) + testPrimitiveString(None, Some(ranges), pageSizes, 19 to 20, batchSize, + dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((0, 3), (5, 7), (15, 18)) + testPrimitiveString(None, Some(ranges), pageSizes, + toStrings(Seq(0, 1, 2, 3, 5, 6, 7, 15, 16, 17, 18)), + batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("primitive type - column index with ranges and nulls") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + PAGE_SIZE_CONFIGS.foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + val valuesWithNulls = VALUES.zipWithIndex.map { + case (v, i) => if (i % 2 == 0) null else v + } + testPrimitiveString(None, None, pageSizes, valuesWithNulls, batchSize, valuesWithNulls, + dictionaryEnabled) + + val ranges = Seq((5L, 7L)) + testPrimitiveString(None, Some(ranges), pageSizes, Seq("f", null, "h"), + batchSize, valuesWithNulls, dictionaryEnabled) + } + } + } + } + + test("primitive type - column index with ranges and first row indexes") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + Seq(true, false).foreach { dictionaryEnabled => + // Single page + val firstRowIndex = 10 + var ranges = Seq((0L, 9L)) + testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length), + Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 25)) + testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length), + 5 to 15, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 35)) + testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length), + 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 39)) + testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length), + 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled) + + // Row indexes: [ [10, 16), [20, 26), [30, 37), [40, 47) ] + // Values: [ [0, 6), [6, 12), [12, 19), [19, 26) ] + var pageSizes = Seq(6, 6, 7, 7) + var firstRowIndexes = Seq(10L, 20, 30, 40) + + ranges = Seq((0L, 9L)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 25)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + 5 to 9, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 35)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + 5 to 14, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 60)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((12, 22), (28, 38)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + toStrings(Seq(2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18)), batchSize, + dictionaryEnabled = dictionaryEnabled) + + // Row indexes: [ [10, 11), [40, 52), [100, 112), [200, 201) ] + // Values: [ [0, 1), [1, 13), [13, 25), [25, 26] ] + pageSizes = Seq(1, 12, 12, 1) + firstRowIndexes = Seq(10L, 40, 100, 200) + ranges = Seq((0L, 9L)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((300, 350)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((50, 80)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + (11 to 12), batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((0, 150)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + 0 to 24, batchSize, dictionaryEnabled = dictionaryEnabled) + + // with nulls + val valuesWithNulls = VALUES.zipWithIndex.map { + case (v, i) => if (i % 2 == 0) null else v + } + ranges = Seq((20, 45)) // select values in [1, 5] + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq("b", null, "d", null, "f"), batchSize, valuesWithNulls, + dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((8, 12), (80, 104)) + testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(null, "n", null, "p", null, "r"), batchSize, valuesWithNulls, + dictionaryEnabled = dictionaryEnabled) + } + } + } + + test("nested type - single page, no column index") { + (1 to 4).foreach { batchSize => + Seq(true, false).foreach { dictionaryEnabled => + testNestedStringArrayOneLevel(None, None, Seq(4), + Seq(Seq("a", "b", "c", "d")), + Seq(0, 1, 1, 1), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(4), + Seq(Seq("a", "b"), Seq("c", "d")), + Seq(0, 1, 0, 1), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(4), + Seq(Seq("a"), Seq("b"), Seq("c"), Seq("d")), + Seq(0, 0, 0, 0), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(4), + Seq(Seq("a"), Seq(null), Seq("c"), Seq(null)), + Seq(0, 0, 0, 0), Seq(3, 2, 3, 2), Seq("a", null, "c", null), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(4), + Seq(Seq("a"), Seq(null, null, null)), + Seq(0, 0, 1, 1), Seq(3, 2, 2, 2), Seq("a", null, null, null), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(6), + Seq(Seq("a"), Seq(null, null, null), null, Seq()), + Seq(0, 0, 1, 1, 0, 0), Seq(3, 2, 2, 2, 0, 1), Seq("a", null, null, null, null, null), + batchSize, dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(8), + Seq(Seq("a"), Seq(), Seq(), null, Seq("b", null, "c"), null), + Seq(0, 0, 0, 0, 0, 1, 1, 0), Seq(3, 1, 1, 0, 3, 2, 3, 0), + Seq("a", null, null, null, "b", null, "c", null), batchSize, + dictionaryEnabled = dictionaryEnabled) + } + } + } + + test("nested type - multiple page, no column index") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + Seq(Seq(2, 3, 2, 3)).foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + testNestedStringArrayOneLevel(None, None, pageSizes, + Seq(Seq("a"), Seq(), Seq("b", null, "c"), Seq("d", "e"), Seq(null), Seq(), null), + Seq(0, 0, 0, 1, 1, 0, 1, 0, 0, 0), Seq(3, 1, 3, 2, 3, 3, 3, 2, 1, 0), + Seq("a", null, "b", null, "c", "d", "e", null, null, null), batchSize, + dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("nested type - multiple page, no column index, batch span multiple pages") { + (1 to 6).foreach { batchSize => + Seq(true, false).foreach { dictionaryEnabled => + // a list across multiple pages + testNestedStringArrayOneLevel(None, None, Seq(1, 5), + Seq(Seq("a"), Seq("b", "c", "d", "e", "f")), + Seq(0, 0, 1, 1, 1, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(1, 3, 2), + Seq(Seq("a"), Seq("b", "c", "d"), Seq("e", "f")), + Seq(0, 0, 1, 1, 0, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize, + dictionaryEnabled = dictionaryEnabled) + + testNestedStringArrayOneLevel(None, None, Seq(2, 2, 2), + Seq(Seq("a", "b"), Seq("c", "d"), Seq("e", "f")), + Seq(0, 1, 0, 1, 0, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize, + dictionaryEnabled = dictionaryEnabled) + } + } + } + + test("nested type - RLE encoding") { + (1 to 8).foreach { batchSize => + Seq(Seq(26), Seq(4, 3, 11, 4, 4), Seq(18, 8)).foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + testNestedStringArrayOneLevel(None, None, pageSizes, + (0 to 6).map(i => Seq(('a' + i).toChar.toString)) ++ + Seq((7 to 17).map(i => ('a' + i).toChar.toString)) ++ + (18 to 25).map(i => Seq(('a' + i).toChar.toString)), + Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("nested type - column index with ranges") { + (1 to 8).foreach { batchSize => + Seq(Seq(8), Seq(6, 2), Seq(1, 5, 2)).foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + var ranges = Seq((1L, 2L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("b", "c", "d", "e", "f"), Seq("g", "h")), + Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3), + Seq("a", "b", "c", "d", "e", "f", "g", "h"), + batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((3L, 5L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(), + Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3), + Seq("a", "b", "c", "d", "e", "f", "g", "h"), + batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((0L, 0L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("a")), + Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3), + Seq("a", "b", "c", "d", "e", "f", "g", "h"), + batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("nested type - column index with ranges and RLE encoding") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + Seq(Seq(26), Seq(4, 3, 11, 4, 4), Seq(18, 8)).foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + var ranges = Seq((0L, 2L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("a"), Seq("b"), Seq("c")), + Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((4L, 6L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("e"), Seq("f"), Seq("g")), + Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((6L, 9L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("g")) ++ Seq((7 to 17).map(i => ('a' + i).toChar.toString)) ++ + Seq(Seq("s"), Seq("t")), + Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((4L, 6L), (14L, 20L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("e"), Seq("f"), Seq("g"), Seq("y"), Seq("z")), + Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("nested type - column index with ranges and nulls") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + Seq(Seq(16), Seq(8, 8), Seq(4, 4, 4, 4), Seq(2, 6, 4, 4)).foreach { pageSizes => + Seq(true, false).foreach { dictionaryEnabled => + testNestedStringArrayOneLevel(None, None, pageSizes, + Seq(Seq("a", null), Seq("c", "d"), Seq(), Seq("f", null, "h"), + Seq("i", "j", "k", null), Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + var ranges = Seq((0L, 15L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("a", null), Seq("c", "d"), Seq(), Seq("f", null, "h"), + Seq("i", "j", "k", null), Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((0L, 2L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("a", null), Seq("c", "d"), Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((3L, 7L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("f", null, "h"), Seq("i", "j", "k", null), Seq(), null, null), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((5, 12L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((5, 12L)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((0L, 0L), (2, 3), (5, 7), (8, 10)) + testNestedStringArrayOneLevel(None, Some(ranges), pageSizes, + Seq(Seq("a", null), Seq(), Seq("f", null, "h"), Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + } + + test("nested type - column index with ranges, nulls and first row indexes") { + BATCH_SIZE_CONFIGS.foreach { batchSize => + Seq(true, false).foreach { dictionaryEnabled => + val pageSizes = Seq(4, 4, 4, 4) + var firstRowIndexes = Seq(10L, 20, 30, 40) + var ranges = Seq((0L, 5L)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((5L, 15)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(Seq("a", null), Seq("c", "d")), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((25, 28)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((35, 45)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(Seq(), null, null, Seq()), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((45, 55)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((45, 55)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + + ranges = Seq((15, 29), (31, 35)) + testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes, + Seq(Seq(), Seq("f", null, "h")), + Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0), + Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1), + (0 to 15), + batchSize = batchSize, dictionaryEnabled = dictionaryEnabled) + } + } + } + + private def testNestedStringArrayOneLevel( + firstRowIndexesOpt: Option[Seq[Long]], + rangesOpt: Option[Seq[(Long, Long)]], + pageSizes: Seq[Int], + expected: Seq[Seq[String]], + rls: Seq[Int], + dls: Seq[Int], + values: Seq[String] = VALUES, + batchSize: Int, + dictionaryEnabled: Boolean = false): Unit = { + assert(pageSizes.sum == rls.length && rls.length == dls.length) + firstRowIndexesOpt.foreach(a => assert(pageSizes.length == a.length)) + + val parquetSchema = MessageTypeParser.parseMessageType( + s"""message root { + | optional group _1 (LIST) { + | repeated group list { + | optional binary a(UTF8); + | } + | } + |} + |""".stripMargin + ) + + val maxRepLevel = 1 + val maxDefLevel = 3 + val ty = parquetSchema.getType("_1", "list", "a").asPrimitiveType() + val cd = new ColumnDescriptor(Seq("_1", "list", "a").toArray, ty, maxRepLevel, maxDefLevel) + + var i = 0 + var numRows = 0 + val memPageStore = new MemPageStore(expected.length) + val pageFirstRowIndexes = ArrayBuffer.empty[Long] + pageSizes.foreach { size => + pageFirstRowIndexes += numRows + numRows += rls.slice(i, i + size).count(_ == 0) + writeDataPage(cd, memPageStore, rls.slice(i, i + size), dls.slice(i, i + size), + values.slice(i, i + size), maxDefLevel, dictionaryEnabled) + i += size + } + + checkAnswer(expected.length, parquetSchema, + TestPageReadStore(memPageStore, firstRowIndexesOpt.getOrElse(pageFirstRowIndexes).toSeq, + rangesOpt), expected.map(i => Row(i)), batchSize) + } + + private def testPrimitiveString( + firstRowIndexesOpt: Option[Seq[Long]], + rangesOpt: Option[Seq[(Long, Long)]], + pageSizes: Seq[Int], + expected: Seq[String], + batchSize: Int, + actual: Seq[String] = VALUES, + dictionaryEnabled: Boolean = false): Unit = { + assert(pageSizes.sum == actual.length) + firstRowIndexesOpt.foreach(a => assert(pageSizes.length == a.length)) + + val isRequiredStr = if (!expected.contains(null)) "required" else "optional" + val parquetSchema: MessageType = MessageTypeParser.parseMessageType( + s"""message root { + | $isRequiredStr binary a(UTF8); + |} + |""".stripMargin + ) + val maxDef = if (actual.contains(null)) 1 else 0 + val ty = parquetSchema.asGroupType().getType("a").asPrimitiveType() + val cd = new ColumnDescriptor(Seq("a").toArray, ty, 0, maxDef) + val rls = Array.fill[Int](actual.length)(0) + val dls = actual.map(v => if (v == null) 0 else 1) + + val memPageStore = new MemPageStore(expected.length) + + var i = 0 + val pageFirstRowIndexes = ArrayBuffer.empty[Long] + pageSizes.foreach { size => + pageFirstRowIndexes += i + writeDataPage(cd, memPageStore, rls.slice(i, i + size), dls.slice(i, i + size), + actual.slice(i, i + size), maxDef, dictionaryEnabled) + i += size + } + + checkAnswer(expected.length, parquetSchema, + TestPageReadStore(memPageStore, firstRowIndexesOpt.getOrElse(pageFirstRowIndexes).toSeq, + rangesOpt), expected.map(i => Row(i)), batchSize) + } + + /** + * Write a single data page using repetition levels (`rls`), definition levels (`dls`) and + * values (`values`) provided. + * + * Note that this requires `rls`, `dls` and `values` to have the same number of elements. For + * null values, the corresponding slots in `values` will be skipped. + */ + private def writeDataPage( + cd: ColumnDescriptor, + pageWriteStore: PageWriteStore, + rls: Seq[Int], + dls: Seq[Int], + values: Seq[Any], + maxDl: Int, + dictionaryEnabled: Boolean = false): Unit = { + val columnWriterStore = new ColumnWriteStoreV1(pageWriteStore, + ParquetProperties.builder() + .withPageSize(4096) + .withDictionaryEncoding(dictionaryEnabled) + .build()) + val columnWriter = columnWriterStore.getColumnWriter(cd) + + rls.zip(dls).zipWithIndex.foreach { case ((rl, dl), i) => + if (dl < maxDl) { + columnWriter.writeNull(rl, dl) + } else { + cd.getPrimitiveType.getPrimitiveTypeName match { + case PrimitiveTypeName.INT32 => + columnWriter.write(values(i).asInstanceOf[Int], rl, dl) + case PrimitiveTypeName.INT64 => + columnWriter.write(values(i).asInstanceOf[Long], rl, dl) + case PrimitiveTypeName.BOOLEAN => + columnWriter.write(values(i).asInstanceOf[Boolean], rl, dl) + case PrimitiveTypeName.FLOAT => + columnWriter.write(values(i).asInstanceOf[Float], rl, dl) + case PrimitiveTypeName.DOUBLE => + columnWriter.write(values(i).asInstanceOf[Double], rl, dl) + case PrimitiveTypeName.BINARY => + columnWriter.write(Binary.fromString(values(i).asInstanceOf[String]), rl, dl) + case _ => + throw new IllegalStateException(s"Unexpected type: " + + s"${cd.getPrimitiveType.getPrimitiveTypeName}") + } + } + columnWriterStore.endRecord() + } + columnWriterStore.flush() + } + + private def checkAnswer( + totalRowCount: Int, + fileSchema: MessageType, + readStore: PageReadStore, + expected: Seq[Row], + batchSize: Int = NUM_VALUES): Unit = { + import collection.JavaConverters._ + + val recordReader = new VectorizedParquetRecordReader( + DateTimeUtils.getZoneId("EST"), "CORRECTED", "CORRECTED", true, batchSize) + recordReader.initialize(fileSchema, fileSchema, + TestParquetRowGroupReader(Seq(readStore)), totalRowCount) + + // convert both actual and expected rows into collections + val schema = recordReader.sparkSchema + val expectedRowIt = ColumnVectorUtils.toBatch( + schema, MemoryMode.ON_HEAP, expected.iterator.asJava).rowIterator() + + val rowOrdering = RowOrdering.createNaturalAscendingOrdering(schema.map(_.dataType)) + var i = 0 + while (expectedRowIt.hasNext && recordReader.nextKeyValue()) { + val expectedRow = expectedRowIt.next() + val actualRow = recordReader.getCurrentValue.asInstanceOf[InternalRow] + assert(rowOrdering.compare(expectedRow, actualRow) == 0, { + val expectedRowStr = toDebugString(schema, expectedRow) + val actualRowStr = toDebugString(schema, actualRow) + s"at index $i, expected row: $expectedRowStr doesn't match actual row: $actualRowStr" + }) + i += 1 + } + } + + private def toDebugString(schema: StructType, row: InternalRow): String = { + if (row == null) "null" + else { + val fieldStrings = schema.fields.zipWithIndex.map { case (f, i) => + f.dataType match { + case IntegerType => + row.getInt(i).toString + case StringType => + val utf8Str = row.getUTF8String(i) + if (utf8Str == null) "null" + else utf8Str.toString + case ArrayType(_, _) => + val elements = row.getArray(i) + if (elements == null) "null" + else "[" + elements.array.mkString(", ") + "]" + case _ => + throw new IllegalStateException(s"Unsupported data type: ${f.dataType}") + } + } + fieldStrings.mkString(", ") + } + } + + case class TestParquetRowGroupReader(groups: Seq[PageReadStore]) extends ParquetRowGroupReader { + private var index: Int = 0 + + override def readNextRowGroup(): PageReadStore = { + if (index == groups.length) { + null + } else { + val res = groups(index) + index += 1 + res + } + } + + override def close(): Unit = {} + } + + private case class TestPageReadStore( + original: PageReadStore, + firstRowIndexes: Seq[Long], + rangesOpt: Option[Seq[(Long, Long)]] = None) extends PageReadStore { + override def getPageReader(descriptor: ColumnDescriptor): PageReader = { + val originalReader = original.getPageReader(descriptor) + TestPageReader(originalReader, firstRowIndexes) + } + + override def getRowCount: Long = original.getRowCount + + override def getRowIndexes: Optional[PrimitiveIterator.OfLong] = { + rangesOpt.map { ranges => + Optional.of(new PrimitiveIterator.OfLong { + private var currentRangeIdx: Int = 0 + private var currentRowIdx: Long = -1 + + override def nextLong(): Long = { + if (!hasNext) throw new NoSuchElementException("No more element") + val res = currentRowIdx + currentRowIdx += 1 + res + } + + override def hasNext: Boolean = { + while (currentRangeIdx < ranges.length) { + if (currentRowIdx > ranges(currentRangeIdx)._2) { + // we've exhausted the current range - move to the next range + currentRangeIdx += 1 + currentRowIdx = -1 + } else { + if (currentRowIdx == -1) { + currentRowIdx = ranges(currentRangeIdx)._1 + } + return true + } + } + false + } + }) + }.getOrElse(Optional.empty()) + } + } + + private case class TestPageReader( + wrapped: PageReader, + firstRowIndexes: Seq[Long]) extends PageReader { + private var index = 0 + + override def readDictionaryPage(): DictionaryPage = wrapped.readDictionaryPage() + override def getTotalValueCount: Long = wrapped.getTotalValueCount + override def readPage(): DataPage = { + val wrappedPage = try { + wrapped.readPage() + } catch { + case _: ParquetDecodingException => + null + } + if (wrappedPage == null) { + wrappedPage + } else { + val res = new TestDataPage(wrappedPage, firstRowIndexes(index)) + index += 1 + res + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala index 43f48abb9734..6721e1ad2584 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala @@ -336,8 +336,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach { val c2 = testVector.getChild(1) c1.putInt(0, 123) c2.putDouble(0, 3.45) + testVector.putStruct(0, 0) c1.putInt(1, 456) c2.putDouble(1, 5.67) + testVector.putStruct(1, 1) assert(testVector.getStruct(0).get(0, IntegerType) === 123) assert(testVector.getStruct(0).get(1, DoubleType) === 3.45) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala index 5fe0a2aef8a8..979e5f486da2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala @@ -921,12 +921,14 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putInt(0, 123) c2.putDouble(0, 3.45) + column.putStruct(0, 0) column.putNull(1) assert(column.getStruct(1) == null) - c1.putInt(2, 456) - c2.putDouble(2, 5.67) + c1.putInt(1, 456) + c2.putDouble(1, 5.67) + column.putStruct(2, 1) val s = column.getStruct(0) assert(s.getInt(0) == 123) @@ -982,6 +984,7 @@ class ColumnarBatchSuite extends SparkFunSuite { (0 until 6).foreach { i => c0.putInt(i, i) c1.putLong(i, i * 10) + data.putStruct(i, i) } // Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)], // [(4, 40), (5, 50)] @@ -1018,6 +1021,10 @@ class ColumnarBatchSuite extends SparkFunSuite { c1.putArray(1, 2, 1) c1.putArray(2, 3, 3) + column.putStruct(0, 0) + column.putStruct(1, 1) + column.putStruct(2, 2) + assert(column.getStruct(0).getInt(0) === 0) assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1)) assert(column.getStruct(1).getInt(0) === 1) @@ -1038,6 +1045,9 @@ class ColumnarBatchSuite extends SparkFunSuite { c0.putInt(0, 0) c0.putInt(1, 1) c0.putInt(2, 2) + column.putStruct(0, 0) + column.putStruct(1, 1) + column.putStruct(2, 2) val c1c0 = c1.getChild(0) val c1c1 = c1.getChild(1) // Structs in c1: (7, 70), (8, 80), (9, 90) @@ -1047,6 +1057,9 @@ class ColumnarBatchSuite extends SparkFunSuite { c1c1.putInt(0, 70) c1c1.putInt(1, 80) c1c1.putInt(2, 90) + c1.putStruct(0, 0) + c1.putStruct(1, 1) + c1.putStruct(2, 2) assert(column.getStruct(0).getInt(0) === 0) assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70))