diff --git a/pom.xml b/pom.xml
index 81a0126539b1..e8164315606a 100644
--- a/pom.xml
+++ b/pom.xml
@@ -2333,6 +2333,27 @@
${parquet.version}
${parquet.deps.scope}
+
+ org.apache.parquet
+ parquet-encoding
+ ${parquet.version}
+ ${parquet.test.deps.scope}
+ tests
+
+
+ org.apache.parquet
+ parquet-common
+ ${parquet.version}
+ ${parquet.test.deps.scope}
+ tests
+
+
+ org.apache.parquet
+ parquet-column
+ ${parquet.version}
+ ${parquet.test.deps.scope}
+ tests
+
org.apache.parquet
parquet-hadoop
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
index 5ef4fba193e0..0033680199a3 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/vectorized/ColumnVector.java
@@ -216,7 +216,7 @@ public double[] getDoubles(int rowId, int count) {
* the struct type, and each child vector is responsible to store the data for its corresponding
* struct field.
*/
- public final ColumnarRow getStruct(int rowId) {
+ public ColumnarRow getStruct(int rowId) {
if (isNullAt(rowId)) return null;
return new ColumnarRow(this, rowId);
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
index 8ba2b9f8fd2a..9620d1397f30 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala
@@ -886,6 +886,14 @@ object SQLConf {
.booleanConf
.createWithDefault(true)
+ val PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED =
+ buildConf("spark.sql.parquet.enableNestedColumnVectorizedReader")
+ .doc("Enables vectorized Parquet decoding for nested columns (e.g., struct, list, map). " +
+ s"Note to enable this ${PARQUET_VECTORIZED_READER_ENABLED} also needs to be enabled.")
+ .version("3.3.0")
+ .booleanConf
+ .createWithDefault(true)
+
val PARQUET_RECORD_FILTER_ENABLED = buildConf("spark.sql.parquet.recordLevelFilter.enabled")
.doc("If true, enables Parquet's native record-level filtering using the pushed down " +
"filters. " +
@@ -3612,6 +3620,9 @@ class SQLConf extends Serializable with Logging {
def parquetVectorizedReaderEnabled: Boolean = getConf(PARQUET_VECTORIZED_READER_ENABLED)
+ def parquetVectorizedReaderNestedColumnEnabled: Boolean =
+ getConf(PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED)
+
def parquetVectorizedReaderBatchSize: Int = getConf(PARQUET_VECTORIZED_READER_BATCH_SIZE)
def columnBatchSize: Int = getConf(COLUMN_BATCH_SIZE)
diff --git a/sql/core/pom.xml b/sql/core/pom.xml
index 27260ce67ae7..752dcd902b4d 100644
--- a/sql/core/pom.xml
+++ b/sql/core/pom.xml
@@ -106,6 +106,24 @@
org.apache.parquet
parquet-column
+
+ org.apache.parquet
+ parquet-encoding
+ test
+ tests
+
+
+ org.apache.parquet
+ parquet-common
+ test
+ tests
+
+
+ org.apache.parquet
+ parquet-column
+ test
+ tests
+
org.apache.parquet
parquet-hadoop
diff --git a/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java b/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java
new file mode 100644
index 000000000000..c1732cc206de
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/parquet/io/ColumnIOUtil.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.parquet.io;
+
+/**
+ * This is a workaround since both methods below are not public in {@link ColumnIO}.
+ * TODO(SPARK-36511): we should remove this once PARQUET-2050 is released with Parquet 1.13.
+ */
+public class ColumnIOUtil {
+ private ColumnIOUtil() {}
+
+ public static int getDefinitionLevel(ColumnIO column) {
+ return column.getDefinitionLevel();
+ }
+
+ public static int getRepetitionLevel(ColumnIO column) {
+ return column.getRepetitionLevel();
+ }
+
+ public static String[] getFieldPath(ColumnIO column) {
+ return column.getFieldPath();
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
index 40ed0b2454c1..9df95892f28b 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnarBatchReader.java
@@ -176,8 +176,7 @@ public void initBatch(
// Initialize the missing columns once.
if (colId == -1) {
OnHeapColumnVector missingCol = new OnHeapColumnVector(capacity, dt);
- missingCol.putNulls(0, capacity);
- missingCol.setIsConstant();
+ missingCol.setAllNull();
orcVectorWrappers[i] = missingCol;
} else {
orcVectorWrappers[i] = OrcColumnVectorUtils.toOrcColumnVector(
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java
new file mode 100644
index 000000000000..c4f6f8778226
--- /dev/null
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetColumn.java
@@ -0,0 +1,318 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Set;
+
+import com.google.common.base.Preconditions;
+import org.apache.spark.memory.MemoryMode;
+import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
+import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
+import org.apache.spark.sql.types.ArrayType;
+import org.apache.spark.sql.types.DataType;
+import org.apache.spark.sql.types.DataTypes;
+import org.apache.spark.sql.types.MapType;
+import org.apache.spark.sql.types.StructType;
+
+/**
+ * Contains necessary information representing a Parquet column, either of primitive or nested type.
+ */
+final class ParquetColumn {
+ private final ParquetType type;
+ private final List children;
+ private final WritableColumnVector vector;
+
+ /**
+ * Repetition & Definition levels
+ * These are allocated only for leaf columns; for non-leaf columns, they simply maintain
+ * references to that of the former.
+ */
+ private WritableColumnVector repetitionLevels;
+ private WritableColumnVector definitionLevels;
+
+ /** Whether this column is primitive (i.e., leaf column) */
+ private final boolean isPrimitive;
+
+ /** Reader for this column - only set if 'isPrimitive' is true */
+ private VectorizedColumnReader columnReader;
+
+ ParquetColumn(
+ ParquetType type,
+ WritableColumnVector vector,
+ int capacity,
+ MemoryMode memoryMode,
+ Set missingColumns) {
+ DataType sparkType = type.sparkType();
+ if (!sparkType.sameType(vector.dataType())) {
+ throw new IllegalArgumentException("Spark type: " + type.sparkType() +
+ " doesn't match the type: " + vector.dataType() + " in column vector");
+ }
+
+ this.type = type;
+ this.vector = vector;
+ this.children = new ArrayList<>();
+ this.isPrimitive = type.isPrimitive();
+
+ if (missingColumns.contains(type)) {
+ vector.setAllNull();
+ return;
+ }
+
+ if (isPrimitive) {
+ // TODO: avoid allocating these if not necessary, for instance, the node is of top-level
+ // and is not repeated, or the node is not top-level but its max repetition level is 0.
+ repetitionLevels = allocateLevelsVector(capacity, memoryMode);
+ definitionLevels = allocateLevelsVector(capacity, memoryMode);
+ } else {
+ Preconditions.checkArgument(type.children().size() == vector.getNumChildren());
+ for (int i = 0; i < type.children().size(); i++) {
+ ParquetColumn childColumn = new ParquetColumn(type.children().apply(i),
+ vector.getChild(i), capacity, memoryMode, missingColumns);
+ children.add(childColumn);
+
+ // only use levels from non-missing child
+ if (!childColumn.vector.isAllNull()) {
+ this.repetitionLevels = childColumn.repetitionLevels;
+ this.definitionLevels = childColumn.definitionLevels;
+ }
+ }
+
+ // this can happen if all the fields of a struct are missing, in which case we should mark
+ // the struct itself as a missing column
+ if (repetitionLevels == null) {
+ vector.setAllNull();
+ }
+ }
+ }
+
+ /**
+ * Returns all the children of this column.
+ */
+ List getChildren() {
+ return children;
+ }
+
+ /**
+ * Returns all the leaf columns in depth-first order.
+ */
+ List getLeaves() {
+ List result = new ArrayList<>();
+ getLeavesHelper(this, result);
+ return result;
+ }
+
+ /**
+ * Assembles this column and calculate collection offsets recursively.
+ * This is a no-op for primitive columns.
+ */
+ void assemble() {
+ // nothing to do if the column itself is missing
+ if (vector.isAllNull()) return;
+
+ DataType type = this.type.sparkType();
+ if (type instanceof ArrayType || type instanceof MapType) {
+ for (ParquetColumn child : children) {
+ child.assemble();
+ }
+ calculateCollectionOffsets();
+ } else if (type instanceof StructType) {
+ for (ParquetColumn child : children) {
+ child.assemble();
+ }
+ calculateStructOffsets();
+ }
+ }
+
+ void reset() {
+ // nothing to do if the column itself is missing
+ if (vector.isAllNull()) return;
+
+ vector.reset();
+ repetitionLevels.reset();
+ definitionLevels.reset();
+ for (ParquetColumn childColumn : children) {
+ childColumn.reset();
+ }
+ }
+
+ ParquetType getType() {
+ return this.type;
+ }
+
+ WritableColumnVector getValueVector() {
+ return this.vector;
+ }
+
+ WritableColumnVector getRepetitionLevelVector() {
+ return this.repetitionLevels;
+ }
+
+ WritableColumnVector getDefinitionLevelVector() {
+ return this.definitionLevels;
+ }
+
+ VectorizedColumnReader getColumnReader() {
+ return this.columnReader;
+ }
+
+ void setColumnReader(VectorizedColumnReader reader) {
+ if (!isPrimitive) {
+ throw new IllegalStateException("can't set reader for non-primitive column");
+ }
+ this.columnReader = reader;
+ }
+
+ private static void getLeavesHelper(ParquetColumn column, List coll) {
+ if (column.isPrimitive) {
+ coll.add(column);
+ } else {
+ for (ParquetColumn child : column.children) {
+ getLeavesHelper(child, coll);
+ }
+ }
+ }
+
+ private void calculateCollectionOffsets() {
+ int maxDefinitionLevel = type.definitionLevel();
+ int maxElementRepetitionLevel = type.repetitionLevel();
+
+ // There are 4 cases when calculating definition levels:
+ // 1. definitionLevel == maxDefinitionLevel
+ // ==> value is defined and not null
+ // 2. definitionLevel == maxDefinitionLevel - 1
+ // ==> value is null
+ // 3. definitionLevel < maxDefinitionLevel - 1
+ // ==> value doesn't exist since one of its optional parent is null
+ // 4. definitionLevel > maxDefinitionLevel
+ // ==> value is a nested element within an array or map
+ //
+ // `i` is the index over all leaf elements of this array, while `offset` is the index over
+ // all top-level elements of this array.
+ int rowId = 0;
+ for (int i = 0, offset = 0; i < definitionLevels.getElementsAppended();
+ i = getNextCollectionStart(maxElementRepetitionLevel, i)) {
+ vector.reserve(rowId + 1);
+ int definitionLevel = definitionLevels.getInt(i);
+ if (definitionLevel == maxDefinitionLevel - 1) {
+ // the collection is null
+ vector.putNull(rowId++);
+ } else if (definitionLevel == maxDefinitionLevel) {
+ // collection is defined but empty
+ vector.putNotNull(rowId);
+ vector.putArray(rowId, offset, 0);
+ rowId++;
+ } else if (definitionLevel > maxDefinitionLevel) {
+ // collection is defined and non-empty: find out how many top element there is till the
+ // start of the next array.
+ vector.putNotNull(rowId);
+ int length = getCollectionSize(maxElementRepetitionLevel, i + 1);
+ vector.putArray(rowId, offset, length);
+ offset += length;
+ rowId++;
+ }
+ }
+ vector.addElementsAppended(rowId);
+ }
+
+ private void calculateStructOffsets() {
+ int maxRepetitionLevel = type.repetitionLevel();
+ int maxDefinitionLevel = type.definitionLevel();
+
+ vector.reserve(definitionLevels.getElementsAppended());
+ int rowId = 0;
+ int nonnullRowId = 0;
+ boolean hasRepetitionLevels = repetitionLevels.getElementsAppended() > 0;
+ for (int i = 0; i < definitionLevels.getElementsAppended(); i++) {
+ // if repetition level > maxRepetitionLevel, the value is a nested element (e.g., an array
+ // element in struct>), and we should skip the definition level since it doesn't
+ // represent with the struct.
+ if (!hasRepetitionLevels || repetitionLevels.getInt(i) <= maxRepetitionLevel) {
+ if (definitionLevels.getInt(i) == maxDefinitionLevel - 1) {
+ // the struct is null
+ vector.putNull(rowId);
+ rowId++;
+ } else if (definitionLevels.getInt(i) >= maxDefinitionLevel) {
+ vector.putNotNull(rowId);
+ vector.putStruct(rowId, nonnullRowId);
+ rowId++;
+ nonnullRowId++;
+ }
+ }
+ }
+ vector.addElementsAppended(rowId);
+ }
+
+ private static WritableColumnVector allocateLevelsVector(int capacity, MemoryMode memoryMode) {
+ switch (memoryMode) {
+ case ON_HEAP:
+ return new OnHeapColumnVector(capacity, DataTypes.IntegerType);
+ case OFF_HEAP:
+ return new OffHeapColumnVector(capacity, DataTypes.IntegerType);
+ default:
+ throw new IllegalArgumentException("Unknown memory mode: " + memoryMode);
+ }
+ }
+
+ private int getNextCollectionStart(int maxRepetitionLevel, int elementIndex) {
+ int idx = elementIndex + 1;
+ for (; idx < repetitionLevels.getElementsAppended(); idx++) {
+ if (repetitionLevels.getInt(idx) <= maxRepetitionLevel) {
+ break;
+ }
+ }
+ return idx;
+ }
+
+ private int getCollectionSize(int maxRepetitionLevel, int idx) {
+ int size = 1;
+ for (; idx < repetitionLevels.getElementsAppended(); idx++) {
+ if (repetitionLevels.getInt(idx) <= maxRepetitionLevel) {
+ break;
+ } else if (repetitionLevels.getInt(idx) <= maxRepetitionLevel + 1) {
+ // only count elements which belong to the current collection
+ // For instance, suppose we have the following Parquet schema:
+ //
+ // message schema { max rl max dl
+ // optional group col (LIST) { 0 1
+ // repeated group list { 1 2
+ // optional group element (LIST) { 1 3
+ // repeated group list { 2 4
+ // required int32 element; 2 4
+ // }
+ // }
+ // }
+ // }
+ // }
+ //
+ // For a list such as: [[[0, 1], [2, 3]], [[4, 5], [6, 7]]], the repetition & definition
+ // levels would be:
+ //
+ // repetition levels: [0, 2, 1, 2, 0, 2, 1, 2]
+ // definition levels: [2, 2, 2, 2, 2, 2, 2, 2]
+ //
+ // when calculating collection size for the outer array, we should only count repetition
+ // levels whose value is <= 1 (which is the max repetition level for the inner array)
+ size++;
+ }
+ }
+ return size;
+ }
+}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
index b26088753465..e4f75e5cfa05 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/ParquetReadState.java
@@ -17,6 +17,8 @@
package org.apache.spark.sql.execution.datasources.parquet;
+import org.apache.parquet.column.ColumnDescriptor;
+
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
@@ -45,21 +47,48 @@ final class ParquetReadState {
/** Maximum definition level for the Parquet column */
final int maxDefinitionLevel;
+ /** Maximum repetition level for the Parquet column */
+ final int maxRepetitionLevel;
+
+ /** Whether this column is required */
+ final boolean isRequired;
+
/** The current index over all rows within the column chunk. This is used to check if the
* current row should be skipped by comparing against the row ranges. */
long rowId;
- /** The offset in the current batch to put the next value */
- int offset;
+ /** The offset to put new values into definition & repetition level vector */
+ int levelOffset;
+
+ /** The offset to put new values into value vector */
+ int valueOffset;
/** The remaining number of values to read in the current page */
int valuesToReadInPage;
- /** The remaining number of values to read in the current batch */
- int valuesToReadInBatch;
+ /** The remaining number of rows to read in the current batch */
+ int rowsToReadInBatch;
+
+ // The following are only used when reading repeated values
+
+ /** When processing repeated values, whether we've found the beginning of the first list after the
+ * current batch. */
+ boolean lastListCompleted;
- ParquetReadState(int maxDefinitionLevel, PrimitiveIterator.OfLong rowIndexes) {
- this.maxDefinitionLevel = maxDefinitionLevel;
+ /** When processing repeated types, the number of accumulated definition levels to process */
+ int numBatchedDefLevels;
+
+ /** When processing repeated types, whether we should skip the current batch of definition
+ * levels. */
+ boolean shouldSkip;
+
+ ParquetReadState(
+ ColumnDescriptor descriptor,
+ boolean isRequired,
+ PrimitiveIterator.OfLong rowIndexes) {
+ this.maxDefinitionLevel = descriptor.getMaxDefinitionLevel();
+ this.maxRepetitionLevel = descriptor.getMaxRepetitionLevel();
+ this.isRequired = isRequired;
this.rowRanges = constructRanges(rowIndexes);
nextRange();
}
@@ -101,8 +130,12 @@ private Iterator constructRanges(PrimitiveIterator.OfLong rowIndexes)
* Must be called at the beginning of reading a new batch.
*/
void resetForNewBatch(int batchSize) {
- this.offset = 0;
- this.valuesToReadInBatch = batchSize;
+ this.levelOffset = 0;
+ this.valueOffset = 0;
+ this.rowsToReadInBatch = batchSize;
+ this.lastListCompleted = this.maxRepetitionLevel == 0;
+ this.numBatchedDefLevels = 0;
+ this.shouldSkip = false;
}
/**
@@ -127,16 +160,6 @@ long currentRangeEnd() {
return currentRange.end;
}
- /**
- * Advance the current offset and rowId to the new values.
- */
- void advanceOffsetAndRowId(int newOffset, long newRowId) {
- valuesToReadInBatch -= (newOffset - offset);
- valuesToReadInPage -= (newRowId - rowId);
- offset = newOffset;
- rowId = newRowId;
- }
-
/**
* Advance to the next range.
*/
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
index 6264d6341c65..59ead1b0f096 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/SpecificParquetRecordReaderBase.java
@@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.datasources.parquet;
+import java.io.Closeable;
import java.io.File;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
@@ -29,6 +30,7 @@
import java.util.Map;
import java.util.Set;
+import org.apache.parquet.column.page.PageReadStore;
import scala.Option;
import org.apache.hadoop.conf.Configuration;
@@ -64,9 +66,9 @@
* this way, albeit at a higher cost to implement. This base class is reusable.
*/
public abstract class SpecificParquetRecordReaderBase extends RecordReader {
- protected Path file;
protected MessageType fileSchema;
- protected MessageType requestedSchema;
+ protected MessageType requestedParquetSchema;
+ protected ParquetType requestedSchema;
protected StructType sparkSchema;
/**
@@ -75,31 +77,42 @@ public abstract class SpecificParquetRecordReaderBase extends RecordReader fileMetadata = reader.getFileMetaData().getKeyValueMetaData();
+ ParquetFileReader fileReader = new ParquetFileReader(
+ HadoopInputFile.fromPath(file, configuration), options);
+ this.reader = new ParquetRowGroupReaderImpl(fileReader);
+ this.fileSchema = fileReader.getFileMetaData().getSchema();
+ Map fileMetadata = fileReader.getFileMetaData().getKeyValueMetaData();
ReadSupport readSupport = getReadSupportInstance(getReadSupportClass(configuration));
ReadSupport.ReadContext readContext = readSupport.init(new InitContext(
taskAttemptContext.getConfiguration(), toSetMultiMap(fileMetadata), fileSchema));
- this.requestedSchema = readContext.getRequestedSchema();
- reader.setRequestedSchema(requestedSchema);
+ this.requestedParquetSchema = readContext.getRequestedSchema();
+ fileReader.setRequestedSchema(requestedParquetSchema);
String sparkRequestedSchemaString =
configuration.get(ParquetReadSupport$.MODULE$.SPARK_ROW_REQUESTED_SCHEMA());
- this.sparkSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString);
- this.totalRowCount = reader.getFilteredRecordCount();
+ StructType sparkRequestedSchema = StructType$.MODULE$.fromString(sparkRequestedSchemaString);
+ boolean caseSensitive = configuration.getBoolean(SQLConf.CASE_SENSITIVE().key(),
+ Boolean.getBoolean(SQLConf.CASE_SENSITIVE().defaultValueString()));
+ ParquetToSparkSchemaConverter converter = new ParquetToSparkSchemaConverter(configuration);
+ this.requestedSchema = converter.convertTypeInfo(requestedParquetSchema,
+ Option.apply(sparkRequestedSchema), caseSensitive);
+ this.sparkSchema = (StructType) requestedSchema.sparkType();
+ this.totalRowCount = fileReader.getFilteredRecordCount();
// For test purpose.
// If the last external accumulator is `NumRowGroupsAccumulator`, the row group number to read
@@ -111,7 +124,7 @@ public void initialize(InputSplit inputSplit, TaskAttemptContext taskAttemptCont
if (accu.isDefined() && accu.get().getClass().getSimpleName().equals("NumRowGroupsAcc")) {
@SuppressWarnings("unchecked")
AccumulatorV2 intAccum = (AccumulatorV2) accu.get();
- intAccum.add(reader.getRowGroups().size());
+ intAccum.add(fileReader.getRowGroups().size());
}
}
}
@@ -148,18 +161,19 @@ protected void initialize(String path, List columns) throws IOException
config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false);
config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
- this.file = new Path(path);
- long length = this.file.getFileSystem(config).getFileStatus(this.file).getLen();
+ Path file = new Path(path);
+ long length = file.getFileSystem(config).getFileStatus(file).getLen();
ParquetReadOptions options = HadoopReadOptions
.builder(config)
.withRange(0, length)
.build();
- this.reader = ParquetFileReader.open(HadoopInputFile.fromPath(file, config), options);
- this.fileSchema = reader.getFooter().getFileMetaData().getSchema();
+ ParquetFileReader fileReader = ParquetFileReader.open(
+ HadoopInputFile.fromPath(file, config), options);
+ this.fileSchema = fileReader.getFooter().getFileMetaData().getSchema();
if (columns == null) {
- this.requestedSchema = fileSchema;
+ this.requestedParquetSchema = fileSchema;
} else {
if (columns.size() > 0) {
Types.MessageTypeBuilder builder = Types.buildMessage();
@@ -170,14 +184,35 @@ protected void initialize(String path, List columns) throws IOException
}
builder.addFields(fileSchema.getType(s));
}
- this.requestedSchema = builder.named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME());
+ this.requestedParquetSchema =
+ builder.named(ParquetSchemaConverter.SPARK_PARQUET_SCHEMA_NAME());
} else {
- this.requestedSchema = ParquetSchemaConverter.EMPTY_MESSAGE();
+ this.requestedParquetSchema = ParquetSchemaConverter.EMPTY_MESSAGE();
}
}
- reader.setRequestedSchema(requestedSchema);
- this.sparkSchema = new ParquetToSparkSchemaConverter(config).convert(requestedSchema);
- this.totalRowCount = reader.getFilteredRecordCount();
+ fileReader.setRequestedSchema(requestedParquetSchema);
+ this.requestedSchema = new ParquetToSparkSchemaConverter(config)
+ .convertTypeInfo(requestedParquetSchema, Option.empty(), true);
+ this.sparkSchema = (StructType) requestedSchema.sparkType();
+ this.totalRowCount = fileReader.getFilteredRecordCount();
+ this.reader = new ParquetRowGroupReaderImpl(fileReader);
+ }
+
+ protected void initialize(
+ MessageType fileSchema,
+ MessageType requestedSchema,
+ ParquetRowGroupReader rowGroupReader,
+ int totalRowCount) throws IOException {
+ this.reader = rowGroupReader;
+ this.fileSchema = fileSchema;
+ this.requestedParquetSchema = requestedSchema;
+ Configuration config = new Configuration();
+ config.setBoolean(SQLConf.PARQUET_BINARY_AS_STRING().key() , false);
+ config.setBoolean(SQLConf.PARQUET_INT96_AS_TIMESTAMP().key(), false);
+ this.requestedSchema = new ParquetToSparkSchemaConverter(config)
+ .convertTypeInfo(requestedSchema, Option.empty(), true);
+ this.sparkSchema = (StructType) this.requestedSchema.sparkType();
+ this.totalRowCount = totalRowCount;
}
@Override
@@ -222,4 +257,31 @@ private static ReadSupport getReadSupportInstance(
throw new BadConfigurationException("could not instantiate read support class", e);
}
}
+
+ interface ParquetRowGroupReader extends Closeable {
+ /**
+ * Read the next row group from this reader. Returns null if there is no more row group.
+ */
+ PageReadStore readNextRowGroup() throws IOException;
+ }
+
+ private static class ParquetRowGroupReaderImpl implements ParquetRowGroupReader {
+ private final ParquetFileReader reader;
+
+ ParquetRowGroupReaderImpl(ParquetFileReader reader) {
+ this.reader = reader;
+ }
+
+ @Override
+ public PageReadStore readNextRowGroup() throws IOException {
+ return reader.readNextFilteredRowGroup();
+ }
+
+ @Override
+ public void close() throws IOException {
+ if (reader != null) {
+ reader.close();
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
index 92dea08102df..f4fc8e616c66 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedColumnReader.java
@@ -19,8 +19,8 @@
import java.io.IOException;
import java.time.ZoneId;
-import java.util.PrimitiveIterator;
+import com.google.common.base.Preconditions;
import org.apache.parquet.bytes.ByteBufferInputStream;
import org.apache.parquet.bytes.BytesInput;
import org.apache.parquet.bytes.BytesUtils;
@@ -38,7 +38,6 @@
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.types.Decimal;
-import static org.apache.parquet.column.ValuesType.REPETITION_LEVEL;
import static org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.INT64;
/**
@@ -65,6 +64,11 @@ public class VectorizedColumnReader {
*/
private VectorizedRleValuesReader defColumn;
+ /**
+ * Vectorized RLE decoder for repetition levels
+ */
+ private VectorizedRleValuesReader repColumn;
+
/**
* Factory to get type-specific vector updater.
*/
@@ -88,16 +92,16 @@ public class VectorizedColumnReader {
public VectorizedColumnReader(
ColumnDescriptor descriptor,
- LogicalTypeAnnotation logicalTypeAnnotation,
- PageReader pageReader,
- PrimitiveIterator.OfLong rowIndexes,
+ boolean isRequiredColumn,
+ PageReadStore pageReadStore,
ZoneId convertTz,
String datetimeRebaseMode,
String int96RebaseMode) throws IOException {
this.descriptor = descriptor;
- this.pageReader = pageReader;
- this.readState = new ParquetReadState(descriptor.getMaxDefinitionLevel(), rowIndexes);
- this.logicalTypeAnnotation = logicalTypeAnnotation;
+ this.pageReader = pageReadStore.getPageReader(descriptor);
+ this.readState = new ParquetReadState(descriptor, isRequiredColumn,
+ pageReadStore.getRowIndexes().orElse(null));
+ this.logicalTypeAnnotation = descriptor.getPrimitiveType().getLogicalTypeAnnotation();
this.updaterFactory = new ParquetVectorUpdaterFactory(
logicalTypeAnnotation, convertTz, datetimeRebaseMode, int96RebaseMode);
@@ -149,37 +153,54 @@ private boolean isLazyDecodingSupported(PrimitiveType.PrimitiveTypeName typeName
/**
* Reads `total` values from this columnReader into column.
*/
- void readBatch(int total, WritableColumnVector column) throws IOException {
+ void readBatch(
+ int total,
+ WritableColumnVector values,
+ WritableColumnVector repetitionLevels,
+ WritableColumnVector definitionLevels) throws IOException {
WritableColumnVector dictionaryIds = null;
- ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, column.dataType());
+ ParquetVectorUpdater updater = updaterFactory.getUpdater(descriptor, values.dataType());
if (dictionary != null) {
// SPARK-16334: We only maintain a single dictionary per row batch, so that it can be used to
// decode all previous dictionary encoded pages if we ever encounter a non-dictionary encoded
// page.
- dictionaryIds = column.reserveDictionaryIds(total);
+ dictionaryIds = values.reserveDictionaryIds(total);
}
readState.resetForNewBatch(total);
- while (readState.valuesToReadInBatch > 0) {
+ while (readState.rowsToReadInBatch > 0 || !readState.lastListCompleted) {
if (readState.valuesToReadInPage == 0) {
int pageValueCount = readPage();
+ if (pageValueCount < 0) {
+ // we've read all the pages
+ break;
+ }
readState.resetForNewPage(pageValueCount, pageFirstRowIndex);
}
+
PrimitiveType.PrimitiveTypeName typeName =
descriptor.getPrimitiveType().getPrimitiveTypeName();
if (isCurrentPageDictionaryEncoded) {
// Save starting offset in case we need to decode dictionary IDs.
- int startOffset = readState.offset;
+ int startOffset = readState.valueOffset;
// Save starting row index so we can check if we need to eagerly decode dict ids later
long startRowId = readState.rowId;
+ Preconditions.checkNotNull(dictionaryIds, "dictionaryIds == null when " +
+ "isCurrentPageDictionaryEncoded is true");
+
// Read and decode dictionary ids.
- defColumn.readIntegers(readState, dictionaryIds, column,
- (VectorizedValuesReader) dataColumn);
+ if (readState.maxRepetitionLevel == 0) {
+ defColumn.readIntegers(readState, dictionaryIds, values, definitionLevels,
+ (VectorizedValuesReader) dataColumn);
+ } else {
+ repColumn.readIntegersNested(readState, repetitionLevels, defColumn, definitionLevels,
+ dictionaryIds, values, (VectorizedValuesReader) dataColumn);
+ }
// TIMESTAMP_MILLIS encoded as INT64 can't be lazily decoded as we need to post process
// the values to add microseconds precision.
- if (column.hasDictionary() || (startRowId == pageFirstRowIndex &&
+ if (values.hasDictionary() || (startRowId == pageFirstRowIndex &&
isLazyDecodingSupported(typeName))) {
// Column vector supports lazy decoding of dictionary values so just set the dictionary.
// We can't do this if startRowId is not the first row index in the page AND the column
@@ -204,26 +225,34 @@ void readBatch(int total, WritableColumnVector column) throws IOException {
boolean isUnsignedInt64 = updaterFactory.isUnsignedIntTypeMatched(64);
boolean needTransform = castLongToInt || isUnsignedInt32 || isUnsignedInt64;
- column.setDictionary(new ParquetDictionary(dictionary, needTransform));
+ values.setDictionary(new ParquetDictionary(dictionary, needTransform));
} else {
- updater.decodeDictionaryIds(readState.offset - startOffset, startOffset, column,
+ updater.decodeDictionaryIds(readState.valueOffset - startOffset, startOffset, values,
dictionaryIds, dictionary);
}
} else {
- if (column.hasDictionary() && readState.offset != 0) {
+ if (values.hasDictionary() && readState.valueOffset != 0) {
// This batch already has dictionary encoded values but this new page is not. The batch
// does not support a mix of dictionary and not so we will decode the dictionary.
- updater.decodeDictionaryIds(readState.offset, 0, column, dictionaryIds, dictionary);
+ updater.decodeDictionaryIds(readState.valueOffset, 0, values, dictionaryIds, dictionary);
}
- column.setDictionary(null);
+ values.setDictionary(null);
VectorizedValuesReader valuesReader = (VectorizedValuesReader) dataColumn;
- defColumn.readBatch(readState, column, valuesReader, updater);
+ if (readState.maxRepetitionLevel == 0) {
+ defColumn.readBatch(readState, values, definitionLevels, valuesReader, updater);
+ } else {
+ repColumn.readBatchNested(readState, repetitionLevels, defColumn, definitionLevels,
+ values, valuesReader, updater);
+ }
}
}
}
private int readPage() {
DataPage page = pageReader.readPage();
+ if (page == null) {
+ return -1;
+ }
this.pageFirstRowIndex = page.getFirstRowIndex().orElse(0L);
return page.accept(new DataPage.Visitor() {
@@ -286,18 +315,15 @@ private int readPageV1(DataPageV1 page) throws IOException {
}
int pageValueCount = page.getValueCount();
- int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
-
- this.defColumn = new VectorizedRleValuesReader(bitWidth);
+ int dlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+ this.defColumn = new VectorizedRleValuesReader(dlBitWidth);
+ int rlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxRepetitionLevel());
+ this.repColumn = new VectorizedRleValuesReader(rlBitWidth);
try {
BytesInput bytes = page.getBytes();
ByteBufferInputStream in = bytes.toInputStream();
- // only used now to consume the repetition level data
- page.getRlEncoding()
- .getValuesReader(descriptor, REPETITION_LEVEL)
- .initFromPage(pageValueCount, in);
-
+ repColumn.initFromPage(pageValueCount, in);
defColumn.initFromPage(pageValueCount, in);
initDataReader(pageValueCount, page.getValueEncoding(), in);
return pageValueCount;
@@ -308,11 +334,16 @@ private int readPageV1(DataPageV1 page) throws IOException {
private int readPageV2(DataPageV2 page) throws IOException {
int pageValueCount = page.getValueCount();
- int bitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
// do not read the length from the stream. v2 pages handle dividing the page bytes.
- defColumn = new VectorizedRleValuesReader(bitWidth, false);
+ int dlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxDefinitionLevel());
+ defColumn = new VectorizedRleValuesReader(dlBitWidth, false);
defColumn.initFromPage(pageValueCount, page.getDefinitionLevels().toInputStream());
+
+ int rlBitWidth = BytesUtils.getWidthFromMaxInt(descriptor.getMaxRepetitionLevel());
+ repColumn = new VectorizedRleValuesReader(rlBitWidth, false);
+ repColumn.initFromPage(pageValueCount, page.getRepetitionLevels().toInputStream());
+
try {
initDataReader(pageValueCount, page.getDataEncoding(), page.getData().toInputStream());
return pageValueCount;
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
index 9f7836ae4818..ca05498eb428 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedParquetRecordReader.java
@@ -20,29 +20,35 @@
import java.io.IOException;
import java.time.ZoneId;
import java.util.Arrays;
+import java.util.HashSet;
import java.util.List;
+import java.util.Set;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.mapreduce.InputSplit;
import org.apache.hadoop.mapreduce.TaskAttemptContext;
import org.apache.parquet.column.ColumnDescriptor;
import org.apache.parquet.column.page.PageReadStore;
-import org.apache.parquet.schema.Type;
+import org.apache.parquet.schema.GroupType;
+import org.apache.parquet.schema.MessageType;
+import org.apache.parquet.schema.Type;
import org.apache.spark.memory.MemoryMode;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils;
import org.apache.spark.sql.execution.vectorized.WritableColumnVector;
import org.apache.spark.sql.execution.vectorized.OffHeapColumnVector;
import org.apache.spark.sql.execution.vectorized.OnHeapColumnVector;
-import org.apache.spark.sql.vectorized.ColumnarBatch;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
+import org.apache.spark.sql.vectorized.ColumnarBatch;
+import scala.collection.JavaConverters;
/**
* A specialized RecordReader that reads into InternalRows or ColumnarBatches directly using the
* Parquet column APIs. This is somewhat based on parquet-mr's ColumnReader.
*
- * TODO: handle complex types, decimal requiring more than 8 bytes, INT96. Schema mismatch.
+ * TODO: decimal requiring more than 8 bytes, INT96. Schema mismatch.
* All of these can be handled efficiently and easily with codegen.
*
* This class can either return InternalRows or ColumnarBatches. With whole stage codegen
@@ -62,10 +68,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
private int numBatched = 0;
/**
- * For each request column, the reader to read this column. This is NULL if this column
- * is missing from the file, in which case we populate the attribute with NULL.
+ * Encapsulate writable column vectors with other Parquet related info such as
+ * repetition / definition levels.
*/
- private VectorizedColumnReader[] columnReaders;
+ private ParquetColumn[] columns;
/**
* The number of rows that have been returned.
@@ -78,9 +84,10 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
private long totalCountLoadedSoFar = 0;
/**
- * For each column, true if the column is missing in the file and we'll instead return NULLs.
+ * For each leaf column, if it is in the set, it means the column is missing in the file and
+ * we'll instead return NULLs.
*/
- private boolean[] missingColumns;
+ private Set missingColumns;
/**
* The timezone that timestamp INT96 values should be converted to. Null if no conversion. Here to
@@ -114,8 +121,6 @@ public class VectorizedParquetRecordReader extends SpecificParquetRecordReaderBa
*/
private ColumnarBatch columnarBatch;
- private WritableColumnVector[] columnVectors;
-
/**
* If true, this class returns batches instead of rows.
*/
@@ -165,6 +170,16 @@ public void initialize(String path, List columns) throws IOException,
initializeInternal();
}
+ @Override
+ public void initialize(
+ MessageType fileSchema,
+ MessageType requestedSchema,
+ ParquetRowGroupReader rowGroupReader,
+ int totalRowCount) throws IOException {
+ super.initialize(fileSchema, requestedSchema, rowGroupReader, totalRowCount);
+ initializeInternal();
+ }
+
@Override
public void close() throws IOException {
if (columnarBatch != null) {
@@ -218,12 +233,20 @@ private void initBatch(
}
}
+ WritableColumnVector[] columnVectors;
if (memMode == MemoryMode.OFF_HEAP) {
columnVectors = OffHeapColumnVector.allocateColumns(capacity, batchSchema);
} else {
columnVectors = OnHeapColumnVector.allocateColumns(capacity, batchSchema);
}
columnarBatch = new ColumnarBatch(columnVectors);
+
+ columns = new ParquetColumn[sparkSchema.fields().length];
+ for (int i = 0; i < columns.length; i++) {
+ columns[i] = new ParquetColumn(requestedSchema.children().apply(i),
+ columnVectors[i], capacity, memMode, missingColumns);
+ }
+
if (partitionColumns != null) {
int partitionIdx = sparkSchema.fields().length;
for (int i = 0; i < partitionColumns.fields().length; i++) {
@@ -231,14 +254,6 @@ private void initBatch(
columnVectors[i + partitionIdx].setIsConstant();
}
}
-
- // Initialize missing columns with nulls.
- for (int i = 0; i < missingColumns.length; i++) {
- if (missingColumns[i]) {
- columnVectors[i].putNulls(0, capacity);
- columnVectors[i].setIsConstant();
- }
- }
}
private void initBatch() {
@@ -270,18 +285,26 @@ public void enableReturningBatches() {
* Advances to the next batch of rows. Returns false if there are no more.
*/
public boolean nextBatch() throws IOException {
- for (WritableColumnVector vector : columnVectors) {
- vector.reset();
+ for (ParquetColumn column : columns) {
+ column.reset();
}
+
columnarBatch.setNumRows(0);
if (rowsReturned >= totalRowCount) return false;
checkEndOfRowGroup();
- int num = (int) Math.min((long) capacity, totalCountLoadedSoFar - rowsReturned);
- for (int i = 0; i < columnReaders.length; ++i) {
- if (columnReaders[i] == null) continue;
- columnReaders[i].readBatch(num, columnVectors[i]);
+ int num = (int) Math.min(capacity, totalCountLoadedSoFar - rowsReturned);
+ for (ParquetColumn col : columns) {
+ for (ParquetColumn leafCol : col.getLeaves()) {
+ VectorizedColumnReader columnReader = leafCol.getColumnReader();
+ if (columnReader != null) {
+ columnReader.readBatch(num, leafCol.getValueVector(),
+ leafCol.getRepetitionLevelVector(), leafCol.getDefinitionLevelVector());
+ }
+ }
+ col.assemble();
}
+
rowsReturned += num;
columnarBatch.setNumRows(num);
numBatched = num;
@@ -290,55 +313,89 @@ public boolean nextBatch() throws IOException {
}
private void initializeInternal() throws IOException, UnsupportedOperationException {
- // Check that the requested schema is supported.
- missingColumns = new boolean[requestedSchema.getFieldCount()];
- List columns = requestedSchema.getColumns();
- List paths = requestedSchema.getPaths();
- for (int i = 0; i < requestedSchema.getFieldCount(); ++i) {
- Type t = requestedSchema.getFields().get(i);
- if (!t.isPrimitive() || t.isRepetition(Type.Repetition.REPEATED)) {
- throw new UnsupportedOperationException("Complex types not supported.");
- }
+ missingColumns = new HashSet<>();
+ for (ParquetType columnType : JavaConverters.seqAsJavaList(requestedSchema.children())) {
+ checkColumn(columnType);
+ }
+ }
- String[] colPath = paths.get(i);
- if (fileSchema.containsPath(colPath)) {
- ColumnDescriptor fd = fileSchema.getColumnDescription(colPath);
- if (!fd.equals(columns.get(i))) {
+ /**
+ * Check whether a column from requested schema is missing from the file schema, or whether it
+ * conforms to the type of the file schema.
+ */
+ private void checkColumn(ParquetType columnType) throws IOException {
+ String[] path = JavaConverters.seqAsJavaList(columnType.path()).toArray(new String[0]);
+ if (containsPath(fileSchema, path)) {
+ if (columnType.isPrimitive()) {
+ ColumnDescriptor desc = columnType.descriptor().get();
+ ColumnDescriptor fd = fileSchema.getColumnDescription(desc.getPath());
+ if (!fd.equals(desc)) {
throw new UnsupportedOperationException("Schema evolution not supported.");
}
- missingColumns[i] = false;
} else {
- if (columns.get(i).getMaxDefinitionLevel() == 0) {
- // Column is missing in data but the required data is non-nullable. This file is invalid.
- throw new IOException("Required column is missing in data file. Col: " +
- Arrays.toString(colPath));
+ for (ParquetType childType : JavaConverters.seqAsJavaList(columnType.children())) {
+ checkColumn(childType);
}
- missingColumns[i] = true;
}
+ } else { // a missing column which is either primitive or complex
+ if (columnType.required()) {
+ // Column is missing in data but the required data is non-nullable. This file is invalid.
+ throw new IOException("Required column is missing in data file. Col: " +
+ Arrays.toString(path));
+ }
+ missingColumns.add(columnType);
}
}
+ /**
+ * Checks whether the given 'path' exists in 'parquetType'. The difference between this and
+ * {@link MessageType#containsPath(String[])} is that the latter only support paths to leaf
+ * nodes, while this support paths both to leaf and non-leaf nodes.
+ */
+ private boolean containsPath(Type parquetType, String[] path) {
+ return containsPath(parquetType, path, 0);
+ }
+
+ private boolean containsPath(Type parquetType, String[] path, int depth) {
+ if (path.length == depth) return true;
+ if (parquetType instanceof GroupType) {
+ String fieldName = path[depth];
+ GroupType parquetGroupType = (GroupType) parquetType;
+ if (parquetGroupType.containsField(fieldName)) {
+ return containsPath(parquetGroupType.getType(fieldName), path, depth + 1);
+ }
+ }
+ return false;
+ }
+
private void checkEndOfRowGroup() throws IOException {
if (rowsReturned != totalCountLoadedSoFar) return;
- PageReadStore pages = reader.readNextFilteredRowGroup();
+ PageReadStore pages = reader.readNextRowGroup();
if (pages == null) {
throw new IOException("expecting more rows but reached last block. Read "
+ rowsReturned + " out of " + totalRowCount);
}
- List columns = requestedSchema.getColumns();
- List types = requestedSchema.asGroupType().getFields();
- columnReaders = new VectorizedColumnReader[columns.size()];
- for (int i = 0; i < columns.size(); ++i) {
- if (missingColumns[i]) continue;
- columnReaders[i] = new VectorizedColumnReader(
- columns.get(i),
- types.get(i).getLogicalTypeAnnotation(),
- pages.getPageReader(columns.get(i)),
- pages.getRowIndexes().orElse(null),
- convertTz,
- datetimeRebaseMode,
- int96RebaseMode);
+ for (ParquetColumn column : columns) {
+ initColumnReader(pages, column);
}
totalCountLoadedSoFar += pages.getRowCount();
}
+
+ private void initColumnReader(PageReadStore pages, ParquetColumn column) throws IOException {
+ if (!missingColumns.contains(column.getType())) {
+ if (column.getType().isPrimitive()) {
+ ParquetType colType = column.getType();
+ Preconditions.checkArgument(colType.isPrimitive());
+ VectorizedColumnReader reader = new VectorizedColumnReader(
+ colType.descriptor().get(), colType.required(), pages, convertTz, datetimeRebaseMode,
+ int96RebaseMode);
+ column.setColumnReader(reader);
+ } else {
+ // not in missing columns and is a complex type: this must be a struct
+ for (ParquetColumn childCol : column.getChildren()) {
+ initColumnReader(pages, childCol);
+ }
+ }
+ }
+ }
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
index af739a52d8ed..c24048f34700 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/parquet/VectorizedRleValuesReader.java
@@ -166,9 +166,10 @@ public int readInteger() {
public void readBatch(
ParquetReadState state,
WritableColumnVector values,
+ WritableColumnVector defLevels,
VectorizedValuesReader valueReader,
ParquetVectorUpdater updater) {
- readBatchInternal(state, values, values, valueReader, updater);
+ readBatchInternal(state, values, values, defLevels, valueReader, updater);
}
/**
@@ -179,21 +180,23 @@ public void readIntegers(
ParquetReadState state,
WritableColumnVector values,
WritableColumnVector nulls,
- VectorizedValuesReader data) {
- readBatchInternal(state, values, nulls, data, new ParquetVectorUpdaterFactory.IntegerUpdater());
+ WritableColumnVector defLevels,
+ VectorizedValuesReader valueReader) throws IOException {
+ readBatchInternal(state, values, nulls, defLevels, valueReader,
+ new ParquetVectorUpdaterFactory.IntegerUpdater());
}
private void readBatchInternal(
ParquetReadState state,
WritableColumnVector values,
WritableColumnVector nulls,
+ WritableColumnVector defLevels,
VectorizedValuesReader valueReader,
ParquetVectorUpdater updater) {
- int offset = state.offset;
- long rowId = state.rowId;
- int leftInBatch = state.valuesToReadInBatch;
+ int leftInBatch = state.rowsToReadInBatch;
int leftInPage = state.valuesToReadInPage;
+ long rowId = state.rowId;
while (leftInBatch > 0 && leftInPage > 0) {
if (this.currentCount == 0) this.readNextGroup();
@@ -227,44 +230,352 @@ private void readBatchInternal(
switch (mode) {
case RLE:
if (currentValue == state.maxDefinitionLevel) {
- updater.readValues(n, offset, values, valueReader);
- } else {
- nulls.putNulls(offset, n);
+ updater.readValues(n, state.valueOffset, values, valueReader);
+ state.valueOffset += n;
+ } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) {
+ // only add null if this represents a null element, but not the case when a
+ // struct is null
+ nulls.putNulls(state.valueOffset, n);
+ state.valueOffset += n;
}
+ defLevels.putInts(state.levelOffset, n, currentValue);
break;
case PACKED:
for (int i = 0; i < n; ++i) {
- if (currentBuffer[currentBufferIdx++] == state.maxDefinitionLevel) {
- updater.readValue(offset + i, values, valueReader);
- } else {
- nulls.putNull(offset + i);
+ int currentValue = currentBuffer[currentBufferIdx++];
+ if (currentValue == state.maxDefinitionLevel) {
+ updater.readValue(state.valueOffset++, values, valueReader);
+ } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) {
+ // only add null if this represents a null element, but not the case when a
+ // struct is null
+ nulls.putNull(state.valueOffset++);
}
+ defLevels.putInt(state.levelOffset + i, currentValue);
}
break;
}
- offset += n;
+ state.levelOffset += n;
leftInBatch -= n;
rowId += n;
leftInPage -= n;
currentCount -= n;
+ defLevels.addElementsAppended(n);
}
}
- state.advanceOffsetAndRowId(offset, rowId);
+ state.rowsToReadInBatch = leftInBatch;
+ state.valuesToReadInPage = leftInPage;
+ state.rowId = rowId;
+ }
+
+ public void readBatchNested(
+ ParquetReadState state,
+ WritableColumnVector repLevels,
+ VectorizedRleValuesReader defLevelsReader,
+ WritableColumnVector defLevels,
+ WritableColumnVector values,
+ VectorizedValuesReader valueReader,
+ ParquetVectorUpdater updater) {
+ readBatchNestedInternal(state, repLevels, defLevelsReader, defLevels, values, values, true,
+ valueReader, updater);
+ }
+
+ public void readIntegersNested(
+ ParquetReadState state,
+ WritableColumnVector repLevels,
+ VectorizedRleValuesReader defLevelsReader,
+ WritableColumnVector defLevels,
+ WritableColumnVector values,
+ WritableColumnVector nulls,
+ VectorizedValuesReader valueReader) {
+ readBatchNestedInternal(state, repLevels, defLevelsReader, defLevels, values, nulls, false,
+ valueReader, new ParquetVectorUpdaterFactory.IntegerUpdater());
}
/**
- * Skip the next `n` values (either null or non-null) from this definition level reader and
- * `valueReader`.
+ * Keep reading repetition level values from the page until either: 1) we've read enough
+ * top-level rows to fill the current batch, or 2) we've drained the data page completely.
+ *
+ * @param valuesReused whether `values` vector is reused for `nulls`
*/
- private void skipValues(
- int n,
+ public void readBatchNestedInternal(
ParquetReadState state,
- VectorizedValuesReader valuesReader,
+ WritableColumnVector repLevels,
+ VectorizedRleValuesReader defLevelsReader,
+ WritableColumnVector defLevels,
+ WritableColumnVector values,
+ WritableColumnVector nulls,
+ boolean valuesReused,
+ VectorizedValuesReader valueReader,
ParquetVectorUpdater updater) {
+
+ int leftInBatch = state.rowsToReadInBatch;
+ int leftInPage = state.valuesToReadInPage;
+ long rowId = state.rowId;
+
+ DefLevelProcessor defLevelProcessor = new DefLevelProcessor(defLevelsReader, state, defLevels,
+ values, nulls, valuesReused, valueReader, updater);
+
+ while ((leftInBatch > 0 || !state.lastListCompleted) && leftInPage > 0) {
+ if (currentCount == 0 && !readNextGroup()) break;
+
+ // values to read in the current RLE/PACKED block, must be <= what's left in the page
+ int valuesLeftInBlock = Math.min(leftInPage, currentCount);
+
+ // the current row range start and end
+ long rangeStart = state.currentRangeStart();
+ long rangeEnd = state.currentRangeEnd();
+
+ switch (mode) {
+ case RLE:
+ // this RLE block is consist of top-level rows, so we'll need to check
+ // if the rows should be skipped according to row indexes.
+ if (currentValue == 0) {
+ if (leftInBatch == 0) {
+ state.lastListCompleted = true;
+ } else {
+ // # of rows to read in the block, must be <= what's left in the current batch
+ int n = Math.min(leftInBatch, valuesLeftInBlock);
+
+ if (rowId + n < rangeStart) {
+ // need to skip all rows in [rowId, rowId + n)
+ defLevelProcessor.skipValues(n);
+ rowId += n;
+ currentCount -= n;
+ leftInPage -= n;
+ } else if (rowId > rangeEnd) {
+ // the current row index already beyond the current range: move to the next range
+ // and repeat
+ state.nextRange();
+ } else {
+ // the range [rowId, rowId + n) overlaps with the current row range
+ long start = Math.max(rangeStart, rowId);
+ long end = Math.min(rangeEnd, rowId + n - 1);
+
+ // skip the rows in [rowId, start)
+ int toSkip = (int) (start - rowId);
+ if (toSkip > 0) {
+ defLevelProcessor.skipValues(toSkip);
+ rowId += toSkip;
+ currentCount -= toSkip;
+ leftInPage -= toSkip;
+ }
+
+ // read the rows in [start, end]
+ n = (int) (end - start + 1);
+
+ leftInBatch -= n;
+ if (n > 0) {
+ repLevels.appendInts(n, 0);
+ defLevelProcessor.readValues(n);
+ }
+
+ rowId += n;
+ currentCount -= n;
+ leftInPage -= n;
+ }
+ }
+ } else {
+ // not a top-level row: just read all the repetition levels in the block if the row
+ // should be included according to row indexes, else skip the rows.
+ if (!state.shouldSkip) {
+ repLevels.appendInts(valuesLeftInBlock, currentValue);
+ }
+ state.numBatchedDefLevels += valuesLeftInBlock;
+ leftInPage -= valuesLeftInBlock;
+ currentCount -= valuesLeftInBlock;
+ }
+ break;
+ case PACKED:
+ int i = 0;
+
+ for (; i < valuesLeftInBlock; i++) {
+ int currentValue = currentBuffer[currentBufferIdx + i];
+ if (currentValue == 0) {
+ if (leftInBatch == 0) {
+ state.lastListCompleted = true;
+ break;
+ } else if (rowId < rangeStart) {
+ // this is a top-level row, therefore check if we should skip it with row indexes
+ // the row is before the current range, skip it
+ defLevelProcessor.skipValues(1);
+ } else if (rowId > rangeEnd) {
+ // the row is after the current range, move to the next range and compare again
+ state.nextRange();
+ break;
+ } else {
+ // the row is in the current range, decrement the row counter and read it
+ leftInBatch--;
+ repLevels.appendInt(0);
+ defLevelProcessor.readValues(1);
+ }
+ rowId++;
+ } else {
+ if (!state.shouldSkip) {
+ repLevels.appendInt(currentValue);
+ }
+ state.numBatchedDefLevels += 1;
+ }
+ }
+
+ leftInPage -= i;
+ currentCount -= i;
+ currentBufferIdx += i;
+ break;
+ }
+ }
+
+ // process all the batched def levels
+ defLevelProcessor.finish();
+
+ state.rowsToReadInBatch = leftInBatch;
+ state.valuesToReadInPage = leftInPage;
+ state.rowId = rowId;
+ }
+
+ private static class DefLevelProcessor {
+ private final VectorizedRleValuesReader reader;
+ private final ParquetReadState state;
+ private final WritableColumnVector defLevels;
+ private final WritableColumnVector values;
+ private final WritableColumnVector nulls;
+ private final boolean valuesReused;
+ private final VectorizedValuesReader valueReader;
+ private final ParquetVectorUpdater updater;
+
+ DefLevelProcessor(
+ VectorizedRleValuesReader reader,
+ ParquetReadState state,
+ WritableColumnVector defLevels,
+ WritableColumnVector values,
+ WritableColumnVector nulls,
+ boolean valuesReused,
+ VectorizedValuesReader valueReader,
+ ParquetVectorUpdater updater) {
+ this.reader = reader;
+ this.state = state;
+ this.defLevels = defLevels;
+ this.values = values;
+ this.nulls = nulls;
+ this.valuesReused = valuesReused;
+ this.valueReader = valueReader;
+ this.updater = updater;
+ }
+
+ void readValues(int n) {
+ if (!state.shouldSkip) {
+ state.numBatchedDefLevels += n;
+ } else {
+ reader.skipValues(state.numBatchedDefLevels, state, valueReader, updater);
+ state.numBatchedDefLevels = n;
+ state.shouldSkip = false;
+ }
+ }
+
+ void skipValues(int n) {
+ if (state.shouldSkip) {
+ state.numBatchedDefLevels += n;
+ } else {
+ reader.readValues(state.numBatchedDefLevels, state, defLevels, values, nulls, valuesReused,
+ valueReader, updater);
+ state.numBatchedDefLevels = n;
+ state.shouldSkip = true;
+ }
+ }
+
+ void finish() {
+ if (state.numBatchedDefLevels > 0) {
+ if (state.shouldSkip) {
+ reader.skipValues(state.numBatchedDefLevels, state, valueReader, updater);
+ } else {
+ reader.readValues(state.numBatchedDefLevels, state, defLevels, values, nulls,
+ valuesReused, valueReader, updater);
+ }
+ state.numBatchedDefLevels = 0;
+ }
+ }
+ }
+
+ /**
+ * Read the next 'total' values (either null or non-null) from this definition level reader and
+ * 'valueReader'. The definition levels are read into 'defLevels'. If a value is not
+ * null, it is appended to 'values'. Otherwise, a null bit will be set in 'nulls'.
+ *
+ * This is only used when reading repeated values.
+ */
+ private void readValues(
+ int total,
+ ParquetReadState state,
+ WritableColumnVector defLevels,
+ WritableColumnVector values,
+ WritableColumnVector nulls,
+ boolean valuesReused,
+ VectorizedValuesReader valueReader,
+ ParquetVectorUpdater updater) {
+ defLevels.reserveAdditional(total);
+ values.reserveAdditional(total);
+ if (!valuesReused) {
+ nulls.reserveAdditional(total);
+ }
+ int n = total;
+ int initialValueOffset = state.valueOffset;
while (n > 0) {
if (this.currentCount == 0) this.readNextGroup();
int num = Math.min(n, this.currentCount);
+ switch (mode) {
+ case RLE:
+ if (currentValue == state.maxDefinitionLevel) {
+ updater.readValues(num, state.valueOffset, values, valueReader);
+ state.valueOffset += num;
+ } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) {
+ // only add null if this represents a null element, but not the case when a
+ // collection is null or empty.
+ nulls.putNulls(state.valueOffset, num);
+ state.valueOffset += num;
+ }
+ defLevels.putInts(state.levelOffset, num, currentValue);
+ break;
+ case PACKED:
+ for (int i = 0; i < num; ++i) {
+ int currentValue = currentBuffer[currentBufferIdx++];
+ if (currentValue == state.maxDefinitionLevel) {
+ updater.readValue(state.valueOffset++, values, valueReader);
+ } else if (!state.isRequired && currentValue == state.maxDefinitionLevel - 1) {
+ // only add null if this represents a null element, but not the case when a
+ // collection is null or empty.
+ nulls.putNull(state.valueOffset++);
+ }
+ defLevels.putInt(state.levelOffset + i, currentValue);
+ }
+ break;
+ }
+ state.levelOffset += num;
+ currentCount -= num;
+ n -= num;
+ }
+ defLevels.addElementsAppended(total);
+
+ int valuesRead = state.valueOffset - initialValueOffset;
+ values.addElementsAppended(valuesRead);
+ if (!valuesReused) {
+ nulls.addElementsAppended(valuesRead);
+ }
+ }
+
+ /**
+ * Skip the next 'total' values (either null or non-null) from this definition level reader and
+ * 'valuesReader'.
+ *
+ * This is used in reading both non-repeated and repeated values.
+ */
+ private void skipValues(
+ int total,
+ ParquetReadState state,
+ VectorizedValuesReader valuesReader,
+ ParquetVectorUpdater updater) {
+ while (total > 0) {
+ if (this.currentCount == 0) this.readNextGroup();
+ int num = Math.min(total, this.currentCount);
switch (mode) {
case RLE:
// we only need to skip non-null values from `valuesReader` since nulls are represented
@@ -283,7 +594,7 @@ private void skipValues(
break;
}
currentCount -= num;
- n -= num;
+ total -= num;
}
}
@@ -497,7 +808,12 @@ private int readIntLittleEndianPaddedOnBitWidth() throws IOException {
/**
* Reads the next group.
*/
- private void readNextGroup() {
+ private boolean readNextGroup() {
+ if (in.available() <= 0) {
+ currentCount = 0;
+ return false;
+ }
+
try {
int header = readUnsignedVarInt();
this.mode = (header & 1) == 0 ? MODE.RLE : MODE.PACKED;
@@ -505,7 +821,7 @@ private void readNextGroup() {
case RLE:
this.currentCount = header >>> 1;
this.currentValue = readIntLittleEndianPaddedOnBitWidth();
- return;
+ break;
case PACKED:
int numGroups = header >>> 1;
this.currentCount = numGroups * 8;
@@ -521,12 +837,13 @@ private void readNextGroup() {
this.packer.unpack8Values(buffer, buffer.position(), this.currentBuffer, valueIndex);
valueIndex += 8;
}
- return;
+ break;
default:
throw new ParquetDecodingException("not a valid mode " + this.mode);
}
} catch (IOException e) {
throw new ParquetDecodingException("Failed to read from input stream", e);
}
+ return true;
}
}
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
index b4b6903cda24..e5dc033acaf2 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OffHeapColumnVector.java
@@ -64,6 +64,9 @@ public static OffHeapColumnVector[] allocateColumns(int capacity, StructField[]
private long lengthData;
private long offsetData;
+ // Only set if type is Struct
+ private long structOffsetData;
+
public OffHeapColumnVector(int capacity, DataType type) {
super(capacity, type);
@@ -71,6 +74,7 @@ public OffHeapColumnVector(int capacity, DataType type) {
data = 0;
lengthData = 0;
offsetData = 0;
+ structOffsetData = 0;
reserveInternal(capacity);
reset();
@@ -91,10 +95,12 @@ public void close() {
Platform.freeMemory(data);
Platform.freeMemory(lengthData);
Platform.freeMemory(offsetData);
+ Platform.freeMemory(structOffsetData);
nulls = 0;
data = 0;
lengthData = 0;
offsetData = 0;
+ structOffsetData = 0;
}
//
@@ -132,7 +138,7 @@ public void putNotNulls(int rowId, int count) {
@Override
public boolean isNullAt(int rowId) {
- return Platform.getByte(null, nulls + rowId) == 1;
+ return isAllNull || Platform.getByte(null, nulls + rowId) == 1;
}
//
@@ -527,6 +533,11 @@ public int getArrayOffset(int rowId) {
return Platform.getInt(null, offsetData + 4L * rowId);
}
+ @Override
+ public int getStructOffset(int rowId) {
+ return Platform.getInt(null, structOffsetData + 4L * rowId);
+ }
+
// APIs dealing with ByteArrays
@Override
public int putByteArray(int rowId, byte[] value, int offset, int length) {
@@ -536,6 +547,11 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
return result;
}
+ @Override
+ public void putStruct(int rowId, int offset) {
+ Platform.putInt(null, structOffsetData + 4L * rowId, offset);
+ }
+
// Split out the slow path.
@Override
protected void reserveInternal(int newCapacity) {
@@ -545,6 +561,9 @@ protected void reserveInternal(int newCapacity) {
Platform.reallocateMemory(lengthData, oldCapacity * 4L, newCapacity * 4L);
this.offsetData =
Platform.reallocateMemory(offsetData, oldCapacity * 4L, newCapacity * 4L);
+ } else if (isStruct()) {
+ this.structOffsetData =
+ Platform.reallocateMemory(structOffsetData, oldCapacity * 4L, newCapacity * 4L);
} else if (type instanceof ByteType || type instanceof BooleanType) {
this.data = Platform.reallocateMemory(data, oldCapacity, newCapacity);
} else if (type instanceof ShortType) {
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
index 3fb96d872cd8..9bed6bd0e728 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/OnHeapColumnVector.java
@@ -73,6 +73,9 @@ public static OnHeapColumnVector[] allocateColumns(int capacity, StructField[] f
private int[] arrayLengths;
private int[] arrayOffsets;
+ // Only set if type is Struct
+ private int[] structOffsets;
+
public OnHeapColumnVector(int capacity, DataType type) {
super(capacity, type);
@@ -92,6 +95,7 @@ public void close() {
doubleData = null;
arrayLengths = null;
arrayOffsets = null;
+ structOffsets = null;
}
//
@@ -127,7 +131,7 @@ public void putNotNulls(int rowId, int count) {
@Override
public boolean isNullAt(int rowId) {
- return nulls[rowId] == 1;
+ return isAllNull || nulls[rowId] == 1;
}
//
@@ -492,6 +496,11 @@ public int getArrayOffset(int rowId) {
return arrayOffsets[rowId];
}
+ @Override
+ public int getStructOffset(int rowId) {
+ return structOffsets[rowId];
+ }
+
@Override
public void putArray(int rowId, int offset, int length) {
arrayOffsets[rowId] = offset;
@@ -510,6 +519,11 @@ public int putByteArray(int rowId, byte[] value, int offset, int length) {
return result;
}
+ @Override
+ public void putStruct(int rowId, int offset) {
+ structOffsets[rowId] = offset;
+ }
+
// Spilt this function out since it is the slow path.
@Override
protected void reserveInternal(int newCapacity) {
@@ -522,6 +536,12 @@ protected void reserveInternal(int newCapacity) {
}
arrayLengths = newLengths;
arrayOffsets = newOffsets;
+ } else if (isStruct()) {
+ int[] newOffsets = new int[newCapacity];
+ if (this.structOffsets != null) {
+ System.arraycopy(this.structOffsets, 0, newOffsets, 0, capacity);
+ }
+ structOffsets = newOffsets;
} else if (type instanceof BooleanType) {
if (byteData == null || byteData.length < newCapacity) {
byte[] newData = new byte[newCapacity];
diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
index 8f7dcf237440..150327a31ac9 100644
--- a/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
+++ b/sql/core/src/main/java/org/apache/spark/sql/execution/vectorized/WritableColumnVector.java
@@ -26,6 +26,7 @@
import org.apache.spark.sql.vectorized.ColumnVector;
import org.apache.spark.sql.vectorized.ColumnarArray;
import org.apache.spark.sql.vectorized.ColumnarMap;
+import org.apache.spark.sql.vectorized.ColumnarRow;
import org.apache.spark.unsafe.array.ByteArrayMethods;
import org.apache.spark.unsafe.types.CalendarInterval;
import org.apache.spark.unsafe.types.UTF8String;
@@ -51,7 +52,7 @@ public abstract class WritableColumnVector extends ColumnVector {
* Resets this column for writing. The currently stored values are no longer accessible.
*/
public void reset() {
- if (isConstant) return;
+ if (isConstant || isAllNull) return;
if (childColumns != null) {
for (WritableColumnVector c: childColumns) {
@@ -81,6 +82,10 @@ public void close() {
dictionary = null;
}
+ public void reserveAdditional(int additionalCapacity) {
+ reserve(elementsAppended + additionalCapacity);
+ }
+
public void reserve(int requiredCapacity) {
if (requiredCapacity < 0) {
throwUnsupportedException(requiredCapacity, null);
@@ -115,7 +120,7 @@ private void throwUnsupportedException(int requiredCapacity, Throwable cause) {
@Override
public boolean hasNull() {
- return numNulls > 0;
+ return isAllNull || numNulls > 0;
}
@Override
@@ -350,6 +355,11 @@ public WritableColumnVector reserveDictionaryIds(int capacity) {
*/
public abstract void putArray(int rowId, int offset, int length);
+ /**
+ * Puts a struct with the given offset which indicates its position across all non-null entries.
+ */
+ public abstract void putStruct(int rowId, int offset);
+
/**
* Sets values from [value + offset, value + offset + count) to the values at rowId.
*/
@@ -433,6 +443,7 @@ public final int appendNull() {
}
public final int appendNotNull() {
+ assert (!(dataType() instanceof StructType)); // Use appendStruct()
reserve(elementsAppended + 1);
putNotNull(elementsAppended);
return elementsAppended++;
@@ -625,21 +636,23 @@ public final int appendArray(int length) {
* common non-struct case.
*/
public final int appendStruct(boolean isNull) {
+ reserve(elementsAppended + 1);
if (isNull) {
// This is the same as appendNull but without the assertion for struct types
- reserve(elementsAppended + 1);
putNull(elementsAppended);
- elementsAppended++;
for (WritableColumnVector c: childColumns) {
- if (c.type instanceof StructType) {
+ if (c.isStruct()) {
c.appendStruct(true);
} else {
c.appendNull();
}
}
} else {
- appendNotNull();
+ reserve(elementsAppended + 1);
+ putNotNull(elementsAppended);
}
+ putStruct(elementsAppended, elementsAppended);
+ elementsAppended++;
return elementsAppended;
}
@@ -659,6 +672,12 @@ public final ColumnarMap getMap(int rowId) {
return new ColumnarMap(getChild(0), getChild(1), getArrayOffset(rowId), getArrayLength(rowId));
}
+ @Override
+ public final ColumnarRow getStruct(int rowId) {
+ if (isNullAt(rowId)) return null;
+ return new ColumnarRow(this, getStructOffset(rowId));
+ }
+
public WritableColumnVector arrayData() {
return childColumns[0];
}
@@ -667,19 +686,49 @@ public WritableColumnVector arrayData() {
public abstract int getArrayOffset(int rowId);
+ public abstract int getStructOffset(int rowId);
+
@Override
public WritableColumnVector getChild(int ordinal) { return childColumns[ordinal]; }
+ public int getNumChildren() {
+ return childColumns.length;
+ }
+
/**
* Returns the elements appended.
*/
public final int getElementsAppended() { return elementsAppended; }
+ /**
+ * Add `num` to the elements appended. This is useful when calling the `putXXX` APIs, for
+ * keeping track of how many values have been added in the vector.
+ */
+ public final void addElementsAppended(int num) {
+ this.elementsAppended += num;
+ }
+
/**
* Marks this column as being constant.
*/
public final void setIsConstant() { isConstant = true; }
+ /**
+ * Marks this column only contain null values.
+ */
+ public final void setAllNull() {
+ isAllNull = true;
+ }
+
+ /**
+ * Returns true iff 'setAllNull' is called, which means this is a constant column vector with
+ * all values being null. If this returns false, it doesn't necessarily mean the vector
+ * contains non-null values. Rather it means the null-ness is unknown.
+ */
+ public final boolean isAllNull() {
+ return isAllNull;
+ }
+
/**
* Maximum number of rows that can be stored in this column.
*/
@@ -702,6 +751,11 @@ public WritableColumnVector arrayData() {
*/
protected boolean isConstant;
+ /**
+ * True if this column's values are all null.
+ */
+ protected boolean isAllNull;
+
/**
* Default size of each array length value. This grows as necessary.
*/
@@ -727,6 +781,10 @@ protected boolean isArray() {
DecimalType.isByteArrayDecimalType(type);
}
+ protected boolean isStruct() {
+ return type instanceof StructType || type instanceof CalendarIntervalType;
+ }
+
/**
* Sets up the common state and also handles creating the child columns if this is a nested
* type.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
index d3ac077ccf4a..e575a418bc7b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala
@@ -45,6 +45,7 @@ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjectio
import org.apache.spark.sql.catalyst.parser.LegacyTypeStringParser
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.errors.QueryExecutionErrors
+import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector}
import org.apache.spark.sql.internal.SQLConf
@@ -169,8 +170,24 @@ class ParquetFileFormat
override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = {
val conf = sparkSession.sessionState.conf
conf.parquetVectorizedReaderEnabled && conf.wholeStageEnabled &&
- schema.length <= conf.wholeStageMaxNumFields &&
- schema.forall(_.dataType.isInstanceOf[AtomicType])
+ isBatchReadSupported(conf, schema) && !WholeStageCodegenExec.isTooManyFields(conf, schema)
+ }
+
+ private def isBatchReadSupported(sqlConf: SQLConf, dt: DataType): Boolean = dt match {
+ case _: AtomicType =>
+ true
+ case at: ArrayType =>
+ sqlConf.parquetVectorizedReaderNestedColumnEnabled &&
+ isBatchReadSupported(sqlConf, at.elementType)
+ case mt: MapType =>
+ sqlConf.parquetVectorizedReaderNestedColumnEnabled &&
+ isBatchReadSupported(sqlConf, mt.keyType) &&
+ isBatchReadSupported(sqlConf, mt.valueType)
+ case st: StructType =>
+ sqlConf.parquetVectorizedReaderNestedColumnEnabled &&
+ st.fields.forall(f => isBatchReadSupported(sqlConf, f.dataType))
+ case _ =>
+ false
}
override def vectorTypes(
@@ -239,7 +256,7 @@ class ParquetFileFormat
val enableOffHeapColumnVector = sqlConf.offHeapColumnVectorEnabled
val enableVectorizedReader: Boolean =
sqlConf.parquetVectorizedReaderEnabled &&
- resultSchema.forall(_.dataType.isInstanceOf[AtomicType])
+ resultSchema.map(_.dataType).forall(isBatchReadSupported(sqlConf, _))
val enableRecordFilter: Boolean = sqlConf.parquetRecordFilterEnabled
val timestampConversion: Boolean = sqlConf.isParquetINT96TimestampConversion
val capacity = sqlConf.parquetVectorizedReaderBatchSize
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
index e0af5d8dd869..f982402f65a2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala
@@ -25,8 +25,9 @@ import scala.collection.JavaConverters._
import scala.collection.mutable.ArrayBuffer
import org.apache.parquet.column.Dictionary
+import org.apache.parquet.io.ColumnIOFactory
import org.apache.parquet.io.api.{Binary, Converter, GroupConverter, PrimitiveConverter}
-import org.apache.parquet.schema.{GroupType, Type}
+import org.apache.parquet.schema.{GroupType, Type, Types}
import org.apache.parquet.schema.LogicalTypeAnnotation._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName.{BINARY, FIXED_LEN_BYTE_ARRAY, INT32, INT64, INT96}
@@ -601,7 +602,9 @@ private[parquet] class ParquetRowConverter(
// Here we try to convert field `list` into a Catalyst type to see whether the converted type
// matches the Catalyst array element type. If it doesn't match, then it's case 1; otherwise,
// it's case 2.
- val guessedElementType = schemaConverter.convertField(repeatedType)
+ val messageType = Types.buildMessage().addField(repeatedType).named("foo")
+ val column = new ColumnIOFactory().getColumnIO(messageType)
+ val guessedElementType = schemaConverter.convertField(column.getChild(0)).sparkType
if (DataType.equalsIgnoreCompatibleNullability(guessedElementType, elementType)) {
// If the repeated field corresponds to the element type, creates a new converter using the
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
index f3bfd99368de..fedcfc0fb9a6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaConverter.scala
@@ -17,9 +17,8 @@
package org.apache.spark.sql.execution.datasources.parquet
-import scala.collection.JavaConverters._
-
import org.apache.hadoop.conf.Configuration
+import org.apache.parquet.io.{ColumnIO, ColumnIOFactory, GroupColumnIO, PrimitiveColumnIO}
import org.apache.parquet.schema._
import org.apache.parquet.schema.LogicalTypeAnnotation._
import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName._
@@ -30,7 +29,6 @@ import org.apache.spark.sql.errors.QueryCompilationErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
-
/**
* This converter class is used to convert Parquet [[MessageType]] to Spark SQL [[StructType]].
*
@@ -60,40 +58,106 @@ class ParquetToSparkSchemaConverter(
/**
* Converts Parquet [[MessageType]] `parquetSchema` to a Spark SQL [[StructType]].
*/
- def convert(parquetSchema: MessageType): StructType = convert(parquetSchema.asGroupType())
+ def convert(parquetSchema: MessageType): StructType = {
+ val column = new ColumnIOFactory().getColumnIO(parquetSchema)
+ val converted = convertInternal(column)
+ converted.sparkType.asInstanceOf[StructType]
+ }
- private def convert(parquetSchema: GroupType): StructType = {
- val fields = parquetSchema.getFields.asScala.map { field =>
- field.getRepetition match {
- case OPTIONAL =>
- StructField(field.getName, convertField(field), nullable = true)
+ /**
+ * Convert `parquetSchema` into a [[ParquetType]] which contains its corresponding Spark
+ * SQL [[StructType]] along with other information such as the maximum repetition and definition
+ * level of each node, column descriptor for the leave nodes, etc.
+ *
+ * If `sparkReadSchema` is not empty, when deriving Spark SQL type from a Parquet field this will
+ * check if the same field also exists in the schema. If so, it will use the Spark SQL type.
+ * This is necessary since conversion from Parquet to Spark could cause precision loss. For
+ * instance, Spark read schema is smallint/tinyint but Parquet only support int.
+ */
+ def convertTypeInfo(
+ parquetSchema: MessageType,
+ sparkReadSchema: Option[StructType] = None,
+ caseSensitive: Boolean = true): ParquetType = {
+ val column = new ColumnIOFactory().getColumnIO(parquetSchema)
+ convertInternal(column, sparkReadSchema, caseSensitive)
+ }
- case REQUIRED =>
- StructField(field.getName, convertField(field), nullable = false)
+ private def convertInternal(
+ groupColumn: GroupColumnIO,
+ sparkReadSchema: Option[StructType] = None,
+ caseSensitive: Boolean = true): ParquetType = {
+ val converted = (0 until groupColumn.getChildrenCount).map { i =>
+ val field = groupColumn.getChild(i)
+ var fieldReadType = sparkReadSchema.flatMap { schema =>
+ schema.find(f => isSameFieldName(f.name, field.getName, caseSensitive)).map(_.dataType)
+ }
+
+ // if a field is repeated here then it is neither contained by a `LIST` nor `MAP`
+ // annotated group (these should've been handled in `convertGroupField`), e.g.:
+ //
+ // message schema {
+ // repeated int32 int_array;
+ // }
+ // or
+ // message schema {
+ // repeated group struct_array {
+ // optional int32 field;
+ // }
+ // }
+ //
+ // the corresponding Spark read type should be an array and we should pass the element type
+ // to the group or primitive type conversion method.
+ if (field.getType.getRepetition == REPEATED) {
+ fieldReadType = fieldReadType.flatMap {
+ case at: ArrayType => Some(at.elementType)
+ case _ =>
+ throw QueryCompilationErrors.illegalParquetTypeError(groupColumn.toString)
+ }
+ }
+
+ var convertedField = convertField(field, fieldReadType)
+
+ field.getType.getRepetition match {
+ case OPTIONAL | REQUIRED =>
+ val nullable = field.getType.getRepetition == OPTIONAL
+ (StructField(field.getType.getName, convertedField.sparkType, nullable = nullable),
+ convertedField)
case REPEATED =>
// A repeated field that is neither contained by a `LIST`- or `MAP`-annotated group nor
// annotated by `LIST` or `MAP` should be interpreted as a required list of required
// elements where the element type is the type of the field.
- val arrayType = ArrayType(convertField(field), containsNull = false)
- StructField(field.getName, arrayType, nullable = false)
+ val arrayType = ArrayType(convertedField.sparkType, containsNull = false)
+ (StructField(field.getType.getName, arrayType, nullable = false),
+ ParquetType(arrayType, None, convertedField.repetitionLevel - 1,
+ convertedField.definitionLevel - 1, required = true, convertedField.path,
+ Seq(convertedField.copy(required = true))))
}
}
- StructType(fields.toSeq)
+ ParquetType(StructType(converted.map(_._1)), groupColumn, converted.map(_._2))
}
+ private def isSameFieldName(left: String, right: String, caseSensitive: Boolean): Boolean =
+ if (caseSensitive) left.equalsIgnoreCase(right)
+ else left == right
+
/**
* Converts a Parquet [[Type]] to a Spark SQL [[DataType]].
*/
- def convertField(parquetType: Type): DataType = parquetType match {
- case t: PrimitiveType => convertPrimitiveField(t)
- case t: GroupType => convertGroupField(t.asGroupType())
+ def convertField(
+ field: ColumnIO,
+ sparkReadType: Option[DataType] = None): ParquetType = field match {
+ case primitiveColumn: PrimitiveColumnIO => convertPrimitiveField(primitiveColumn, sparkReadType)
+ case groupColumn: GroupColumnIO => convertGroupField(groupColumn, sparkReadType)
}
- private def convertPrimitiveField(field: PrimitiveType): DataType = {
- val typeName = field.getPrimitiveTypeName
- val typeAnnotation = field.getLogicalTypeAnnotation
+ private def convertPrimitiveField(
+ primitiveColumn: PrimitiveColumnIO,
+ sparkReadType: Option[DataType] = None): ParquetType = {
+ val parquetType = primitiveColumn.getType.asPrimitiveType()
+ val typeAnnotation = primitiveColumn.getType.getLogicalTypeAnnotation
+ val typeName = primitiveColumn.getPrimitive
def typeString =
if (typeAnnotation == null) s"$typeName" else s"$typeName ($typeAnnotation)"
@@ -108,7 +172,7 @@ class ParquetToSparkSchemaConverter(
// specified in field.getDecimalMetadata. This is useful when interpreting decimal types stored
// as binaries with variable lengths.
def makeDecimalType(maxPrecision: Int = -1): DecimalType = {
- val decimalLogicalTypeAnnotation = field.getLogicalTypeAnnotation
+ val decimalLogicalTypeAnnotation = typeAnnotation
.asInstanceOf[DecimalLogicalTypeAnnotation]
val precision = decimalLogicalTypeAnnotation.getPrecision
val scale = decimalLogicalTypeAnnotation.getScale
@@ -120,7 +184,7 @@ class ParquetToSparkSchemaConverter(
DecimalType(precision, scale)
}
- typeName match {
+ val sparkType = sparkReadType.getOrElse(typeName match {
case BOOLEAN => BooleanType
case FLOAT => FloatType
@@ -195,17 +259,23 @@ class ParquetToSparkSchemaConverter(
case FIXED_LEN_BYTE_ARRAY =>
typeAnnotation match {
case _: DecimalLogicalTypeAnnotation =>
- makeDecimalType(Decimal.maxPrecisionForBytes(field.getTypeLength))
+ makeDecimalType(Decimal.maxPrecisionForBytes(parquetType.getTypeLength))
case _: IntervalLogicalTypeAnnotation => typeNotImplemented()
case _ => illegalType()
}
case _ => illegalType()
- }
+ })
+
+ ParquetType(sparkType, primitiveColumn)
}
- private def convertGroupField(field: GroupType): DataType = {
- Option(field.getLogicalTypeAnnotation).fold(convert(field): DataType) {
+ private def convertGroupField(
+ groupColumn: GroupColumnIO,
+ sparkReadType: Option[DataType] = None): ParquetType = {
+ val field = groupColumn.getType.asGroupType()
+ Option(field.getLogicalTypeAnnotation).fold(
+ convertInternal(groupColumn, sparkReadType.map(_.asInstanceOf[StructType]))) {
// A Parquet list is represented as a 3-level structure:
//
// group (LIST) {
@@ -222,17 +292,37 @@ class ParquetToSparkSchemaConverter(
case _: ListLogicalTypeAnnotation =>
ParquetSchemaConverter.checkConversionRequirement(
field.getFieldCount == 1, s"Invalid list type $field")
+ ParquetSchemaConverter.checkConversionRequirement(
+ sparkReadType.forall(_.isInstanceOf[ArrayType]),
+ s"Invalid Spark read type: expected $field to be list type but found $sparkReadType")
- val repeatedType = field.getType(0)
+ val repeated = groupColumn.getChild(0)
+ val repeatedType = repeated.getType
ParquetSchemaConverter.checkConversionRequirement(
repeatedType.isRepetition(REPEATED), s"Invalid list type $field")
+ val sparkReadElementType = sparkReadType.map(_.asInstanceOf[ArrayType].elementType)
if (isElementType(repeatedType, field.getName)) {
- ArrayType(convertField(repeatedType), containsNull = false)
+ var converted = convertField(repeated, sparkReadElementType)
+ val convertedType = sparkReadElementType.getOrElse(converted.sparkType)
+
+ // legacy format such as:
+ // optional group my_list (LIST) {
+ // repeated int32 element;
+ // }
+ // we should mark the primitive field as required
+ if (repeatedType.isPrimitive) converted = converted.copy(required = true)
+
+ ParquetType(ArrayType(convertedType, containsNull = false),
+ groupColumn, Seq(converted))
} else {
- val elementType = repeatedType.asGroupType().getType(0)
+ val element = repeated.asInstanceOf[GroupColumnIO].getChild(0)
+ val elementType = element.getType
val optional = elementType.isRepetition(OPTIONAL)
- ArrayType(convertField(elementType), containsNull = optional)
+ val converted = convertField(element, sparkReadElementType)
+ val convertedType = sparkReadElementType.getOrElse(converted.sparkType)
+ ParquetType(ArrayType(convertedType, containsNull = optional),
+ groupColumn, Seq(converted))
}
// scalastyle:off
@@ -243,20 +333,30 @@ class ParquetToSparkSchemaConverter(
ParquetSchemaConverter.checkConversionRequirement(
field.getFieldCount == 1 && !field.getType(0).isPrimitive,
s"Invalid map type: $field")
+ ParquetSchemaConverter.checkConversionRequirement(
+ sparkReadType.forall(_.isInstanceOf[MapType]),
+ s"Invalid Spark read type: expected $field to be map type but found $sparkReadType")
- val keyValueType = field.getType(0).asGroupType()
+ val keyValue = groupColumn.getChild(0).asInstanceOf[GroupColumnIO]
+ val keyValueType = keyValue.getType.asGroupType()
ParquetSchemaConverter.checkConversionRequirement(
keyValueType.isRepetition(REPEATED) && keyValueType.getFieldCount == 2,
s"Invalid map type: $field")
- val keyType = keyValueType.getType(0)
- val valueType = keyValueType.getType(1)
+ val key = keyValue.getChild(0)
+ val value = keyValue.getChild(1)
+ val sparkReadKeyType = sparkReadType.map(_.asInstanceOf[MapType].keyType)
+ val sparkReadValueType = sparkReadType.map(_.asInstanceOf[MapType].valueType)
+ val valueType = value.getType
val valueOptional = valueType.isRepetition(OPTIONAL)
- MapType(
- convertField(keyType),
- convertField(valueType),
- valueContainsNull = valueOptional)
-
+ val convertedKey = convertField(key, sparkReadKeyType)
+ val convertedValue = convertField(value, sparkReadValueType)
+ val convertedKeyType = sparkReadKeyType.getOrElse(convertedKey.sparkType)
+ val convertedValueType = sparkReadValueType.getOrElse(convertedValue.sparkType)
+ ParquetType(
+ MapType(convertedKeyType, convertedValueType,
+ valueContainsNull = valueOptional),
+ groupColumn, Seq(convertedKey, convertedValue))
case _ =>
throw QueryCompilationErrors.unrecognizedParquetTypeError(field.toString)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala
new file mode 100644
index 000000000000..ddf4a95a337c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetType.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import scala.collection.mutable
+
+import org.apache.parquet.column.ColumnDescriptor
+import org.apache.parquet.io.{ColumnIOUtil, GroupColumnIO, PrimitiveColumnIO}
+import org.apache.parquet.schema.Type.Repetition
+
+import org.apache.spark.sql.types.DataType
+
+/**
+ * Rich type information for a Parquet type together with its SparkSQL type.
+ */
+case class ParquetType(
+ sparkType: DataType,
+ descriptor: Option[ColumnDescriptor], // only set when this is a primitive type
+ repetitionLevel: Int,
+ definitionLevel: Int,
+ required: Boolean,
+ path: Seq[String],
+ children: Seq[ParquetType]) {
+
+ def isPrimitive: Boolean = descriptor.nonEmpty
+
+ /**
+ * Returns all the leaves (i.e., primitive columns) of this, in depth-first order.
+ */
+ def leaves: Seq[ParquetType] = {
+ val buffer = mutable.ArrayBuffer[ParquetType]()
+ leaves0(buffer)
+ buffer.toSeq
+ }
+
+ private def leaves0(buffer: mutable.ArrayBuffer[ParquetType]): Unit = {
+ children.foreach(_.leaves0(buffer))
+ }
+}
+
+object ParquetType {
+ def apply(sparkType: DataType, io: PrimitiveColumnIO): ParquetType = {
+ this(sparkType, Some(io.getColumnDescriptor), ColumnIOUtil.getRepetitionLevel(io),
+ ColumnIOUtil.getDefinitionLevel(io), io.getType.isRepetition(Repetition.REQUIRED),
+ ColumnIOUtil.getFieldPath(io), Seq.empty)
+ }
+
+ def apply(sparkType: DataType, io: GroupColumnIO, children: Seq[ParquetType]): ParquetType = {
+ this(sparkType, None, ColumnIOUtil.getRepetitionLevel(io),
+ ColumnIOUtil.getDefinitionLevel(io), io.getType.isRepetition(Repetition.REQUIRED),
+ ColumnIOUtil.getFieldPath(io), children)
+ }
+}
diff --git a/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java b/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java
new file mode 100644
index 000000000000..da3081919589
--- /dev/null
+++ b/sql/core/src/test/java/org/apache/parquet/column/page/TestDataPage.java
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.parquet.column.page;
+
+import java.util.Optional;
+
+/**
+ * A hack to create Parquet data pages with customized first row index. We have to put it under
+ * 'org.apache.parquet.column.page' since the constructor of `DataPage` is package-private.
+ */
+public class TestDataPage extends DataPage {
+ private final DataPage wrapped;
+
+ public TestDataPage(DataPage wrapped, long firstRowIndex) {
+ super(wrapped.getCompressedSize(), wrapped.getUncompressedSize(), wrapped.getValueCount(),
+ firstRowIndex);
+ this.wrapped = wrapped;
+ }
+
+ @Override
+ public Optional getIndexRowCount() {
+ return Optional.empty();
+ }
+
+ @Override
+ public T accept(Visitor visitor) {
+ return wrapped.accept(visitor);
+ }
+}
diff --git a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
index f5e5b46d29ce..f98fb1eb2a57 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain-aqe.sql.out
@@ -1125,7 +1125,8 @@ struct
-- !query output
== Physical Plan ==
*Filter v#x IN ([a],null)
-+- FileScan parquet default.t[v#x] Batched: false, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct>
++- *ColumnarToRow
+ +- FileScan parquet default.t[v#x] Batched: true, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct>
-- !query
diff --git a/sql/core/src/test/resources/sql-tests/results/explain.sql.out b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
index 4e552d51a395..a563eda1e7b0 100644
--- a/sql/core/src/test/resources/sql-tests/results/explain.sql.out
+++ b/sql/core/src/test/resources/sql-tests/results/explain.sql.out
@@ -1067,7 +1067,8 @@ struct
-- !query output
== Physical Plan ==
*Filter v#x IN ([a],null)
-+- FileScan parquet default.t[v#x] Batched: false, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct>
++- *ColumnarToRow
+ +- FileScan parquet default.t[v#x] Batched: true, DataFilters: [v#x IN ([a],null)], Format: Parquet, Location [not included in comparison]/{warehouse_dir}/t], PartitionFilters: [], PushedFilters: [In(v, [[a],null])], ReadSchema: struct>
-- !query
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
index 0fc43c7052d0..b6623d6f9d8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/DataSourceReadBenchmark.scala
@@ -66,15 +66,23 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
try f finally tableNames.foreach(spark.catalog.dropTempView)
}
- private def prepareTable(dir: File, df: DataFrame, partition: Option[String] = None): Unit = {
+ private def prepareTable(
+ dir: File,
+ df: DataFrame,
+ partition: Option[String] = None,
+ isComplexType: Boolean = false): Unit = {
val testDf = if (partition.isDefined) {
df.write.partitionBy(partition.get)
} else {
df.write
}
- saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv")
- saveAsJsonTable(testDf, dir.getCanonicalPath + "/json")
+ // don't create CSV & JSON tables when benchmarking complex types as they don't support them
+ if (!isComplexType) {
+ saveAsCsvTable(testDf, dir.getCanonicalPath + "/csv")
+ saveAsJsonTable(testDf, dir.getCanonicalPath + "/json")
+ }
+
saveAsParquetTable(testDf, dir.getCanonicalPath + "/parquet")
saveAsOrcTable(testDf, dir.getCanonicalPath + "/orc")
}
@@ -540,6 +548,106 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
}
}
+ /**
+ * Similar to [[numericScanBenchmark]] but accessed column is a struct field.
+ */
+ def nestedNumericScanBenchmark(values: Int, dataType: DataType): Unit = {
+ val sqlBenchmark = new Benchmark(
+ s"SQL Single ${dataType.sql} Column Scan in Struct",
+ values,
+ output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "parquetTable", "orcTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).createOrReplaceTempView("t1")
+
+ prepareTable(dir,
+ spark.sql(s"SELECT named_struct('f', CAST(value as ${dataType.sql})) as col FROM t1"),
+ isComplexType = true)
+
+ sqlBenchmark.addCase("SQL Parquet MR") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("select sum(col.f) from parquetTable").noop()
+ }
+ }
+
+ sqlBenchmark.addCase("SQL Parquet Vectorized (Disabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql("select sum(col.f) from parquetTable").noop()
+ }
+ }
+
+ sqlBenchmark.addCase("SQL Parquet Vectorized (Enabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql("select sum(col.f) from parquetTable").noop()
+ }
+ }
+
+ sqlBenchmark.run()
+ }
+ }
+ }
+
+ def nestedColumnScanBenchmark(values: Int): Unit = {
+ val benchmark = new Benchmark(s"SQL Nested Column Scan", values, minNumIters = 10,
+ output = output)
+
+ withTempPath { dir =>
+ withTempTable("t1", "parquetTable", "orcTable") {
+ import spark.implicits._
+ spark.range(values).map(_ => Random.nextLong).map { x =>
+ val arrayOfStructColumn = (0 until 5).map(i => (x + i, s"$x" * 5))
+ val mapOfStructColumn = Map(
+ s"$x" -> (x * 0.1, (x, s"$x" * 100)),
+ (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)),
+ (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300)))
+ (arrayOfStructColumn, mapOfStructColumn)
+ }.toDF("col1", "col2").createOrReplaceTempView("t1")
+
+ prepareTable(dir, spark.sql(s"SELECT * FROM t1"), isComplexType = true)
+
+ benchmark.addCase("SQL ORC MR") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL ORC Vectorized (Disabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL ORC Vectorized (Enabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM orcTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL Parquet MR") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL Parquet Vectorized (Disabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "false") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop()
+ }
+ }
+
+ benchmark.addCase("SQL Parquet Vectorized (Enabled Nested Column)") { _ =>
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ spark.sql("SELECT SUM(SIZE(col1)), SUM(SIZE(col2)) FROM parquetTable").noop()
+ }
+ }
+
+ benchmark.run()
+ }
+ }
+ }
+
override def runBenchmarkSuite(mainArgs: Array[String]): Unit = {
runBenchmark("SQL Single Numeric Column Scan") {
Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach {
@@ -565,5 +673,13 @@ object DataSourceReadBenchmark extends SqlBasedBenchmark {
columnsBenchmark(1024 * 1024 * 1, columnWidth)
}
}
+ runBenchmark("SQL Single Numeric Column Scan in Struct") {
+ Seq(ByteType, ShortType, IntegerType, LongType, FloatType, DoubleType).foreach {
+ dataType => nestedNumericScanBenchmark(1024 * 1024 * 15, dataType)
+ }
+ }
+ runBenchmark("SQL Nested Column Scan") {
+ nestedColumnScanBenchmark(1024 * 1024)
+ }
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala
index c2dc20b0099a..01fac6336155 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileBasedDataSourceTest.scala
@@ -38,6 +38,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
protected val dataSourceName: String
// The SQL config key for enabling vectorized reader.
protected val vectorizedReaderEnabledKey: String
+ // The SQL config key for enabling vectorized reader for nested types.
+ protected val vectorizedReaderNestedEnabledKey: String
/**
* Reads data source file from given `path` as `DataFrame` and passes it to given function.
@@ -52,7 +54,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
f(spark.read.format(dataSourceName).load(path.toString))
}
if (testVectorized) {
- withSQLConf(vectorizedReaderEnabledKey -> "true") {
+ withSQLConf(vectorizedReaderEnabledKey -> "true",
+ vectorizedReaderNestedEnabledKey -> "true") {
f(spark.read.format(dataSourceName).load(path.toString))
}
}
@@ -66,7 +69,8 @@ private[sql] trait FileBasedDataSourceTest extends SQLTestUtils {
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- spark.createDataFrame(data).write.format(dataSourceName).save(file.getCanonicalPath)
+ spark.createDataFrame(data).coalesce(1)
+ .write.format(dataSourceName).save(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
index 4243318ac1dd..0e98e2a2803d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcTest.scala
@@ -56,6 +56,8 @@ abstract class OrcTest extends QueryTest with FileBasedDataSourceTest with Befor
override protected val dataSourceName: String = "orc"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
+ override protected val vectorizedReaderNestedEnabledKey: String =
+ SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key
protected override def beforeAll(): Unit = {
super.beforeAll()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
index 2ce38dae47db..4d33eacecc13 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV1SchemaPruningSuite.scala
@@ -25,6 +25,8 @@ class OrcV1SchemaPruningSuite extends SchemaPruningSuite {
override protected val dataSourceName: String = "orc"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
+ override protected val vectorizedReaderNestedEnabledKey: String =
+ SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key
override protected def sparkConf: SparkConf =
super
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
index 47254f4231d5..107a2b791202 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcV2SchemaPruningSuite.scala
@@ -29,6 +29,8 @@ class OrcV2SchemaPruningSuite extends SchemaPruningSuite with AdaptiveSparkPlanH
override protected val dataSourceName: String = "orc"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.ORC_VECTORIZED_READER_ENABLED.key
+ override protected val vectorizedReaderNestedEnabledKey: String =
+ SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key
override protected def sparkConf: SparkConf =
super
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
index 5f1c5b5cdb4e..4b1e0445fe6f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetColumnIndexSuite.scala
@@ -96,10 +96,23 @@ class ParquetColumnIndexSuite extends QueryTest with ParquetTest with SharedSpar
test("SPARK-36123: reading from unaligned pages - test filters with nulls") {
// insert 50 null values in [400, 450) to verify that they are skipped during processing row
// range [500, 1000) against the second page of col_2 [400, 800)
- var df = spark.range(0, 2000).map { i =>
+ val df = spark.range(0, 2000).map { i =>
val strVal = if (i >= 400 && i < 450) null else i + ":" + "o" * (i / 100).toInt
(i, strVal)
}.toDF()
checkUnalignedPages(df)(actions: _*)
}
+
+ test("SPARK-34861: reading unaligned pages - struct type") {
+ val df = (0 until 2000).map(i => Tuple1((i.toLong, i + ":" + "o" * (i / 100)))).toDF("s")
+ checkUnalignedPages(df)(
+ df => df.filter("s._1 = 500"),
+ df => df.filter("s._1 = 500 or s._1 = 1500"),
+ df => df.filter("s._1 = 500 or s._1 = 501 or s._1 = 1500"),
+ df => df.filter("s._1 = 500 or s._1 = 501 or s._1 = 1000 or s._1 = 1500"),
+ // range filter
+ df => df.filter("s._1 >= 500 and s._1 < 1000"),
+ df => df.filter("(s._1 >= 500 and s._1 < 1000) or (s._1 >= 1500 and s._1 < 1600)")
+ )
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index a330b82de2d0..8943b1680cfc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -290,6 +290,332 @@ class ParquetIOSuite extends QueryTest with ParquetTest with SharedSparkSession
}
}
+ test("vectorized reader: array") {
+ val data = Seq(
+ Tuple1(null),
+ Tuple1(Seq()),
+ Tuple1(Seq("a", "b", "c")),
+ Tuple1(Seq(null))
+ )
+
+ withParquetFile(data) { file =>
+ readParquetFile(file) { df =>
+ checkAnswer(df.sort("_1"),
+ Row(null) :: Row(Seq()) :: Row(Seq(null)) :: Row(Seq("a", "b", "c")) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: missing array") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ val data = Seq(
+ Tuple1(null),
+ Tuple1(Seq()),
+ Tuple1(Seq("a", "b", "c")),
+ Tuple1(Seq(null))
+ )
+
+ val readSchema = new StructType().add("_2", new ArrayType(
+ new StructType().add("a", LongType, nullable = true),
+ containsNull = true)
+ )
+
+ withParquetFile(data) { file =>
+ checkAnswer(spark.read.schema(readSchema).parquet(file),
+ Row(null) :: Row(null) :: Row(null) :: Row(null) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: array of array") {
+ val data = Seq(
+ Tuple1(Seq(Seq(0, 1), Seq(2, 3))),
+ Tuple1(Seq(Seq(4, 5), Seq(6, 7)))
+ )
+
+ withParquetFile(data) { file =>
+ readParquetFile(file) { df =>
+ checkAnswer(df.sort("_1"),
+ Row(Seq(Seq(0, 1), Seq(2, 3))) :: Row(Seq(Seq(4, 5), Seq(6, 7))) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: array of struct") {
+ val data = Seq(
+ Tuple1(Tuple2("a", null)),
+ Tuple1(null),
+ Tuple1(Tuple2(null, null)),
+ Tuple1(Tuple2(null, Seq("b", "c"))),
+ Tuple1(Tuple2("d", Seq("e", "f"))),
+ Tuple1(null)
+ )
+
+ withParquetFile(data) { file =>
+ readParquetFile(file) { df =>
+ checkAnswer(df,
+ Row(Row("a", null)) :: Row(null) :: Row(Row(null, null)) ::
+ Row(Row(null, Seq("b", "c"))) :: Row(Row("d", Seq("e", "f"))) :: Row(null) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: array of nested struct") {
+ val data = Seq(
+ Tuple1(Tuple2("a", null)),
+ Tuple1(Tuple2("b", Seq(Tuple2("c", "d")))),
+ Tuple1(null),
+ Tuple1(Tuple2("e", Seq(Tuple2("f", null), Tuple2(null, "g")))),
+ Tuple1(Tuple2(null, null)),
+ Tuple1(Tuple2(null, Seq(null))),
+ Tuple1(Tuple2(null, Seq(Tuple2(null, null), Tuple2("h", null), null))),
+ Tuple1(Tuple2("i", Seq())),
+ Tuple1(null)
+ )
+
+ withParquetFile(data) { file =>
+ readParquetFile(file) { df =>
+ checkAnswer(df,
+ Row(Row("a", null)) ::
+ Row(Row("b", Seq(Row("c", "d")))) ::
+ Row(null) ::
+ Row(Row("e", Seq(Row("f", null), Row(null, "g")))) ::
+ Row(Row(null, null)) ::
+ Row(Row(null, Seq(null))) ::
+ Row(Row(null, Seq(Row(null, null), Row("h", null), null))) ::
+ Row(Row("i", Seq())) ::
+ Row(null) ::
+ Nil)
+ }
+ }
+ }
+
+ test("vectorized reader: required array with required elements") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = {
+ val schemaStr =
+ """message spark_schema {
+ | required group _1 (LIST) {
+ | repeated group list {
+ | required binary element (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin
+ val schema = MessageTypeParser.parseMessageType(schemaStr)
+ val writer = createParquetWriter(schema, path, dictionaryEnabled)
+
+ val factory = new SimpleGroupFactory(schema)
+ expected.foreach { values =>
+ val group = factory.newGroup()
+ val list = group.addGroup(0)
+ values.foreach { value =>
+ list.addGroup(0).append("element", value)
+ }
+ writer.write(group)
+ }
+ writer.close()
+ }
+
+ // write the following into the Parquet file:
+ // 0: [ "a", "b" ]
+ // 1: [ ]
+ // 2: [ "c", "d" ]
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ val expected = Seq(Seq("a", "b"), Seq(), Seq("c", "d"))
+ makeRawParquetFile(path, expected)
+ readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) }
+ }
+ }
+ }
+
+ test("vectorized reader: optional array with required elements") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = {
+ val schemaStr =
+ """message spark_schema {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | required binary element (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin
+ val schema = MessageTypeParser.parseMessageType(schemaStr)
+ val writer = createParquetWriter(schema, path, dictionaryEnabled)
+
+ val factory = new SimpleGroupFactory(schema)
+ expected.foreach { values =>
+ val group = factory.newGroup()
+ if (values != null) {
+ val list = group.addGroup(0)
+ values.foreach { value =>
+ list.addGroup(0).append("element", value)
+ }
+ }
+ writer.write(group)
+ }
+ writer.close()
+ }
+
+ // write the following into the Parquet file:
+ // 0: [ "a", "b" ]
+ // 1: null
+ // 2: [ "c", "d" ]
+ // 3: [ ]
+ // 4: [ "e", "f" ]
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ val expected = Seq(Seq("a", "b"), null, Seq("c", "d"), Seq(), Seq("e", "f"))
+ makeRawParquetFile(path, expected)
+ readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) }
+ }
+ }
+ }
+
+ test("vectorized reader: required array with optional elements") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = {
+ val schemaStr =
+ """message spark_schema {
+ | required group _1 (LIST) {
+ | repeated group list {
+ | optional binary element (UTF8);
+ | }
+ | }
+ |}
+ """.stripMargin
+ val schema = MessageTypeParser.parseMessageType(schemaStr)
+ val writer = createParquetWriter(schema, path, dictionaryEnabled)
+
+ val factory = new SimpleGroupFactory(schema)
+ expected.foreach { values =>
+ val group = factory.newGroup()
+ if (values != null) {
+ val list = group.addGroup(0)
+ values.foreach { value =>
+ val group = list.addGroup(0)
+ if (value != null) group.append("element", value)
+ }
+ }
+ writer.write(group)
+ }
+ writer.close()
+ }
+
+ // write the following into the Parquet file:
+ // 0: [ "a", null ]
+ // 3: [ ]
+ // 4: [ null, "b" ]
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ val expected = Seq(Seq("a", null), Seq(), Seq(null, "b"))
+ makeRawParquetFile(path, expected)
+ readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) }
+ }
+ }
+ }
+
+ test("vectorized reader: required array with legacy format") {
+ Seq(true, false).foreach { dictionaryEnabled =>
+ def makeRawParquetFile(path: Path, expected: Seq[Seq[String]]): Unit = {
+ val schemaStr =
+ """message spark_schema {
+ | repeated binary element (UTF8);
+ |}
+ """.stripMargin
+ val schema = MessageTypeParser.parseMessageType(schemaStr)
+ val writer = createParquetWriter(schema, path, dictionaryEnabled)
+
+ val factory = new SimpleGroupFactory(schema)
+ expected.foreach { values =>
+ val group = factory.newGroup()
+ values.foreach(group.append("element", _))
+ writer.write(group)
+ }
+ writer.close()
+ }
+
+ // write the following into the Parquet file:
+ // 0: [ "a", "b" ]
+ // 3: [ ]
+ // 4: [ "c", "d" ]
+ withTempDir { dir =>
+ val path = new Path(dir.toURI.toString, "part-r-0.parquet")
+ val expected = Seq(Seq("a", "b"), Seq(), Seq("c", "d"))
+ makeRawParquetFile(path, expected)
+ readParquetFile(path.toString) { df => checkAnswer(df, expected.map(Row(_))) }
+ }
+ }
+ }
+
+ test("vectorized reader: struct") {
+ val data = Seq(
+ Tuple1(null),
+ Tuple1((1, "a")),
+ Tuple1((2, null)),
+ Tuple1((3, "b")),
+ Tuple1(null)
+ )
+
+ withParquetFile(data) { file =>
+ readParquetFile(file) { df =>
+ checkAnswer(df.sort("_1"),
+ Row(null) :: Row(null) :: Row(Row(1, "a")) :: Row(Row(2, null)) :: Row(Row(3, "b")) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: missing all struct fields") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ val data = Seq(
+ Tuple1((1, "a")),
+ Tuple1((2, null)),
+ Tuple1(null)
+ )
+
+ val readSchema = new StructType().add("_1",
+ new StructType()
+ .add("_3", IntegerType, nullable = true)
+ .add("_4", LongType, nullable = true),
+ nullable = true)
+
+ withParquetFile(data) { file =>
+ checkAnswer(spark.read.schema(readSchema).parquet(file),
+ Row(null) :: Row(null) :: Row(null) :: Nil
+ )
+ }
+ }
+ }
+
+ test("vectorized reader: missing some struct fields") {
+ withSQLConf(SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") {
+ val data = Seq(
+ Tuple1((1, "a")),
+ Tuple1((2, null)),
+ Tuple1(null)
+ )
+
+ val readSchema = new StructType().add("_1",
+ new StructType()
+ .add("_1", IntegerType, nullable = true)
+ .add("_3", LongType, nullable = true),
+ nullable = true)
+
+ withParquetFile(data) { file =>
+ checkAnswer(spark.read.schema(readSchema).parquet(file),
+ Row(null) :: Row(Row(1, null)) :: Row(Row(2, null)) :: Nil
+ )
+ }
+ }
+ }
+
test("SPARK-34817: Support for unsigned Parquet logical types") {
val parquetSchema = MessageTypeParser.parseMessageType(
"""message root {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
index cab93bd96fff..6a93b72472c7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaPruningSuite.scala
@@ -31,6 +31,8 @@ abstract class ParquetSchemaPruningSuite extends SchemaPruningSuite with Adaptiv
override protected val dataSourceName: String = "parquet"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key
+ override protected val vectorizedReaderNestedEnabledKey: String =
+ SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 7a4a382f7f5c..16eb854cf250 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -49,6 +49,8 @@ private[sql] trait ParquetTest extends FileBasedDataSourceTest {
override protected val dataSourceName: String = "parquet"
override protected val vectorizedReaderEnabledKey: String =
SQLConf.PARQUET_VECTORIZED_READER_ENABLED.key
+ override protected val vectorizedReaderNestedEnabledKey: String =
+ SQLConf.PARQUET_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key
/**
* Reads the parquet file at `path`
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
new file mode 100644
index 000000000000..832df35b5b1e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetVectorizedSuite.scala
@@ -0,0 +1,751 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.datasources.parquet
+
+import java.util.{Optional, PrimitiveIterator}
+
+import scala.collection.mutable.ArrayBuffer
+import scala.language.implicitConversions
+
+import org.apache.parquet.column.{ColumnDescriptor, ParquetProperties}
+import org.apache.parquet.column.impl.ColumnWriteStoreV1
+import org.apache.parquet.column.page._
+import org.apache.parquet.column.page.mem.MemPageStore
+import org.apache.parquet.io.ParquetDecodingException
+import org.apache.parquet.io.api.Binary
+import org.apache.parquet.schema.{MessageType, MessageTypeParser}
+import org.apache.parquet.schema.PrimitiveType.PrimitiveTypeName
+
+import org.apache.spark.memory.MemoryMode
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.RowOrdering
+import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.execution.datasources.parquet.SpecificParquetRecordReaderBase.ParquetRowGroupReader
+import org.apache.spark.sql.execution.vectorized.ColumnVectorUtils
+import org.apache.spark.sql.test.SharedSparkSession
+import org.apache.spark.sql.types._
+
+class ParquetVectorizedSuite extends QueryTest with ParquetTest with SharedSparkSession {
+ private val VALUES: Seq[String] = ('a' to 'z').map(_.toString)
+ private val NUM_VALUES: Int = VALUES.length
+ private val BATCH_SIZE_CONFIGS: Seq[Int] = Seq(1, 3, 5, 7, 10, 20, 40)
+ private val PAGE_SIZE_CONFIGS: Seq[Seq[Int]] = Seq(Seq(6, 6, 7, 7), Seq(4, 9, 4, 9))
+
+ implicit def toStrings(ints: Seq[Int]): Seq[String] = ints.map(i => ('a' + i).toChar.toString)
+
+ test("primitive type - no column index") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ PAGE_SIZE_CONFIGS.foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ testPrimitiveString(None, None, pageSizes, VALUES, batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("primitive type - column index with ranges") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ PAGE_SIZE_CONFIGS.foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ var ranges = Seq((0L, 9L))
+ testPrimitiveString(None, Some(ranges), pageSizes, 0 to 9, batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((30, 50))
+ testPrimitiveString(None, Some(ranges), pageSizes, Seq.empty, batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 25))
+ testPrimitiveString(None, Some(ranges), pageSizes, 15 to 19, batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((19, 20))
+ testPrimitiveString(None, Some(ranges), pageSizes, 19 to 20, batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((0, 3), (5, 7), (15, 18))
+ testPrimitiveString(None, Some(ranges), pageSizes,
+ toStrings(Seq(0, 1, 2, 3, 5, 6, 7, 15, 16, 17, 18)),
+ batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("primitive type - column index with ranges and nulls") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ PAGE_SIZE_CONFIGS.foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ val valuesWithNulls = VALUES.zipWithIndex.map {
+ case (v, i) => if (i % 2 == 0) null else v
+ }
+ testPrimitiveString(None, None, pageSizes, valuesWithNulls, batchSize, valuesWithNulls,
+ dictionaryEnabled)
+
+ val ranges = Seq((5L, 7L))
+ testPrimitiveString(None, Some(ranges), pageSizes, Seq("f", null, "h"),
+ batchSize, valuesWithNulls, dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("primitive type - column index with ranges and first row indexes") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ // Single page
+ val firstRowIndex = 10
+ var ranges = Seq((0L, 9L))
+ testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
+ Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 25))
+ testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
+ 5 to 15, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 35))
+ testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
+ 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 39))
+ testPrimitiveString(Some(Seq(firstRowIndex)), Some(ranges), Seq(VALUES.length),
+ 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ // Row indexes: [ [10, 16), [20, 26), [30, 37), [40, 47) ]
+ // Values: [ [0, 6), [6, 12), [12, 19), [19, 26) ]
+ var pageSizes = Seq(6, 6, 7, 7)
+ var firstRowIndexes = Seq(10L, 20, 30, 40)
+
+ ranges = Seq((0L, 9L))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 25))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ 5 to 9, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 35))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ 5 to 14, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 60))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ 5 to 19, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((12, 22), (28, 38))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ toStrings(Seq(2, 3, 4, 5, 6, 7, 8, 12, 13, 14, 15, 16, 17, 18)), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ // Row indexes: [ [10, 11), [40, 52), [100, 112), [200, 201) ]
+ // Values: [ [0, 1), [1, 13), [13, 25), [25, 26] ]
+ pageSizes = Seq(1, 12, 12, 1)
+ firstRowIndexes = Seq(10L, 40, 100, 200)
+ ranges = Seq((0L, 9L))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((300, 350))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq.empty, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((50, 80))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ (11 to 12), batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((0, 150))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ 0 to 24, batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ // with nulls
+ val valuesWithNulls = VALUES.zipWithIndex.map {
+ case (v, i) => if (i % 2 == 0) null else v
+ }
+ ranges = Seq((20, 45)) // select values in [1, 5]
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq("b", null, "d", null, "f"), batchSize, valuesWithNulls,
+ dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((8, 12), (80, 104))
+ testPrimitiveString(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(null, "n", null, "p", null, "r"), batchSize, valuesWithNulls,
+ dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+
+ test("nested type - single page, no column index") {
+ (1 to 4).foreach { batchSize =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ testNestedStringArrayOneLevel(None, None, Seq(4),
+ Seq(Seq("a", "b", "c", "d")),
+ Seq(0, 1, 1, 1), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(4),
+ Seq(Seq("a", "b"), Seq("c", "d")),
+ Seq(0, 1, 0, 1), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(4),
+ Seq(Seq("a"), Seq("b"), Seq("c"), Seq("d")),
+ Seq(0, 0, 0, 0), Seq(3, 3, 3, 3), Seq("a", "b", "c", "d"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(4),
+ Seq(Seq("a"), Seq(null), Seq("c"), Seq(null)),
+ Seq(0, 0, 0, 0), Seq(3, 2, 3, 2), Seq("a", null, "c", null), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(4),
+ Seq(Seq("a"), Seq(null, null, null)),
+ Seq(0, 0, 1, 1), Seq(3, 2, 2, 2), Seq("a", null, null, null), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(6),
+ Seq(Seq("a"), Seq(null, null, null), null, Seq()),
+ Seq(0, 0, 1, 1, 0, 0), Seq(3, 2, 2, 2, 0, 1), Seq("a", null, null, null, null, null),
+ batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(8),
+ Seq(Seq("a"), Seq(), Seq(), null, Seq("b", null, "c"), null),
+ Seq(0, 0, 0, 0, 0, 1, 1, 0), Seq(3, 1, 1, 0, 3, 2, 3, 0),
+ Seq("a", null, null, null, "b", null, "c", null), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+
+ test("nested type - multiple page, no column index") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ Seq(Seq(2, 3, 2, 3)).foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ testNestedStringArrayOneLevel(None, None, pageSizes,
+ Seq(Seq("a"), Seq(), Seq("b", null, "c"), Seq("d", "e"), Seq(null), Seq(), null),
+ Seq(0, 0, 0, 1, 1, 0, 1, 0, 0, 0), Seq(3, 1, 3, 2, 3, 3, 3, 2, 1, 0),
+ Seq("a", null, "b", null, "c", "d", "e", null, null, null), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("nested type - multiple page, no column index, batch span multiple pages") {
+ (1 to 6).foreach { batchSize =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ // a list across multiple pages
+ testNestedStringArrayOneLevel(None, None, Seq(1, 5),
+ Seq(Seq("a"), Seq("b", "c", "d", "e", "f")),
+ Seq(0, 0, 1, 1, 1, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(1, 3, 2),
+ Seq(Seq("a"), Seq("b", "c", "d"), Seq("e", "f")),
+ Seq(0, 0, 1, 1, 0, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+
+ testNestedStringArrayOneLevel(None, None, Seq(2, 2, 2),
+ Seq(Seq("a", "b"), Seq("c", "d"), Seq("e", "f")),
+ Seq(0, 1, 0, 1, 0, 1), Seq.fill(6)(3), Seq("a", "b", "c", "d", "e", "f"), batchSize,
+ dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+
+ test("nested type - RLE encoding") {
+ (1 to 8).foreach { batchSize =>
+ Seq(Seq(26), Seq(4, 3, 11, 4, 4), Seq(18, 8)).foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ testNestedStringArrayOneLevel(None, None, pageSizes,
+ (0 to 6).map(i => Seq(('a' + i).toChar.toString)) ++
+ Seq((7 to 17).map(i => ('a' + i).toChar.toString)) ++
+ (18 to 25).map(i => Seq(('a' + i).toChar.toString)),
+ Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("nested type - column index with ranges") {
+ (1 to 8).foreach { batchSize =>
+ Seq(Seq(8), Seq(6, 2), Seq(1, 5, 2)).foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ var ranges = Seq((1L, 2L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("b", "c", "d", "e", "f"), Seq("g", "h")),
+ Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3),
+ Seq("a", "b", "c", "d", "e", "f", "g", "h"),
+ batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((3L, 5L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(),
+ Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3),
+ Seq("a", "b", "c", "d", "e", "f", "g", "h"),
+ batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((0L, 0L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("a")),
+ Seq(0, 0, 1, 1, 1, 1, 0, 1), Seq.fill(8)(3),
+ Seq("a", "b", "c", "d", "e", "f", "g", "h"),
+ batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("nested type - column index with ranges and RLE encoding") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ Seq(Seq(26), Seq(4, 3, 11, 4, 4), Seq(18, 8)).foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ var ranges = Seq((0L, 2L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("a"), Seq("b"), Seq("c")),
+ Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((4L, 6L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("e"), Seq("f"), Seq("g")),
+ Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((6L, 9L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("g")) ++ Seq((7 to 17).map(i => ('a' + i).toChar.toString)) ++
+ Seq(Seq("s"), Seq("t")),
+ Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((4L, 6L), (14L, 20L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("e"), Seq("f"), Seq("g"), Seq("y"), Seq("z")),
+ Seq.fill(8)(0) ++ Seq.fill(10)(1) ++ Seq.fill(8)(0), Seq.fill(26)(3),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("nested type - column index with ranges and nulls") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ Seq(Seq(16), Seq(8, 8), Seq(4, 4, 4, 4), Seq(2, 6, 4, 4)).foreach { pageSizes =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ testNestedStringArrayOneLevel(None, None, pageSizes,
+ Seq(Seq("a", null), Seq("c", "d"), Seq(), Seq("f", null, "h"),
+ Seq("i", "j", "k", null), Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ var ranges = Seq((0L, 15L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("a", null), Seq("c", "d"), Seq(), Seq("f", null, "h"),
+ Seq("i", "j", "k", null), Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((0L, 2L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("a", null), Seq("c", "d"), Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((3L, 7L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("f", null, "h"), Seq("i", "j", "k", null), Seq(), null, null),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((5, 12L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((5, 12L))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((0L, 0L), (2, 3), (5, 7), (8, 10))
+ testNestedStringArrayOneLevel(None, Some(ranges), pageSizes,
+ Seq(Seq("a", null), Seq(), Seq("f", null, "h"), Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+ }
+
+ test("nested type - column index with ranges, nulls and first row indexes") {
+ BATCH_SIZE_CONFIGS.foreach { batchSize =>
+ Seq(true, false).foreach { dictionaryEnabled =>
+ val pageSizes = Seq(4, 4, 4, 4)
+ var firstRowIndexes = Seq(10L, 20, 30, 40)
+ var ranges = Seq((0L, 5L))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((5L, 15))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(Seq("a", null), Seq("c", "d")),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((25, 28))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((35, 45))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(Seq(), null, null, Seq()),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((45, 55))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((45, 55))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+
+ ranges = Seq((15, 29), (31, 35))
+ testNestedStringArrayOneLevel(Some(firstRowIndexes), Some(ranges), pageSizes,
+ Seq(Seq(), Seq("f", null, "h")),
+ Seq(0, 1, 0, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 0, 0, 0),
+ Seq(3, 2, 3, 3, 1, 3, 2, 3, 3, 3, 3, 2, 1, 0, 0, 1),
+ (0 to 15),
+ batchSize = batchSize, dictionaryEnabled = dictionaryEnabled)
+ }
+ }
+ }
+
+ private def testNestedStringArrayOneLevel(
+ firstRowIndexesOpt: Option[Seq[Long]],
+ rangesOpt: Option[Seq[(Long, Long)]],
+ pageSizes: Seq[Int],
+ expected: Seq[Seq[String]],
+ rls: Seq[Int],
+ dls: Seq[Int],
+ values: Seq[String] = VALUES,
+ batchSize: Int,
+ dictionaryEnabled: Boolean = false): Unit = {
+ assert(pageSizes.sum == rls.length && rls.length == dls.length)
+ firstRowIndexesOpt.foreach(a => assert(pageSizes.length == a.length))
+
+ val parquetSchema = MessageTypeParser.parseMessageType(
+ s"""message root {
+ | optional group _1 (LIST) {
+ | repeated group list {
+ | optional binary a(UTF8);
+ | }
+ | }
+ |}
+ |""".stripMargin
+ )
+
+ val maxRepLevel = 1
+ val maxDefLevel = 3
+ val ty = parquetSchema.getType("_1", "list", "a").asPrimitiveType()
+ val cd = new ColumnDescriptor(Seq("_1", "list", "a").toArray, ty, maxRepLevel, maxDefLevel)
+
+ var i = 0
+ var numRows = 0
+ val memPageStore = new MemPageStore(expected.length)
+ val pageFirstRowIndexes = ArrayBuffer.empty[Long]
+ pageSizes.foreach { size =>
+ pageFirstRowIndexes += numRows
+ numRows += rls.slice(i, i + size).count(_ == 0)
+ writeDataPage(cd, memPageStore, rls.slice(i, i + size), dls.slice(i, i + size),
+ values.slice(i, i + size), maxDefLevel, dictionaryEnabled)
+ i += size
+ }
+
+ checkAnswer(expected.length, parquetSchema,
+ TestPageReadStore(memPageStore, firstRowIndexesOpt.getOrElse(pageFirstRowIndexes).toSeq,
+ rangesOpt), expected.map(i => Row(i)), batchSize)
+ }
+
+ private def testPrimitiveString(
+ firstRowIndexesOpt: Option[Seq[Long]],
+ rangesOpt: Option[Seq[(Long, Long)]],
+ pageSizes: Seq[Int],
+ expected: Seq[String],
+ batchSize: Int,
+ actual: Seq[String] = VALUES,
+ dictionaryEnabled: Boolean = false): Unit = {
+ assert(pageSizes.sum == actual.length)
+ firstRowIndexesOpt.foreach(a => assert(pageSizes.length == a.length))
+
+ val isRequiredStr = if (!expected.contains(null)) "required" else "optional"
+ val parquetSchema: MessageType = MessageTypeParser.parseMessageType(
+ s"""message root {
+ | $isRequiredStr binary a(UTF8);
+ |}
+ |""".stripMargin
+ )
+ val maxDef = if (actual.contains(null)) 1 else 0
+ val ty = parquetSchema.asGroupType().getType("a").asPrimitiveType()
+ val cd = new ColumnDescriptor(Seq("a").toArray, ty, 0, maxDef)
+ val rls = Array.fill[Int](actual.length)(0)
+ val dls = actual.map(v => if (v == null) 0 else 1)
+
+ val memPageStore = new MemPageStore(expected.length)
+
+ var i = 0
+ val pageFirstRowIndexes = ArrayBuffer.empty[Long]
+ pageSizes.foreach { size =>
+ pageFirstRowIndexes += i
+ writeDataPage(cd, memPageStore, rls.slice(i, i + size), dls.slice(i, i + size),
+ actual.slice(i, i + size), maxDef, dictionaryEnabled)
+ i += size
+ }
+
+ checkAnswer(expected.length, parquetSchema,
+ TestPageReadStore(memPageStore, firstRowIndexesOpt.getOrElse(pageFirstRowIndexes).toSeq,
+ rangesOpt), expected.map(i => Row(i)), batchSize)
+ }
+
+ /**
+ * Write a single data page using repetition levels (`rls`), definition levels (`dls`) and
+ * values (`values`) provided.
+ *
+ * Note that this requires `rls`, `dls` and `values` to have the same number of elements. For
+ * null values, the corresponding slots in `values` will be skipped.
+ */
+ private def writeDataPage(
+ cd: ColumnDescriptor,
+ pageWriteStore: PageWriteStore,
+ rls: Seq[Int],
+ dls: Seq[Int],
+ values: Seq[Any],
+ maxDl: Int,
+ dictionaryEnabled: Boolean = false): Unit = {
+ val columnWriterStore = new ColumnWriteStoreV1(pageWriteStore,
+ ParquetProperties.builder()
+ .withPageSize(4096)
+ .withDictionaryEncoding(dictionaryEnabled)
+ .build())
+ val columnWriter = columnWriterStore.getColumnWriter(cd)
+
+ rls.zip(dls).zipWithIndex.foreach { case ((rl, dl), i) =>
+ if (dl < maxDl) {
+ columnWriter.writeNull(rl, dl)
+ } else {
+ cd.getPrimitiveType.getPrimitiveTypeName match {
+ case PrimitiveTypeName.INT32 =>
+ columnWriter.write(values(i).asInstanceOf[Int], rl, dl)
+ case PrimitiveTypeName.INT64 =>
+ columnWriter.write(values(i).asInstanceOf[Long], rl, dl)
+ case PrimitiveTypeName.BOOLEAN =>
+ columnWriter.write(values(i).asInstanceOf[Boolean], rl, dl)
+ case PrimitiveTypeName.FLOAT =>
+ columnWriter.write(values(i).asInstanceOf[Float], rl, dl)
+ case PrimitiveTypeName.DOUBLE =>
+ columnWriter.write(values(i).asInstanceOf[Double], rl, dl)
+ case PrimitiveTypeName.BINARY =>
+ columnWriter.write(Binary.fromString(values(i).asInstanceOf[String]), rl, dl)
+ case _ =>
+ throw new IllegalStateException(s"Unexpected type: " +
+ s"${cd.getPrimitiveType.getPrimitiveTypeName}")
+ }
+ }
+ columnWriterStore.endRecord()
+ }
+ columnWriterStore.flush()
+ }
+
+ private def checkAnswer(
+ totalRowCount: Int,
+ fileSchema: MessageType,
+ readStore: PageReadStore,
+ expected: Seq[Row],
+ batchSize: Int = NUM_VALUES): Unit = {
+ import collection.JavaConverters._
+
+ val recordReader = new VectorizedParquetRecordReader(
+ DateTimeUtils.getZoneId("EST"), "CORRECTED", "CORRECTED", true, batchSize)
+ recordReader.initialize(fileSchema, fileSchema,
+ TestParquetRowGroupReader(Seq(readStore)), totalRowCount)
+
+ // convert both actual and expected rows into collections
+ val schema = recordReader.sparkSchema
+ val expectedRowIt = ColumnVectorUtils.toBatch(
+ schema, MemoryMode.ON_HEAP, expected.iterator.asJava).rowIterator()
+
+ val rowOrdering = RowOrdering.createNaturalAscendingOrdering(schema.map(_.dataType))
+ var i = 0
+ while (expectedRowIt.hasNext && recordReader.nextKeyValue()) {
+ val expectedRow = expectedRowIt.next()
+ val actualRow = recordReader.getCurrentValue.asInstanceOf[InternalRow]
+ assert(rowOrdering.compare(expectedRow, actualRow) == 0, {
+ val expectedRowStr = toDebugString(schema, expectedRow)
+ val actualRowStr = toDebugString(schema, actualRow)
+ s"at index $i, expected row: $expectedRowStr doesn't match actual row: $actualRowStr"
+ })
+ i += 1
+ }
+ }
+
+ private def toDebugString(schema: StructType, row: InternalRow): String = {
+ if (row == null) "null"
+ else {
+ val fieldStrings = schema.fields.zipWithIndex.map { case (f, i) =>
+ f.dataType match {
+ case IntegerType =>
+ row.getInt(i).toString
+ case StringType =>
+ val utf8Str = row.getUTF8String(i)
+ if (utf8Str == null) "null"
+ else utf8Str.toString
+ case ArrayType(_, _) =>
+ val elements = row.getArray(i)
+ if (elements == null) "null"
+ else "[" + elements.array.mkString(", ") + "]"
+ case _ =>
+ throw new IllegalStateException(s"Unsupported data type: ${f.dataType}")
+ }
+ }
+ fieldStrings.mkString(", ")
+ }
+ }
+
+ case class TestParquetRowGroupReader(groups: Seq[PageReadStore]) extends ParquetRowGroupReader {
+ private var index: Int = 0
+
+ override def readNextRowGroup(): PageReadStore = {
+ if (index == groups.length) {
+ null
+ } else {
+ val res = groups(index)
+ index += 1
+ res
+ }
+ }
+
+ override def close(): Unit = {}
+ }
+
+ private case class TestPageReadStore(
+ original: PageReadStore,
+ firstRowIndexes: Seq[Long],
+ rangesOpt: Option[Seq[(Long, Long)]] = None) extends PageReadStore {
+ override def getPageReader(descriptor: ColumnDescriptor): PageReader = {
+ val originalReader = original.getPageReader(descriptor)
+ TestPageReader(originalReader, firstRowIndexes)
+ }
+
+ override def getRowCount: Long = original.getRowCount
+
+ override def getRowIndexes: Optional[PrimitiveIterator.OfLong] = {
+ rangesOpt.map { ranges =>
+ Optional.of(new PrimitiveIterator.OfLong {
+ private var currentRangeIdx: Int = 0
+ private var currentRowIdx: Long = -1
+
+ override def nextLong(): Long = {
+ if (!hasNext) throw new NoSuchElementException("No more element")
+ val res = currentRowIdx
+ currentRowIdx += 1
+ res
+ }
+
+ override def hasNext: Boolean = {
+ while (currentRangeIdx < ranges.length) {
+ if (currentRowIdx > ranges(currentRangeIdx)._2) {
+ // we've exhausted the current range - move to the next range
+ currentRangeIdx += 1
+ currentRowIdx = -1
+ } else {
+ if (currentRowIdx == -1) {
+ currentRowIdx = ranges(currentRangeIdx)._1
+ }
+ return true
+ }
+ }
+ false
+ }
+ })
+ }.getOrElse(Optional.empty())
+ }
+ }
+
+ private case class TestPageReader(
+ wrapped: PageReader,
+ firstRowIndexes: Seq[Long]) extends PageReader {
+ private var index = 0
+
+ override def readDictionaryPage(): DictionaryPage = wrapped.readDictionaryPage()
+ override def getTotalValueCount: Long = wrapped.getTotalValueCount
+ override def readPage(): DataPage = {
+ val wrappedPage = try {
+ wrapped.readPage()
+ } catch {
+ case _: ParquetDecodingException =>
+ null
+ }
+ if (wrappedPage == null) {
+ wrappedPage
+ } else {
+ val res = new TestDataPage(wrappedPage, firstRowIndexes(index))
+ index += 1
+ res
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
index 43f48abb9734..6721e1ad2584 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnVectorSuite.scala
@@ -336,8 +336,10 @@ class ColumnVectorSuite extends SparkFunSuite with BeforeAndAfterEach {
val c2 = testVector.getChild(1)
c1.putInt(0, 123)
c2.putDouble(0, 3.45)
+ testVector.putStruct(0, 0)
c1.putInt(1, 456)
c2.putDouble(1, 5.67)
+ testVector.putStruct(1, 1)
assert(testVector.getStruct(0).get(0, IntegerType) === 123)
assert(testVector.getStruct(0).get(1, DoubleType) === 3.45)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
index 5fe0a2aef8a8..979e5f486da2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/vectorized/ColumnarBatchSuite.scala
@@ -921,12 +921,14 @@ class ColumnarBatchSuite extends SparkFunSuite {
c1.putInt(0, 123)
c2.putDouble(0, 3.45)
+ column.putStruct(0, 0)
column.putNull(1)
assert(column.getStruct(1) == null)
- c1.putInt(2, 456)
- c2.putDouble(2, 5.67)
+ c1.putInt(1, 456)
+ c2.putDouble(1, 5.67)
+ column.putStruct(2, 1)
val s = column.getStruct(0)
assert(s.getInt(0) == 123)
@@ -982,6 +984,7 @@ class ColumnarBatchSuite extends SparkFunSuite {
(0 until 6).foreach { i =>
c0.putInt(i, i)
c1.putLong(i, i * 10)
+ data.putStruct(i, i)
}
// Arrays in column: [(0, 0), (1, 10)], [(1, 10), (2, 20), (3, 30)],
// [(4, 40), (5, 50)]
@@ -1018,6 +1021,10 @@ class ColumnarBatchSuite extends SparkFunSuite {
c1.putArray(1, 2, 1)
c1.putArray(2, 3, 3)
+ column.putStruct(0, 0)
+ column.putStruct(1, 1)
+ column.putStruct(2, 2)
+
assert(column.getStruct(0).getInt(0) === 0)
assert(column.getStruct(0).getArray(1).toIntArray() === Array(0, 1))
assert(column.getStruct(1).getInt(0) === 1)
@@ -1038,6 +1045,9 @@ class ColumnarBatchSuite extends SparkFunSuite {
c0.putInt(0, 0)
c0.putInt(1, 1)
c0.putInt(2, 2)
+ column.putStruct(0, 0)
+ column.putStruct(1, 1)
+ column.putStruct(2, 2)
val c1c0 = c1.getChild(0)
val c1c1 = c1.getChild(1)
// Structs in c1: (7, 70), (8, 80), (9, 90)
@@ -1047,6 +1057,9 @@ class ColumnarBatchSuite extends SparkFunSuite {
c1c1.putInt(0, 70)
c1c1.putInt(1, 80)
c1c1.putInt(2, 90)
+ c1.putStruct(0, 0)
+ c1.putStruct(1, 1)
+ c1.putStruct(2, 2)
assert(column.getStruct(0).getInt(0) === 0)
assert(column.getStruct(0).getStruct(1, 2).toSeq(subSchema) === Seq(7, 70))