diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 79c54bc5aaa7..c4ffdb402fab 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -36,6 +36,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.sources._ import org.apache.spark.sql.types._ @@ -118,7 +119,7 @@ class OrcFileFormat override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && - schema.length <= conf.wholeStageMaxNumFields && + !WholeStageCodegenExec.isTooManyFields(conf, schema) && schema.forall(s => OrcUtils.supportColumnarReads( s.dataType, sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 930adc08e77a..c5020cb79524 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -30,6 +30,7 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.connector.read.{InputPartition, PartitionReader} +import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, OrcDeserializer, OrcFilters, OrcUtils} import org.apache.spark.sql.execution.datasources.v2._ @@ -63,7 +64,7 @@ case class OrcPartitionReaderFactory( override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && - resultSchema.length <= sqlConf.wholeStageMaxNumFields && + !WholeStageCodegenExec.isTooManyFields(sqlConf, resultSchema) && resultSchema.forall(s => OrcUtils.supportColumnarReads( s.dataType, sqlConf.orcVectorizedReaderNestedColumnEnabled)) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index 680c2cf2b42e..e4c33e96faa1 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -742,6 +742,32 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } } + + test("SPARK-36594: ORC vectorized reader should properly check maximal number of fields") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).map { x => + val stringColumn = s"$x" * 10 + val structColumn = (x, s"$x" * 100) + val arrayColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) + val mapColumn = Map(s"$x" -> (x * 0.1, (x, s"$x" * 100))) + (x, stringColumn, structColumn, arrayColumn, mapColumn) + }.toDF("int_col", "string_col", "struct_col", "array_col", "map_col") + df.write.format("orc").save(path) + + Seq(("5", false), ("10", true)).foreach { + case (maxNumFields, vectorizedEnabled) => + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true", + SQLConf.WHOLESTAGE_MAX_NUM_FIELDS.key -> maxNumFields) { + val scanPlan = spark.read.orc(path).queryExecution.executedPlan + assert(scanPlan.find { + case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar + case _ => false + }.isDefined == vectorizedEnabled) + } + } + } + } } class OrcV1QuerySuite extends OrcQuerySuite {