diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleAggregationManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleAggregationManager.scala new file mode 100644 index 0000000000000..74b9f9e8bb717 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleAggregationManager.scala @@ -0,0 +1,78 @@ +package org.apache.spark.shuffle + +import java.util + +import org.apache.spark.SparkEnv +import org.apache.spark.SparkConf +import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap} +import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.MutableList +import scala.collection.Iterator + +/** + * Created by vladio on 7/14/15. + */ +private[spark] class ShuffleAggregationManager[K, V]( + iterator: Iterator[Product2[K, V]]) { + + private[this] var isSpillEnabled = true + private[this] var partialAggCheckInterval = 10000 + private[this] var partialAggReduction = 0.5 + private[this] var partialAggEnabled = true + + private[this] var uniqueKeysMap = createNormalMap[K, Boolean] + private[this] var iteratedElements = MutableList[Product2[K, V]]() + private[this] var numIteratedRecords = 0 + + private[spark] def withConf(conf: SparkConf): this.type = { + isSpillEnabled = conf.getBoolean("spark.shuffle.spill", true) + partialAggCheckInterval = conf.getInt("spark.partialAgg.interval", 10000) + partialAggReduction = conf.getDouble("spark.partialAgg.reduction", 0.5) + this + } + + // Functions for creating an ExternalAppendOnlyMap + def createCombiner[T](i: T) = ArrayBuffer[T](i) + def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i + def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] = + buf1 ++= buf2 + + def createExternalMap[K, T] = new ExternalAppendOnlyMap[K, T, ArrayBuffer[T]]( + createCombiner[T], mergeValue[T], mergeCombiners[T]) + + def createNormalMap[K, T] = new AppendOnlyMap[K, T]() + + if (SparkEnv.get != null) { + withConf(SparkEnv.get.conf) + } + + def getRestoredIterator(): Iterator[Product2[K, V]] = { + if (iterator.hasNext) + iteratedElements.toIterator ++ iterator + else + iteratedElements.toIterator + } + + def enableAggregation(): Boolean = { + var ok : Boolean = true + while (iterator.hasNext && partialAggEnabled && ok) { + val kv = iterator.next() + + iteratedElements += kv + numIteratedRecords += 1 + + uniqueKeysMap.update(kv._1, true) + + if (numIteratedRecords == partialAggCheckInterval) { + val partialAggSize = uniqueKeysMap.size + if (partialAggSize > numIteratedRecords * partialAggReduction) { + partialAggEnabled = false + } + + ok = false + } + } + + partialAggEnabled + } +} diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala index eb87cee15903c..1c45daf18d647 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala @@ -50,9 +50,16 @@ private[spark] class HashShuffleWriter[K, V]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { + // Decide if it's optimal to do the pre-aggregation. + val aggManager = new ShuffleAggregationManager[K, V](records) + val iter = if (dep.aggregator.isDefined) { if (dep.mapSideCombine) { - dep.aggregator.get.combineValuesByKey(records, context) + if (aggManager.enableAggregation()) { + dep.aggregator.get.combineValuesByKey(aggManager.getRestoredIterator(), context) + } else { + aggManager.getRestoredIterator() + } } else { records } diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala index 5865e7640c1cf..d110cbe0ae9c2 100644 --- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala +++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala @@ -21,7 +21,7 @@ import org.apache.spark._ import org.apache.spark.executor.ShuffleWriteMetrics import org.apache.spark.scheduler.MapStatus import org.apache.spark.serializer.Serializer -import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} +import org.apache.spark.shuffle.{ShuffleAggregationManager, IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle} import org.apache.spark.storage.ShuffleBlockId import org.apache.spark.util.collection.ExternalSorter @@ -50,7 +50,14 @@ private[spark] class SortShuffleWriter[K, V, C]( /** Write a bunch of records to this task's output */ override def write(records: Iterator[Product2[K, V]]): Unit = { - sorter = if (dep.mapSideCombine) { + // Decide if it's optimal to do the pre-aggregation. + val aggManager = new ShuffleAggregationManager[K, V](records) + val enableAggregation = aggManager.enableAggregation() + val enableMapSideCombine = (dep.mapSideCombine & enableAggregation) + + System.out.println("SortShuffleWriter - Enable Pre-Aggregation: " + enableMapSideCombine) + + sorter = if (enableMapSideCombine) { require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!") new ExternalSorter[K, V, C]( dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer) @@ -70,7 +77,8 @@ private[spark] class SortShuffleWriter[K, V, C]( new ExternalSorter[K, V, V]( aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer) } - sorter.insertAll(records) + + sorter.insertAll(aggManager.getRestoredIterator()) // Don't bother including the time to open the merged output file in the shuffle write time, // because it just opens a single file, so is typically too fast to measure accurately diff --git a/core/src/test/scala/org/apache/spark/ShuffleAggregationManagerSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleAggregationManagerSuite.scala new file mode 100644 index 0000000000000..445341480d165 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/ShuffleAggregationManagerSuite.scala @@ -0,0 +1,61 @@ +package org.apache.spark + +import org.apache.spark.shuffle.ShuffleAggregationManager +import org.apache.spark.shuffle.sort.SortShuffleWriter._ +import org.mockito.Mockito._ + +/** + * Created by vladio on 7/15/15. + */ +class ShuffleAggregationManagerSuite extends SparkFunSuite { + + test("conditions for doing the pre-aggregation") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.partialAgg.interval", "4") + conf.set("spark.partialAgg.reduction", "0.5") + + // This test will pass if the first 4 elements of a set contains at most 2 unique keys. + // Generate the records. + val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta")) + + // Test. + val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf) + assert(aggManager.enableAggregation() == true) + } + + test("conditions for skipping the pre-aggregation") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.partialAgg.interval", "4") + conf.set("spark.partialAgg.reduction", "0.5") + + val records = Iterator((1, "Vlad"), (2, "Marius"), (3, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta")) + + val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf) + assert(aggManager.enableAggregation() == false) + } + + test("restoring the iterator") { + val conf = new SparkConf(loadDefaults = false) + conf.set("spark.partialAgg.interval", "4") + conf.set("spark.partialAgg.reduction", "0.5") + + val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta")) + val recordsCopy = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta")) + + val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf) + assert(aggManager.enableAggregation() == true) + + val restoredRecords = aggManager.getRestoredIterator() + assert(restoredRecords.hasNext) + + while (restoredRecords.hasNext && recordsCopy.hasNext) { + val kv1 = restoredRecords.next() + val kv2 = recordsCopy.next() + + assert(kv1 == kv2) + } + + assert(!restoredRecords.hasNext) + assert(!recordsCopy.hasNext) + } +}