diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala index 95aefb6422d67..43a6006f9b5c0 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala @@ -54,7 +54,7 @@ object NestedColumnAliasing { /** * Return a replaced project list. */ - private def getNewProjectList( + def getNewProjectList( projectList: Seq[NamedExpression], nestedFieldToAlias: Map[ExtractValue, Alias]): Seq[NamedExpression] = { projectList.map(_.transform { @@ -66,7 +66,7 @@ object NestedColumnAliasing { /** * Return a plan with new children replaced with aliases. */ - private def replaceChildrenWithAliases( + def replaceChildrenWithAliases( plan: LogicalPlan, attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = { plan.withNewChildren(plan.children.map { plan => @@ -107,10 +107,10 @@ object NestedColumnAliasing { * 1. ExtractValue -> Alias: A new alias is created for each nested field. * 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it. */ - private def getAliasSubMap(projectList: Seq[NamedExpression]) + def getAliasSubMap(exprList: Seq[Expression]) : Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = { val (nestedFieldReferences, otherRootReferences) = - projectList.flatMap(collectRootReferenceAndExtractValue).partition { + exprList.flatMap(collectRootReferenceAndExtractValue).partition { case _: ExtractValue => true case _ => false } @@ -155,4 +155,15 @@ object NestedColumnAliasing { case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType) case _ => 1 // UDT and others } + + /** + * This is a while-list for pruning nested fields at `Generator`. + */ + def canPruneGenerator(g: Generator): Boolean = g match { + case _: Explode => true + case _: Stack => true + case _: PosExplode => true + case _: Inline => true + case _ => false + } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index c99d2c06fac63..b59cbd802b86b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -588,6 +588,24 @@ object ColumnPruning extends Rule[LogicalPlan] { .map(_._2) p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices)) + // prune unrequired nested fields + case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions && + NestedColumnAliasing.canPruneGenerator(g.generator) => + NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map { + case (nestedFieldToAlias, attrToAliases) => + val newGenerator = g.generator.transform { + case f: ExtractValue if nestedFieldToAlias.contains(f) => + nestedFieldToAlias(f).toAttribute + }.asInstanceOf[Generator] + + // Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`. + val newGenerate = g.copy(generator = newGenerator) + + val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases) + + Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild) + }.getOrElse(p) + // Eliminate unneeded attributes from right side of a Left Existence Join. case j @ Join(_, right, LeftExistence(_), _, _) => j.copy(right = prunedChild(right, j.references)) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 57f5128fd4fbe..fbdb1c5f957d9 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -1656,6 +1656,16 @@ object SQLConf { .booleanConf .createWithDefault(false) + val NESTED_PRUNING_ON_EXPRESSIONS = + buildConf("spark.sql.optimizer.expression.nestedPruning.enabled") + .internal() + .doc("Prune nested fields from expressions in an operator which are unnecessary in " + + "satisfying a query. Note that this optimization doesn't prune nested fields from " + + "physical data source scanning. For pruning nested fields from scanning, please use " + + "`spark.sql.optimizer.nestedSchemaPruning.enabled` config.") + .booleanConf + .createWithDefault(false) + val TOP_K_SORT_FALLBACK_THRESHOLD = buildConf("spark.sql.execution.topKSortFallbackThreshold") .internal() @@ -2315,6 +2325,8 @@ class SQLConf extends Serializable with Logging { def serializerNestedSchemaPruningEnabled: Boolean = getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED) + def nestedPruningOnExpressions: Boolean = getConf(NESTED_PRUNING_ON_EXPRESSIONS) + def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING) def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala index 78ae131328644..75ff07637fccc 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala @@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import scala.reflect.runtime.universe.TypeTag +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder @@ -26,7 +27,8 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.rules.RuleExecutor -import org.apache.spark.sql.types.StringType +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types.{StringType, StructType} class ColumnPruningSuite extends PlanTest { @@ -101,6 +103,81 @@ class ColumnPruningSuite extends PlanTest { comparePlans(optimized, correctAnswer) } + test("Nested column pruning for Generate") { + def runTest( + origGenerator: Generator, + replacedGenerator: Seq[String] => Generator, + aliasedExprs: Seq[String] => Seq[Expression], + unrequiredChildIndex: Seq[Int], + generatorOutputNames: Seq[String]) { + withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") { + val structType = StructType.fromDDL("d double, e array, f double, g double, " + + "h array>") + val input = LocalRelation('a.int, 'b.int, 'c.struct(structType)) + val generatorOutputs = generatorOutputNames.map(UnresolvedAttribute(_)) + + val selectedExprs = Seq(UnresolvedAttribute("a"), 'c.getField("d")) ++ + generatorOutputs + + val query = + input + .generate(origGenerator, outputNames = generatorOutputNames) + .select(selectedExprs: _*) + .analyze + + val optimized = Optimize.execute(query) + + val aliases = NestedColumnAliasingSuite.collectGeneratedAliases(optimized) + + val selectedFields = UnresolvedAttribute("a") +: aliasedExprs(aliases) + val finalSelectedExprs = Seq(UnresolvedAttribute("a"), $"${aliases(0)}".as("c.d")) ++ + generatorOutputs + + val correctAnswer = + input + .select(selectedFields: _*) + .generate(replacedGenerator(aliases), + unrequiredChildIndex = unrequiredChildIndex, + outputNames = generatorOutputNames) + .select(finalSelectedExprs: _*) + .analyze + + comparePlans(optimized, correctAnswer) + } + } + + runTest( + Explode('c.getField("e")), + aliases => Explode($"${aliases(1)}".as("c.e")), + aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))), + Seq(2), + Seq("explode") + ) + runTest(Stack(2 :: 'c.getField("f") :: 'c.getField("g") :: Nil), + aliases => Stack(2 :: $"${aliases(1)}".as("c.f") :: $"${aliases(2)}".as("c.g") :: Nil), + aliases => Seq( + 'c.getField("d").as(aliases(0)), + 'c.getField("f").as(aliases(1)), + 'c.getField("g").as(aliases(2))), + Seq(2, 3), + Seq("stack") + ) + runTest( + PosExplode('c.getField("e")), + aliases => PosExplode($"${aliases(1)}".as("c.e")), + aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))), + Seq(2), + Seq("pos", "explode") + ) + runTest( + Inline('c.getField("h")), + aliases => Inline($"${aliases(1)}".as("c.h")), + aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("h").as(aliases(1))), + Seq(2), + Seq("h1", "h2") + ) + } + test("Column pruning for Project on Sort") { val input = LocalRelation('a.int, 'b.string, 'c.double) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala index ab2bd6dff1265..2351d8321c5f3 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType} class NestedColumnAliasingSuite extends SchemaPruningTest { + import NestedColumnAliasingSuite._ + object Optimize extends RuleExecutor[LogicalPlan] { val batches = Batch("Nested column pruning", FixedPoint(100), ColumnPruning, @@ -264,9 +266,10 @@ class NestedColumnAliasingSuite extends SchemaPruningTest { .analyze comparePlans(optimized, expected) } +} - - private def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = { +object NestedColumnAliasingSuite { + def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = { val aliases = ArrayBuffer[String]() query.transformAllExpressions { case a @ Alias(_, name) if name.startsWith("_gen_alias_") => diff --git a/sql/core/benchmarks/MiscBenchmark-results.txt b/sql/core/benchmarks/MiscBenchmark-results.txt index 85acd57893655..c4ae052095656 100644 --- a/sql/core/benchmarks/MiscBenchmark-results.txt +++ b/sql/core/benchmarks/MiscBenchmark-results.txt @@ -105,6 +105,13 @@ generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Ro generate big struct array wholestage off 708 / 776 0.1 11803.5 1.0X generate big struct array wholestage on 535 / 589 0.1 8913.9 1.3X +OpenJDK 64-Bit Server VM 1.8.0_212-b04 on Linux 3.10.0-862.3.2.el7.x86_64 +Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz +generate big nested struct array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative +------------------------------------------------------------------------------------------------------------------------ +generate big nested struct array wholestage off 540 553 19 0.1 8997.4 1.0X +generate big nested struct array wholestage on 523 554 31 0.1 8725.0 1.0X + ================================================================================================ generate regular generator diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index c4662c8999e42..bafc0337bdc0e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.benchmark import org.apache.spark.benchmark.Benchmark +import org.apache.spark.sql.internal.SQLConf /** * Benchmark to measure whole stage codegen performance. @@ -130,6 +131,20 @@ object MiscBenchmark extends SqlBasedBenchmark { df.selectExpr("*", "explode(arr) as arr_col") .select("col", "arr_col.*").count } + + withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") { + codegenBenchmark("generate big nested struct array", M) { + import spark.implicits._ + val df = spark.sparkContext.parallelize(Seq(("1", + Array.fill(M)({ + val i = math.random + (i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString) + })))).toDF("col", "arr") + .selectExpr("col", "struct(col, arr) as st") + .selectExpr("col", "st.col as col1", "explode(st.arr) as arr_col") + df.collect() + } + } } }