Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -264,12 +264,16 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
*/
private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] {

private def supportCodegen(e: Expression): Boolean = e match {
case e: LeafExpression => true
// CodegenFallback requires the input to be an InternalRow
case e: CodegenFallback => false
case _ => true
}

private def supportCodegen(plan: SparkPlan): Boolean = plan match {
case plan: CodegenSupport if plan.supportCodegen =>
// Non-leaf with CodegenFallback does not work with whole stage codegen
val willFallback = plan.expressions.exists(
_.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined
)
val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined)
// the generated code will be huge if there are too many columns
val haveManyColumns = plan.output.length > 200
!willFallback && !haveManyColumns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap}
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.StructType

Expand All @@ -35,7 +36,7 @@ case class TungstenAggregate(
initialInputBufferOffset: Int,
resultExpressions: Seq[NamedExpression],
child: SparkPlan)
extends UnaryNode {
extends UnaryNode with CodegenSupport {

private[this] val aggregateBufferAttributes = {
aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
Expand Down Expand Up @@ -113,6 +114,86 @@ case class TungstenAggregate(
}
}

override def supportCodegen: Boolean = {
groupingExpressions.isEmpty &&
// ImperativeAggregate is not supported right now
!aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) &&
// final aggregation only have one row, do not need to codegen
!aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete)
}

// The variables used as aggregation buffer
private var bufVars: Seq[ExprCode] = _

private val modes = aggregateExpressions.map(_.mode).distinct

protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = {
val initAgg = ctx.freshName("initAgg")
ctx.addMutableState("boolean", initAgg, s"$initAgg = false;")

// generate variables for aggregation buffer
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
val initExpr = functions.flatMap(f => f.initialValues)
bufVars = initExpr.map { e =>
val isNull = ctx.freshName("bufIsNull")
val value = ctx.freshName("bufValue")
// The initial expression should not access any column
val ev = e.gen(ctx)
val initVars = s"""
| boolean $isNull = ${ev.isNull};
| ${ctx.javaType(e.dataType)} $value = ${ev.value};
""".stripMargin
ExprCode(ev.code + initVars, isNull, value)
}

val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this)
val source =
s"""
| if (!$initAgg) {
| $initAgg = true;
|
| // initialize aggregation buffer
| ${bufVars.map(_.code).mkString("\n")}
|
| $childSource
|
| // output the result
| ${consume(ctx, bufVars)}
| }
""".stripMargin

(rdd, source)
}

override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = {
// only have DeclarativeAggregate
val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate])
// the mode could be only Partial or PartialMerge
val updateExpr = if (modes.contains(Partial)) {
functions.flatMap(_.updateExpressions)
} else {
functions.flatMap(_.mergeExpressions)
}

val inputAttr = functions.flatMap(_.aggBufferAttributes) ++ child.output
val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, inputAttr))
ctx.currentVars = bufVars ++ input
// TODO: support subexpression elimination
val codes = boundExpr.zipWithIndex.map { case (e, i) =>
val ev = e.gen(ctx)
s"""
| ${ev.code}
| ${bufVars(i).isNull} = ${ev.isNull};
| ${bufVars(i).value} = ${ev.value};
""".stripMargin
}

s"""
| // do aggregate and update aggregation buffer
| ${codes.mkString("")}
""".stripMargin
}

override def simpleString: String = {
val allAggregateExpressions = aggregateExpressions

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,10 +46,10 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {

/*
Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz
Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------
Without whole stage codegen 6725.52 31.18 1.00 X
With whole stage codegen 2233.05 93.91 3.01 X
Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate
-------------------------------------------------------------------------------
Without whole stage codegen 7775.53 26.97 1.00 X
With whole stage codegen 342.15 612.94 22.73 X
*/
benchmark.run()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@

package org.apache.spark.sql.execution

import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.execution.aggregate.TungstenAggregate
import org.apache.spark.sql.functions.{avg, col, max}
import org.apache.spark.sql.test.SharedSQLContext

class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
Expand All @@ -35,4 +38,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
sortAnswers = false
)
}

test("Aggregate should be included in WholeStageCodegen") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add a case with multiple agg exprs?

val df = sqlContext.range(10).groupBy().agg(max(col("id")), avg(col("id")))
val plan = df.queryExecution.executedPlan
assert(plan.find(p =>
p.isInstanceOf[WholeStageCodegen] &&
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[TungstenAggregate]).isDefined)
assert(df.collect() === Array(Row(9, 4.5)))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
val previousExecutionIds = sqlContext.listener.executionIdToData.keySet
df.collect()
withSQLConf("spark.sql.codegen.wholeStage" -> "false") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@davies, why this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this PR, the SQLMetrics are not supported in whole stage codegen.

df.collect()
}
sparkContext.listenerBus.waitUntilEmpty(10000)
val executionIds = sqlContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
Expand Down