diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala index 8aad0b7dee054..24f2dfba3f753 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/functions.scala @@ -546,6 +546,9 @@ case class HyperLogLogPlusPlus(child: Expression, relativeSD: Double = 0.05) AttributeReference(s"MS[$i]", LongType)() } + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + /** Fill all words with zeros. */ override def initialize(buffer: MutableRow): Unit = { var word = 0 diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala index 9ba3a9c980457..25b41c02dcb19 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/interfaces.scala @@ -231,9 +231,6 @@ abstract class ImperativeAggregate extends AggregateFunction2 { * Use `fieldNumber + inputAggBufferOffset` to access fields of `inputAggBuffer`. */ def merge(mutableAggBuffer: MutableRow, inputAggBuffer: InternalRow): Unit - - final lazy val inputAggBufferAttributes: Seq[AttributeReference] = - aggBufferAttributes.map(_.newInstance()) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index b078c8b6b05ca..bd4ffee7184a9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -195,19 +195,22 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { converted match { case None => Nil // Cannot convert to new aggregation code path. case Some(logical.Aggregate(groupingExpressions, resultExpressions, child)) => - // Extracts all distinct aggregate expressions from the resultExpressions. + // A single aggregate expression might appear multiple times in resultExpressions. + // In order to avoid evaluating an individual aggregate function multiple times, we'll + // build a set of the distinct aggregate expressions and build a function which can + // be used to re-write expressions so that they reference the single copy of the + // aggregate function which actually gets computed. val aggregateExpressions = resultExpressions.flatMap { expr => expr.collect { case agg: AggregateExpression2 => agg } - }.toSet.toSeq + }.distinct // For those distinct aggregate expressions, we create a map from the // aggregate function to the corresponding attribute of the function. - val aggregateFunctionMap = aggregateExpressions.map { agg => + val aggregateFunctionToAttribute = aggregateExpressions.map { agg => val aggregateFunction = agg.aggregateFunction - val attribtue = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute - (aggregateFunction, agg.isDistinct) -> - (aggregateFunction -> attribtue) + val attribute = Alias(aggregateFunction, aggregateFunction.toString)().toAttribute + (aggregateFunction, agg.isDistinct) -> attribute }.toMap val (functionsWithDistinct, functionsWithoutDistinct) = @@ -220,6 +223,40 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "code path.") } + val namedGroupingExpressions = groupingExpressions.map { + case ne: NamedExpression => ne -> ne + // If the expression is not a NamedExpressions, we add an alias. + // So, when we generate the result of the operator, the Aggregate Operator + // can directly get the Seq of attributes representing the grouping expressions. + case other => + val withAlias = Alias(other, other.toString)() + other -> withAlias + } + val groupExpressionMap = namedGroupingExpressions.toMap + + // The original `resultExpressions` are a set of expressions which may reference + // aggregate expressions, grouping column values, and constants. When aggregate operator + // emits output rows, we will use `resultExpressions` to generate an output projection + // which takes the grouping columns and final aggregate result buffer as input. + // Thus, we must re-write the result expressions so that their attributes match up with + // the attributes of the final result projection's input row: + val rewrittenResultExpressions = resultExpressions.map { expr => + expr.transformDown { + case AggregateExpression2(aggregateFunction, _, isDistinct) => + // The final aggregation buffer's attributes will be `finalAggregationAttributes`, + // so replace each aggregate expression by its corresponding attribute in the set: + aggregateFunctionToAttribute(aggregateFunction, isDistinct) + case expression => + // Since we're using `namedGroupingAttributes` to extract the grouping key + // columns, we need to replace grouping key expressions with their corresponding + // attributes. We do not rely on the equality check at here since attributes may + // differ cosmetically. Instead, we use semanticEquals. + groupExpressionMap.collectFirst { + case (expr, ne) if expr semanticEquals expression => ne.toAttribute + }.getOrElse(expression) + }.asInstanceOf[NamedExpression] + } + val aggregateOperator = if (aggregateExpressions.map(_.aggregateFunction).exists(!_.supportsPartial)) { if (functionsWithDistinct.nonEmpty) { @@ -227,26 +264,26 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { "aggregate functions which don't support partial aggregation.") } else { aggregate.Utils.planAggregateWithoutPartial( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } } else if (functionsWithDistinct.isEmpty) { aggregate.Utils.planAggregateWithoutDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), aggregateExpressions, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } else { aggregate.Utils.planAggregateWithOneDistinct( - groupingExpressions, + namedGroupingExpressions.map(_._2), functionsWithDistinct, functionsWithoutDistinct, - aggregateFunctionMap, - resultExpressions, + aggregateFunctionToAttribute, + rewrittenResultExpressions, planLater(child)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 5f7341e88c7c9..0c629d96af1d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -321,7 +321,7 @@ abstract class AggregationIterator( // Initializing the function used to generate the output row. protected val generateOutput: (InternalRow, MutableRow) => InternalRow = { val rowToBeEvaluated = new JoinedRow - val safeOutputRow = new GenericMutableRow(resultExpressions.length) + val safeOutputRow = new SpecificMutableRow(resultExpressions.map(_.dataType)) val mutableOutput = if (outputsUnsafeRows) { UnsafeProjection.create(resultExpressions.map(_.dataType).toArray).apply(safeOutputRow) } else { @@ -359,7 +359,8 @@ abstract class AggregationIterator( val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferSchemata)() val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes // TODO: Use unsafe row. - val aggregateResult = new GenericMutableRow(aggregateResultSchema.length) + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = newMutableProjection( resultExpressions, groupingKeyAttributes ++ aggregateResultSchema)() @@ -367,7 +368,7 @@ abstract class AggregationIterator( (currentGroupingKey: InternalRow, currentBuffer: MutableRow) => { // Generate results for all expression-based aggregate functions. - expressionAggEvalProjection.target(aggregateResult)(currentBuffer) + expressionAggEvalProjection(currentBuffer) // Generate results for all imperative aggregate functions. var i = 0 while (i < allImperativeAggregateFunctions.length) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 3cd22af30592c..0c70e4086f2dc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -31,7 +31,10 @@ case class TungstenAggregate( requiredChildDistributionExpressions: Option[Seq[Expression]], groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], child: SparkPlan) extends UnaryNode { @@ -77,7 +80,10 @@ case class TungstenAggregate( new TungstenAggregationIterator( groupingExpressions, nonCompleteAggregateExpressions, + nonCompleteAggregateAttributes, completeAggregateExpressions, + completeAggregateAttributes, + initialInputBufferOffset, resultExpressions, newMutableProjection, child.output, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index a6f4c1d92f6dc..ffcfea5e12723 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.execution.aggregate +import scala.collection.mutable.ArrayBuffer + import org.apache.spark.unsafe.KVIterator import org.apache.spark.{InternalAccumulator, Logging, SparkEnv, TaskContext} import org.apache.spark.sql.catalyst.expressions._ @@ -60,8 +62,12 @@ import org.apache.spark.sql.types.StructType * @param nonCompleteAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Partial]], * [[PartialMerge]], or [[Final]]. + * @param nonCompleteAggregateAttributes the attributes of the nonCompleteAggregateExpressions' + * outputs when they are stored in the final aggregation buffer. * @param completeAggregateExpressions * [[AggregateExpression2]] containing [[AggregateFunction2]]s with mode [[Complete]]. + * @param completeAggregateAttributes the attributes of completeAggregateExpressions' outputs + * when they are stored in the final aggregation buffer. * @param resultExpressions * expressions for generating output rows. * @param newMutableProjection @@ -72,7 +78,10 @@ import org.apache.spark.sql.types.StructType class TungstenAggregationIterator( groupingExpressions: Seq[NamedExpression], nonCompleteAggregateExpressions: Seq[AggregateExpression2], + nonCompleteAggregateAttributes: Seq[Attribute], completeAggregateExpressions: Seq[AggregateExpression2], + completeAggregateAttributes: Seq[Attribute], + initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), originalInputAttributes: Seq[Attribute], @@ -128,19 +137,73 @@ class TungstenAggregationIterator( completeAggregateExpressions.map(_.mode).distinct.headOption } - // All aggregate functions. TungstenAggregationIterator only handles expression-based aggregate. - // If there is any functions that is an ImperativeAggregateFunction, we throw an - // IllegalStateException. - private[this] val allAggregateFunctions: Array[DeclarativeAggregate] = { - if (!allAggregateExpressions.forall( - _.aggregateFunction.isInstanceOf[DeclarativeAggregate])) { - throw new IllegalStateException( - "Only ExpressionAggregateFunctions should be passed in TungstenAggregationIterator.") + // Initialize all AggregateFunctions by binding references, if necessary, + // and setting inputBufferOffset and mutableBufferOffset. + private def initializeAllAggregateFunctions( + startingInputBufferOffset: Int): Array[AggregateFunction2] = { + var mutableBufferOffset = 0 + var inputBufferOffset: Int = startingInputBufferOffset + val functions = new Array[AggregateFunction2](allAggregateExpressions.length) + var i = 0 + while (i < allAggregateExpressions.length) { + val func = allAggregateExpressions(i).aggregateFunction + val aggregateExpressionIsNonComplete = i < nonCompleteAggregateExpressions.length + // We need to use this mode instead of func.mode in order to handle aggregation mode switching + // when switching to sort-based aggregation: + val mode = + if (aggregateExpressionIsNonComplete) aggregationMode._1 else aggregationMode._2 + val funcWithBoundReferences = mode match { + case Some(Partial) | Some(Complete) if func.isInstanceOf[ImperativeAggregate] => + // We need to create BoundReferences if the function is not an + // expression-based aggregate function (it does not support code-gen) and the mode of + // this function is Partial or Complete because we will call eval of this + // function's children in the update method of this aggregate function. + // Those eval calls require BoundReferences to work. + BindReferences.bindReference(func, originalInputAttributes) + case _ => + // We only need to set inputBufferOffset for aggregate functions with mode + // PartialMerge and Final. + func match { + case function: ImperativeAggregate => + function.withNewInputAggBufferOffset(inputBufferOffset) + case _ => + } + inputBufferOffset += func.aggBufferSchema.length + func + } + // Set mutableBufferOffset for this function. It is important that setting + // mutableBufferOffset happens after all potential bindReference operations + // because bindReference will create a new instance of the function. + funcWithBoundReferences match { + case function: ImperativeAggregate => + function.withNewMutableAggBufferOffset(mutableBufferOffset) + case _ => + } + mutableBufferOffset += funcWithBoundReferences.aggBufferSchema.length + functions(i) = funcWithBoundReferences + i += 1 } + functions + } + + private[this] var allAggregateFunctions: Array[AggregateFunction2] = + initializeAllAggregateFunctions(initialInputBufferOffset) - allAggregateExpressions - .map(_.aggregateFunction.asInstanceOf[DeclarativeAggregate]) - .toArray + // Positions of those imperative aggregate functions in allAggregateFunctions. + // For example, we have func1, func2, func3, func4 in aggregateFunctions, and + // func2 and func3 are imperative aggregate functions. + // ImperativeAggregateFunctionPositions will be [1, 2]. + private[this] val allImperativeAggregateFunctionPositions: Array[Int] = { + val positions = new ArrayBuffer[Int]() + var i = 0 + while (i < allAggregateFunctions.length) { + allAggregateFunctions(i) match { + case agg: DeclarativeAggregate => + case _ => positions += i + } + i += 1 + } + positions.toArray } /////////////////////////////////////////////////////////////////////////// @@ -149,9 +212,14 @@ class TungstenAggregationIterator( // rows. /////////////////////////////////////////////////////////////////////////// - // The projection used to initialize buffer values. - private[this] val initialProjection: MutableProjection = { - val initExpressions = allAggregateFunctions.flatMap(_.initialValues) + // The projection used to initialize buffer values for all expression-based aggregates. + private[this] val expressionAggInitialProjection: MutableProjection = { + val initExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.initialValues + // For the positions corresponding to imperative aggregate functions, we'll use special + // no-op expressions which are ignored during projection code-generation. + case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) + } newMutableProjection(initExpressions, Nil)() } @@ -164,10 +232,27 @@ class TungstenAggregationIterator( val bufferRowSize: Int = bufferSchema.length val genericMutableBuffer = new GenericMutableRow(bufferRowSize) - val unsafeProjection = - UnsafeProjection.create(bufferSchema.map(_.dataType)) - val buffer = unsafeProjection.apply(genericMutableBuffer) - initialProjection.target(buffer)(EmptyRow) + // TODO(josh): figure out whether we have to use + val useUnsafeBuffer = bufferSchema.map(_.dataType).forall(UnsafeRow.isMutable) + + val buffer = /* if (useUnsafeBuffer) */ { + val unsafeProjection = + UnsafeProjection.create(bufferSchema.map(_.dataType)) + unsafeProjection.apply(genericMutableBuffer) +// } else { +// genericMutableBuffer + } + expressionAggInitialProjection.target(buffer)(EmptyRow) + // TODO(josh): this can be done more cleanly + val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allImperativeAggregateFunctionPositions + .map(allAggregateFunctions) + .map(_.asInstanceOf[ImperativeAggregate]) + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + allImperativeAggregateFunctions(i).initialize(buffer) + i += 1 + } buffer } @@ -181,72 +266,124 @@ class TungstenAggregationIterator( aggregationMode match { // Partial-only case (Some(Partial), None) => - val updateExpressions = allAggregateFunctions.flatMap(_.updateExpressions) - val updateProjection = + val updateExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + val expressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - updateProjection.target(currentBuffer) - updateProjection(joinedRow(currentBuffer, row)) + expressionAggUpdateProjection.target(currentBuffer) + // Process all expression-based aggregate functions. + expressionAggUpdateProjection(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // PartialMerge-only or Final-only case (Some(PartialMerge), None) | (Some(Final), None) => - val mergeExpressions = allAggregateFunctions.flatMap(_.mergeExpressions) - val mergeProjection = + val mergeExpressions = allAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val imperativeAggregateFunctions: Array[ImperativeAggregate] = + allAggregateFunctions.collect { case func: ImperativeAggregate => func} + // This projection is used to merge buffer values for all expression-based aggregates. + val expressionAggMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - mergeProjection.target(currentBuffer) - mergeProjection(joinedRow(currentBuffer, row)) + // Process all expression-based aggregate functions. + expressionAggMergeProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + // Process all imperative aggregate functions. + var i = 0 + while (i < imperativeAggregateFunctions.length) { + imperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Final-Complete case (Some(Final), Some(Complete)) => - val nonCompleteAggregateFunctions: Array[DeclarativeAggregate] = - allAggregateFunctions.take(nonCompleteAggregateExpressions.length) - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } + val nonCompleteAggregateFunctions: Array[AggregateFunction2] = + allAggregateFunctions.take(nonCompleteAggregateExpressions.length) + val nonCompleteImperativeAggregateFunctions: Array[ImperativeAggregate] = + nonCompleteAggregateFunctions.collect { case func: ImperativeAggregate => func } val completeOffsetExpressions = Seq.fill(completeAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) val mergeExpressions = - nonCompleteAggregateFunctions.flatMap(_.mergeExpressions) ++ completeOffsetExpressions + nonCompleteAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.mergeExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } ++ completeOffsetExpressions val finalMergeProjection = newMutableProjection(mergeExpressions, aggregationBufferAttributes ++ inputAttributes)() // We do not touch buffer values of aggregate functions with the Final mode. val finalOffsetExpressions = Seq.fill(nonCompleteAggregateFunctions.map(_.aggBufferAttributes.length).sum)(NoOp) - val updateExpressions = - finalOffsetExpressions ++ completeAggregateFunctions.flatMap(_.updateExpressions) + val updateExpressions = finalOffsetExpressions ++ completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } val completeUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { val input = joinedRow(currentBuffer, row) - // For all aggregate functions with mode Complete, update the given currentBuffer. + // For all aggregate functions with mode Complete, update buffers. completeUpdateProjection.target(currentBuffer)(input) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } // For all aggregate functions with mode Final, merge buffer values in row to // currentBuffer. finalMergeProjection.target(currentBuffer)(input) + i = 0 + while (i < nonCompleteImperativeAggregateFunctions.length) { + nonCompleteImperativeAggregateFunctions(i).merge(currentBuffer, row) + i += 1 + } } // Complete-only case (None, Some(Complete)) => - val completeAggregateFunctions: Array[DeclarativeAggregate] = + val completeAggregateFunctions: Array[AggregateFunction2] = allAggregateFunctions.takeRight(completeAggregateExpressions.length) + // All imperative aggregate functions with mode Complete. + val completeImperativeAggregateFunctions: Array[ImperativeAggregate] = + completeAggregateFunctions.collect { case func: ImperativeAggregate => func } - val updateExpressions = - completeAggregateFunctions.flatMap(_.updateExpressions) - val completeUpdateProjection = + val updateExpressions = completeAggregateFunctions.flatMap { + case ae: DeclarativeAggregate => ae.updateExpressions + case agg: AggregateFunction2 => Seq.fill(agg.aggBufferAttributes.length)(NoOp) + } + val completeExpressionAggUpdateProjection = newMutableProjection(updateExpressions, aggregationBufferAttributes ++ inputAttributes)() (currentBuffer: UnsafeRow, row: InternalRow) => { - completeUpdateProjection.target(currentBuffer) - // For all aggregate functions with mode Complete, update the given currentBuffer. - completeUpdateProjection(joinedRow(currentBuffer, row)) + // For all aggregate functions with mode Complete, update buffers. + completeExpressionAggUpdateProjection.target(currentBuffer)(joinedRow(currentBuffer, row)) + var i = 0 + while (i < completeImperativeAggregateFunctions.length) { + completeImperativeAggregateFunctions(i).update(currentBuffer, row) + i += 1 + } } // Grouping only. @@ -280,17 +417,41 @@ class TungstenAggregationIterator( // resultExpressions. case (Some(Final), None) | (Some(Final) | None, Some(Complete)) => val joinedRow = new JoinedRow() + val evalExpressions = allAggregateFunctions.map { + case ae: DeclarativeAggregate => ae.evaluateExpression + case agg: AggregateFunction2 => NoOp + } + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + // These are the attributes of the row produced by `expressionAggEvalProjection` + val aggregateResultSchema = nonCompleteAggregateAttributes ++ completeAggregateAttributes + // TODO: Use unsafe row. + val aggregateResult = new SpecificMutableRow(aggregateResultSchema.map(_.dataType)) + expressionAggEvalProjection.target(aggregateResult) val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes ++ bufferAttributes) + UnsafeProjection.create(resultExpressions, groupingAttributes ++ aggregateResultSchema) + + val allImperativeAggregateFunctions: Array[ImperativeAggregate] = + allImperativeAggregateFunctionPositions + .map(allAggregateFunctions) + .map(_.asInstanceOf[ImperativeAggregate]) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { - resultProjection(joinedRow(currentGroupingKey, currentBuffer)) + // Generate results for all expression-based aggregate functions. + expressionAggEvalProjection(currentBuffer) + // Generate results for all imperative aggregate functions. + var i = 0 + while (i < allImperativeAggregateFunctions.length) { + aggregateResult.update( + allImperativeAggregateFunctionPositions(i), + allImperativeAggregateFunctions(i).eval(currentBuffer)) + i += 1 + } + resultProjection(joinedRow(currentGroupingKey, aggregateResult)) } // Grouping-only: a output row is generated from values of grouping expressions. case (None, None) => - val resultProjection = - UnsafeProjection.create(resultExpressions, groupingAttributes) + val resultProjection = UnsafeProjection.create(resultExpressions, groupingAttributes) (currentGroupingKey: UnsafeRow, currentBuffer: UnsafeRow) => { resultProjection(currentGroupingKey) @@ -467,8 +628,8 @@ class TungstenAggregationIterator( // When needsProcess is false, the format of input rows is groupingKey + aggregation buffer. // We need to project the aggregation buffer part from an input row. val buffer = createNewAggregationBuffer() - // The originalInputAttributes are using cloneBufferAttributes. So, we need to use - // allAggregateFunctions.flatMap(_.cloneBufferAttributes). + // The originalInputAttributes are using inputAggBufferAttributes. So, we need to use + // allAggregateFunctions.flatMap(_.inputAggBufferAttributes). val bufferExtractor = newMutableProjection( allAggregateFunctions.flatMap(_.inputAggBufferAttributes), originalInputAttributes)() @@ -497,8 +658,10 @@ class TungstenAggregationIterator( } aggregationMode = newAggregationMode + allAggregateFunctions = initializeAllAggregateFunctions(startingInputBufferOffset = 0) + // Basically the value of the KVIterator returned by externalSorter - // will just aggregation buffer. At here, we use cloneBufferAttributes. + // will just aggregation buffer. At here, we use inputAggBufferAttributes. val newInputAttributes: Seq[Attribute] = allAggregateFunctions.flatMap(_.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index fd02be1225f27..622a7f79d926a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -341,6 +341,9 @@ private[sql] case class ScalaUDAF( override val aggBufferAttributes: Seq[AttributeReference] = aggBufferSchema.toAttributes + override val inputAggBufferAttributes: Seq[AttributeReference] = + aggBufferAttributes.map(_.newInstance()) + private[this] lazy val childrenSchema: StructType = { val inputFields = children.zipWithIndex.map { case (child, index) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala index e1c2d9475a10f..ebee37f35a8df 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/utils.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.execution.aggregate -import scala.collection.mutable - import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.aggregate._ import org.apache.spark.sql.execution.{UnsafeFixedWidthAggregationMap, SparkPlan} @@ -38,102 +36,70 @@ object Utils { } def planAggregateWithoutPartial( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) - + val groupingAttributes = groupingExpressions.map(_.toAttribute) val completeAggregateExpressions = aggregateExpressions.map(_.copy(mode = Complete)) - val completeAggregateAttributes = - completeAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] + val completeAggregateAttributes = completeAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingExpressions.map(_._2), + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = Nil, nonCompleteAggregateAttributes = Nil, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, initialInputBufferOffset = 0, - resultExpressions = rewrittenResultExpressions, + resultExpressions = resultExpressions, child = child ) :: Nil } def planAggregateWithoutDistinct( - groupingExpressions: Seq[Expression], + groupingExpressions: Seq[NamedExpression], aggregateExpressions: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { // Check if we can use TungstenAggregate. val usesTungstenAggregate = child.sqlContext.conf.unsafeEnabled && - aggregateExpressions.forall(_.aggregateFunction.isInstanceOf[DeclarativeAggregate]) && supportsTungstenAggregate( groupingExpressions, aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) val partialAggregateExpressions = aggregateExpressions.map(_.copy(mode = Partial)) val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialResultExpressions = - namedGroupingAttributes ++ + groupingAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, resultExpressions = partialResultExpressions, child = child) } else { SortBasedAggregate( requiredChildDistributionExpressions = None: Option[Seq[Expression]], - groupingExpressions = namedGroupingExpressions.map(_._2), + groupingExpressions = groupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, @@ -145,58 +111,33 @@ object Utils { // 2. Create an Aggregate Operator for final aggregations. val finalAggregateExpressions = aggregateExpressions.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } val finalAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - // aggregateFunctionMap contains unique aggregate functions. - val aggregateFunction = - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._1 - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = Nil, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, child = partialAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transformDown { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = namedGroupingAttributes.length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = groupingExpressions.length, + resultExpressions = resultExpressions, child = partialAggregate) } @@ -204,10 +145,13 @@ object Utils { } def planAggregateWithOneDistinct( - groupingExpressions: Seq[Expression], + // The grouping expressions are original groupingExpressions and + // distinct columns. For example, for avg(distinct value) ... group by key + // the grouping expressions of this Aggregate Operator will be [key, value]. + groupingExpressions: Seq[NamedExpression], functionsWithDistinct: Seq[AggregateExpression2], functionsWithoutDistinct: Seq[AggregateExpression2], - aggregateFunctionMap: Map[(AggregateFunction2, Boolean), (AggregateFunction2, Attribute)], + aggregateFunctionToAttribute: Map[(AggregateFunction2, Boolean), Attribute], resultExpressions: Seq[NamedExpression], child: SparkPlan): Seq[SparkPlan] = { @@ -221,20 +165,8 @@ object Utils { aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)) // 1. Create an Aggregate Operator for partial aggregations. - // The grouping expressions are original groupingExpressions and - // distinct columns. For example, for avg(distinct value) ... group by key - // the grouping expressions of this Aggregate Operator will be [key, value]. - val namedGroupingExpressions = groupingExpressions.map { - case ne: NamedExpression => ne -> ne - // If the expression is not a NamedExpressions, we add an alias. - // So, when we generate the result of the operator, the Aggregate Operator - // can directly get the Seq of attributes representing the grouping expressions. - case other => - val withAlias = Alias(other, other.toString)() - other -> withAlias - } - val groupExpressionMap = namedGroupingExpressions.toMap - val namedGroupingAttributes = namedGroupingExpressions.map(_._2.toAttribute) + + val groupingAttributes = groupingExpressions.map(_.toAttribute) // It is safe to call head at here since functionsWithDistinct has at least one // AggregateExpression2. @@ -253,22 +185,25 @@ object Utils { val partialAggregateAttributes = partialAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialAggregateGroupingExpressions = - (namedGroupingExpressions ++ namedDistinctColumnExpressions).map(_._2) + groupingExpressions ++ namedDistinctColumnExpressions.map(_._2) val partialAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, + nonCompleteAggregateAttributes = partialAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = 0, resultExpressions = partialAggregateResult, child = child) } else { SortBasedAggregate( - requiredChildDistributionExpressions = None: Option[Seq[Expression]], + requiredChildDistributionExpressions = None, groupingExpressions = partialAggregateGroupingExpressions, nonCompleteAggregateExpressions = partialAggregateExpressions, nonCompleteAggregateAttributes = partialAggregateAttributes, @@ -284,41 +219,41 @@ object Utils { val partialMergeAggregateAttributes = partialMergeAggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes) val partialMergeAggregateResult = - namedGroupingAttributes ++ + groupingAttributes ++ distinctColumnAttributes ++ partialMergeAggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) val partialMergeAggregate = if (usesTungstenAggregate) { TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, + nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, + completeAggregateAttributes = Nil, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } else { SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes ++ distinctColumnAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes ++ distinctColumnAttributes, nonCompleteAggregateExpressions = partialMergeAggregateExpressions, nonCompleteAggregateAttributes = partialMergeAggregateAttributes, completeAggregateExpressions = Nil, completeAggregateAttributes = Nil, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, resultExpressions = partialMergeAggregateResult, child = partialAggregate) } // 3. Create an Aggregate Operator for partial merge aggregations. val finalAggregateExpressions = functionsWithoutDistinct.map(_.copy(mode = Final)) - val finalAggregateAttributes = - finalAggregateExpressions.map { - expr => aggregateFunctionMap(expr.aggregateFunction, expr.isDistinct)._2 - } - // Create a map to store those rewritten aggregate functions. We always need to use - // both function and its corresponding isDistinct flag as the key because function itself - // does not knows if it is has distinct keyword or now. - val rewrittenAggregateFunctions = - mutable.Map.empty[(AggregateFunction2, Boolean), AggregateFunction2] + // The attributes of the final aggregation buffer, which is presented as input to the result + // projection: + val finalAggregateAttributes = finalAggregateExpressions.map { + expr => aggregateFunctionToAttribute(expr.aggregateFunction, expr.isDistinct) + } + val (completeAggregateExpressions, completeAggregateAttributes) = functionsWithDistinct.map { // Children of an AggregateFunction with DISTINCT keyword has already // been evaluated. At here, we need to replace original children @@ -328,9 +263,6 @@ object Utils { case expr if distinctColumnExpressionMap.contains(expr) => distinctColumnExpressionMap(expr).toAttribute }.asInstanceOf[AggregateFunction2] - // Because we have rewritten the aggregate function, we use rewrittenAggregateFunctions - // to track the old version and the new version of this function. - rewrittenAggregateFunctions += (aggregateFunction, true) -> rewrittenAggregateFunction // We rewrite the aggregate function to a non-distinct aggregation because // its input will have distinct arguments. // We just keep the isDistinct setting to true, so when users look at the query plan, @@ -338,66 +270,31 @@ object Utils { val rewrittenAggregateExpression = AggregateExpression2(rewrittenAggregateFunction, Complete, true) - val aggregateFunctionAttribute = - aggregateFunctionMap(agg.aggregateFunction, true)._2 - (rewrittenAggregateExpression -> aggregateFunctionAttribute) + val aggregateFunctionAttribute = aggregateFunctionToAttribute(agg.aggregateFunction, true) + (rewrittenAggregateExpression, aggregateFunctionAttribute) }.unzip val finalAndCompleteAggregate = if (usesTungstenAggregate) { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - val function = agg.aggregateFunction - val isDistinct = agg.isDistinct - val aggregateFunction = - if (rewrittenAggregateFunctions.contains(function, isDistinct)) { - // If this function has been rewritten, we get the rewritten version from - // rewrittenAggregateFunctions. - rewrittenAggregateFunctions(function, isDistinct) - } else { - // Oterwise, we get it from aggregateFunctionMap, which contains unique - // aggregate functions that have not been rewritten. - aggregateFunctionMap(function, isDistinct)._1 - } - aggregateFunction.asInstanceOf[DeclarativeAggregate].evaluateExpression - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } - TungstenAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, + nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, - resultExpressions = rewrittenResultExpressions, + completeAggregateAttributes = completeAggregateAttributes, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = resultExpressions, child = partialMergeAggregate) } else { - val rewrittenResultExpressions = resultExpressions.map { expr => - expr.transform { - case agg: AggregateExpression2 => - aggregateFunctionMap(agg.aggregateFunction, agg.isDistinct)._2 - case expression => - // We do not rely on the equality check at here since attributes may - // different cosmetically. Instead, we use semanticEquals. - groupExpressionMap.collectFirst { - case (expr, ne) if expr semanticEquals expression => ne.toAttribute - }.getOrElse(expression) - }.asInstanceOf[NamedExpression] - } SortBasedAggregate( - requiredChildDistributionExpressions = Some(namedGroupingAttributes), - groupingExpressions = namedGroupingAttributes, + requiredChildDistributionExpressions = Some(groupingAttributes), + groupingExpressions = groupingAttributes, nonCompleteAggregateExpressions = finalAggregateExpressions, nonCompleteAggregateAttributes = finalAggregateAttributes, completeAggregateExpressions = completeAggregateExpressions, completeAggregateAttributes = completeAggregateAttributes, - initialInputBufferOffset = (namedGroupingAttributes ++ distinctColumnAttributes).length, - resultExpressions = rewrittenResultExpressions, + initialInputBufferOffset = (groupingAttributes ++ distinctColumnAttributes).length, + resultExpressions = resultExpressions, child = partialMergeAggregate) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala index 7ca677a6c72ad..0cc4988ff681c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala @@ -38,8 +38,8 @@ class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLConte () => new InterpretedMutableProjection(expr, schema) } val dummyAccum = SQLMetrics.createLongMetric(sparkContext, "dummy") - iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, - Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) + iter = new TungstenAggregationIterator(Seq.empty, Seq.empty, Seq.empty, Seq.empty, Seq.empty, + 0, Seq.empty, newMutableProjection, Seq.empty, None, dummyAccum, dummyAccum) val numPages = iter.getHashMap.getNumDataPages assert(numPages === 1) } finally { diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala index 18bbdb9908142..9362ae05dca7b 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUDFs.scala @@ -614,7 +614,9 @@ private[hive] case class HiveUDAFFunction( buffer = function.getNewAggregationBuffer } - override def aggBufferAttributes: Seq[AttributeReference] = Nil + override val aggBufferAttributes: Seq[AttributeReference] = Nil + + override val inputAggBufferAttributes: Seq[AttributeReference] = Nil // We rely on Hive to check the input data types, so use `AnyDataType` here to bypass our // catalyst type checking framework.