diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java index 6e13e97b4cbc..b0c818f5a4df 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcArrayColumnVector.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.ListColumnVector; import org.apache.spark.sql.types.ArrayType; import org.apache.spark.sql.types.DataType; @@ -31,26 +32,22 @@ */ public class OrcArrayColumnVector extends OrcColumnVector { private final OrcColumnVector data; - private final long[] offsets; - private final long[] lengths; OrcArrayColumnVector( DataType type, ColumnVector vector, - OrcColumnVector data, - long[] offsets, - long[] lengths) { + OrcColumnVector data) { super(type, vector); this.data = data; - this.offsets = offsets; - this.lengths = lengths; } @Override public ColumnarArray getArray(int rowId) { - return new ColumnarArray(data, (int) offsets[rowId], (int) lengths[rowId]); + int offsets = (int) ((ListColumnVector) baseData).offsets[rowId]; + int lengths = (int) ((ListColumnVector) baseData).lengths[rowId]; + return new ColumnarArray(data, offsets, lengths); } @Override diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java index 0becd2572f99..7fe1b306142e 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVector.java @@ -29,7 +29,7 @@ * this column vector is used to adapt Hive ColumnVector with Spark ColumnarVector. */ public abstract class OrcColumnVector extends org.apache.spark.sql.vectorized.ColumnVector { - private final ColumnVector baseData; + protected final ColumnVector baseData; private int batchSize; OrcColumnVector(DataType type, ColumnVector vector) { diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java index 3bc7cc8f8014..89f6996e4610 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcColumnVectorUtils.java @@ -53,15 +53,13 @@ static OrcColumnVector toOrcColumnVector(DataType type, ColumnVector vector) { ListColumnVector listVector = (ListColumnVector) vector; OrcColumnVector dataVector = toOrcColumnVector( ((ArrayType) type).elementType(), listVector.child); - return new OrcArrayColumnVector( - type, vector, dataVector, listVector.offsets, listVector.lengths); + return new OrcArrayColumnVector(type, vector, dataVector); } else if (vector instanceof MapColumnVector) { MapColumnVector mapVector = (MapColumnVector) vector; MapType mapType = (MapType) type; OrcColumnVector keysVector = toOrcColumnVector(mapType.keyType(), mapVector.keys); OrcColumnVector valuesVector = toOrcColumnVector(mapType.valueType(), mapVector.values); - return new OrcMapColumnVector( - type, vector, keysVector, valuesVector, mapVector.offsets, mapVector.lengths); + return new OrcMapColumnVector(type, vector, keysVector, valuesVector); } else { throw new IllegalArgumentException( String.format("OrcColumnVectorUtils.toOrcColumnVector should not take %s as type " + diff --git a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java index ace8d157792d..7eedd8b59412 100644 --- a/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java +++ b/sql/core/src/main/java/org/apache/spark/sql/execution/datasources/orc/OrcMapColumnVector.java @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.datasources.orc; import org.apache.hadoop.hive.ql.exec.vector.ColumnVector; +import org.apache.hadoop.hive.ql.exec.vector.MapColumnVector; import org.apache.spark.sql.types.DataType; import org.apache.spark.sql.types.Decimal; @@ -32,28 +33,24 @@ public class OrcMapColumnVector extends OrcColumnVector { private final OrcColumnVector keys; private final OrcColumnVector values; - private final long[] offsets; - private final long[] lengths; OrcMapColumnVector( DataType type, ColumnVector vector, OrcColumnVector keys, - OrcColumnVector values, - long[] offsets, - long[] lengths) { + OrcColumnVector values) { super(type, vector); this.keys = keys; this.values = values; - this.offsets = offsets; - this.lengths = lengths; } @Override public ColumnarMap getMap(int ordinal) { - return new ColumnarMap(keys, values, (int) offsets[ordinal], (int) lengths[ordinal]); + int offsets = (int) ((MapColumnVector) baseData).offsets[ordinal]; + int lengths = (int) ((MapColumnVector) baseData).lengths[ordinal]; + return new ColumnarMap(keys, values, offsets, lengths); } @Override diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 8bc92f8d57ad..038606b854d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -744,6 +744,28 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } + test("SPARK-37728: Reading nested columns with ORC vectorized reader should not " + + "cause ArrayIndexOutOfBoundsException") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(100).map { _ => + val arrayColumn = (0 until 50).map(_ => (0 until 1000).map(k => k.toString)) + arrayColumn + }.toDF("record").repartition(1) + df.write.format("orc").save(path) + + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val readDf = spark.read.orc(path) + val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar + case _ => false + }.isDefined + assert(vectorizationEnabled) + checkAnswer(readDf, df) + } + } + } + test("SPARK-36594: ORC vectorized reader should properly check maximal number of fields") { withTempPath { dir => val path = dir.getCanonicalPath