diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala index 85c0ff01cfba..56ddbcc2d567 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcFileFormat.scala @@ -115,27 +115,12 @@ class OrcFileFormat } } - private def supportBatchForNestedColumn( - sparkSession: SparkSession, - schema: StructType): Boolean = { - val hasNestedColumn = schema.map(_.dataType).exists { - case _: ArrayType | _: MapType | _: StructType => true - case _ => false - } - if (hasNestedColumn) { - sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled - } else { - true - } - } - override def supportBatch(sparkSession: SparkSession, schema: StructType): Boolean = { val conf = sparkSession.sessionState.conf conf.orcVectorizedReaderEnabled && conf.wholeStageEnabled && schema.length <= conf.wholeStageMaxNumFields && - schema.forall(s => supportDataType(s.dataType) && - !s.dataType.isInstanceOf[UserDefinedType[_]]) && - supportBatchForNestedColumn(sparkSession, schema) + schema.forall(s => OrcUtils.supportColumnarReads( + s.dataType, sparkSession.sessionState.conf.orcVectorizedReaderNestedColumnEnabled)) } override def isSplitable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala index a8647726fe02..391ead85c4ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/orc/OrcUtils.scala @@ -259,4 +259,27 @@ object OrcUtils extends Logging { OrcConf.MAPRED_INPUT_SCHEMA.setString(conf, resultSchemaString) resultSchemaString } + + /** + * Checks if `dataType` supports columnar reads. + * + * @param dataType Data type of the orc files. + * @param nestedColumnEnabled True if columnar reads is enabled for nested column types. + * @return Returns true if data type supports columnar reads. + */ + def supportColumnarReads( + dataType: DataType, + nestedColumnEnabled: Boolean): Boolean = { + dataType match { + case _: AtomicType => true + case st: StructType if nestedColumnEnabled => + st.forall(f => supportColumnarReads(f.dataType, nestedColumnEnabled)) + case ArrayType(elementType, _) if nestedColumnEnabled => + supportColumnarReads(elementType, nestedColumnEnabled) + case MapType(keyType, valueType, _) if nestedColumnEnabled => + supportColumnarReads(keyType, nestedColumnEnabled) && + supportColumnarReads(valueType, nestedColumnEnabled) + case _ => false + } + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index 414252cc1248..930adc08e77a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -35,7 +35,7 @@ import org.apache.spark.sql.execution.datasources.orc.{OrcColumnarBatchReader, O import org.apache.spark.sql.execution.datasources.v2._ import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.sources.Filter -import org.apache.spark.sql.types.{AtomicType, StructType} +import org.apache.spark.sql.types.StructType import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util.{SerializableConfiguration, Utils} @@ -64,7 +64,8 @@ case class OrcPartitionReaderFactory( override def supportColumnarReads(partition: InputPartition): Boolean = { sqlConf.orcVectorizedReaderEnabled && sqlConf.wholeStageEnabled && resultSchema.length <= sqlConf.wholeStageMaxNumFields && - resultSchema.forall(_.dataType.isInstanceOf[AtomicType]) + resultSchema.forall(s => OrcUtils.supportColumnarReads( + s.dataType, sqlConf.orcVectorizedReaderNestedColumnEnabled)) } private def pushDownPredicates(filePath: Path, conf: Configuration): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala index ead2c2cf1b70..680c2cf2b42e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcQuerySuite.scala @@ -34,7 +34,9 @@ import org.apache.orc.mapreduce.OrcInputFormat import org.apache.spark.{SparkConf, SparkException} import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.TableIdentifier +import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{HadoopFsRelation, LogicalRelation, RecordReaderIterator} +import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ @@ -713,6 +715,33 @@ abstract class OrcQuerySuite extends OrcQueryTest with SharedSparkSession { } } } + + test("SPARK-34862: Support ORC vectorized reader for nested column") { + withTempPath { dir => + val path = dir.getCanonicalPath + val df = spark.range(10).map { x => + val stringColumn = s"$x" * 10 + val structColumn = (x, s"$x" * 100) + val arrayColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) + val mapColumn = Map( + s"$x" -> (x * 0.1, (x, s"$x" * 100)), + (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), + (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) + (x, stringColumn, structColumn, arrayColumn, mapColumn) + }.toDF("int_col", "string_col", "struct_col", "array_col", "map_col") + df.write.format("orc").save(path) + + withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { + val readDf = spark.read.orc(path) + val vectorizationEnabled = readDf.queryExecution.executedPlan.find { + case scan @ (_: FileSourceScanExec | _: BatchScanExec) => scan.supportsColumnar + case _ => false + }.isDefined + assert(vectorizationEnabled) + checkAnswer(readDf, df) + } + } + } } class OrcV1QuerySuite extends OrcQuerySuite { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala index 9acf59cbd94e..fb75ea1dd43e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcSourceSuite.scala @@ -33,7 +33,6 @@ import org.scalatest.BeforeAndAfterAll import org.apache.spark.{SPARK_VERSION_SHORT, SparkException} import org.apache.spark.sql.{Row, SPARK_VERSION_METADATA_KEY} -import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{CommonFileDataSourceSuite, SchemaMergeUtils} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -553,7 +552,6 @@ abstract class OrcSuite extends OrcTest with BeforeAndAfterAll with CommonFileDa } class OrcSourceSuite extends OrcSuite with SharedSparkSession { - import testImplicits._ protected override def beforeAll(): Unit = { super.beforeAll() @@ -617,33 +615,6 @@ class OrcSourceSuite extends OrcSuite with SharedSparkSession { } } - test("SPARK-34862: Support ORC vectorized reader for nested column") { - withTempPath { dir => - val path = dir.getCanonicalPath - val df = spark.range(10).map { x => - val stringColumn = s"$x" * 10 - val structColumn = (x, s"$x" * 100) - val arrayColumn = (0 until 5).map(i => (x + i, s"$x" * 5)) - val mapColumn = Map( - s"$x" -> (x * 0.1, (x, s"$x" * 100)), - (s"$x" * 2) -> (x * 0.2, (x, s"$x" * 200)), - (s"$x" * 3) -> (x * 0.3, (x, s"$x" * 300))) - (x, stringColumn, structColumn, arrayColumn, mapColumn) - }.toDF("int_col", "string_col", "struct_col", "array_col", "map_col") - df.write.format("orc").save(path) - - withSQLConf(SQLConf.ORC_VECTORIZED_READER_NESTED_COLUMN_ENABLED.key -> "true") { - val readDf = spark.read.orc(path) - val vectorizationEnabled = readDf.queryExecution.executedPlan.find { - case scan: FileSourceScanExec => scan.supportsColumnar - case _ => false - }.isDefined - assert(vectorizationEnabled) - checkAnswer(readDf, df) - } - } - } - test("SPARK-34897: Support reconcile schemas based on index after nested column pruning") { withTable("t1") { spark.sql(