diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 43a6006f9b5c0..ea85014a37bd8 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -155,6 +155,53 @@ object NestedColumnAliasing { case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType) case _ => 1 // UDT and others } +} + +/** + * This prunes unnessary nested columns from `Generate` and optional `Project` on top + * of it. + */ +object GeneratorNestedColumnAliasing { + def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { + // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we + // need to prune nested columns through Project and under Generate. The difference is + // when `nestedSchemaPruningEnabled` is on, nested columns will be pruned further at + // file format readers if it is supported. + case Project(projectList, g: Generate) if (SQLConf.get.nestedPruningOnExpressions || + SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) => + // On top on `Generate`, a `Project` that might have nested column accessors. + // We try to get alias maps for both project list and generator's children expressions. + NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { + case (nestedFieldToAlias, attrToAliases) => + val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases) + Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + } + + case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && + canPruneGenerator(g.generator) => + NestedColumnAliasing.getAliasSubMap(g.generator.children).map { + case (nestedFieldToAlias, attrToAliases) => + pruneGenerate(g, nestedFieldToAlias, attrToAliases) + } + + case _ => + None + } + + private def pruneGenerate( + g: Generate, + nestedFieldToAlias: Map[ExtractValue, Alias], + attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { + val newGenerator = g.generator.transform { + case f: ExtractValue if nestedFieldToAlias.contains(f) => + nestedFieldToAlias(f).toAttribute + }.asInstanceOf[Generator] + + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + val newGenerate = g.copy(generator = newGenerator) + + NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) + } /** * This is a while-list for pruning nested fields at `Generator`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 935d62015afa1..0fdf6b022d885 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -598,31 +598,24 @@ object ColumnPruning extends Rule[LogicalPlan] { s.copy(child = prunedChild(child, s.references)) // prune unrequired references - case p @ Project(_, g: Generate) if p.references != g.outputSet => - val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references - val newChild = prunedChild(g.child, requiredAttrs) - val unrequired = g.generator.references -- p.references - val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) - .map(_._2) - p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) - - // prune unrequired nested fields - case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions && - NestedColumnAliasing.canPruneGenerator(g.generator) => - NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { - case (nestedFieldToAlias, attrToAliases) => - val newGenerator = g.generator.transform { - case f: ExtractValue if nestedFieldToAlias.contains(f) => - nestedFieldToAlias(f).toAttribute - }.asInstanceOf[Generator] - - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newGenerate = g.copy(generator = newGenerator) - - val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) - - Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) - }.getOrElse(p) + case p @ Project(_, g: Generate) => + val currP = if (p.references != g.outputSet) { + val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references + val newChild = prunedChild(g.child, requiredAttrs) + val unrequired = g.generator.references -- p.references + val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) + .map(_._2) + p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) + } else { + p + } + // If we can prune nested column on Project + Generate, do it now. + // Otherwise by transforming down to Generate, it could be pruned individually, + // and causes nested column on top Project unable to resolve. + GeneratorNestedColumnAliasing.unapply(currP).getOrElse(currP) + + // prune unrequired nested fields from `Generate`. + case GeneratorNestedColumnAliasing(p) => p // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index a3d4905e82cee..5977e867f788a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -301,6 +301,38 @@ abstract class SchemaPruningSuite checkAnswer(query, Row("Y.", 1) :: Row("X.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) } + testSchemaPruning("select explode of nested field of array of struct") { + // Config combinations + val configs = Seq((true, true), (true, false), (false, true), (false, false)) + + configs.foreach { case (nestedPruning, nestedPruningOnExpr) => + withSQLConf( + SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> nestedPruning.toString, + SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> nestedPruningOnExpr.toString) { + val query1 = spark.table("contacts") + .select(explode(col("friends.first"))) + if (nestedPruning) { + // If `NESTED_SCHEMA_PRUNING_ENABLED` is enabled, + // even disabling `NESTED_PRUNING_ON_EXPRESSIONS`, + // nested schema is still pruned at scan node. + checkScan(query1, "struct>>") + } else { + checkScan(query1, "struct>>") + } + checkAnswer(query1, Row("Susan") :: Nil) + + val query2 = spark.table("contacts") + .select(explode(col("friends.first")), col("friends.middle")) + if (nestedPruning) { + checkScan(query2, "struct>>") + } else { + checkScan(query2, "struct>>") + } + checkAnswer(query2, Row("Susan", Array("Z.")) :: Nil) + } + } + } + protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") {