diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index ea85014a37bd8..43a6006f9b5c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -155,53 +155,6 @@ object NestedColumnAliasing { case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType) case _ => 1 // UDT and others } -} - -/** - * This prunes unnessary nested columns from `Generate` and optional `Project` on top - * of it. - */ -object GeneratorNestedColumnAliasing { - def unapply(plan: LogicalPlan): Option[LogicalPlan] = plan match { - // Either `nestedPruningOnExpressions` or `nestedSchemaPruningEnabled` is enabled, we - // need to prune nested columns through Project and under Generate. The difference is - // when `nestedSchemaPruningEnabled` is on, nested columns will be pruned further at - // file format readers if it is supported. - case Project(projectList, g: Generate) if (SQLConf.get.nestedPruningOnExpressions || - SQLConf.get.nestedSchemaPruningEnabled) && canPruneGenerator(g.generator) => - // On top on `Generate`, a `Project` that might have nested column accessors. - // We try to get alias maps for both project list and generator's children expressions. - NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { - case (nestedFieldToAlias, attrToAliases) => - val newChild = pruneGenerate(g, nestedFieldToAlias, attrToAliases) - Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) - } - - case g: Generate if SQLConf.get.nestedSchemaPruningEnabled && - canPruneGenerator(g.generator) => - NestedColumnAliasing.getAliasSubMap(g.generator.children).map { - case (nestedFieldToAlias, attrToAliases) => - pruneGenerate(g, nestedFieldToAlias, attrToAliases) - } - - case _ => - None - } - - private def pruneGenerate( - g: Generate, - nestedFieldToAlias: Map[ExtractValue, Alias], - attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { - val newGenerator = g.generator.transform { - case f: ExtractValue if nestedFieldToAlias.contains(f) => - nestedFieldToAlias(f).toAttribute - }.asInstanceOf[Generator] - - // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. - val newGenerate = g.copy(generator = newGenerator) - - NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) - } /** * This is a while-list for pruning nested fields at `Generator`. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 0fdf6b022d885..935d62015afa1 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -598,24 +598,31 @@ object ColumnPruning extends Rule[LogicalPlan] { s.copy(child = prunedChild(child, s.references)) // prune unrequired references - case p @ Project(_, g: Generate) => - val currP = if (p.references != g.outputSet) { - val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references - val newChild = prunedChild(g.child, requiredAttrs) - val unrequired = g.generator.references -- p.references - val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) - .map(_._2) - p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) - } else { - p - } - // If we can prune nested column on Project + Generate, do it now. - // Otherwise by transforming down to Generate, it could be pruned individually, - // and causes nested column on top Project unable to resolve. - GeneratorNestedColumnAliasing.unapply(currP).getOrElse(currP) - - // prune unrequired nested fields from `Generate`. - case GeneratorNestedColumnAliasing(p) => p + case p @ Project(_, g: Generate) if p.references != g.outputSet => + val requiredAttrs = p.references -- g.producedAttributes ++ g.generator.references + val newChild = prunedChild(g.child, requiredAttrs) + val unrequired = g.generator.references -- p.references + val unrequiredIndices = newChild.output.zipWithIndex.filter(t => unrequired.contains(t._1)) + .map(_._2) + p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) + + // prune unrequired nested fields + case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions && + NestedColumnAliasing.canPruneGenerator(g.generator) => + NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { + case (nestedFieldToAlias, attrToAliases) => + val newGenerator = g.generator.transform { + case f: ExtractValue if nestedFieldToAlias.contains(f) => + nestedFieldToAlias(f).toAttribute + }.asInstanceOf[Generator] + + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + val newGenerate = g.copy(generator = newGenerator) + + val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) + + Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + }.getOrElse(p) // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _, _) => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala index 5977e867f788a..a3d4905e82cee 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala @@ -301,38 +301,6 @@ abstract class SchemaPruningSuite checkAnswer(query, Row("Y.", 1) :: Row("X.", 1) :: Row(null, 2) :: Row(null, 2) :: Nil) } - testSchemaPruning("select explode of nested field of array of struct") { - // Config combinations - val configs = Seq((true, true), (true, false), (false, true), (false, false)) - - configs.foreach { case (nestedPruning, nestedPruningOnExpr) => - withSQLConf( - SQLConf.NESTED_SCHEMA_PRUNING_ENABLED.key -> nestedPruning.toString, - SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> nestedPruningOnExpr.toString) { - val query1 = spark.table("contacts") - .select(explode(col("friends.first"))) - if (nestedPruning) { - // If `NESTED_SCHEMA_PRUNING_ENABLED` is enabled, - // even disabling `NESTED_PRUNING_ON_EXPRESSIONS`, - // nested schema is still pruned at scan node. - checkScan(query1, "struct>>") - } else { - checkScan(query1, "struct>>") - } - checkAnswer(query1, Row("Susan") :: Nil) - - val query2 = spark.table("contacts") - .select(explode(col("friends.first")), col("friends.middle")) - if (nestedPruning) { - checkScan(query2, "struct>>") - } else { - checkScan(query2, "struct>>") - } - checkAnswer(query2, Row("Susan", Array("Z.")) :: Nil) - } - } - } - protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = { test(s"Spark vectorized reader - without partition data column - $testName") { withSQLConf(vectorizedReaderEnabledKey -> "true") {