@@ -438,8 +438,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
438438 val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
439439 submit(reduceRdd, Array (0 , 1 ))
440440 complete(taskSets(0 ), Seq (
441- (Success , makeMapStatus(" hostA" , 1 )),
442- (Success , makeMapStatus(" hostB" , 1 ))))
441+ (Success , makeMapStatus(" hostA" , reduceRdd.partitions.size )),
442+ (Success , makeMapStatus(" hostB" , reduceRdd.partitions.size ))))
443443 // the 2nd ResultTask failed
444444 complete(taskSets(1 ), Seq (
445445 (Success , 42 ),
@@ -449,7 +449,7 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
449449 // ask the scheduler to try it again
450450 scheduler.resubmitFailedStages()
451451 // have the 2nd attempt pass
452- complete(taskSets(2 ), Seq ((Success , makeMapStatus(" hostA" , 1 ))))
452+ complete(taskSets(2 ), Seq ((Success , makeMapStatus(" hostA" , reduceRdd.partitions.size ))))
453453 // we can see both result blocks now
454454 assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1.host) === Array (" hostA" , " hostB" ))
455455 complete(taskSets(3 ), Seq ((Success , 43 )))
@@ -464,8 +464,8 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
464464 val reduceRdd = new MyRDD (sc, 2 , List (shuffleDep))
465465 submit(reduceRdd, Array (0 , 1 ))
466466 complete(taskSets(0 ), Seq (
467- (Success , makeMapStatus(" hostA" , 1 )),
468- (Success , makeMapStatus(" hostB" , 1 ))))
467+ (Success , makeMapStatus(" hostA" , reduceRdd.partitions.size )),
468+ (Success , makeMapStatus(" hostB" , reduceRdd.partitions.size ))))
469469 // The MapOutputTracker should know about both map output locations.
470470 assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1.host) ===
471471 Array (" hostA" , " hostB" ))
@@ -507,14 +507,18 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
507507 assert(newEpoch > oldEpoch)
508508 val taskSet = taskSets(0 )
509509 // should be ignored for being too old
510- runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostA" , 1 ), null , createFakeTaskInfo(), null ))
510+ runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostA" ,
511+ reduceRdd.partitions.size), null , createFakeTaskInfo(), null ))
511512 // should work because it's a non-failed host
512- runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostB" , 1 ), null , createFakeTaskInfo(), null ))
513+ runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostB" ,
514+ reduceRdd.partitions.size), null , createFakeTaskInfo(), null ))
513515 // should be ignored for being too old
514- runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostA" , 1 ), null , createFakeTaskInfo(), null ))
516+ runEvent(CompletionEvent (taskSet.tasks(0 ), Success , makeMapStatus(" hostA" ,
517+ reduceRdd.partitions.size), null , createFakeTaskInfo(), null ))
515518 // should work because it's a new epoch
516519 taskSet.tasks(1 ).epoch = newEpoch
517- runEvent(CompletionEvent (taskSet.tasks(1 ), Success , makeMapStatus(" hostA" , 1 ), null , createFakeTaskInfo(), null ))
520+ runEvent(CompletionEvent (taskSet.tasks(1 ), Success , makeMapStatus(" hostA" ,
521+ reduceRdd.partitions.size), null , createFakeTaskInfo(), null ))
518522 assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1) ===
519523 Array (makeBlockManagerId(" hostB" ), makeBlockManagerId(" hostA" )))
520524 complete(taskSets(1 ), Seq ((Success , 42 ), (Success , 43 )))
@@ -739,19 +743,63 @@ class DAGSchedulerSuite extends FunSuiteLike with BeforeAndAfter with LocalSpar
739743 assertDataStructuresEmpty
740744 }
741745
746+ test(" shuffle with reducer locality" ) {
747+ // Create an shuffleMapRdd with 1 partition
748+ val shuffleMapRdd = new MyRDD (sc, 1 , Nil )
749+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
750+ val shuffleId = shuffleDep.shuffleId
751+ val reduceRdd = new MyRDD (sc, 1 , List (shuffleDep))
752+ submit(reduceRdd, Array (0 ))
753+ complete(taskSets(0 ), Seq (
754+ (Success , makeMapStatus(" hostA" , 1 ))))
755+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0 ).map(_._1) ===
756+ Array (makeBlockManagerId(" hostA" )))
757+
758+ // Reducer should run on the same host that map task ran
759+ val reduceTaskSet = taskSets(1 )
760+ assertLocations(reduceTaskSet, Seq (Seq (" hostA" )))
761+ complete(reduceTaskSet, Seq ((Success , 42 )))
762+ assert(results === Map (0 -> 42 ))
763+ assertDataStructuresEmpty
764+ }
765+
766+ test(" reducer locality with different sizes" ) {
767+ val numMapTasks = scheduler.NUM_REDUCER_PREF_LOCS + 1
768+ // Create an shuffleMapRdd with more partitions
769+ val shuffleMapRdd = new MyRDD (sc, numMapTasks, Nil )
770+ val shuffleDep = new ShuffleDependency (shuffleMapRdd, null )
771+ val shuffleId = shuffleDep.shuffleId
772+ val reduceRdd = new MyRDD (sc, 1 , List (shuffleDep))
773+ submit(reduceRdd, Array (0 ))
774+
775+ val statuses = (1 to numMapTasks).map { i =>
776+ (Success , makeMapStatus(" host" + i, 1 , (10 * i).toByte))
777+ }
778+ complete(taskSets(0 ), statuses)
779+
780+ // Reducer should prefer the last hosts where output size is larger
781+ val hosts = (1 to numMapTasks).map(i => " host" + i).reverse.take(numMapTasks - 1 )
782+
783+ val reduceTaskSet = taskSets(1 )
784+ assertLocations(reduceTaskSet, Seq (hosts))
785+ complete(reduceTaskSet, Seq ((Success , 42 )))
786+ assert(results === Map (0 -> 42 ))
787+ assertDataStructuresEmpty
788+ }
789+
742790 /**
743791 * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations.
744792 * Note that this checks only the host and not the executor ID.
745793 */
746794 private def assertLocations (taskSet : TaskSet , hosts : Seq [Seq [String ]]) {
747795 assert(hosts.size === taskSet.tasks.size)
748796 for ((taskLocs, expectedLocs) <- taskSet.tasks.map(_.preferredLocations).zip(hosts)) {
749- assert(taskLocs.map(_.host) === expectedLocs)
797+ assert(taskLocs.map(_.host).toSet === expectedLocs.toSet )
750798 }
751799 }
752800
753- private def makeMapStatus (host : String , reduces : Int ): MapStatus =
754- MapStatus (makeBlockManagerId(host), Array .fill[Long ](reduces)(2 ))
801+ private def makeMapStatus (host : String , reduces : Int , sizes : Byte = 2 ): MapStatus =
802+ MapStatus (makeBlockManagerId(host), Array .fill[Long ](reduces)(sizes ))
755803
756804 private def makeBlockManagerId (host : String ): BlockManagerId =
757805 BlockManagerId (" exec-" + host, host, 12345 )
0 commit comments