diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala index 315bedb12e71..c925fa8d3ebf 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala @@ -103,6 +103,12 @@ object SQLConf { .booleanConf .createWithDefault(true) + val PREFER_SORTAGGREGATE = buildConf("spark.sql.aggregate.preferSortAggregate") + .internal() + .doc("When true, prefer sort aggregate over shuffle hash aggregate.") + .booleanConf + .createWithDefault(false) + val RADIX_SORT_ENABLED = buildConf("spark.sql.sort.enableRadixSort") .internal() .doc("When true, enable use of radix sort when possible. Radix sort is much faster but " + @@ -853,6 +859,8 @@ class SQLConf extends Serializable with Logging { def preferSortMergeJoin: Boolean = getConf(PREFER_SORTMERGEJOIN) + def preferSortAggregate: Boolean = getConf(PREFER_SORTAGGREGATE) + def enableRadixSort: Boolean = getConf(RADIX_SORT_ENABLED) def defaultSizeInBytes: Long = getConf(DEFAULT_SIZE_IN_BYTES) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala index c31fd92447c0..5963ff81f411 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegenExec.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.rules.Rule -import org.apache.spark.sql.execution.aggregate.HashAggregateExec +import org.apache.spark.sql.execution.aggregate.{HashAggregateExec, SortAggregateExec} import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.internal.SQLConf @@ -38,7 +38,8 @@ trait CodegenSupport extends SparkPlan { /** Prefix used in the current operator's variable names. */ private def variablePrefix: String = this match { - case _: HashAggregateExec => "agg" + case _: HashAggregateExec => "hagg" + case _: SortAggregateExec => "sagg" case _: BroadcastHashJoinExec => "bhj" case _: SortMergeJoinExec => "smj" case _: RDDScanExec => "rdd" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala index aa789af6f812..4c347ca14dde 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggUtils.scala @@ -21,7 +21,6 @@ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.streaming.{StateStoreRestoreExec, StateStoreSaveExec} -import org.apache.spark.sql.internal.SQLConf /** * Utility functions used by the query planner to convert our plan to new aggregation code path. @@ -35,23 +34,14 @@ object AggUtils { initialInputBufferOffset: Int = 0, resultExpressions: Seq[NamedExpression] = Nil, child: SparkPlan): SparkPlan = { - val useHash = HashAggregateExec.supportsAggregate( - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) - if (useHash) { - HashAggregateExec( - requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, - child = child) - } else { - val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation - val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions) - - if (objectHashEnabled && useObjectHash) { - ObjectHashAggregateExec( + val hashAggregateOption = { + val preferSortAggregate = child.sqlContext.conf.preferSortAggregate + val useHash = HashAggregateExec.supportsAggregate( + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) + if (preferSortAggregate) { + None + } else if (useHash) { + val agg = HashAggregateExec( requiredChildDistributionExpressions = requiredChildDistributionExpressions, groupingExpressions = groupingExpressions, aggregateExpressions = aggregateExpressions, @@ -59,17 +49,36 @@ object AggUtils { initialInputBufferOffset = initialInputBufferOffset, resultExpressions = resultExpressions, child = child) + Some(agg) } else { - SortAggregateExec( - requiredChildDistributionExpressions = requiredChildDistributionExpressions, - groupingExpressions = groupingExpressions, - aggregateExpressions = aggregateExpressions, - aggregateAttributes = aggregateAttributes, - initialInputBufferOffset = initialInputBufferOffset, - resultExpressions = resultExpressions, - child = child) + val objectHashEnabled = child.sqlContext.conf.useObjectHashAggregation + val useObjectHash = ObjectHashAggregateExec.supportsAggregate(aggregateExpressions) + + if (objectHashEnabled && useObjectHash) { + val agg = ObjectHashAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + Some(agg) + } else { + None + } } } + hashAggregateOption.getOrElse { + SortAggregateExec( + requiredChildDistributionExpressions = requiredChildDistributionExpressions, + groupingExpressions = groupingExpressions, + aggregateExpressions = aggregateExpressions, + aggregateAttributes = aggregateAttributes, + initialInputBufferOffset = initialInputBufferOffset, + resultExpressions = resultExpressions, + child = child) + } } def planAggregateWithoutDistinct( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenHelper.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenHelper.scala new file mode 100644 index 000000000000..af97dc2b9658 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateCodegenHelper.scala @@ -0,0 +1,226 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.execution.CodegenSupport +import org.apache.spark.sql.types.StructType + +trait AggregateCodegenHelper { + self: AggregateExec with CodegenSupport => + + protected val groupingAttributes = groupingExpressions.map(_.toAttribute) + protected val groupingKeySchema = StructType.fromAttributes(groupingAttributes) + protected val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + + protected lazy val declFunctions = + aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + + protected var bufVars: Seq[ExprCode] = _ + + override def usedInputs: AttributeSet = inputSet + + protected def generateBufVarsInitCode(ctx: CodegenContext): String = { + // generate variables for aggregation buffer + val initExpr = declFunctions.flatMap(f => f.initialValues) + bufVars = initExpr.map { e => + val isNull = ctx.freshName("bufIsNull") + val value = ctx.freshName("bufValue") + ctx.addMutableState("boolean", isNull, "") + ctx.addMutableState(ctx.javaType(e.dataType), value, "") + // The initial expression should not access any column + val ev = e.genCode(ctx) + val initVars = s""" + | $isNull = ${ev.isNull}; + | $value = ${ev.value}; + """.stripMargin + ExprCode(ev.code + initVars, isNull, value) + } + evaluateVariables(bufVars) + } + + protected def generateBufVarsEvalCode(ctx: CodegenContext): String = { + val initAgg = ctx.freshName("initAgg") + ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") + + val initBufVar = generateBufVarsInitCode(ctx) + + // generate variables for output + val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { + // evaluate aggregate results + ctx.currentVars = bufVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // evaluate result expressions + ctx.currentVars = aggResults + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) + } + (resultVars, s""" + |$evaluateAggResults + |${evaluateVariables(resultVars)} + """.stripMargin) + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // output the aggregate buffer directly + (bufVars, "") + } else { + // no aggregate function, the result should be literals + val resultVars = resultExpressions.map(_.genCode(ctx)) + (resultVars, evaluateVariables(resultVars)) + } + + val doAgg = ctx.freshName("doAggregateWithoutKey") + ctx.addNewFunction(doAgg, + s""" + | private void $doAgg() throws java.io.IOException { + | // initialize aggregation buffer + | $initBufVar + | + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin) + + val numOutput = metricTerm(ctx, "numOutputRows") + val aggTime = metricTerm(ctx, "aggTime") + val beforeAgg = ctx.freshName("beforeAgg") + s""" + | while (!$initAgg) { + | $initAgg = true; + | long $beforeAgg = System.nanoTime(); + | $doAgg(); + | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); + | + | // output the result + | ${genResult.trim} + | + | $numOutput.add(1); + | ${consume(ctx, resultVars).trim} + | } + """.stripMargin + } + + protected def generateBufVarsUpdateCode(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // only have DeclarativeAggregate + val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) + val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output + val updateExpr = aggregateExpressions.flatMap { e => + e.mode match { + case Partial | Complete => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions + case PartialMerge | Final => + e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions + } + } + ctx.currentVars = bufVars ++ input + val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) + val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) + val effectiveCodes = subExprs.codes.mkString("\n") + val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { + boundUpdateExpr.map(_.genCode(ctx)) + } + // aggregate buffer should be updated atomic + val updates = aggVals.zipWithIndex.map { case (ev, i) => + s""" + | ${bufVars(i).isNull} = ${ev.isNull}; + | ${bufVars(i).value} = ${ev.value}; + """.stripMargin + } + s""" + | // do aggregate + | // common sub-expressions + | $effectiveCodes + | // evaluate aggregate function + | ${evaluateVariables(aggVals)} + | // update aggregation buffer + | ${updates.mkString("\n").trim} + """.stripMargin + } + + /** + * This is called by generated Java class, should be public. + */ + def createUnsafeJoiner(): UnsafeRowJoiner = { + GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) + } + + /** + * Generate the code for output. + */ + protected def generateResultCode( + ctx: CodegenContext, + keyTerm: String, + bufferTerm: String, + self: String): String = { + if (modes.contains(Final) || modes.contains(Complete)) { + // generate output using resultExpressions + ctx.currentVars = null + ctx.INPUT_ROW = keyTerm + val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateKeyVars = evaluateVariables(keyVars) + ctx.INPUT_ROW = bufferTerm + val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => + BoundReference(i, e.dataType, e.nullable).genCode(ctx) + } + val evaluateBufferVars = evaluateVariables(bufferVars) + // evaluate the aggregation result + ctx.currentVars = bufferVars + val aggResults = declFunctions.map(_.evaluateExpression).map { e => + BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) + } + val evaluateAggResults = evaluateVariables(aggResults) + // generate the final result + ctx.currentVars = keyVars ++ aggResults + val inputAttrs = groupingAttributes ++ aggregateAttributes + val resultVars = resultExpressions.map { e => + BindReferences.bindReference(e, inputAttrs).genCode(ctx) + } + s""" + $evaluateKeyVars + $evaluateBufferVars + $evaluateAggResults + ${consume(ctx, resultVars)} + """ + + } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { + // This should be the last operator in a stage, we should output UnsafeRow directly + val joinerTerm = ctx.freshName("unsafeRowJoiner") + ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, + s"$joinerTerm = $self.createUnsafeJoiner();") + val resultRow = ctx.freshName("resultRow") + s""" + UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); + ${consume(ctx, null, resultRow)} + """ + + } else { + // generate result based on grouping key + ctx.INPUT_ROW = keyTerm + ctx.currentVars = null + val eval = resultExpressions.map{ e => + BindReferences.bindReference(e, groupingAttributes).genCode(ctx) + } + consume(ctx, eval) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala new file mode 100644 index 000000000000..e79249603482 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregateExec.scala @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.aggregate + +import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.UnaryExecNode + +/** + * A base class for aggregate implementation. + */ +abstract class AggregateExec extends UnaryExecNode { + + def requiredChildDistributionExpressions: Option[Seq[Expression]] + def groupingExpressions: Seq[NamedExpression] + def aggregateExpressions: Seq[AggregateExpression] + def aggregateAttributes: Seq[Attribute] + def initialInputBufferOffset: Int + def resultExpressions: Seq[NamedExpression] + def child: SparkPlan + + // all the mode of aggregate expressions + protected val modes = aggregateExpressions.map(_.mode).distinct + + protected val aggregateBufferAttributes = { + aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index 68c8e6ce62cb..22d5d5c96013 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -43,11 +43,7 @@ case class HashAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode with CodegenSupport { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + extends AggregateExec with CodegenSupport with AggregateCodegenHelper { require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) @@ -127,11 +123,6 @@ case class HashAggregateExec( } } - // all the mode of aggregate expressions - private val modes = aggregateExpressions.map(_.mode).distinct - - override def usedInputs: AttributeSet = inputSet - override def supportCodegen: Boolean = { // ImperativeAggregate is not supported right now !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) @@ -157,133 +148,16 @@ case class HashAggregateExec( } } - // The variables used as aggregation buffer - private var bufVars: Seq[ExprCode] = _ - private def doProduceWithoutKeys(ctx: CodegenContext): String = { - val initAgg = ctx.freshName("initAgg") - ctx.addMutableState("boolean", initAgg, s"$initAgg = false;") - - // generate variables for aggregation buffer - val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val initExpr = functions.flatMap(f => f.initialValues) - bufVars = initExpr.map { e => - val isNull = ctx.freshName("bufIsNull") - val value = ctx.freshName("bufValue") - ctx.addMutableState("boolean", isNull, "") - ctx.addMutableState(ctx.javaType(e.dataType), value, "") - // The initial expression should not access any column - val ev = e.genCode(ctx) - val initVars = s""" - | $isNull = ${ev.isNull}; - | $value = ${ev.value}; - """.stripMargin - ExprCode(ev.code + initVars, isNull, value) - } - val initBufVar = evaluateVariables(bufVars) - - // generate variables for output - val (resultVars, genResult) = if (modes.contains(Final) || modes.contains(Complete)) { - // evaluate aggregate results - ctx.currentVars = bufVars - val aggResults = functions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } - val evaluateAggResults = evaluateVariables(aggResults) - // evaluate result expressions - ctx.currentVars = aggResults - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, aggregateAttributes).genCode(ctx) - } - (resultVars, s""" - |$evaluateAggResults - |${evaluateVariables(resultVars)} - """.stripMargin) - } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { - // output the aggregate buffer directly - (bufVars, "") - } else { - // no aggregate function, the result should be literals - val resultVars = resultExpressions.map(_.genCode(ctx)) - (resultVars, evaluateVariables(resultVars)) - } - - val doAgg = ctx.freshName("doAggregateWithoutKey") - ctx.addNewFunction(doAgg, - s""" - | private void $doAgg() throws java.io.IOException { - | // initialize aggregation buffer - | $initBufVar - | - | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} - | } - """.stripMargin) - - val numOutput = metricTerm(ctx, "numOutputRows") - val aggTime = metricTerm(ctx, "aggTime") - val beforeAgg = ctx.freshName("beforeAgg") - s""" - | while (!$initAgg) { - | $initAgg = true; - | long $beforeAgg = System.nanoTime(); - | $doAgg(); - | $aggTime.add((System.nanoTime() - $beforeAgg) / 1000000); - | - | // output the result - | ${genResult.trim} - | - | $numOutput.add(1); - | ${consume(ctx, resultVars).trim} - | } - """.stripMargin + generateBufVarsEvalCode(ctx) } protected override val shouldStopRequired = false private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // only have DeclarativeAggregate - val functions = aggregateExpressions.map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - val inputAttrs = functions.flatMap(_.aggBufferAttributes) ++ child.output - val updateExpr = aggregateExpressions.flatMap { e => - e.mode match { - case Partial | Complete => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].updateExpressions - case PartialMerge | Final => - e.aggregateFunction.asInstanceOf[DeclarativeAggregate].mergeExpressions - } - } - ctx.currentVars = bufVars ++ input - val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttrs)) - val subExprs = ctx.subexpressionEliminationForWholeStageCodegen(boundUpdateExpr) - val effectiveCodes = subExprs.codes.mkString("\n") - val aggVals = ctx.withSubExprEliminationExprs(subExprs.states) { - boundUpdateExpr.map(_.genCode(ctx)) - } - // aggregate buffer should be updated atomic - val updates = aggVals.zipWithIndex.map { case (ev, i) => - s""" - | ${bufVars(i).isNull} = ${ev.isNull}; - | ${bufVars(i).value} = ${ev.value}; - """.stripMargin - } - s""" - | // do aggregate - | // common sub-expressions - | $effectiveCodes - | // evaluate aggregate function - | ${evaluateVariables(aggVals)} - | // update aggregation buffer - | ${updates.mkString("\n").trim} - """.stripMargin + generateBufVarsUpdateCode(ctx, input) } - private val groupingAttributes = groupingExpressions.map(_.toAttribute) - private val groupingKeySchema = StructType.fromAttributes(groupingAttributes) - private val declFunctions = aggregateExpressions.map(_.aggregateFunction) - .filter(_.isInstanceOf[DeclarativeAggregate]) - .map(_.asInstanceOf[DeclarativeAggregate]) - private val bufferSchema = StructType.fromAttributes(aggregateBufferAttributes) - // The name for Fast HashMap private var fastHashMapTerm: String = _ private var isFastHashMapEnabled: Boolean = false @@ -327,13 +201,6 @@ case class HashAggregateExec( initialBuffer } - /** - * This is called by generated Java class, should be public. - */ - def createUnsafeJoiner(): UnsafeRowJoiner = { - GenerateUnsafeRowJoiner.create(groupingKeySchema, bufferSchema) - } - /** * Called by generated Java class to finish the aggregate and return a KVIterator. */ @@ -416,68 +283,6 @@ case class HashAggregateExec( } } - /** - * Generate the code for output. - */ - private def generateResultCode( - ctx: CodegenContext, - keyTerm: String, - bufferTerm: String, - plan: String): String = { - if (modes.contains(Final) || modes.contains(Complete)) { - // generate output using resultExpressions - ctx.currentVars = null - ctx.INPUT_ROW = keyTerm - val keyVars = groupingExpressions.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).genCode(ctx) - } - val evaluateKeyVars = evaluateVariables(keyVars) - ctx.INPUT_ROW = bufferTerm - val bufferVars = aggregateBufferAttributes.zipWithIndex.map { case (e, i) => - BoundReference(i, e.dataType, e.nullable).genCode(ctx) - } - val evaluateBufferVars = evaluateVariables(bufferVars) - // evaluate the aggregation result - ctx.currentVars = bufferVars - val aggResults = declFunctions.map(_.evaluateExpression).map { e => - BindReferences.bindReference(e, aggregateBufferAttributes).genCode(ctx) - } - val evaluateAggResults = evaluateVariables(aggResults) - // generate the final result - ctx.currentVars = keyVars ++ aggResults - val inputAttrs = groupingAttributes ++ aggregateAttributes - val resultVars = resultExpressions.map { e => - BindReferences.bindReference(e, inputAttrs).genCode(ctx) - } - s""" - $evaluateKeyVars - $evaluateBufferVars - $evaluateAggResults - ${consume(ctx, resultVars)} - """ - - } else if (modes.contains(Partial) || modes.contains(PartialMerge)) { - // This should be the last operator in a stage, we should output UnsafeRow directly - val joinerTerm = ctx.freshName("unsafeRowJoiner") - ctx.addMutableState(classOf[UnsafeRowJoiner].getName, joinerTerm, - s"$joinerTerm = $plan.createUnsafeJoiner();") - val resultRow = ctx.freshName("resultRow") - s""" - UnsafeRow $resultRow = $joinerTerm.join($keyTerm, $bufferTerm); - ${consume(ctx, null, resultRow)} - """ - - } else { - // generate result based on grouping key - ctx.INPUT_ROW = keyTerm - ctx.currentVars = null - val eval = resultExpressions.map{ e => - BindReferences.bindReference(e, groupingAttributes).genCode(ctx) - } - consume(ctx, eval) - } - } - /** * A required check for any fast hash map implementation (basically the common requirements * for row-based and vectorized). @@ -696,7 +501,6 @@ case class HashAggregateExec( } private def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { - // create grouping key ctx.currentVars = input val unsafeRowKeyCode = GenerateUnsafeProjection.createCode( @@ -755,7 +559,6 @@ case class HashAggregateExec( } } - def updateRowInFastHashMap(isVectorized: Boolean): Option[String] = { ctx.INPUT_ROW = fastRowBuffer val boundUpdateExpr = updateExpr.map(BindReferences.bindReference(_, inputAttr)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala index be3198b8e7d8..153acf6137d6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortAggregateExec.scala @@ -22,9 +22,11 @@ import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ +import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.plans.physical._ -import org.apache.spark.sql.execution.{SparkPlan, UnaryExecNode} +import org.apache.spark.sql.execution.{CodegenSupport, SparkPlan} import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.StructType import org.apache.spark.util.Utils /** @@ -38,11 +40,7 @@ case class SortAggregateExec( initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) - extends UnaryExecNode { - - private[this] val aggregateBufferAttributes = { - aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) - } + extends AggregateExec with CodegenSupport with AggregateCodegenHelper { override def producedAttributes: AttributeSet = AttributeSet(aggregateAttributes) ++ @@ -50,7 +48,8 @@ case class SortAggregateExec( AttributeSet(aggregateBufferAttributes) override lazy val metrics = Map( - "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) + "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), + "aggTime" -> SQLMetrics.createTimingMetric(sparkContext, "aggregate time")) override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute) @@ -106,6 +105,126 @@ case class SortAggregateExec( } } + override def supportCodegen: Boolean = { + val aggregationBufferSchema = StructType.fromAttributes(aggregateBufferAttributes) + aggregationBufferSchema.forall(f => UnsafeRow.isMutable(f.dataType)) && + // ImperativeAggregate is not supported right now + !aggregateExpressions.exists(_.aggregateFunction.isInstanceOf[ImperativeAggregate]) + } + + override def inputRDDs(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].inputRDDs() + } + + override protected def doProduce(ctx: CodegenContext): String = { + if (groupingExpressions.isEmpty) { + doProduceWithoutKeys(ctx) + } else { + doProduceWithKeys(ctx) + } + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + if (groupingExpressions.isEmpty) { + doConsumeWithoutKeys(ctx, input) + } else { + doConsumeWithKeys(ctx, input) + } + } + + private def doProduceWithoutKeys(ctx: CodegenContext): String = { + generateBufVarsEvalCode(ctx) + } + + private def doConsumeWithoutKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + generateBufVarsUpdateCode(ctx, input) + } + + // The grouping keys of a current partition + private var currentGroupingKeyTerm: String = _ + + // The output code for a single partition + private var outputCode: String = _ + + private def generateOutputCode(ctx: CodegenContext): String = { + ctx.currentVars = bufVars + val bufferEv = GenerateUnsafeProjection.createCode( + ctx, aggregateBufferAttributes.map( + BindReferences.bindReference[Expression](_, aggregateBufferAttributes))) + val sortAggregate = ctx.addReferenceObj("sortAggregate", this) + s""" + |${bufferEv.code} + |${generateResultCode(ctx, currentGroupingKeyTerm, bufferEv.value, sortAggregate)} + """.stripMargin + } + + private def doProduceWithKeys(ctx: CodegenContext): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | + |if ($currentGroupingKeyTerm != null) { + | // for the last aggregation + | do { + | $numOutput.add(1); + | $outputCode + | $currentGroupingKeyTerm = null; + | + | if (shouldStop()) return; + | } while (false); + |} + """.stripMargin + } + + def doConsumeWithKeys(ctx: CodegenContext, input: Seq[ExprCode]): String = { + // Create the grouping keys of a current partition + currentGroupingKeyTerm = ctx.freshName("currentGroupingKey") + ctx.addMutableState("UnsafeRow", currentGroupingKeyTerm, s"$currentGroupingKeyTerm = null;") + + // Generate buffer-handling code + val initBufVarsCodes = generateBufVarsInitCode(ctx) + val updateBufVarsCode = generateBufVarsUpdateCode(ctx, input) + + // Create grouping keys for input + ctx.currentVars = input + val groupingEv = GenerateUnsafeProjection.createCode( + ctx, groupingExpressions.map(BindReferences.bindReference[Expression](_, child.output))) + val groupingKeys = groupingEv.value + + // Generate code for output + outputCode = generateOutputCode(ctx) + + val numOutput = metricTerm(ctx, "numOutputRows") + s""" + |// generate grouping keys + |${groupingEv.code.trim} + | + |if ($currentGroupingKeyTerm == null) { + | $currentGroupingKeyTerm = $groupingKeys.copy(); + | // init aggregation buffer vars + | $initBufVarsCodes + | // do aggregation + | $updateBufVarsCode + |} else { + | if ($currentGroupingKeyTerm.equals($groupingKeys)) { + | $updateBufVarsCode + | } else { + | do { + | $numOutput.add(1); + | $outputCode + | } while (false); + | + | // init buffer vars for a next partition + | $currentGroupingKeyTerm = $groupingKeys.copy(); + | $initBufVarsCodes + | $updateBufVarsCode + | + | if (shouldStop()) return; + | } + |} + """.stripMargin + } + override def simpleString: String = toString(verbose = false) override def verboseString: String = toString(verbose = true) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index e7079120bb7d..fcf2438b1072 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -29,7 +29,15 @@ case class Fact(date: Int, hour: Int, minute: Int, room_name: String, temp: Doub class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { import testImplicits._ - test("groupBy") { + private def testAgg(testName: String)(f: => Unit): Unit = { + test(testName) { + Seq("true", "false").map { preferAggregate => + withSQLConf(SQLConf.PREFER_SORTAGGREGATE.key -> preferAggregate) { f } + } + } + } + + testAgg("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(1, 3), Row(2, 3), Row(3, 3)) @@ -87,7 +95,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-17124 agg should be ordering preserving") { + testAgg("SPARK-17124 agg should be ordering preserving") { val df = spark.range(2) val ret = df.groupBy("id").agg("id" -> "sum", "id" -> "count", "id" -> "min") assert(ret.schema.map(_.name) == Seq("id", "sum(id)", "count(id)", "min(id)")) @@ -97,7 +105,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") { + testAgg("SPARK-18952: regexes fail codegen when used as keys due to bad forward-slash escapes") { val df = Seq(("some[thing]", "random-string")).toDF("key", "val") checkAnswer( @@ -106,7 +114,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("rollup") { + testAgg("rollup") { checkAnswer( courseSales.rollup("course", "year").sum("earnings"), Row("Java", 2012, 20000.0) :: @@ -119,7 +127,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("cube") { + testAgg("cube") { checkAnswer( courseSales.cube("course", "year").sum("earnings"), Row("Java", 2012, 20000.0) :: @@ -143,7 +151,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(cube0.where("date IS NULL").count > 0) } - test("grouping and grouping_id") { + testAgg("grouping and grouping_id") { checkAnswer( courseSales.cube("course", "year") .agg(grouping("course"), grouping("year"), grouping_id("course", "year")), @@ -166,7 +174,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { } } - test("grouping/grouping_id inside window function") { + testAgg("grouping/grouping_id inside window function") { val w = Window.orderBy(sum("earnings")) checkAnswer( @@ -186,7 +194,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("rollup overlapping columns") { + testAgg("rollup overlapping columns") { checkAnswer( testData2.rollup($"a" + $"b" as "foo", $"b" as "bar").agg(sum($"a" - $"b") as "foo"), Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) @@ -202,7 +210,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("cube overlapping columns") { + testAgg("cube overlapping columns") { checkAnswer( testData2.cube($"a" + $"b", $"b").agg(sum($"a" - $"b")), Row(2, 1, 0) :: Row(3, 2, -1) :: Row(3, 1, 1) :: Row(4, 2, 0) :: Row(4, 1, 2) :: Row(5, 2, 1) @@ -220,7 +228,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("spark.sql.retainGroupColumns config") { + testAgg("spark.sql.retainGroupColumns config") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(1, 3), Row(2, 3), Row(3, 3)) @@ -234,21 +242,21 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { spark.conf.set(SQLConf.DATAFRAME_RETAIN_GROUP_COLUMNS.key, true) } - test("agg without groups") { + testAgg("agg without groups") { checkAnswer( testData2.agg(sum('b)), Row(9) ) } - test("agg without groups and functions") { + testAgg("agg without groups and functions") { checkAnswer( testData2.agg(lit(1)), Row(1) ) } - test("average") { + testAgg("average") { checkAnswer( testData2.agg(avg('a), mean('a)), Row(2.0, 2.0)) @@ -274,7 +282,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(2.0), new java.math.BigDecimal(6)) :: Nil) } - test("null average") { + testAgg("null average") { checkAnswer( testData3.agg(avg('b)), Row(2.0)) @@ -288,7 +296,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(2.0, 2.0)) } - test("zero average") { + testAgg("zero average") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(avg('a)), @@ -299,7 +307,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null)) } - test("count") { + testAgg("count") { assert(testData2.count() === testData2.rdd.map(_ => 1).count()) checkAnswer( @@ -307,7 +315,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(6, 6.0)) } - test("null count") { + testAgg("null count") { checkAnswer( testData3.groupBy('a).agg(count('b)), Seq(Row(1, 0), Row(2, 1)) @@ -329,7 +337,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("multiple column distinct count") { + testAgg("multiple column distinct count") { val df1 = Seq( ("a", "b", "c"), ("a", "b", "c"), @@ -354,14 +362,14 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("zero count") { + testAgg("zero count") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(count('a), sumDistinct('a)), // non-partial Row(0, null)) } - test("stddev") { + testAgg("stddev") { val testData2ADev = math.sqrt(4.0 / 5.0) checkAnswer( testData2.agg(stddev('a), stddev_pop('a), stddev_samp('a)), @@ -371,28 +379,28 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(testData2ADev, math.sqrt(4 / 6.0), testData2ADev)) } - test("zero stddev") { + testAgg("zero stddev") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(stddev('a), stddev_pop('a), stddev_samp('a)), Row(null, null, null)) } - test("zero sum") { + testAgg("zero sum") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sum('a)), Row(null)) } - test("zero sum distinct") { + testAgg("zero sum distinct") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( emptyTableData.agg(sumDistinct('a)), Row(null)) } - test("moments") { + testAgg("moments") { val absTol = 1e-8 val sparkVariance = testData2.agg(variance('a)) @@ -411,7 +419,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { checkAggregatesWithTol(sparkKurtosis, Row(-1.5), absTol) } - test("zero moments") { + testAgg("zero moments") { val input = Seq((1, 2)).toDF("a", "b") checkAnswer( input.agg(stddev('a), stddev_samp('a), stddev_pop('a), variance('a), @@ -433,7 +441,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Double.NaN, Double.NaN)) } - test("null moments") { + testAgg("null moments") { val emptyTableData = Seq.empty[(Int, Int)].toDF("a", "b") checkAnswer( @@ -450,7 +458,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(null, null, null, null, null)) } - test("collect functions") { + testAgg("collect functions") { val df = Seq((1, 2), (2, 2), (3, 4)).toDF("a", "b") checkAnswer( df.select(collect_list($"a"), collect_list($"b")), @@ -462,7 +470,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("collect functions structs") { + testAgg("collect functions structs") { val df = Seq((1, 2, 2), (2, 2, 2), (3, 4, 1)) .toDF("a", "x", "y") .select($"a", struct($"x", $"y").as("b")) @@ -476,7 +484,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("collect_set functions cannot have maps") { + testAgg("collect_set functions cannot have maps") { val df = Seq((1, 3, 0), (2, 3, 0), (3, 4, 1)) .toDF("a", "x", "y") .select($"a", map($"x", $"y").as("b")) @@ -486,7 +494,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { assert(error.message.contains("collect_set() cannot have map type data")) } - test("SPARK-17641: collect functions should not collect null values") { + testAgg("SPARK-17641: collect functions should not collect null values") { val df = Seq(("1", 2), (null, 2), ("1", 4)).toDF("a", "b") checkAnswer( df.select(collect_list($"a"), collect_list($"b")), @@ -498,7 +506,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { ) } - test("SPARK-14664: Decimal sum/avg over window should work.") { + testAgg("SPARK-14664: Decimal sum/avg over window should work.") { checkAnswer( spark.sql("select sum(a) over () from values 1.0, 2.0, 3.0 T(a)"), Row(6.0) :: Row(6.0) :: Row(6.0) :: Nil) @@ -507,7 +515,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(2.0) :: Row(2.0) :: Row(2.0) :: Nil) } - test("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { + testAgg("SQL decimal test (used for catching certain demical handling bugs in aggregates)") { checkAnswer( decimalData.groupBy('a cast DecimalType(10, 2)).agg(avg('b cast DecimalType(10, 2))), Seq(Row(new java.math.BigDecimal(1.0), new java.math.BigDecimal(1.5)), @@ -515,7 +523,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Row(new java.math.BigDecimal(3.0), new java.math.BigDecimal(1.5)))) } - test("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { + testAgg("SPARK-17616: distinct aggregate combined with a non-partial aggregate") { val df = Seq((1, 3, "a"), (1, 2, "b"), (3, 4, "c"), (3, 4, "c"), (3, 5, "d")) .toDF("x", "y", "z") checkAnswer( @@ -523,7 +531,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { Seq(Row(1, 2, Seq("a", "b")), Row(3, 2, Seq("c", "c", "d")))) } - test("SPARK-18004 limit + aggregates") { + testAgg("SPARK-18004 limit + aggregates") { val df = Seq(("a", 1), ("b", 2), ("c", 1), ("d", 5)).toDF("id", "value") val limit2Df = df.limit(2) checkAnswer( @@ -531,7 +539,7 @@ class DataFrameAggregateSuite extends QueryTest with SharedSQLContext { limit2Df.select($"id")) } - test("SPARK-17237 remove backticks in a pivot result schema") { + testAgg("SPARK-17237 remove backticks in a pivot result schema") { val df = Seq((2, 3, 4), (3, 4, 5)).toDF("a", "x", "y") checkAnswer( df.groupBy("a").pivot("x").agg(count("y"), avg("y")).na.fill(0), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala index 02ccebd22bdf..5fdbe648e7c5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala @@ -37,6 +37,20 @@ class PlannerSuite extends SharedSQLContext { setupTestData() + test("Use sort aggregate if PREFER_SORTAGGREGATE is true") { + withSQLConf(SQLConf.PREFER_SORTAGGREGATE.key -> "true") { + val planner = spark.sessionState.planner + import planner._ + val query = testData.groupBy('value).agg(count('key)).queryExecution.analyzed + val planned = Aggregation(query).headOption + assert(planned.nonEmpty, s"An input query has no aggregation: $query") + planned.foreach { planned => + val aggregations = planned.collect { case n if n.nodeName contains "SortAggregate" => n } + assert(aggregations.size > 0, "") + } + } + } + private def testPartialAggregationPlan(query: LogicalPlan): Unit = { val planner = spark.sessionState.planner import planner._ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala index 01773c238b0d..91b64b437bf4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/MiscBenchmark.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.execution.benchmark +import org.apache.spark.sql.functions._ import org.apache.spark.util.Benchmark /** @@ -28,6 +29,63 @@ import org.apache.spark.util.Benchmark */ class MiscBenchmark extends BenchmarkBase { + ignore("sort aggregate") { + import sparkSession.implicits._ + val preferSortAgg = "spark.sql.aggregate.preferSortAggregate" + val currnetValue = if (sparkSession.conf.contains(preferSortAgg)) { + Some(sparkSession.conf.get(preferSortAgg)) + } else { + None + } + + // Force a planner to use sort-based aggregate + sparkSession.conf.set(preferSortAgg, "true") + + try { + val N = 1L << 23 + + /* + range/limit/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + range/limit/sum wholestage off 617 / 617 13.6 73.5 1.0X + range/limit/sum wholestage on 70 / 92 120.2 8.3 8.8X + */ + runBenchmark("range/limit/sum", N) { + sparkSession.range(N).groupBy().sum().collect() + } + + /* + aggregate non-sorted data: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + non-sorted data wholestage off 2540 / 2735 3.3 302.8 1.0X + non-sorted data wholestage on 1226 / 1528 6.8 146.1 2.1X + */ + val inputDf = sparkSession.range(N).selectExpr( + "id % 1024 AS key", "rand() AS value1", "rand() AS value2", "rand() AS value3") + runBenchmark("non-sorted data", N) { + inputDf.filter("value1 > 0.1").groupBy($"key") + .agg(sum("value1"), sum("value2"), avg("value3")).collect() + } + + // Sort and cache input data + val cachedDf = inputDf.sort($"key").cache + cachedDf.queryExecution.executedPlan.foreach(_ => {}) + + /* + aggregate cached and sorted data: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative + ---------------------------------------------------------------------------------------------- + cached and sorted data wholestage off 1455 / 1586 5.8 173.4 1.0X + cached and sorted data wholestage on 663 / 767 12.7 79.0 2.2X + */ + runBenchmark("cached and sorted data", N) { + cachedDf.filter("value1 > 0.1").groupBy($"key") + .agg(sum("value1"), sum("value2"), avg("value3")).collect() + } + } finally { + currnetValue.foreach(sparkSession.conf.set(preferSortAgg, _)) + } + } + ignore("filter & aggregate without group") { val N = 500L << 22 runBenchmark("range/filter/sum", N) {