Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
package org.apache.spark.shuffle

import java.util

import org.apache.spark.SparkEnv
import org.apache.spark.SparkConf
import org.apache.spark.util.collection.{AppendOnlyMap, ExternalAppendOnlyMap}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.MutableList
import scala.collection.Iterator

/**
* Created by vladio on 7/14/15.
*/
private[spark] class ShuffleAggregationManager[K, V](
iterator: Iterator[Product2[K, V]]) {

private[this] var isSpillEnabled = true
private[this] var partialAggCheckInterval = 10000
private[this] var partialAggReduction = 0.5
private[this] var partialAggEnabled = true

private[this] var uniqueKeysMap = createNormalMap[K, Boolean]
private[this] var iteratedElements = MutableList[Product2[K, V]]()
private[this] var numIteratedRecords = 0

private[spark] def withConf(conf: SparkConf): this.type = {
isSpillEnabled = conf.getBoolean("spark.shuffle.spill", true)
partialAggCheckInterval = conf.getInt("spark.partialAgg.interval", 10000)
partialAggReduction = conf.getDouble("spark.partialAgg.reduction", 0.5)
this
}

// Functions for creating an ExternalAppendOnlyMap
def createCombiner[T](i: T) = ArrayBuffer[T](i)
def mergeValue[T](buffer: ArrayBuffer[T], i: T): ArrayBuffer[T] = buffer += i
def mergeCombiners[T](buf1: ArrayBuffer[T], buf2: ArrayBuffer[T]): ArrayBuffer[T] =
buf1 ++= buf2

def createExternalMap[K, T] = new ExternalAppendOnlyMap[K, T, ArrayBuffer[T]](
createCombiner[T], mergeValue[T], mergeCombiners[T])

def createNormalMap[K, T] = new AppendOnlyMap[K, T]()

if (SparkEnv.get != null) {
withConf(SparkEnv.get.conf)
}

def getRestoredIterator(): Iterator[Product2[K, V]] = {
if (iterator.hasNext)
iteratedElements.toIterator ++ iterator
else
iteratedElements.toIterator
}

def enableAggregation(): Boolean = {
var ok : Boolean = true
while (iterator.hasNext && partialAggEnabled && ok) {
val kv = iterator.next()

iteratedElements += kv
numIteratedRecords += 1

uniqueKeysMap.update(kv._1, true)

if (numIteratedRecords == partialAggCheckInterval) {
val partialAggSize = uniqueKeysMap.size
if (partialAggSize > numIteratedRecords * partialAggReduction) {
partialAggEnabled = false
}

ok = false
}
}

partialAggEnabled
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,16 @@ private[spark] class HashShuffleWriter[K, V](

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
// Decide if it's optimal to do the pre-aggregation.
val aggManager = new ShuffleAggregationManager[K, V](records)

val iter = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
dep.aggregator.get.combineValuesByKey(records, context)
if (aggManager.enableAggregation()) {
dep.aggregator.get.combineValuesByKey(aggManager.getRestoredIterator(), context)
} else {
aggManager.getRestoredIterator()
}
} else {
records
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.shuffle.{ShuffleAggregationManager, IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
import org.apache.spark.storage.ShuffleBlockId
import org.apache.spark.util.collection.ExternalSorter

Expand Down Expand Up @@ -50,7 +50,14 @@ private[spark] class SortShuffleWriter[K, V, C](

/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
sorter = if (dep.mapSideCombine) {
// Decide if it's optimal to do the pre-aggregation.
val aggManager = new ShuffleAggregationManager[K, V](records)
val enableAggregation = aggManager.enableAggregation()
val enableMapSideCombine = (dep.mapSideCombine & enableAggregation)

System.out.println("SortShuffleWriter - Enable Pre-Aggregation: " + enableMapSideCombine)

sorter = if (enableMapSideCombine) {
require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
Expand All @@ -70,7 +77,8 @@ private[spark] class SortShuffleWriter[K, V, C](
new ExternalSorter[K, V, V](
aggregator = None, Some(dep.partitioner), ordering = None, dep.serializer)
}
sorter.insertAll(records)

sorter.insertAll(aggManager.getRestoredIterator())

// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package org.apache.spark

import org.apache.spark.shuffle.ShuffleAggregationManager
import org.apache.spark.shuffle.sort.SortShuffleWriter._
import org.mockito.Mockito._

/**
* Created by vladio on 7/15/15.
*/
class ShuffleAggregationManagerSuite extends SparkFunSuite {

test("conditions for doing the pre-aggregation") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

// This test will pass if the first 4 elements of a set contains at most 2 unique keys.
// Generate the records.
val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta"))

// Test.
val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf)
assert(aggManager.enableAggregation() == true)
}

test("conditions for skipping the pre-aggregation") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

val records = Iterator((1, "Vlad"), (2, "Marius"), (3, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta"))

val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf)
assert(aggManager.enableAggregation() == false)
}

test("restoring the iterator") {
val conf = new SparkConf(loadDefaults = false)
conf.set("spark.partialAgg.interval", "4")
conf.set("spark.partialAgg.reduction", "0.5")

val records = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta"))
val recordsCopy = Iterator((1, "Vlad"), (2, "Marius"), (1, "Marian"), (2, "Cornel"), (3, "Patricia"), (4, "Georgeta"))

val aggManager = new ShuffleAggregationManager[Int, String](records).withConf(conf)
assert(aggManager.enableAggregation() == true)

val restoredRecords = aggManager.getRestoredIterator()
assert(restoredRecords.hasNext)

while (restoredRecords.hasNext && recordsCopy.hasNext) {
val kv1 = restoredRecords.next()
val kv2 = recordsCopy.next()

assert(kv1 == kv2)
}

assert(!restoredRecords.hasNext)
assert(!recordsCopy.hasNext)
}
}