diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala index 59fdf659c9e1..3d8ce3f1fc48 100644 --- a/core/src/main/scala/org/apache/spark/Aggregator.scala +++ b/core/src/main/scala/org/apache/spark/Aggregator.scala @@ -19,6 +19,7 @@ package org.apache.spark import org.apache.spark.annotation.DeveloperApi import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import org.apache.spark.serializer.Serializer /** * :: DeveloperApi :: @@ -27,12 +28,14 @@ import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} * @param createCombiner function to create the initial value of the aggregation. * @param mergeValue function to merge a new value into the aggregation result. * @param mergeCombiners function to merge outputs from multiple mergeValue function. + * @param serializer serializer to persist data internally. */ @DeveloperApi case class Aggregator[K, V, C] ( createCombiner: V => C, mergeValue: (C, V) => C, - mergeCombiners: (C, C) => C) { + mergeCombiners: (C, C) => C, + serializer: Serializer = SparkEnv.get.serializer) { private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true) @@ -54,7 +57,8 @@ case class Aggregator[K, V, C] ( } combiners.iterator } else { - val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners) + val combiners = new ExternalAppendOnlyMap[K, V, C]( + createCombiner, mergeValue, mergeCombiners, serializer) while (iter.hasNext) { val (k, v) = iter.next() combiners.insert(k, v) @@ -83,7 +87,8 @@ case class Aggregator[K, V, C] ( } combiners.iterator } else { - val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners) + val combiners = new ExternalAppendOnlyMap[K, C, C]( + identity, mergeCombiners, mergeCombiners, serializer) while (iter.hasNext) { val (k, c) = iter.next() combiners.insert(k, c) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala index b152f95f96c7..90199a47f1dd 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregates.scala @@ -80,6 +80,7 @@ abstract class AggregateFunction override def dataType = base.dataType def update(input: Row): Unit + def merge(other: AggregateFunction): Unit override def eval(input: Row): Any // Do we really need this? @@ -189,6 +190,16 @@ case class AverageFunction(expr: Expression, base: AggregateExpression) count += 1 sum.update(addFunction, input) } + + override def merge(other: AggregateFunction): Unit = { + other match { + case avg: AverageFunction => { + count += avg.count + sum.update(Add(sum, avg.sum), EmptyRow) + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } } case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction { @@ -203,6 +214,15 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag } } + override def merge(other: AggregateFunction): Unit = { + other match { + case c: CountFunction => { + count += c.count + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } + override def eval(input: Row): Any = count } @@ -217,6 +237,15 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr sum.update(addFunction, input) } + override def merge(other: AggregateFunction): Unit = { + other match { + case s: SumFunction => { + sum.update(Add(sum, s.sum), EmptyRow) + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } + override def eval(input: Row): Any = sum.eval(null) } @@ -234,6 +263,19 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression) } } + override def merge(other: AggregateFunction): Unit = { + other match { + case sd: SumDistinctFunction => { + // TODO(lamuguo): Change to HashSet union scala rebase to support it. Related change: + // https://github.com/scala/scala/pull/3322 + for (item <- sd.seen) { + seen += item + } + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } + override def eval(input: Row): Any = seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus) } @@ -252,6 +294,17 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio } } + override def merge(other: AggregateFunction): Unit = { + other match { + case cd: CountDistinctFunction => { + for (item <- cd.seen) { + seen += item + } + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } + override def eval(input: Row): Any = seen.size } @@ -266,5 +319,16 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag } } + override def merge(other: AggregateFunction): Unit = { + other match { + case second: FirstFunction => { + if (result == null) { + result = second.result + } + } + case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}") + } + } + override def eval(input: Row): Any = result } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala index 36b3b956da96..6b5bd7b23da2 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala @@ -17,13 +17,13 @@ package org.apache.spark.sql.execution -import java.util.HashMap - import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.SparkContext +import org.apache.spark.{Logging, SparkConf, Aggregator, SparkContext} import org.apache.spark.sql.catalyst.errors._ import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.physical._ +import org.apache.spark.sql.execution.SparkSqlSerializer +import scala.collection.mutable.ArrayBuffer /** * :: DeveloperApi :: @@ -42,7 +42,7 @@ case class Aggregate( groupingExpressions: Seq[Expression], aggregateExpressions: Seq[NamedExpression], child: SparkPlan)(@transient sc: SparkContext) - extends UnaryNode with NoBind { + extends UnaryNode with NoBind with Logging { override def requiredChildDistribution = if (partial) { @@ -155,48 +155,63 @@ case class Aggregate( } } else { child.execute().mapPartitions { iter => - val hashTable = new HashMap[Row, Array[AggregateFunction]] val groupingProjection = new MutableProjection(groupingExpressions, childOutput) - - var currentRow: Row = null - while (iter.hasNext) { - currentRow = iter.next() - val currentGroup = groupingProjection(currentRow) - var currentBuffer = hashTable.get(currentGroup) - if (currentBuffer == null) { - currentBuffer = newAggregateBuffer() - hashTable.put(currentGroup.copy(), currentBuffer) + def createCombiner(row: Row) = mergeValue(newAggregateBuffer(), row) + def mergeValue(buffer: Array[AggregateFunction], row: Row) = { + var i = 0 + while (i < buffer.length) { + buffer(i).update(row) + i += 1 + } + buffer + } + def mergeCombiners(buf1: Array[AggregateFunction], buf2: Array[AggregateFunction]) = { + if (buf1.length != buf2.length) { + throw new TreeNodeException(this, s"Unequal aggregate buffer length ${buf1.length} != ${buf2.length}") } - var i = 0 - while (i < currentBuffer.length) { - currentBuffer(i).update(currentRow) + while (i < buf1.length) { + buf1(i).merge(buf2(i)) i += 1 } + buf1 } - + val aggregator = new Aggregator[Row, Row, Array[AggregateFunction]]( + createCombiner, mergeValue, mergeCombiners, new SparkSqlSerializer(new SparkConf(false))) + + val aggIter = aggregator.combineValuesByKey( + new Iterator[(Row, Row)] { // (groupKey, row) + override final def hasNext: Boolean = iter.hasNext + + override final def next(): (Row, Row) = { + val row = iter.next() + // TODO: copy() here for suppressing reference problems. Please clearly address + // the root-cause and remove copy() here. + (groupingProjection(row).copy(), row) + } + }, + null + ) new Iterator[Row] { - private[this] val hashTableIter = hashTable.entrySet().iterator() private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length) - private[this] val resultProjection = - new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2)) + private[this] val resultProjection = new MutableProjection( + resultExpressions, computedSchema ++ namedGroups.map(_._2)) private[this] val joinedRow = new JoinedRow - override final def hasNext: Boolean = hashTableIter.hasNext + override final def hasNext: Boolean = aggIter.hasNext override final def next(): Row = { - val currentEntry = hashTableIter.next() - val currentGroup = currentEntry.getKey - val currentBuffer = currentEntry.getValue + val entry = aggIter.next() + val group = entry._1 + val data = entry._2 var i = 0 - while (i < currentBuffer.length) { - // Evaluating an aggregate buffer returns the result. No row is required since we - // already added all rows in the group using update. - aggregateResults(i) = currentBuffer(i).eval(EmptyRow) + while (i < data.length) { + aggregateResults(i) = data(i).eval(EmptyRow) i += 1 } - resultProjection(joinedRow(aggregateResults, currentGroup)) + + resultProjection(joinedRow(aggregateResults, group)) } } }