Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions core/src/main/scala/org/apache/spark/MapOutputTracker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,11 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]()
private val cachedSerializedStatuses = new TimeStampedHashMap[Int, Array[Byte]]()

// For each shuffleId we also maintain a Map from reducerId -> (location, size)
// Lazily populated whenever the statuses are requested from DAGScheduler
private val statusByReducer =
new TimeStampedHashMap[Int, HashMap[Int, Array[(BlockManagerId, Long)]]]()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we consider sampling the map tasks to speed up the sort?


// For cleaning up TimeStampedHashMaps
private val metadataCleaner =
new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf)
Expand Down Expand Up @@ -276,6 +281,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
/** Unregister shuffle data */
override def unregisterShuffle(shuffleId: Int) {
mapStatuses.remove(shuffleId)
statusByReducer.remove(shuffleId)
cachedSerializedStatuses.remove(shuffleId)
}

Expand All @@ -284,6 +290,30 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
}

// Return the list of locations and blockSizes for each reducer.
// The map is keyed by reducerId and for each reducer the value contains the array
// of (location, size) of map outputs.
//
// This method is not thread-safe
def getStatusByReducer(shuffleId: Int): Option[Map[Int, Array[(BlockManagerId, Long)]]] = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

comment on the thread safety

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also comment on the semantics of the return value (what does the Int mean - what does the index in the array mean, etc)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comments -- This method is not thread safe as TimestampedHashMap is not thread safe. However we only call this from DAGScheduler which is single threaded AFAIK

if (!statusByReducer.contains(shuffleId) && mapStatuses.contains(shuffleId)) {
val statuses = mapStatuses(shuffleId)
if (statuses.length > 0) {
val numReducers = statuses(0).compressedSizes.length
statusByReducer(shuffleId) = new HashMap[Int, Array[(BlockManagerId, Long)]]
var r = 0
while (r < numReducers) {
val locs = statuses.map { s =>
(s.location, MapOutputTracker.decompressSize(s.compressedSizes(r)))
}
statusByReducer(shuffleId) += (r -> locs)
r = r + 1
}
}
}
statusByReducer.get(shuffleId)
}

def incrementEpoch() {
epochLock.synchronized {
epoch += 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ import org.apache.spark.executor.TaskMetrics
import org.apache.spark.partial.{ApproximateActionListener, ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage._
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.storage.BlockManagerMessages.BlockManagerHeartbeat
import org.apache.spark.util.{CallSite, SystemClock, Clock, Utils}
import org.apache.spark.util.collection.{Utils => CollectionUtils}

/**
* The high-level scheduling layer that implements stage-oriented scheduling. It computes a DAG of
Expand Down Expand Up @@ -121,6 +122,9 @@ class DAGScheduler(

private[scheduler] var eventProcessActor: ActorRef = _

// Number of preferred locations to use for reducer tasks
private[scheduler] val NUM_REDUCER_PREF_LOCS = 5

private def initializeEventProcessActor() {
// blocking the thread until supervisor is started, which ensures eventProcessActor is
// not null before any job is submitted
Expand Down Expand Up @@ -1252,6 +1256,19 @@ class DAGScheduler(
return locs
}
}
case s: ShuffleDependency[_, _, _] =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add some inline comment explaining this case

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

// Assign preferred locations for reducers by looking at map output location and sizes
val mapStatuses = mapOutputTracker.getStatusByReducer(s.shuffleId)
mapStatuses.map { status =>
// Get the map output locations for this reducer
if (status.contains(partition)) {
// Select first few locations as preferred locations for the reducer
val topLocs = CollectionUtils.takeOrdered(status(partition).iterator,
NUM_REDUCER_PREF_LOCS)(Ordering.by[(BlockManagerId, Long), Long](_._2).reverse).toSeq

return topLocs.map(_._1).map(loc => TaskLocation(loc.host, loc.executorId))
}
}
case _ =>
}
Nil
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,50 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
assertDataStructuresEmpty
}

test("shuffle with reducer locality") {
// Create an shuffleMapRdd with 1 partition
val shuffleMapRdd = new MyRDD(sc, 1, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1))))
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
Array(makeBlockManagerId("hostA")))

// Reducer should run on the same host that map task ran
val reduceTaskSet = taskSets(1)
assertLocations(reduceTaskSet, Seq(Seq("hostA")))
complete(reduceTaskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty
}

test("reducer locality with different sizes") {
val numMapTasks = scheduler.NUM_REDUCER_PREF_LOCS + 1
// Create an shuffleMapRdd with more partitions
val shuffleMapRdd = new MyRDD(sc, numMapTasks, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
val shuffleId = shuffleDep.shuffleId
val reduceRdd = new MyRDD(sc, 1, List(shuffleDep))
submit(reduceRdd, Array(0))

val statuses = (1 to numMapTasks).map { i =>
(Success, makeMapStatus("host" + i, 1, (10*i).toByte))
}
complete(taskSets(0), statuses)

// Reducer should prefer the last hosts where output size is larger
val hosts = (1 to numMapTasks).map(i => "host" + i).reverse.take(numMapTasks - 1)

val reduceTaskSet = taskSets(1)
assertLocations(reduceTaskSet, Seq(hosts))
complete(reduceTaskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
assertDataStructuresEmpty
}

test("run trivial shuffle with fetch failure") {
val shuffleMapRdd = new MyRDD(sc, 2, Nil)
val shuffleDep = new ShuffleDependency(shuffleMapRdd, null)
Expand Down Expand Up @@ -694,12 +738,12 @@ class DAGSchedulerSuite extends TestKit(ActorSystem("DAGSchedulerSuite")) with F
private def assertLocations(taskSet: TaskSet, hosts: Seq[Seq[String]]) {
assert(hosts.size === taskSet.tasks.size)
for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
assert(taskLocs.map(_.host) === expectedLocs)
assert(taskLocs.map(_.host).toSet === expectedLocs.toSet)
}
}

private def makeMapStatus(host: String, reduces: Int): MapStatus =
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(2))
private def makeMapStatus(host: String, reduces: Int, sizes: Byte = 2): MapStatus =
new MapStatus(makeBlockManagerId(host), Array.fill[Byte](reduces)(sizes))

private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345, 0)
Expand Down