diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala index efa4f3f166d9..e37f2283e00c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFileFormat.scala @@ -310,6 +310,9 @@ class ParquetFileFormat hadoopConf.set( SQLConf.SESSION_LOCAL_TIMEZONE.key, sparkSession.sessionState.conf.sessionLocalTimeZone) + hadoopConf.setBoolean( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + sparkSession.sessionState.conf.nestedSchemaPruningEnabled) hadoopConf.setBoolean( SQLConf.CASE_SENSITIVE.key, sparkSession.sessionState.conf.caseSensitiveAnalysis) @@ -424,11 +427,12 @@ class ParquetFileFormat } else { logDebug(s"Falling back to parquet-mr") // ParquetRecordReader returns UnsafeRow + val readSupport = new ParquetReadSupport(convertTz, enableVectorizedReader = false) val reader = if (pushed.isDefined && enableRecordFilter) { val parquetFilter = FilterCompat.get(pushed.get, null) - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz), parquetFilter) + new ParquetRecordReader[UnsafeRow](readSupport, parquetFilter) } else { - new ParquetRecordReader[UnsafeRow](new ParquetReadSupport(convertTz)) + new ParquetRecordReader[UnsafeRow](readSupport) } val iter = new RecordReaderIterator(reader) // SPARK-23457 Register a task completion lister before `initialization`. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala index 3319e73f2b31..df7766520290 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetReadSupport.scala @@ -49,15 +49,16 @@ import org.apache.spark.sql.types._ * Due to this reason, we no longer rely on [[ReadContext]] to pass requested schema from [[init()]] * to [[prepareForRead()]], but use a private `var` for simplicity. */ -private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) - extends ReadSupport[UnsafeRow] with Logging { +private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone], + enableVectorizedReader: Boolean) + extends ReadSupport[UnsafeRow] with Logging { private var catalystRequestedSchema: StructType = _ def this() { // We need a zero-arg constructor for SpecificParquetRecordReaderBase. But that is only // used in the vectorized reader, where we get the convertTz value directly, and the value here // is ignored. - this(None) + this(None, enableVectorizedReader = true) } /** @@ -65,18 +66,48 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) * readers. Responsible for figuring out Parquet requested schema used for column pruning. */ override def init(context: InitContext): ReadContext = { + val conf = context.getConfiguration catalystRequestedSchema = { - val conf = context.getConfiguration val schemaString = conf.get(ParquetReadSupport.SPARK_ROW_REQUESTED_SCHEMA) assert(schemaString != null, "Parquet requested schema not set.") StructType.fromString(schemaString) } - val caseSensitive = context.getConfiguration.getBoolean(SQLConf.CASE_SENSITIVE.key, + val caseSensitive = conf.getBoolean(SQLConf.CASE_SENSITIVE.key, SQLConf.CASE_SENSITIVE.defaultValue.get) - val parquetRequestedSchema = ParquetReadSupport.clipParquetSchema( - context.getFileSchema, catalystRequestedSchema, caseSensitive) - + val schemaPruningEnabled = conf.getBoolean(SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key, + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.defaultValue.get) + val parquetFileSchema = context.getFileSchema + val parquetClippedSchema = ParquetReadSupport.clipParquetSchema(parquetFileSchema, + catalystRequestedSchema, caseSensitive) + + // We pass two schema to ParquetRecordMaterializer: + // - parquetRequestedSchema: the schema of the file data we want to read + // - catalystRequestedSchema: the schema of the rows we want to return + // The reader is responsible for reconciling the differences between the two. + val parquetRequestedSchema = if (schemaPruningEnabled && !enableVectorizedReader) { + // Parquet-MR reader requires that parquetRequestedSchema include only those fields present + // in the underlying parquetFileSchema. Therefore, we intersect the parquetClippedSchema + // with the parquetFileSchema + ParquetReadSupport.intersectParquetGroups(parquetClippedSchema, parquetFileSchema) + .map(groupType => new MessageType(groupType.getName, groupType.getFields)) + .getOrElse(ParquetSchemaConverter.EMPTY_MESSAGE) + } else { + // Spark's vectorized reader only support atomic types currently. It also skip fields + // in parquetRequestedSchema which are not present in the file. + parquetClippedSchema + } + logDebug( + s"""Going to read the following fields from the Parquet file with the following schema: + |Parquet file schema: + |$parquetFileSchema + |Parquet clipped schema: + |$parquetClippedSchema + |Parquet requested schema: + |$parquetRequestedSchema + |Catalyst requested schema: + |${catalystRequestedSchema.treeString} + """.stripMargin) new ReadContext(parquetRequestedSchema, Map.empty[String, String].asJava) } @@ -90,19 +121,7 @@ private[parquet] class ParquetReadSupport(val convertTz: Option[TimeZone]) keyValueMetaData: JMap[String, String], fileSchema: MessageType, readContext: ReadContext): RecordMaterializer[UnsafeRow] = { - log.debug(s"Preparing for read Parquet file with message type: $fileSchema") val parquetRequestedSchema = readContext.getRequestedSchema - - logInfo { - s"""Going to read the following fields from the Parquet file: - | - |Parquet form: - |$parquetRequestedSchema - |Catalyst form: - |$catalystRequestedSchema - """.stripMargin - } - new ParquetRecordMaterializer( parquetRequestedSchema, ParquetReadSupport.expandUDT(catalystRequestedSchema), @@ -322,6 +341,35 @@ private[parquet] object ParquetReadSupport { } } + /** + * Computes the structural intersection between two Parquet group types. + * This is used to create a requestedSchema for ReadContext of Parquet-MR reader. + * Parquet-MR reader does not support the nested field access to non-existent field + * while parquet library does support to read the non-existent field by regular field access. + */ + private def intersectParquetGroups( + groupType1: GroupType, groupType2: GroupType): Option[GroupType] = { + val fields = + groupType1.getFields.asScala + .filter(field => groupType2.containsField(field.getName)) + .flatMap { + case field1: GroupType => + val field2 = groupType2.getType(field1.getName) + if (field2.isPrimitive) { + None + } else { + intersectParquetGroups(field1, field2.asGroupType) + } + case field1 => Some(field1) + } + + if (fields.nonEmpty) { + Some(groupType1.withNewFields(fields.asJava)) + } else { + None + } + } + def expandUDT(schema: StructType): StructType = { def expand(dataType: DataType): DataType = { dataType match { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala index 004a96d13413..b772b6b77d1c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetRowConverter.scala @@ -130,8 +130,8 @@ private[parquet] class ParquetRowConverter( extends ParquetGroupConverter(updater) with Logging { assert( - parquetType.getFieldCount == catalystType.length, - s"""Field counts of the Parquet schema and the Catalyst schema don't match: + parquetType.getFieldCount <= catalystType.length, + s"""Field count of the Parquet schema is greater than the field count of the Catalyst schema: | |Parquet schema: |$parquetType @@ -182,10 +182,11 @@ private[parquet] class ParquetRowConverter( // Converters for each field. private val fieldConverters: Array[Converter with HasParentContainerUpdater] = { - parquetType.getFields.asScala.zip(catalystType).zipWithIndex.map { - case ((parquetFieldType, catalystField), ordinal) => - // Converted field value should be set to the `ordinal`-th cell of `currentRow` - newConverter(parquetFieldType, catalystField.dataType, new RowUpdater(currentRow, ordinal)) + parquetType.getFields.asScala.map { parquetField => + val fieldIndex = catalystType.fieldIndex(parquetField.getName) + val catalystField = catalystType(fieldIndex) + // Converted field value should be set to the `fieldIndex`-th cell of `currentRow` + newConverter(parquetField, catalystField.dataType, new RowUpdater(currentRow, fieldIndex)) }.toArray } @@ -193,7 +194,7 @@ private[parquet] class ParquetRowConverter( override def end(): Unit = { var i = 0 - while (i < currentRow.numFields) { + while (i < fieldConverters.length) { fieldConverters(i).updater.end() i += 1 } @@ -203,10 +204,14 @@ private[parquet] class ParquetRowConverter( override def start(): Unit = { var i = 0 while (i < currentRow.numFields) { - fieldConverters(i).updater.start() currentRow.setNullAt(i) i += 1 } + i = 0 + while (i < fieldConverters.length) { + fieldConverters(i).updater.start() + i += 1 + } } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index b0314e621e6e..22317fe8d13a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -135,7 +135,7 @@ abstract class SchemaPruningSuite Row("X.", 1) :: Row("Y.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) } - ignore("partial schema intersection - select missing subfield") { + testSchemaPruning("partial schema intersection - select missing subfield") { val query = sql("select name.middle, address from contacts where p=2") checkScan(query, "struct,address:string>") checkAnswer(query.orderBy("id"),