Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ object NestedColumnAliasing {
/**
* Return a replaced project list.
*/
private def getNewProjectList(
def getNewProjectList(
projectList: Seq[NamedExpression],
nestedFieldToAlias: Map[ExtractValue, Alias]): Seq[NamedExpression] = {
projectList.map(_.transform {
Expand All @@ -66,7 +66,7 @@ object NestedColumnAliasing {
/**
* Return a plan with new children replaced with aliases.
*/
private def replaceChildrenWithAliases(
def replaceChildrenWithAliases(
plan: LogicalPlan,
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
plan.withNewChildren(plan.children.map { plan =>
Expand Down Expand Up @@ -107,10 +107,10 @@ object NestedColumnAliasing {
* 1. ExtractValue -> Alias: A new alias is created for each nested field.
* 2. ExprId -> Seq[Alias]: A reference attribute has multiple aliases pointing it.
*/
private def getAliasSubMap(projectList: Seq[NamedExpression])
def getAliasSubMap(exprList: Seq[Expression])
: Option[(Map[ExtractValue, Alias], Map[ExprId, Seq[Alias]])] = {
val (nestedFieldReferences, otherRootReferences) =
projectList.flatMap(collectRootReferenceAndExtractValue).partition {
exprList.flatMap(collectRootReferenceAndExtractValue).partition {
case _: ExtractValue => true
case _ => false
}
Expand Down Expand Up @@ -155,4 +155,15 @@ object NestedColumnAliasing {
case MapType(keyType, valueType, _) => totalFieldNum(keyType) + totalFieldNum(valueType)
case _ => 1 // UDT and others
}

/**
* This is a while-list for pruning nested fields at `Generator`.
*/
def canPruneGenerator(g: Generator): Boolean = g match {
case _: Explode => true
case _: Stack => true
case _: PosExplode => true
case _: Inline => true
case _ => false
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -588,6 +588,24 @@ object ColumnPruning extends Rule[LogicalPlan] {
.map(_._2)
p.copy(child = g.copy(child = newChild, unrequiredChildIndex = unrequiredIndices))

// prune unrequired nested fields
case p @ Project(projectList, g: Generate) if SQLConf.get.nestedPruningOnExpressions &&
NestedColumnAliasing.canPruneGenerator(g.generator) =>
NestedColumnAliasing.getAliasSubMap(projectList ++ g.generator.children).map {
case (nestedFieldToAlias, attrToAliases) =>
val newGenerator = g.generator.transform {
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
nestedFieldToAlias(f).toAttribute
}.asInstanceOf[Generator]

// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
val newGenerate = g.copy(generator = newGenerator)

val newChild = NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases)

Project(NestedColumnAliasing.getNewProjectList(projectList, nestedFieldToAlias), newChild)
}.getOrElse(p)

// Eliminate unneeded attributes from right side of a Left Existence Join.
case j @ Join(_, right, LeftExistence(_), _, _) =>
j.copy(right = prunedChild(right, j.references))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1656,6 +1656,16 @@ object SQLConf {
.booleanConf
.createWithDefault(false)

val NESTED_PRUNING_ON_EXPRESSIONS =
buildConf("spark.sql.optimizer.expression.nestedPruning.enabled")
.internal()
.doc("Prune nested fields from expressions in an operator which are unnecessary in " +
"satisfying a query. Note that this optimization doesn't prune nested fields from " +
"physical data source scanning. For pruning nested fields from scanning, please use " +
"`spark.sql.optimizer.nestedSchemaPruning.enabled` config.")
.booleanConf
.createWithDefault(false)

val TOP_K_SORT_FALLBACK_THRESHOLD =
buildConf("spark.sql.execution.topKSortFallbackThreshold")
.internal()
Expand Down Expand Up @@ -2315,6 +2325,8 @@ class SQLConf extends Serializable with Logging {
def serializerNestedSchemaPruningEnabled: Boolean =
getConf(SERIALIZER_NESTED_SCHEMA_PRUNING_ENABLED)

def nestedPruningOnExpressions: Boolean = getConf(NESTED_PRUNING_ON_EXPRESSIONS)

def csvColumnPruning: Boolean = getConf(SQLConf.CSV_PARSER_COLUMN_PRUNING)

def legacySizeOfNull: Boolean = getConf(SQLConf.LEGACY_SIZE_OF_NULL)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ package org.apache.spark.sql.catalyst.optimizer

import scala.reflect.runtime.universe.TypeTag

import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.types.StringType
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{StringType, StructType}

class ColumnPruningSuite extends PlanTest {

Expand Down Expand Up @@ -101,6 +103,81 @@ class ColumnPruningSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}

test("Nested column pruning for Generate") {
def runTest(
origGenerator: Generator,
replacedGenerator: Seq[String] => Generator,
aliasedExprs: Seq[String] => Seq[Expression],
unrequiredChildIndex: Seq[Int],
generatorOutputNames: Seq[String]) {
withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") {
val structType = StructType.fromDDL("d double, e array<string>, f double, g double, " +
"h array<struct<h1: int, h2: double>>")
val input = LocalRelation('a.int, 'b.int, 'c.struct(structType))
val generatorOutputs = generatorOutputNames.map(UnresolvedAttribute(_))

val selectedExprs = Seq(UnresolvedAttribute("a"), 'c.getField("d")) ++
generatorOutputs

val query =
input
.generate(origGenerator, outputNames = generatorOutputNames)
.select(selectedExprs: _*)
.analyze

val optimized = Optimize.execute(query)

val aliases = NestedColumnAliasingSuite.collectGeneratedAliases(optimized)

val selectedFields = UnresolvedAttribute("a") +: aliasedExprs(aliases)
val finalSelectedExprs = Seq(UnresolvedAttribute("a"), $"${aliases(0)}".as("c.d")) ++
generatorOutputs

val correctAnswer =
input
.select(selectedFields: _*)
.generate(replacedGenerator(aliases),
unrequiredChildIndex = unrequiredChildIndex,
outputNames = generatorOutputNames)
.select(finalSelectedExprs: _*)
.analyze

comparePlans(optimized, correctAnswer)
}
}

runTest(
Explode('c.getField("e")),
aliases => Explode($"${aliases(1)}".as("c.e")),
aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))),
Seq(2),
Seq("explode")
)
runTest(Stack(2 :: 'c.getField("f") :: 'c.getField("g") :: Nil),
aliases => Stack(2 :: $"${aliases(1)}".as("c.f") :: $"${aliases(2)}".as("c.g") :: Nil),
aliases => Seq(
'c.getField("d").as(aliases(0)),
'c.getField("f").as(aliases(1)),
'c.getField("g").as(aliases(2))),
Seq(2, 3),
Seq("stack")
)
runTest(
PosExplode('c.getField("e")),
aliases => PosExplode($"${aliases(1)}".as("c.e")),
aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("e").as(aliases(1))),
Seq(2),
Seq("pos", "explode")
)
runTest(
Inline('c.getField("h")),
aliases => Inline($"${aliases(1)}".as("c.h")),
aliases => Seq('c.getField("d").as(aliases(0)), 'c.getField("h").as(aliases(1))),
Seq(2),
Seq("h1", "h2")
)
}

test("Column pruning for Project on Sort") {
val input = LocalRelation('a.int, 'b.string, 'c.double)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ import org.apache.spark.sql.types.{StringType, StructField, StructType}

class NestedColumnAliasingSuite extends SchemaPruningTest {

import NestedColumnAliasingSuite._

object Optimize extends RuleExecutor[LogicalPlan] {
val batches = Batch("Nested column pruning", FixedPoint(100),
ColumnPruning,
Expand Down Expand Up @@ -264,9 +266,10 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
.analyze
comparePlans(optimized, expected)
}
}


private def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
object NestedColumnAliasingSuite {
def collectGeneratedAliases(query: LogicalPlan): ArrayBuffer[String] = {
val aliases = ArrayBuffer[String]()
query.transformAllExpressions {
case a @ Alias(_, name) if name.startsWith("_gen_alias_") =>
Expand Down
7 changes: 7 additions & 0 deletions sql/core/benchmarks/MiscBenchmark-results.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,13 @@ generate big struct array: Best/Avg Time(ms) Rate(M/s) Per Ro
generate big struct array wholestage off 708 / 776 0.1 11803.5 1.0X
generate big struct array wholestage on 535 / 589 0.1 8913.9 1.3X

OpenJDK 64-Bit Server VM 1.8.0_212-b04 on Linux 3.10.0-862.3.2.el7.x86_64
Intel(R) Xeon(R) CPU E5-2670 v2 @ 2.50GHz
generate big nested struct array: Best Time(ms) Avg Time(ms) Stdev(ms) Rate(M/s) Per Row(ns) Relative
------------------------------------------------------------------------------------------------------------------------
generate big nested struct array wholestage off 540 553 19 0.1 8997.4 1.0X
generate big nested struct array wholestage on 523 554 31 0.1 8725.0 1.0X


================================================================================================
generate regular generator
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark.sql.execution.benchmark

import org.apache.spark.benchmark.Benchmark
import org.apache.spark.sql.internal.SQLConf

/**
* Benchmark to measure whole stage codegen performance.
Expand Down Expand Up @@ -130,6 +131,20 @@ object MiscBenchmark extends SqlBasedBenchmark {
df.selectExpr("*", "explode(arr) as arr_col")
.select("col", "arr_col.*").count
}

withSQLConf(SQLConf.NESTED_PRUNING_ON_EXPRESSIONS.key -> "true") {
codegenBenchmark("generate big nested struct array", M) {
import spark.implicits._
val df = spark.sparkContext.parallelize(Seq(("1",
Array.fill(M)({
val i = math.random
(i.toString, (i + 1).toString, (i + 2).toString, (i + 3).toString)
})))).toDF("col", "arr")
.selectExpr("col", "struct(col, arr) as st")
.selectExpr("col", "st.col as col1", "explode(st.arr) as arr_col")
df.collect()
}
}
}
}

Expand Down