@@ -21,7 +21,7 @@ import java.io._
2121import java .util .concurrent .ConcurrentHashMap
2222import java .util .zip .{GZIPInputStream , GZIPOutputStream }
2323
24- import scala .collection .mutable .{HashSet , Map , HashMap }
24+ import scala .collection .mutable .{HashMap , HashSet , Map }
2525import scala .collection .JavaConversions ._
2626import scala .reflect .ClassTag
2727
@@ -30,6 +30,7 @@ import org.apache.spark.scheduler.MapStatus
3030import org .apache .spark .shuffle .MetadataFetchFailedException
3131import org .apache .spark .storage .BlockManagerId
3232import org .apache .spark .util ._
33+ import org .apache .spark .util .collection .{Utils => CollectionUtils }
3334
3435private [spark] sealed trait MapOutputTrackerMessage
3536private [spark] case class GetMapOutputStatuses (shuffleId : Int )
@@ -232,11 +233,10 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
232233 protected val mapStatuses = new TimeStampedHashMap [Int , Array [MapStatus ]]()
233234 private val cachedSerializedStatuses = new TimeStampedHashMap [Int , Array [Byte ]]()
234235
235- // For each shuffleId we also maintain a Map from reducerId -> (location, size )
236+ // For each shuffleId we also maintain a Map from reducerId -> (locations with largest outputs )
236237 // Lazily populated whenever the statuses are requested from DAGScheduler
237- private val statusByReducer =
238- new TimeStampedHashMap [Int , HashMap [Int , Array [(BlockManagerId , Long )]]]()
239-
238+ private val shuffleIdToReduceLocations =
239+ new TimeStampedHashMap [Int , HashMap [Int , Array [BlockManagerId ]]]()
240240
241241 // For cleaning up TimeStampedHashMaps
242242 private val metadataCleaner =
@@ -283,38 +283,47 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
283283 override def unregisterShuffle (shuffleId : Int ) {
284284 mapStatuses.remove(shuffleId)
285285 cachedSerializedStatuses.remove(shuffleId)
286- statusByReducer .remove(shuffleId)
286+ shuffleIdToReduceLocations .remove(shuffleId)
287287 }
288288
289289 /** Check if the given shuffle is being tracked */
290290 def containsShuffle (shuffleId : Int ): Boolean = {
291291 cachedSerializedStatuses.contains(shuffleId) || mapStatuses.contains(shuffleId)
292292 }
293293
294- // Return the list of locations and blockSizes for each reducer.
295- // The map is keyed by reducerId and for each reducer the value contains the array
296- // of (location, size) of map outputs.
297- //
298- // This method is not thread-safe
299- def getStatusByReducer (
294+ /**
295+ * Return a list of locations which have the largest map outputs given a shuffleId
296+ * and a reducerId.
297+ *
298+ * This method is not thread-safe
299+ */
300+ def getLocationsWithLargestOutputs (
300301 shuffleId : Int ,
301- numReducers : Int )
302- : Option [Map [Int , Array [(BlockManagerId , Long )]]] = {
303- if (! statusByReducer.contains(shuffleId) && mapStatuses.contains(shuffleId)) {
302+ reducerId : Int ,
303+ numReducers : Int ,
304+ numTopLocs : Int )
305+ : Option [Array [BlockManagerId ]] = {
306+ if (! shuffleIdToReduceLocations.contains(shuffleId) && mapStatuses.contains(shuffleId)) {
307+ // Pre-compute the top locations for each reducer and cache it
304308 val statuses = mapStatuses(shuffleId)
305- if (statuses.length > 0 ) {
306- statusByReducer(shuffleId) = new HashMap [Int , Array [(BlockManagerId , Long )]]
309+ if (statuses.nonEmpty) {
310+ val ordering = Ordering .by[(BlockManagerId , Long ), Long ](_._2).reverse
311+ shuffleIdToReduceLocations(shuffleId) = new HashMap [Int , Array [BlockManagerId ]]
307312 var r = 0
308313 while (r < numReducers) {
314+ // Add up sizes of all blocks at the same location
309315 val locs = statuses.map { s =>
310316 (s.location, s.getSizeForBlock(r))
311- }
312- statusByReducer(shuffleId) += (r -> locs)
317+ }.groupBy(_._1).mapValues { sizes =>
318+ sizes.map(_._2).reduceLeft(_ + _)
319+ }.toIterator
320+ val topLocs = CollectionUtils .takeOrdered(locs, numTopLocs)(ordering)
321+ shuffleIdToReduceLocations(shuffleId) += (r -> topLocs.map(_._1).toArray)
313322 r = r + 1
314323 }
315324 }
316325 }
317- statusByReducer .get(shuffleId)
326+ shuffleIdToReduceLocations .get(shuffleId).flatMap(_.get(reducerId) )
318327 }
319328
320329 def incrementEpoch () {
@@ -364,7 +373,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
364373 private def cleanup (cleanupTime : Long ) {
365374 mapStatuses.clearOldValues(cleanupTime)
366375 cachedSerializedStatuses.clearOldValues(cleanupTime)
367- statusByReducer .clearOldValues(cleanupTime)
376+ shuffleIdToReduceLocations .clearOldValues(cleanupTime)
368377 }
369378}
370379
0 commit comments