diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala index d9009e3848e58..f08acbefd37b2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/FunctionRegistry.scala @@ -192,6 +192,7 @@ object FunctionRegistry { expression[Average]("mean"), expression[Min]("min"), expression[StddevSamp]("stddev"), + expression[StddevSamp1]("stddev1"), expression[StddevPop]("stddev_pop"), expression[StddevSamp]("stddev_samp"), expression[Sum]("sum"), diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala index 30f602227b17d..f14832398969d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/CentralMomentAgg.scala @@ -57,7 +57,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w override def dataType: DataType = DoubleType - override def inputTypes: Seq[AbstractDataType] = Seq(NumericType) + override def inputTypes: Seq[AbstractDataType] = Seq(DoubleType) override def checkInputDataTypes(): TypeCheckResult = TypeUtils.checkForNumericExpr(child.dataType, s"function $prettyName") @@ -109,7 +109,7 @@ abstract class CentralMomentAgg(child: Expression) extends ImperativeAggregate w * Update the central moments buffer. */ override def update(buffer: MutableRow, input: InternalRow): Unit = { - val v = Cast(child, DoubleType).eval(input) + val v = child.eval(input) if (v != null) { val updateValue = v match { case d: Double => d diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala index eec79a9033e36..6e9ad8fd13dfc 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Stddev.scala @@ -17,7 +17,9 @@ package org.apache.spark.sql.catalyst.expressions.aggregate +import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.types._ case class StddevSamp(child: Expression, mutableAggBufferOffset: Int = 0, @@ -79,3 +81,116 @@ case class StddevPop( } } } + +// Compute standard deviation based on online algorithm specified here: +// http://en.wikipedia.org/wiki/Algorithms_for_calculating_variance +abstract class StddevAgg(child: Expression) extends DeclarativeAggregate { + + def isSample: Boolean + + override def children: Seq[Expression] = child :: Nil + + override def nullable: Boolean = true + + override def dataType: DataType = resultType + + // Expected input data type. + // TODO: Right now, we replace old aggregate functions (based on AggregateExpression) to the + // new version at planning time (after analysis phase). For now, NullType is added at here + // to make it resolved when we have cases like `select stddev(null)`. + // We can use our analyzer to cast NullType to the default data type of the NumericType once + // we remove the old aggregate functions. Then, we will not need NullType at here. + override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(DoubleType, NullType)) + + private val resultType = DoubleType + + private val count = AttributeReference("count", resultType, false)() + private val avg = AttributeReference("avg", resultType, false)() + private val mk = AttributeReference("mk", resultType, false)() + + override val aggBufferAttributes = count :: avg :: mk :: Nil + + override val initialValues: Seq[Expression] = Seq( + /* count = */ Literal(0.0), + /* avg = */ Literal(0.0), + /* mk = */ Literal(0.0) + ) + + override val updateExpressions: Seq[Expression] = { + val newCount = count + Literal(1.0) + + // update average + // avg = avg + (value - avg)/count + val newAvg = avg + (child - avg) / newCount + + // update sum ofference from mean + // Mk = Mk + (value - preAvg) * (value - updatedAvg) + val newMk = mk + (child - avg) * (child - newAvg) + + if (child.nullable) { + Seq( + /* count = */ If(IsNull(child), count, newCount), + /* avg = */ If(IsNull(child), avg, newAvg), + /* mk = */ If(IsNull(child), mk, newMk) + ) + } else { + Seq( + /* count = */ newCount, + /* avg = */ newAvg, + /* mk = */ newMk + ) + } + } + + override val mergeExpressions: Seq[Expression] = { + + // count merge + val newCount = count.left + count.right + + // average merge + val newAvg = ((avg.left * count.left) + (avg.right * count.right)) / newCount + + // update sum of square differences + val newMk = { + val avgDelta = avg.right - avg.left + val mkDelta = (avgDelta * avgDelta) * (count.left * count.right) / newCount + mk.left + mk.right + mkDelta + } + + Seq( + /* count = */ newCount, + /* avg = */ newAvg, + /* mk = */ newMk + ) + } + + override val evaluateExpression: Expression = { + // when count == 0, return null + // when count == 1, return 0 + // when count >1 + // stddev_samp = sqrt (mk/(count -1)) + // stddev_pop = sqrt (mk/count) + val varCol = + if (isSample) { + mk / (count - Literal(1.0)) + } else { + mk / count + } + + If(EqualTo(count, Literal(0.0)), Literal.create(null, resultType), + If(EqualTo(count, Literal(1.0)), Literal(0.0), + Sqrt(varCol))) + } +} + +// Compute the population standard deviation of a column +case class StddevPop1(child: Expression) extends StddevAgg(child) { + override def isSample: Boolean = false + override def prettyName: String = "stddev_pop" +} + +// Compute the sample standard deviation of a column +case class StddevSamp1(child: Expression) extends StddevAgg(child) { + override def isSample: Boolean = true + override def prettyName: String = "stddev_samp" +} diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 59ef0f5836a3c..658d3c91bb8f4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -40,10 +40,14 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { val ctx = newCodeGenContext() - val projectionCodes = expressions.zipWithIndex.map { - case (NoOp, _) => "" - case (e, i) => - val evaluationCode = e.gen(ctx) + val (validExpr, index) = expressions.zipWithIndex.filter { + case (NoOp, _) => false + case _ => true + }.unzip + val exprVals = ctx.generateExpressions(validExpr, true) + val projectionCodes = exprVals.zip(index).map { + case (ev, i) => + val e = expressions(i) if (e.nullable) { val isNull = s"isNull_$i" val value = s"value_$i" @@ -51,22 +55,25 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$isNull = ${evaluationCode.isNull}; - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$isNull = ${ev.isNull}; + this.$value = ${ev.value}; """ } else { val value = s"value_$i" ctx.addMutableState(ctx.javaType(e.dataType), value, s"this.$value = ${ctx.defaultValue(e.dataType)};") s""" - ${evaluationCode.code} - this.$value = ${evaluationCode.value}; + ${ev.code} + this.$value = ${ev.value}; """ } } - val updates = expressions.zipWithIndex.map { - case (NoOp, _) => "" + + // Reset the subexpression values for each row. + val subexprReset = ctx.subExprResetVariables.mkString("\n") + + val updates = validExpr.zip(index).map { case (e, i) => if (e.nullable) { if (e.dataType.isInstanceOf[DecimalType]) { @@ -128,6 +135,7 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu public java.lang.Object apply(java.lang.Object _i) { InternalRow ${ctx.INPUT_ROW} = (InternalRow) _i; + $subexprReset $allProjections // copy all the results into MutableRow $allUpdates diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala index 61e7469ee4be2..687ae217034e4 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateUnsafeProjection.scala @@ -300,9 +300,8 @@ object GenerateUnsafeProjection extends CodeGenerator[Seq[Expression], UnsafePro val code = s""" $bufferHolder.reset(); - $subexprReset + ${subexprReset.trim} ${writeExpressionsToBuffer(ctx, ctx.INPUT_ROW, exprEvals, exprTypes, bufferHolder)} - $result.pointTo($bufferHolder.buffer, $bufferHolder.totalSize()); """ ExprCode(code, "false", result) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala index c74ef2c03541e..2242a842b13e1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala @@ -106,6 +106,8 @@ class GroupedData protected[sql]( UnresolvedFunction("avg", inputExpr :: Nil, isDistinct = false) case "stddev" | "std" => UnresolvedFunction("stddev", inputExpr :: Nil, isDistinct = false) + case "stddev1" | "std" => + UnresolvedFunction("stddev1", inputExpr :: Nil, isDistinct = false) // Also special handle count because we need to take care count(*). case "count" | "size" => // Turn count(*) into count(1) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index c15fabab805a7..043cb4ddcbe14 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -23,6 +23,7 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.SQLContext import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, Expression, LeafExpression} +import org.apache.spark.sql.catalyst.expressions.aggregate.ImperativeAggregate import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.rules.Rule @@ -190,7 +191,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) """ // try to compile, helpful for debug // println(s"${CodeFormatter.format(source)}") - CodeGenerator.compile(source) + // CodeGenerator.compile(source) rdd.mapPartitions { iter => val clazz = CodeGenerator.compile(source) @@ -264,12 +265,17 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) */ private[sql] case class CollapseCodegenStages(sqlContext: SQLContext) extends Rule[SparkPlan] { + private def supportCodegen(e: Expression): Boolean = e match { + case e: LeafExpression => true + case e: ImperativeAggregate => true + case e: CodegenFallback => false + case e => true + } + + private def supportCodegen(plan: SparkPlan): Boolean = plan match { case plan: CodegenSupport if plan.supportCodegen => - // Non-leaf with CodegenFallback does not work with whole stage codegen - val willFallback = plan.expressions.exists( - _.find(e => e.isInstanceOf[CodegenFallback] && !e.isInstanceOf[LeafExpression]).isDefined - ) + val willFallback = plan.expressions.exists(_.find(e => !supportCodegen(e)).isDefined) // the generated code will be huge if there are too many columns val haveManyColumns = plan.output.length > 200 !willFallback && !haveManyColumns diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 0c74df0aa5fdd..beed73cef0451 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -66,53 +66,9 @@ abstract class AggregationIterator( s"$aggregateExpressions can't have Partial/PartialMerge and Final/Complete in the same time.") } - // Initialize all AggregateFunctions by binding references if necessary, - // and set inputBufferOffset and mutableBufferOffset. - protected def initializeAggregateFunctions( - expressions: Seq[AggregateExpression], - startingInputBufferOffset: Int): Array[AggregateFunction] = { - var mutableBufferOffset = 0 - var inputBufferOffset: Int = startingInputBufferOffset - val functions = new Array[AggregateFunction](expressions.length) - var i = 0 - while (i < expressions.length) { - val func = expressions(i).aggregateFunction - val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { - case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => - // We need to create BoundReferences if the function is not an - // expression-based aggregate function (it does not support code-gen) and the mode of - // this function is Partial or Complete because we will call eval of this - // function's children in the update method of this aggregate function. - // Those eval calls require BoundReferences to work. - BindReferences.bindReference(func, inputAttributes) - case _ => - // We only need to set inputBufferOffset for aggregate functions with mode - // PartialMerge and Final. - val updatedFunc = func match { - case function: ImperativeAggregate => - function.withNewInputAggBufferOffset(inputBufferOffset) - case function => function - } - inputBufferOffset += func.aggBufferSchema.length - updatedFunc - } - val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { - case function: ImperativeAggregate => - // Set mutableBufferOffset for this function. It is important that setting - // mutableBufferOffset happens after all potential bindReference operations - // because bindReference will create a new instance of the function. - function.withNewMutableAggBufferOffset(mutableBufferOffset) - case function => function - } - mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length - functions(i) = funcWithUpdatedAggBufferOffset - i += 1 - } - functions - } - protected val aggregateFunctions: Array[AggregateFunction] = - initializeAggregateFunctions(aggregateExpressions, initialInputBufferOffset) + AggregationIterator.initializeAggregateFunctions( + aggregateExpressions, inputAttributes, initialInputBufferOffset) // Positions of those imperative aggregate functions in allAggregateFunctions. // For example, we have func1, func2, func3, func4 in aggregateFunctions, and @@ -259,3 +215,51 @@ abstract class AggregationIterator( } } } + +object AggregationIterator { + // Initialize all AggregateFunctions by binding references if necessary, + // and set inputBufferOffset and mutableBufferOffset. + def initializeAggregateFunctions( + expressions: Seq[AggregateExpression], + inputAttributes: Seq[Attribute], + startingInputBufferOffset: Int): Array[AggregateFunction] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction](expressions.length) + var i = 0 + while (i < expressions.length) { + val func = expressions(i).aggregateFunction + val funcWithBoundReferences: AggregateFunction = expressions(i).mode match { + case Partial | Complete if func.isInstanceOf[ImperativeAggregate] => + // We need to create BoundReferences if the function is not an + // expression-based aggregate function (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, inputAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + val updatedFunc = func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case function => function + } + inputBufferOffset += func.aggBufferSchema.length + updatedFunc + } + val funcWithUpdatedAggBufferOffset = funcWithBoundReferences match { + case function: ImperativeAggregate => + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case function => function + } + mutableBufferOffset += funcWithUpdatedAggBufferOffset.aggBufferSchema.length + functions(i) = funcWithUpdatedAggBufferOffset + i += 1 + } + functions + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index a9cf04388d2e8..110322a000fc9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -21,9 +21,10 @@ import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan, UnaryNode, UnsafeFixedWidthAggregationMap} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.StructType @@ -35,7 +36,7 @@ case class TungstenAggregate( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryNode { + extends UnaryNode with CodegenSupport { private[this] val aggregateBufferAttributes = { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) @@ -112,6 +113,191 @@ case class TungstenAggregate( } } + override def supportCodegen: Boolean = { + groupingExpressions.isEmpty && + // final aggregation only have one row, do not need to codegen + !aggregateExpressions.exists(e => e.mode == Final || e.mode == Complete) + } + + // For declarative functions + private val declFunctions = aggregateExpressions.map(_.aggregateFunction) + .filter(_.isInstanceOf[DeclarativeAggregate]) + .map(_.asInstanceOf[DeclarativeAggregate]) + private var declBufVars: Seq[ExprCode] = _ + + // For imperative functions + private val impFunctions = AggregationIterator.initializeAggregateFunctions( + aggregateExpressions.filter(_.aggregateFunction.isInstanceOf[ImperativeAggregate]), + child.output, + 0 + ).map(_.asInstanceOf[ImperativeAggregate]) + private var impBuffTerm: String = _ + private var impFunctionTerms: Array[String] = _ + + private val modes = aggregateExpressions.map(_.mode).distinct + + protected override def doProduce(ctx: CodegenContext): (RDD[InternalRow], String) = { + // generate buffer variables for declarative aggregation functions + val declInitExpr = declFunctions.flatMap(f => f.initialValues) + declBufVars = declInitExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + ctx.addMutableState("boolean", isNull, "") + val value = ctx.freshName("bufValue") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + // The initial expression should not access any column + val ev = e.gen(ctx) + val code = if (e.nullable) { + s""" + | $isNull = ${ev.isNull}; + | if (!${ev.isNull}) { + | $value = ${ev.value}; + | } + """.stripMargin + } else { + s"$value = ${ev.value};" + } + ExprCode(ev.code + code, isNull, value) + } + + // generate aggregation buffer for imperative functions + val impInitCode = if (impFunctions.nonEmpty) { + + // create aggregation buffer + val aggBuffer = new SpecificMutableRow( + impFunctions.flatMap(_.aggBufferAttributes).map(_.dataType)) + val buffIndex = ctx.references.length + ctx.references += aggBuffer + impBuffTerm = ctx.freshName("impBuff") + ctx.addMutableState("MutableRow", impBuffTerm, + s"this.$impBuffTerm = (MutableRow) references[$buffIndex];") + + // create varialbles for imperative functions + val funcName = classOf[ImperativeAggregate].getName + impFunctionTerms = impFunctions.map { f => + val idx = ctx.references.length + ctx.references += f + val funcTerm = ctx.freshName("aggFunc") + ctx.addMutableState(funcName, funcTerm, s"this.$funcTerm = ($funcName) references[$idx];") + funcTerm + } + + // call initialize() of imperative functions + impFunctionTerms.map { f => + s"$f.initialize($impBuffTerm);" + }.mkString("\n") + } else { + "" + } + + // create variables for result (aggregation buffer) + val modes = aggregateExpressions.map(_.mode).distinct + // Final aggregation only output one row, do not need codegen + assert(!modes.contains(Final) && !modes.contains(Complete)) + assert(modes.contains(Partial) || modes.contains(PartialMerge)) + + // create variables for imperative functions + // TODO: the next operator should be Exchange, we could output the aggregation buffer + // directly without creating any variables, if there is no declarative function. + ctx.INPUT_ROW = impBuffTerm + ctx.currentVars = null + val impAttrs = impFunctions.flatMap(_.aggBufferAttributes) + val impBufVars = impAttrs.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).gen(ctx) + } + + val (rdd, childSource) = child.asInstanceOf[CodegenSupport].produce(ctx, this) + + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + val source = + s""" + | if (!$initAgg) { + | $initAgg = true; + | + | // initialize declarative aggregation buffer + | ${declBufVars.map(_.code).mkString("\n")} + | + | // initialize imperative aggregate buffer + | $impInitCode + | + | $childSource + | + | // output the result + | ${impBufVars.map(_.code).mkString("\n")} + | ${consume(ctx, declBufVars ++ impBufVars)} + | } + """.stripMargin + + (rdd, source) + } + + override def doConsume(ctx: CodegenContext, child: SparkPlan, input: Seq[ExprCode]): String = { + // update expressions for declarative functions + val updateExpr = if (modes.contains(Partial)) { + declFunctions.flatMap(_.updateExpressions) + } else { + declFunctions.flatMap(_.mergeExpressions) + } + + // evaluate update expression to update buffer variables + val declInputAttr = declFunctions.flatMap(_.aggBufferAttributes) ++ child.output + val boundExpr = updateExpr.map(e => BindReferences.bindReference(e, declInputAttr)) + ctx.currentVars = declBufVars ++ input + // TODO: eliminate common sub-expression + val declUpdateCode = boundExpr.zipWithIndex.map { case (e, i) => + val ev = e.gen(ctx) + if (e.nullable) { + s""" + | ${ev.code} + | ${declBufVars(i).isNull} = ${ev.isNull}; + | if (!${ev.isNull}) { + | ${declBufVars(i).value} = ${ev.value}; + | } + """.stripMargin + } else { + s""" + | ${ev.code} + | ${declBufVars(i).isNull} = false; + | ${declBufVars(i).value} = ${ev.value}; + """.stripMargin + } + } + + val impUpdateCode = if (impFunctions.nonEmpty) { + // create a UnsafeRow as input for imperative functions + // TODO: only create the columns that are needed + val columns = child.output.zipWithIndex.map { + case (a, i) => new BoundReference(i, a.dataType, a.nullable) + } + ctx.currentVars = input + val rowCode = GenerateUnsafeProjection.createCode(ctx, columns) + + // call agg functions + // all aggregation expression should have the same mode + val updates = if (modes.contains(Partial)) { + impFunctionTerms.map { f => s"$f.update($impBuffTerm, ${rowCode.value});" } + } else { + impFunctionTerms.map { f => s"$f.merge($impBuffTerm, ${rowCode.value});" } + } + s""" + | // create an UnsafeRow for imperative functions + | ${rowCode.code} + | // call update()/merge() on imperative functions + | ${updates.mkString("\n")} + """.stripMargin + } else { + "" + } + + s""" + | // declarative aggregation + | ${declUpdateCode.mkString("\n")} + | + | // imperative aggregation + | $impUpdateCode + """.stripMargin + } + override def simpleString: String = { val allAggregateExpressions = aggregateExpressions diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 41799c596b6d3..9ee8525e20956 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -250,7 +250,8 @@ class TungstenAggregationIterator( agg.copy(mode = Final) case other => other } - val newFunctions = initializeAggregateFunctions(newExpressions, 0) + val newFunctions = AggregationIterator.initializeAggregateFunctions( + newExpressions, originalInputAttributes, 0) val newInputAttributes = newFunctions.flatMap(_.inputAggBufferAttributes) sortBasedProcessRow = generateProcessRow(newExpressions, newFunctions, newInputAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index 83379ae90f703..11e2b92f6cf61 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -87,12 +87,16 @@ object Utils { aggregateFunctionToAttribute: Map[(AggregateFunction, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - // Check if we can use TungstenAggregate. + + // group aggregation expressions by type (declarative functions always come before imperative + // functions) to easy code generation + val sortedAggExpressions = + aggregateExpressions.sortBy(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) // 1. Create an Aggregate Operator for partial aggregations. val groupingAttributes = groupingExpressions.map(_.toAttribute) - val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) + val partialAggregateExpressions = sortedAggExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = @@ -109,7 +113,7 @@ object Utils { child = child) // 2. Create an Aggregate Operator for final aggregations. - val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) + val finalAggregateExpressions = sortedAggExpressions.map(_.copy(mode = Final)) // The attributes of the final aggregation buffer, which is presented as input to the result // projection: val finalAggregateAttributes = finalAggregateExpressions.map { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala index 788b04fcf8c2e..6f1f5b9e83477 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala @@ -27,12 +27,13 @@ import org.apache.spark.util.Benchmark * build/sbt "sql/test-only *BenchmarkWholeStageCodegen" */ class BenchmarkWholeStageCodegen extends SparkFunSuite { - def testWholeStage(values: Int): Unit = { - val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") - val sc = SparkContext.getOrCreate(conf) - val sqlContext = SQLContext.getOrCreate(sc) + lazy val conf = new SparkConf().setMaster("local[1]").setAppName("benchmark") + lazy val sc = SparkContext.getOrCreate(conf) + lazy val sqlContext = SQLContext.getOrCreate(sc) - val benchmark = new Benchmark("Single Int Column Scan", values) + def testRangeFilterAndAggregation(values: Int): Unit = { + + val benchmark = new Benchmark("range/filter/aggregation", values) benchmark.addCase("Without whole stage codegen") { iter => sqlContext.setConf("spark.sql.codegen.wholeStage", "false") @@ -46,15 +47,64 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite { /* Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz - Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate - ------------------------------------------------------------------------- - Without whole stage codegen 6725.52 31.18 1.00 X - With whole stage codegen 2233.05 93.91 3.01 X + Single Int Column Scan: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + Without whole stage codegen 6585.36 31.85 1.00 X + With whole stage codegen 343.80 609.99 19.15 X + */ + benchmark.run() + } + + def testImperitaveAggregation(values: Int): Unit = { + + val benchmark = new Benchmark("aggregation", values) + + benchmark.addCase("ImpAgg w/o whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("DeclAgg w/o whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "false") + sqlContext.range(values).groupBy().agg("id" -> "stddev1").collect() + } + + benchmark.addCase("ImpAgg w whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "stddev").collect() + } + + benchmark.addCase("DeclAgg w whole stage codegen") { iter => + sqlContext.setConf("spark.sql.codegen.wholeStage", "true") + sqlContext.range(values).groupBy().agg("id" -> "stddev1").collect() + } + + /* + Before optimizing CentralMomentAgg and generated mutable projection: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + aggregation: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + ImpAgg w/o whole stage codegen 9047.35 11.59 1.00 X + DeclAgg w/o whole stage codegen 6507.27 16.11 1.39 X + ImpAgg w whole stage codegen 6947.30 15.09 1.30 X + DeclAgg w whole stage codegen 1376.74 76.16 6.57 X + + After optimization: + + Intel(R) Core(TM) i7-4558U CPU @ 2.80GHz + aggregation: Avg Time(ms) Avg Rate(M/s) Relative Rate + ------------------------------------------------------------------------------- + ImpAgg w/o whole stage codegen 6159.03 17.03 1.00 X + DeclAgg w/o whole stage codegen 5248.69 19.98 1.17 X + ImpAgg w whole stage codegen 4202.30 24.95 1.47 X + DeclAgg w whole stage codegen 1367.34 76.69 4.50 X */ benchmark.run() } - ignore("benchmark") { - testWholeStage(1024 * 1024 * 200) + test("benchmark") { + testRangeFilterAndAggregation(1024 * 1024 * 200) + testImperitaveAggregation(1024 * 1024 * 100) } }