Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions core/src/main/scala/org/apache/spark/Aggregator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package org.apache.spark

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
import org.apache.spark.serializer.Serializer

/**
* :: DeveloperApi ::
Expand All @@ -32,7 +33,8 @@ import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
case class Aggregator[K, V, C] (
createCombiner: V => C,
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
mergeCombiners: (C, C) => C,
serializer: Serializer = SparkEnv.get.serializer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also update the documentation above to add the new parameter.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done


private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)

Expand All @@ -54,7 +56,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, V, C](createCombiner, mergeValue, mergeCombiners)
val combiners = new ExternalAppendOnlyMap[K, V, C](
createCombiner, mergeValue, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, v) = iter.next()
combiners.insert(k, v)
Expand Down Expand Up @@ -83,7 +86,8 @@ case class Aggregator[K, V, C] (
}
combiners.iterator
} else {
val combiners = new ExternalAppendOnlyMap[K, C, C](identity, mergeCombiners, mergeCombiners)
val combiners = new ExternalAppendOnlyMap[K, C, C](
identity, mergeCombiners, mergeCombiners, serializer)
while (iter.hasNext) {
val (k, c) = iter.next()
combiners.insert(k, c)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ abstract class AggregateFunction
override def dataType = base.dataType

def update(input: Row): Unit
def merge(other: AggregateFunction): Unit
override def eval(input: Row): Any

// Do we really need this?
Expand Down Expand Up @@ -189,6 +190,16 @@ case class AverageFunction(expr: Expression, base: AggregateExpression)
count += 1
sum.update(addFunction, input)
}

override def merge(other: AggregateFunction): Unit = {
other match {
case avg: AverageFunction => {
count += avg.count
sum.update(Add(sum, avg.sum), EmptyRow)
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}
}

case class CountFunction(expr: Expression, base: AggregateExpression) extends AggregateFunction {
Expand All @@ -203,6 +214,15 @@ case class CountFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case c: CountFunction => {
count += c.count
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = count
}

Expand All @@ -217,6 +237,15 @@ case class SumFunction(expr: Expression, base: AggregateExpression) extends Aggr
sum.update(addFunction, input)
}

override def merge(other: AggregateFunction): Unit = {
other match {
case s: SumFunction => {
sum.update(Add(sum, s.sum), EmptyRow)
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = sum.eval(null)
}

Expand All @@ -234,6 +263,19 @@ case class SumDistinctFunction(expr: Expression, base: AggregateExpression)
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case sd: SumDistinctFunction => {
// TODO(lamuguo): Change to HashSet union scala rebase to support it. Related change:
// https://github.com/scala/scala/pull/3322
for (item <- sd.seen) {
seen += item
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any =
seen.reduceLeft(base.dataType.asInstanceOf[NumericType].numeric.asInstanceOf[Numeric[Any]].plus)
}
Expand All @@ -252,6 +294,17 @@ case class CountDistinctFunction(expr: Seq[Expression], base: AggregateExpressio
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case cd: CountDistinctFunction => {
for (item <- cd.seen) {
seen += item
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = seen.size
}

Expand All @@ -266,5 +319,16 @@ case class FirstFunction(expr: Expression, base: AggregateExpression) extends Ag
}
}

override def merge(other: AggregateFunction): Unit = {
other match {
case second: FirstFunction => {
if (result == null) {
result = second.result
}
}
case _ => throw new TreeNodeException(this, s"Types do not match ${this.dataType} != ${other.dataType}")
}
}

override def eval(input: Row): Any = result
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.execution

import java.util.HashMap

import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext
import org.apache.spark.{Logging, SparkConf, Aggregator, SparkContext}
import org.apache.spark.sql.catalyst.errors._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.SparkSqlSerializer
import scala.collection.mutable.ArrayBuffer

/**
* :: DeveloperApi ::
Expand All @@ -42,7 +42,7 @@ case class Aggregate(
groupingExpressions: Seq[Expression],
aggregateExpressions: Seq[NamedExpression],
child: SparkPlan)(@transient sc: SparkContext)
extends UnaryNode with NoBind {
extends UnaryNode with NoBind with Logging {

override def requiredChildDistribution =
if (partial) {
Expand Down Expand Up @@ -155,48 +155,60 @@ case class Aggregate(
}
} else {
child.execute().mapPartitions { iter =>
val hashTable = new HashMap[Row, Array[AggregateFunction]]
val groupingProjection = new MutableProjection(groupingExpressions, childOutput)

var currentRow: Row = null
while (iter.hasNext) {
currentRow = iter.next()
val currentGroup = groupingProjection(currentRow)
var currentBuffer = hashTable.get(currentGroup)
if (currentBuffer == null) {
currentBuffer = newAggregateBuffer()
hashTable.put(currentGroup.copy(), currentBuffer)
val groupingProjection = new
MutableProjection(groupingExpressions, childOutput)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to wrap this line

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// TODO: Can't use "Array[AggregateFunction]" directly, due to lack of
// "concat(AggregateFunction, AggregateFunction)". Should add
// AggregateFunction.update(agg: AggregateFunction) in the future.
def createCombiner(row: Row) = mergeValue(newAggregateBuffer(), row)
def mergeValue(buffer: Array[AggregateFunction], row: Row) = {
for (i <- 0 to buffer.length - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It'd be better to rewrite this using a while loop, since while loops perform much better than for loop in Scala.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

buffer(i).update(row)
}

var i = 0
while (i < currentBuffer.length) {
currentBuffer(i).update(currentRow)
i += 1
buffer
}
def mergeCombiners(buf1: Array[AggregateFunction], buf2: Array[AggregateFunction]) = {
if (buf1.length != buf2.length) {
throw new TreeNodeException(this, s"Unequal aggregate buffer length ${buf1.length} != ${buf2.length}")
}
for (i <- 0 to buf1.length - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while loop here too

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

buf1(i).merge(buf2(i))
}
buf1
}

val aggregator = new Aggregator[Row, Row, Array[AggregateFunction]](
createCombiner, mergeValue, mergeCombiners, new SparkSqlSerializer(new SparkConf(false)))

val aggIter = aggregator.combineValuesByKey(
new Iterator[(Row, Row)] { // (groupKey, row)
override final def hasNext: Boolean = iter.hasNext

override final def next(): (Row, Row) = {
val row = iter.next()
// TODO: copy() here for suppressing reference problems. Please clearly address
// the root-cause and remove copy() here.
(groupingProjection(row).copy(), row)
}
},
null
)
new Iterator[Row] {
private[this] val hashTableIter = hashTable.entrySet().iterator()
private[this] val aggregateResults = new GenericMutableRow(computedAggregates.length)
private[this] val resultProjection =
new MutableProjection(resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val resultProjection = new MutableProjection(
resultExpressions, computedSchema ++ namedGroups.map(_._2))
private[this] val joinedRow = new JoinedRow

override final def hasNext: Boolean = hashTableIter.hasNext
override final def hasNext: Boolean = aggIter.hasNext

override final def next(): Row = {
val currentEntry = hashTableIter.next()
val currentGroup = currentEntry.getKey
val currentBuffer = currentEntry.getValue

var i = 0
while (i < currentBuffer.length) {
// Evaluating an aggregate buffer returns the result. No row is required since we
// already added all rows in the group using update.
aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
i += 1
val entry = aggIter.next()
val group = entry._1
val data = entry._2

for (i <- 0 to data.length - 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

again, while loop here.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi there,

I made some changes per comments couple of days ago here:
#867 (comment). Please take
another look. Thanks!

Best Regards,
Xiaofeng

On Sat, May 24, 2014 at 8:52 PM, Reynold Xin [email protected]
wrote:

In sql/core/src/main/scala/org/apache/spark/sql/execution/Aggregate.scala:

       override final def next(): Row = {
  •        val currentEntry = hashTableIter.next()
    
  •        val currentGroup = currentEntry.getKey
    

- val currentBuffer = currentEntry.getValue

  •        var i = 0
    
  •        while (i < currentBuffer.length) {
    
  •          // Evaluating an aggregate buffer returns the result.  No row is required since we
    
  •          // already added all rows in the group using update.
    
  •          aggregateResults(i) = currentBuffer(i).eval(EmptyRow)
    
  •          i += 1
    
  •        val entry = aggIter.next()
    
  •        val group = entry._1
    
  •        val data = entry._2
    
  •        for (i <- 0 to data.length - 1) {
    

again, while loop here.


Reply to this email directly or view it on GitHub
https://github.com/apache/spark/pull/867/files#r13029115.

aggregateResults(i) = data(i).eval(EmptyRow)
}
resultProjection(joinedRow(aggregateResults, currentGroup))
resultProjection(joinedRow(aggregateResults, group))
}
}
}
Expand Down