diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index ea85014a37bd8..fd431f1e53ff9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -106,8 +106,11 @@ object NestedColumnAliasing { * * 1. ExtractValue -> Alias: A new alias is created for each nested field. * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it. + * + * @param exprList a sequence of expressions that possibly access nested fields. + * @param skipAttrs a set of attributes we do not want to replace nested fields within. */ - def getAliasSubMap(exprList: Seq[Expression]) + def getAliasSubMap(exprList: Seq[Expression], skipAttrs: AttributeSet = AttributeSet.empty) : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = { val (nestedFieldReferences, otherRootReferences) = exprList.flatMap(collectRootReferenceAndExtractValue).partition { @@ -116,7 +119,10 @@ object NestedColumnAliasing { } val aliasSub = nestedFieldReferences.asInstanceOf[Seq[ExtractValue]] - .filter(!_.references.subsetOf(AttributeSet(otherRootReferences))) + .filter { nestedRef => + !nestedRef.references.subsetOf(AttributeSet(otherRootReferences)) && + !nestedRef.references.subsetOf(skipAttrs) + } .groupBy(_.references.head) .flatMap { case (attr, nestedFields: Seq[ExtractValue]) => // Each expression can contain multiple nested fields. @@ -179,7 +185,10 @@ object GeneratorNestedColumnAliasing { case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && canPruneGenerator(g.generator) => - NestedColumnAliasing.getAliasSubMap(g.generator.children).map { + // For the child outputs required by the operator on top of `Generate`, we do not want + // to prune it. + val requiredAttrs = AttributeSet(g.requiredChildOutput) + NestedColumnAliasing.getAliasSubMap(g.generator.children, requiredAttrs).map { case (nestedFieldToAlias, attrToAliases) => pruneGenerate(g, nestedFieldToAlias, attrToAliases) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 5977e867f788a..17a1eb85da555 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -333,6 +333,14 @@ abstract class SchemaPruningSuite } } + testSchemaPruning("select explode of nested field of array of struct and " + + "all remaining nested fields") { + val query = spark.table("contacts") + .select(explode(col("friends.first")), col("friends.middle"), col("friends.last")) + checkScan(query, "struct>>") + checkAnswer(query, Row("Susan", Array("Z."), Array("Smith")) :: Nil) + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") {