From bca146c531060f32fa351397e13f0f778fe35b25 Mon Sep 17 00:00:00 2001 From: jinxing Date: Tue, 4 Apr 2017 23:25:19 +0800 Subject: [PATCH 1/9] Sort tasks based on their size. --- .../apache/spark/scheduler/DAGScheduler.scala | 28 ++++++++ .../spark/scheduler/TaskSetManager.scala | 64 +++++++++++++++---- 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 09717316833a7..a6498d725e865 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -471,6 +471,34 @@ class DAGScheduler( missing.toList } + /** + * Get ancestor splits in ShuffledRDD. + */ + private[spark] def parentSplitsInShuffledRDD(stageId: Int, pId: Int): Option[Map[Int, Set[Int]]] = + { + stageIdToStage.get(stageId) match { + case Some(stage) => + val waitingForVisit = new Stack[Tuple2[RDD[_], Int]] + waitingForVisit.push((stage.rdd, pId)) + val ret = new HashMap[Int, HashSet[Int]]() + while(waitingForVisit.nonEmpty) { + val (rdd, split) = waitingForVisit.pop() + rdd.dependencies.foreach { + case dep: ShuffleDependency[_, _, _] => + ret.getOrElseUpdate(dep.shuffleId, new HashSet[Int]()).add(split) + case dep: NarrowDependency[_] => + dep.getParents(split).foreach { + case parentSplit => + waitingForVisit.push((dep.rdd, parentSplit)) + } + } + } + Some(ret.mapValues(_.toSet).toMap) + case None => + None + } + } + /** * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a41b059fa7dec..50ed8776960d1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue +import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.max @@ -32,6 +33,8 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap + +// scalastyle:off /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of * each task, retries tasks if they fail (up to a limited number of times), and @@ -125,20 +128,20 @@ private[spark] class TaskSetManager( // of failures. // Duplicates are handled in dequeueTaskFromList, which ensures that a // task hasn't already started running before launching it. - private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + private var pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, // but at host level. - private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + private var pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each rack -- similar to the above. - private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + private var pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - private val allPendingTasks = new ArrayBuffer[Int] + private var allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. @@ -168,12 +171,11 @@ private[spark] class TaskSetManager( t.epoch = epoch } - // Add all our tasks to the pending lists. We do this in reverse order - // of task index so that tasks with low indices get launched first. + val sortedPendingTasks = new AtomicBoolean(false) + for (i <- (0 until numTasks).reverse) { addPendingTask(i) } - /** * Track the set of locality levels which are valid given the tasks locality preferences and * the set of currently available executors. This is updated as executors are added and removed. @@ -438,6 +440,11 @@ private[spark] class TaskSetManager( blacklist.isExecutorBlacklistedForTaskSet(execId) } if (!isZombie && !offerBlacklisted) { + if (!sortedPendingTasks.get()) { + sortedPendingTasks.set(true) + sortPendingTasks() + } + val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -512,6 +519,42 @@ private[spark] class TaskSetManager( } } + private[this] def sortPendingTasks(): Unit = { + val taskIndexs = (0 until numTasks).toArray + implicit def ord = new Ordering[Int] { + override def compare(x: Int, y: Int): Int = + getTaskInputSizeFromShuffledRDD(tasks(x)) compare + getTaskInputSizeFromShuffledRDD(tasks(y)) + } + if (tasks.nonEmpty) { + // Sort the tasks based on their input size from ShuffledRDD. + pendingTasksForExecutor.foreach { + case (k, v) => pendingTasksForExecutor(k) = v.sorted + } + pendingTasksForHost.foreach { + case (k, v) => pendingTasksForHost(k) = v.sorted + } + pendingTasksForRack.foreach { + case (k, v) => pendingTasksForRack(k) = v.sorted + } + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sorted + allPendingTasks = allPendingTasks.sorted + } + } + + private[this] def getTaskInputSizeFromShuffledRDD(task: Task[_]): Long = { + sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId) match { + case Some(parentSplits) => + parentSplits.map { + case (shuffleId, splits) => + splits.map(SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) + .flatMap(_._2.map(_._2)).sum).sum + }.sum + case None => + 0 + } + } + private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) @@ -560,7 +603,6 @@ private[spark] class TaskSetManager( emptyKeys.foreach(id => pendingTasks.remove(id)) hasTasks } - while (currentLocalityIndex < myLocalityLevels.length - 1) { val moreTasks = myLocalityLevels(currentLocalityIndex) match { case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) @@ -573,15 +615,11 @@ private[spark] class TaskSetManager( // be scheduled at a particular locality level, there is no point in waiting // for the locality wait timeout (SPARK-4939). lastLaunchTime = curTime - logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + - s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") currentLocalityIndex += 1 } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { // Jump to the next locality level, and reset lastLaunchTime so that the next locality // wait timer doesn't immediately expire lastLaunchTime += localityWaits(currentLocalityIndex) - logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " + - s"${localityWaits(currentLocalityIndex)}ms") currentLocalityIndex += 1 } else { return myLocalityLevels(currentLocalityIndex) @@ -833,6 +871,7 @@ private[spark] class TaskSetManager( s" has already succeeded).") } else { addPendingTask(index) + sortPendingTasks() } if (!isZombie && reason.countTowardsTaskFailures) { @@ -904,6 +943,7 @@ private[spark] class TaskSetManager( copiesRunning(index) -= 1 tasksSuccessful -= 1 addPendingTask(index) + sortPendingTasks() // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. sched.dagScheduler.taskEnded( From f757e4125935f7237a29bb07313ffb46ccbb3cd0 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 11:06:41 +0800 Subject: [PATCH 2/9] Add unit test. --- .../spark/scheduler/TaskSetManager.scala | 43 ++++++++++++++----- .../spark/scheduler/TaskSetManagerSuite.scala | 14 +++++- 2 files changed, 46 insertions(+), 11 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 50ed8776960d1..c484e7fd7bedf 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -33,8 +33,6 @@ import org.apache.spark.TaskState.TaskState import org.apache.spark.util.{AccumulatorV2, Clock, SystemClock, Utils} import org.apache.spark.util.collection.MedianHeap - -// scalastyle:off /** * Schedules the tasks within a single TaskSet in the TaskSchedulerImpl. This class keeps track of * each task, retries tasks if they fail (up to a limited number of times), and @@ -173,9 +171,14 @@ private[spark] class TaskSetManager( val sortedPendingTasks = new AtomicBoolean(false) + val taskInputSizeFromShuffledRDD = HashMap[Task[_], Long]() + + // Add all our tasks to the pending lists. We do this in reverse order + // of task index so that tasks with low indices get launched first. for (i <- (0 until numTasks).reverse) { addPendingTask(i) } + /** * Track the set of locality levels which are valid given the tasks locality preferences and * the set of currently available executors. This is updated as executors are added and removed. @@ -542,16 +545,31 @@ private[spark] class TaskSetManager( } } + // Visible for testing + private[spark] def setTaskInputSizeFromShuffledRDD(inputSize: Map[Task[_], Long]) = { + taskInputSizeFromShuffledRDD.clear() + inputSize.foreach{ + case (k, v) => taskInputSizeFromShuffledRDD(k) = v + } + } + private[this] def getTaskInputSizeFromShuffledRDD(task: Task[_]): Long = { - sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId) match { - case Some(parentSplits) => - parentSplits.map { - case (shuffleId, splits) => - splits.map(SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) - .flatMap(_._2.map(_._2)).sum).sum - }.sum + taskInputSizeFromShuffledRDD.get(task) match { + case Some(size) => size case None => - 0 + val size = + sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId) match { + case Some(parentSplits) => + parentSplits.map { + case (shuffleId, splits) => + splits.map(sched.mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) + .flatMap(_._2.map(_._2)).sum).sum + }.sum + case None => + 0L + } + taskInputSizeFromShuffledRDD(task) = size + size } } @@ -603,6 +621,7 @@ private[spark] class TaskSetManager( emptyKeys.foreach(id => pendingTasks.remove(id)) hasTasks } + while (currentLocalityIndex < myLocalityLevels.length - 1) { val moreTasks = myLocalityLevels(currentLocalityIndex) match { case TaskLocality.PROCESS_LOCAL => moreTasksToRunIn(pendingTasksForExecutor) @@ -615,11 +634,15 @@ private[spark] class TaskSetManager( // be scheduled at a particular locality level, there is no point in waiting // for the locality wait timeout (SPARK-4939). lastLaunchTime = curTime + logDebug(s"No tasks for locality level ${myLocalityLevels(currentLocalityIndex)}, " + + s"so moving to locality level ${myLocalityLevels(currentLocalityIndex + 1)}") currentLocalityIndex += 1 } else if (curTime - lastLaunchTime >= localityWaits(currentLocalityIndex)) { // Jump to the next locality level, and reset lastLaunchTime so that the next locality // wait timer doesn't immediately expire lastLaunchTime += localityWaits(currentLocalityIndex) + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " + + s"${localityWaits(currentLocalityIndex)}ms") currentLocalityIndex += 1 } else { return myLocalityLevels(currentLocalityIndex) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 9ca6b8b0fe635..88d93273f592b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -180,7 +180,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg } } - test("TaskSet with no preferences") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc, ("exec1", "host1")) @@ -1139,6 +1138,19 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) } + test("Schedule tasks based on size of input from ShuffledRDD.") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc) + val taskSet = FakeTask.createTaskSet(4) + val clock = new ManualClock() + val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) + manager.setTaskInputSizeFromShuffledRDD(taskSet.tasks.zip(Seq(1L, 100L, 1000L, 10000L)).toMap) + assert(manager.resourceOffer("exec", "host", ANY).get.index === 3) + assert(manager.resourceOffer("exec", "host", ANY).get.index === 2) + assert(manager.resourceOffer("exec", "host", ANY).get.index === 1) + assert(manager.resourceOffer("exec", "host", ANY).get.index === 0) + } + private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = { From 6d18a0903cd4bfb42d7dc33f46854993625c5d7b Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 11:39:22 +0800 Subject: [PATCH 3/9] Refine unit test. --- .../org/apache/spark/scheduler/TaskSetManagerSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 88d93273f592b..47f896b502003 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1144,9 +1144,9 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg val taskSet = FakeTask.createTaskSet(4) val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) - manager.setTaskInputSizeFromShuffledRDD(taskSet.tasks.zip(Seq(1L, 100L, 1000L, 10000L)).toMap) - assert(manager.resourceOffer("exec", "host", ANY).get.index === 3) + manager.setTaskInputSizeFromShuffledRDD(taskSet.tasks.zip(Seq(1L, 100L, 10000L, 1000L)).toMap) assert(manager.resourceOffer("exec", "host", ANY).get.index === 2) + assert(manager.resourceOffer("exec", "host", ANY).get.index === 3) assert(manager.resourceOffer("exec", "host", ANY).get.index === 1) assert(manager.resourceOffer("exec", "host", ANY).get.index === 0) } From e4af778a4ec9f200642ccbf4b2f3fc619a5b1232 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 11:50:56 +0800 Subject: [PATCH 4/9] small fix --- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index c484e7fd7bedf..97bd12a6ceef0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -126,11 +126,11 @@ private[spark] class TaskSetManager( // of failures. // Duplicates are handled in dequeueTaskFromList, which ensures that a // task hasn't already started running before launching it. - private var pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForExecutor = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each host. Similar to pendingTasksForExecutor, // but at host level. - private var pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each rack -- similar to the above. private var pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] From 462d92ed13e1763f832de41a1ac06a6d837131cd Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 18:05:54 +0800 Subject: [PATCH 5/9] small fix --- .../main/scala/org/apache/spark/scheduler/TaskSetManager.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 97bd12a6ceef0..de0233f665599 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -133,7 +133,7 @@ private[spark] class TaskSetManager( private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]] // Set of pending tasks for each rack -- similar to the above. - private var pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] + private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]] // Set containing pending tasks with no locality preferences. private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] From 97afe0a8f77ea987adfcbfd95dbd8312670586bd Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 20:28:30 +0800 Subject: [PATCH 6/9] Refine according to Owen's comments. --- .../spark/scheduler/TaskSetManager.scala | 25 ++++++++----------- 1 file changed, 10 insertions(+), 15 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index de0233f665599..400aed3fa103a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -169,7 +169,7 @@ private[spark] class TaskSetManager( t.epoch = epoch } - val sortedPendingTasks = new AtomicBoolean(false) + private val sortedPendingTasks = new AtomicBoolean(false) val taskInputSizeFromShuffledRDD = HashMap[Task[_], Long]() @@ -443,8 +443,7 @@ private[spark] class TaskSetManager( blacklist.isExecutorBlacklistedForTaskSet(execId) } if (!isZombie && !offerBlacklisted) { - if (!sortedPendingTasks.get()) { - sortedPendingTasks.set(true) + if (sortedPendingTasks.compareAndSet(false, true)) { sortPendingTasks() } @@ -524,33 +523,29 @@ private[spark] class TaskSetManager( private[this] def sortPendingTasks(): Unit = { val taskIndexs = (0 until numTasks).toArray - implicit def ord = new Ordering[Int] { - override def compare(x: Int, y: Int): Int = - getTaskInputSizeFromShuffledRDD(tasks(x)) compare - getTaskInputSizeFromShuffledRDD(tasks(y)) + def ordFunc(x: Int, y: Int): Boolean = { + getTaskInputSizeFromShuffledRDD(tasks(x)) < getTaskInputSizeFromShuffledRDD(tasks(y)) } if (tasks.nonEmpty) { // Sort the tasks based on their input size from ShuffledRDD. pendingTasksForExecutor.foreach { - case (k, v) => pendingTasksForExecutor(k) = v.sorted + case (k, v) => pendingTasksForExecutor(k) = v.sortWith(ordFunc) } pendingTasksForHost.foreach { - case (k, v) => pendingTasksForHost(k) = v.sorted + case (k, v) => pendingTasksForHost(k) = v.sortWith(ordFunc) } pendingTasksForRack.foreach { - case (k, v) => pendingTasksForRack(k) = v.sorted + case (k, v) => pendingTasksForRack(k) = v.sortWith(ordFunc) } - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sorted - allPendingTasks = allPendingTasks.sorted + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sortWith(ordFunc) + allPendingTasks = allPendingTasks.sortWith(ordFunc) } } // Visible for testing private[spark] def setTaskInputSizeFromShuffledRDD(inputSize: Map[Task[_], Long]) = { taskInputSizeFromShuffledRDD.clear() - inputSize.foreach{ - case (k, v) => taskInputSizeFromShuffledRDD(k) = v - } + taskInputSizeFromShuffledRDD ++= inputSize } private[this] def getTaskInputSizeFromShuffledRDD(task: Task[_]): Long = { From b0c3abcd9861187f738e3f683db9b3eb9c477a54 Mon Sep 17 00:00:00 2001 From: jinxing Date: Wed, 5 Apr 2017 22:26:22 +0800 Subject: [PATCH 7/9] small fix. --- .../scala/org/apache/spark/scheduler/TaskSetManager.scala | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 400aed3fa103a..80234f4e02fa2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -553,16 +553,14 @@ private[spark] class TaskSetManager( case Some(size) => size case None => val size = - sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId) match { - case Some(parentSplits) => + sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId).map { + case parentSplits => parentSplits.map { case (shuffleId, splits) => splits.map(sched.mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) .flatMap(_._2.map(_._2)).sum).sum }.sum - case None => - 0L - } + }.getOrElse(0L) taskInputSizeFromShuffledRDD(task) = size size } From e3a15c3fffd699770738caff2e03f066bf0e149c Mon Sep 17 00:00:00 2001 From: jinxing Date: Sun, 9 Apr 2017 01:25:24 +0800 Subject: [PATCH 8/9] Refine for squito's comments. --- .../apache/spark/scheduler/DAGScheduler.scala | 51 ++++++++++++--- .../org/apache/spark/scheduler/TaskSet.scala | 5 +- .../spark/scheduler/TaskSetManager.scala | 63 ++++++------------- .../spark/scheduler/DAGSchedulerSuite.scala | 37 +++++++++++ .../spark/scheduler/TaskSetManagerSuite.scala | 7 ++- 5 files changed, 108 insertions(+), 55 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a6498d725e865..236581dcc867f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -483,14 +483,27 @@ class DAGScheduler( val ret = new HashMap[Int, HashSet[Int]]() while(waitingForVisit.nonEmpty) { val (rdd, split) = waitingForVisit.pop() - rdd.dependencies.foreach { - case dep: ShuffleDependency[_, _, _] => - ret.getOrElseUpdate(dep.shuffleId, new HashSet[Int]()).add(split) - case dep: NarrowDependency[_] => - dep.getParents(split).foreach { - case parentSplit => - waitingForVisit.push((dep.rdd, parentSplit)) - } + if (getCacheLocs(rdd)(split) == Nil) { + rdd.dependencies.foreach { + case dep: ShuffleDependency[_, _, _] => + val noPartitionerConflict = rdd.partitioner match { + case Some(partitioner) => + partitioner.isInstanceOf[HashPartitioner] && + dep.partitioner.isInstanceOf[HashPartitioner] && + partitioner.numPartitions == dep.partitioner.numPartitions + case None => true + } + if (noPartitionerConflict) { + ret.getOrElseUpdate(dep.shuffleId, new HashSet[Int]()).add(split) + } + case dep: NarrowDependency[_] => + dep.getParents(split).foreach { + case parentSplit => + if (getCacheLocs(dep.rdd)(parentSplit) == Nil) { + waitingForVisit.push((dep.rdd, parentSplit)) + } + } + } } } Some(ret.mapValues(_.toSet).toMap) @@ -1086,7 +1099,8 @@ class DAGScheduler( logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) + tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties, + Some(getTaskInputSizesFromShuffledRDD(tasks)))) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1108,6 +1122,25 @@ class DAGScheduler( } } + // Visible for testing. + private[spark] def getTaskInputSizesFromShuffledRDD(tasks: Seq[Task[_]]): Map[Task[_], Long] = { + val taskInputSizeFromShuffledRDD = HashMap[Task[_], Long]() + tasks.foreach { + case task => + val size = + parentSplitsInShuffledRDD(task.stageId, task.partitionId).map { + case parentSplits => + parentSplits.map { + case (shuffleId, splits) => + splits.map(mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) + .flatMap(_._2.map(_._2)).sum).sum + }.sum + }.getOrElse(0L) + taskInputSizeFromShuffledRDD(task) = size + } + taskInputSizeFromShuffledRDD.toMap + } + /** * Merge local values from a task into the corresponding accumulators previously registered * here on the driver. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 517c8991aed78..945f6a176876a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -19,6 +19,8 @@ package org.apache.spark.scheduler import java.util.Properties +import scala.collection.Map + /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. @@ -28,7 +30,8 @@ private[spark] class TaskSet( val stageId: Int, val stageAttemptId: Int, val priority: Int, - val properties: Properties) { + val properties: Properties, + val taskInputSizesFromShuffledRDDOpt: Option[Map[Task[_], Long]] = None) { val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 80234f4e02fa2..a6fdedd7655c1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -171,8 +171,6 @@ private[spark] class TaskSetManager( private val sortedPendingTasks = new AtomicBoolean(false) - val taskInputSizeFromShuffledRDD = HashMap[Task[_], Long]() - // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. for (i <- (0 until numTasks).reverse) { @@ -522,47 +520,26 @@ private[spark] class TaskSetManager( } private[this] def sortPendingTasks(): Unit = { - val taskIndexs = (0 until numTasks).toArray - def ordFunc(x: Int, y: Int): Boolean = { - getTaskInputSizeFromShuffledRDD(tasks(x)) < getTaskInputSizeFromShuffledRDD(tasks(y)) - } - if (tasks.nonEmpty) { - // Sort the tasks based on their input size from ShuffledRDD. - pendingTasksForExecutor.foreach { - case (k, v) => pendingTasksForExecutor(k) = v.sortWith(ordFunc) - } - pendingTasksForHost.foreach { - case (k, v) => pendingTasksForHost(k) = v.sortWith(ordFunc) - } - pendingTasksForRack.foreach { - case (k, v) => pendingTasksForRack(k) = v.sortWith(ordFunc) - } - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sortWith(ordFunc) - allPendingTasks = allPendingTasks.sortWith(ordFunc) - } - } - - // Visible for testing - private[spark] def setTaskInputSizeFromShuffledRDD(inputSize: Map[Task[_], Long]) = { - taskInputSizeFromShuffledRDD.clear() - taskInputSizeFromShuffledRDD ++= inputSize - } - - private[this] def getTaskInputSizeFromShuffledRDD(task: Task[_]): Long = { - taskInputSizeFromShuffledRDD.get(task) match { - case Some(size) => size - case None => - val size = - sched.dagScheduler.parentSplitsInShuffledRDD(task.stageId, task.partitionId).map { - case parentSplits => - parentSplits.map { - case (shuffleId, splits) => - splits.map(sched.mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) - .flatMap(_._2.map(_._2)).sum).sum - }.sum - }.getOrElse(0L) - taskInputSizeFromShuffledRDD(task) = size - size + taskSet.taskInputSizesFromShuffledRDDOpt match { + case Some(taskInputSizeFromShuffledRDD) => + def ordFunc(x: Int, y: Int): Boolean = { + taskInputSizeFromShuffledRDD(tasks(x)) < taskInputSizeFromShuffledRDD(tasks(y)) + } + if (tasks.nonEmpty) { + // Sort the tasks based on their input size from ShuffledRDD. + pendingTasksForExecutor.foreach { + case (k, v) => pendingTasksForExecutor(k) = v.sortWith(ordFunc) + } + pendingTasksForHost.foreach { + case (k, v) => pendingTasksForHost(k) = v.sortWith(ordFunc) + } + pendingTasksForRack.foreach { + case (k, v) => pendingTasksForRack(k) = v.sortWith(ordFunc) + } + pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sortWith(ordFunc) + allPendingTasks = allPendingTasks.sortWith(ordFunc) + } + case None => // no-op } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a10941b579fe2..70c466ff4d359 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2277,6 +2277,43 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } + test("Tasks input size from shuffled RDD should be correct.") { + val rddA = new MyRDD(sc, 2, Nil) + val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(2)) + val shuffleIdA = shuffleDepA.shuffleId + + val rddB = new MyRDD(sc, 2, Nil) + val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(2)) + val shuffleIdB = shuffleDepB.shuffleId + + val rddC = new MyRDD(sc, 2, List(shuffleDepA, shuffleDepB), tracker = mapOutputTracker) + submit(rddC, Array(0, 1)) + + def compressAndDecompress(sizes: Array[Long]): Array[Long] = { + sizes.map(size => MapStatus.decompressSize(MapStatus.compressSize(size))) + } + assert(taskSets(0).stageId === 0 && taskSets(0).stageAttemptId === 0) + complete(taskSets(0), Seq( + (Success, MapStatus(makeBlockManagerId("hostA"), compressAndDecompress(Array(10, 1000)))), + (Success, MapStatus(makeBlockManagerId("hostA"), compressAndDecompress(Array(100, 10000)))))) + assert(taskSets(1).stageId === 1 && taskSets(1).stageAttemptId === 0) + complete(taskSets(1), Seq( + (Success, MapStatus(makeBlockManagerId("hostB"), compressAndDecompress(Array(20, 2000)))), + (Success, MapStatus(makeBlockManagerId("hostB"), compressAndDecompress(Array(200, 20000)))))) + + assert(taskSets(2).stageId === 2 && taskSets(2).stageAttemptId === 0) + assert(taskSets(2).taskInputSizesFromShuffledRDDOpt != None) + taskSets(2).taskInputSizesFromShuffledRDDOpt match { + case Some(inputSize) => + assert(inputSize(taskSets(2).tasks(0)) === + compressAndDecompress(Array(10, 100, 20, 200)).sum) + assert(inputSize(taskSets(2).tasks(1)) === + compressAndDecompress(Array(1000, 10000, 2000, 20000)).sum) + case None => + throw new DAGSchedulerSuiteDummyException + } + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 47f896b502003..05a4e86220dfb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1141,10 +1141,13 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg test("Schedule tasks based on size of input from ShuffledRDD.") { sc = new SparkContext("local", "test") sched = new FakeTaskScheduler(sc) - val taskSet = FakeTask.createTaskSet(4) + val tasks = Array.tabulate[Task[_]](4) { i => + new FakeTask(stageId = 0, partitionId = i, prefLocs = Nil) + } + val taskSet = new TaskSet(tasks, stageId = 0, stageAttemptId = 0, priority = 0, null, + Some(tasks.zip(Seq(1L, 100L, 10000L, 1000L)).toMap)) val clock = new ManualClock() val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) - manager.setTaskInputSizeFromShuffledRDD(taskSet.tasks.zip(Seq(1L, 100L, 10000L, 1000L)).toMap) assert(manager.resourceOffer("exec", "host", ANY).get.index === 2) assert(manager.resourceOffer("exec", "host", ANY).get.index === 3) assert(manager.resourceOffer("exec", "host", ANY).get.index === 1) From 7212089c6d90b30f22d6f94538558dcf3e892a44 Mon Sep 17 00:00:00 2001 From: jinxing Date: Fri, 14 Apr 2017 15:29:16 +0800 Subject: [PATCH 9/9] Simplify the code. --- .../apache/spark/scheduler/DAGScheduler.scala | 84 ++++++------------- .../org/apache/spark/scheduler/TaskSet.scala | 5 +- .../spark/scheduler/TaskSetManager.scala | 35 +------- .../spark/scheduler/DAGSchedulerSuite.scala | 26 +++--- .../spark/scheduler/TaskSetManagerSuite.scala | 16 ---- 5 files changed, 40 insertions(+), 126 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 236581dcc867f..05a848d7ed96e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -471,47 +471,6 @@ class DAGScheduler( missing.toList } - /** - * Get ancestor splits in ShuffledRDD. - */ - private[spark] def parentSplitsInShuffledRDD(stageId: Int, pId: Int): Option[Map[Int, Set[Int]]] = - { - stageIdToStage.get(stageId) match { - case Some(stage) => - val waitingForVisit = new Stack[Tuple2[RDD[_], Int]] - waitingForVisit.push((stage.rdd, pId)) - val ret = new HashMap[Int, HashSet[Int]]() - while(waitingForVisit.nonEmpty) { - val (rdd, split) = waitingForVisit.pop() - if (getCacheLocs(rdd)(split) == Nil) { - rdd.dependencies.foreach { - case dep: ShuffleDependency[_, _, _] => - val noPartitionerConflict = rdd.partitioner match { - case Some(partitioner) => - partitioner.isInstanceOf[HashPartitioner] && - dep.partitioner.isInstanceOf[HashPartitioner] && - partitioner.numPartitions == dep.partitioner.numPartitions - case None => true - } - if (noPartitionerConflict) { - ret.getOrElseUpdate(dep.shuffleId, new HashSet[Int]()).add(split) - } - case dep: NarrowDependency[_] => - dep.getParents(split).foreach { - case parentSplit => - if (getCacheLocs(dep.rdd)(parentSplit) == Nil) { - waitingForVisit.push((dep.rdd, parentSplit)) - } - } - } - } - } - Some(ret.mapValues(_.toSet).toMap) - case None => - None - } - } - /** * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. @@ -1098,9 +1057,12 @@ class DAGScheduler( if (tasks.size > 0) { logInfo(s"Submitting ${tasks.size} missing tasks from $stage (${stage.rdd}) (first 15 " + s"tasks are for partitions ${tasks.take(15).map(_.partitionId)})") + def ordFunc(x: Task[_], y: Task[_]): Boolean = { + inputSizeFromShuffledRDD(stageIdToStage(x.stageId).rdd, x.partitionId) > + inputSizeFromShuffledRDD(stageIdToStage(y.stageId).rdd, y.partitionId) + } taskScheduler.submitTasks(new TaskSet( - tasks.toArray, stage.id, stage.latestInfo.attemptId, jobId, properties, - Some(getTaskInputSizesFromShuffledRDD(tasks)))) + tasks.sortWith(ordFunc).toArray, stage.id, stage.latestInfo.attemptId, jobId, properties)) stage.latestInfo.submissionTime = Some(clock.getTimeMillis()) } else { // Because we posted SparkListenerStageSubmitted earlier, we should mark @@ -1122,23 +1084,27 @@ class DAGScheduler( } } - // Visible for testing. - private[spark] def getTaskInputSizesFromShuffledRDD(tasks: Seq[Task[_]]): Map[Task[_], Long] = { - val taskInputSizeFromShuffledRDD = HashMap[Task[_], Long]() - tasks.foreach { - case task => - val size = - parentSplitsInShuffledRDD(task.stageId, task.partitionId).map { - case parentSplits => - parentSplits.map { - case (shuffleId, splits) => - splits.map(mapOutputTracker.getMapSizesByExecutorId(shuffleId, _) - .flatMap(_._2.map(_._2)).sum).sum - }.sum - }.getOrElse(0L) - taskInputSizeFromShuffledRDD(task) = size + private[scheduler] def inputSizeFromShuffledRDD(rdd: RDD[_], pId: Int): Long = + { + var ret = 0L + val waitingForVisit = new Stack[Tuple2[RDD[_], Int]] + if (getCacheLocs(rdd)(pId) == Nil) { + waitingForVisit.push((rdd, pId)) + } + while(waitingForVisit.nonEmpty) { + val (rdd, split) = waitingForVisit.pop() + rdd.dependencies.foreach { + case dep: ShuffleDependency[_, _, _] => + if (rdd.partitioner.isEmpty || rdd.partitioner == Some(dep.partitioner)) { + ret += mapOutputTracker.getStatistics(dep).bytesByPartitionId(split) + } + case dep: NarrowDependency[_] => + dep.getParents(split).foreach { + case parentSplit => ret += inputSizeFromShuffledRDD(dep.rdd, parentSplit) + } + } } - taskInputSizeFromShuffledRDD.toMap + ret } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala index 945f6a176876a..517c8991aed78 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala @@ -19,8 +19,6 @@ package org.apache.spark.scheduler import java.util.Properties -import scala.collection.Map - /** * A set of tasks submitted together to the low-level TaskScheduler, usually representing * missing partitions of a particular stage. @@ -30,8 +28,7 @@ private[spark] class TaskSet( val stageId: Int, val stageAttemptId: Int, val priority: Int, - val properties: Properties, - val taskInputSizesFromShuffledRDDOpt: Option[Map[Task[_], Long]] = None) { + val properties: Properties) { val id: String = stageId + "." + stageAttemptId override def toString: String = "TaskSet " + id diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index a6fdedd7655c1..a41b059fa7dec 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -20,7 +20,6 @@ package org.apache.spark.scheduler import java.io.NotSerializableException import java.nio.ByteBuffer import java.util.concurrent.ConcurrentLinkedQueue -import java.util.concurrent.atomic.AtomicBoolean import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} import scala.math.max @@ -139,7 +138,7 @@ private[spark] class TaskSetManager( private[scheduler] var pendingTasksWithNoPrefs = new ArrayBuffer[Int] // Set containing all pending tasks (also used as a stack, as above). - private var allPendingTasks = new ArrayBuffer[Int] + private val allPendingTasks = new ArrayBuffer[Int] // Tasks that can be speculated. Since these will be a small fraction of total // tasks, we'll just hold them in a HashSet. @@ -169,8 +168,6 @@ private[spark] class TaskSetManager( t.epoch = epoch } - private val sortedPendingTasks = new AtomicBoolean(false) - // Add all our tasks to the pending lists. We do this in reverse order // of task index so that tasks with low indices get launched first. for (i <- (0 until numTasks).reverse) { @@ -441,10 +438,6 @@ private[spark] class TaskSetManager( blacklist.isExecutorBlacklistedForTaskSet(execId) } if (!isZombie && !offerBlacklisted) { - if (sortedPendingTasks.compareAndSet(false, true)) { - sortPendingTasks() - } - val curTime = clock.getTimeMillis() var allowedLocality = maxLocality @@ -519,30 +512,6 @@ private[spark] class TaskSetManager( } } - private[this] def sortPendingTasks(): Unit = { - taskSet.taskInputSizesFromShuffledRDDOpt match { - case Some(taskInputSizeFromShuffledRDD) => - def ordFunc(x: Int, y: Int): Boolean = { - taskInputSizeFromShuffledRDD(tasks(x)) < taskInputSizeFromShuffledRDD(tasks(y)) - } - if (tasks.nonEmpty) { - // Sort the tasks based on their input size from ShuffledRDD. - pendingTasksForExecutor.foreach { - case (k, v) => pendingTasksForExecutor(k) = v.sortWith(ordFunc) - } - pendingTasksForHost.foreach { - case (k, v) => pendingTasksForHost(k) = v.sortWith(ordFunc) - } - pendingTasksForRack.foreach { - case (k, v) => pendingTasksForRack(k) = v.sortWith(ordFunc) - } - pendingTasksWithNoPrefs = pendingTasksWithNoPrefs.sortWith(ordFunc) - allPendingTasks = allPendingTasks.sortWith(ordFunc) - } - case None => // no-op - } - } - private def maybeFinishTaskSet() { if (isZombie && runningTasks == 0) { sched.taskSetFinished(this) @@ -864,7 +833,6 @@ private[spark] class TaskSetManager( s" has already succeeded).") } else { addPendingTask(index) - sortPendingTasks() } if (!isZombie && reason.countTowardsTaskFailures) { @@ -936,7 +904,6 @@ private[spark] class TaskSetManager( copiesRunning(index) -= 1 tasksSuccessful -= 1 addPendingTask(index) - sortPendingTasks() // Tell the DAGScheduler that this task was resubmitted so that it doesn't think our // stage finishes when a total of tasks.size tasks finish. sched.dagScheduler.taskEnded( diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 70c466ff4d359..1f3ad99b2f803 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -2277,13 +2277,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, 1))) } - test("Tasks input size from shuffled RDD should be correct.") { + test("Tasks should be in descending order by input size from ShuffledRDD.") { + val partitioner = new HashPartitioner(2) val rddA = new MyRDD(sc, 2, Nil) - val shuffleDepA = new ShuffleDependency(rddA, new HashPartitioner(2)) + val shuffleDepA = new ShuffleDependency(rddA, partitioner) val shuffleIdA = shuffleDepA.shuffleId val rddB = new MyRDD(sc, 2, Nil) - val shuffleDepB = new ShuffleDependency(rddB, new HashPartitioner(2)) + val shuffleDepB = new ShuffleDependency(rddB, partitioner) val shuffleIdB = shuffleDepB.shuffleId val rddC = new MyRDD(sc, 2, List(shuffleDepA, shuffleDepB), tracker = mapOutputTracker) @@ -2302,16 +2303,15 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with Timeou (Success, MapStatus(makeBlockManagerId("hostB"), compressAndDecompress(Array(200, 20000)))))) assert(taskSets(2).stageId === 2 && taskSets(2).stageAttemptId === 0) - assert(taskSets(2).taskInputSizesFromShuffledRDDOpt != None) - taskSets(2).taskInputSizesFromShuffledRDDOpt match { - case Some(inputSize) => - assert(inputSize(taskSets(2).tasks(0)) === - compressAndDecompress(Array(10, 100, 20, 200)).sum) - assert(inputSize(taskSets(2).tasks(1)) === - compressAndDecompress(Array(1000, 10000, 2000, 20000)).sum) - case None => - throw new DAGSchedulerSuiteDummyException - } + + // Tasks input size from shuffled RDD should be correct + assert(scheduler.inputSizeFromShuffledRDD(rddC, 0) === + compressAndDecompress(Array(10, 100, 20, 200)).sum) + assert(scheduler.inputSizeFromShuffledRDD(rddC, 1) === + compressAndDecompress(Array(1000, 10000, 2000, 20000)).sum) + + // Tasks should be in descending order by input size from ShuffledRDD. + assert(taskSets(2).tasks(0).partitionId === 1 && taskSets(2).tasks(1).partitionId === 0) } /** diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 05a4e86220dfb..081d8ddeb9d19 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1138,22 +1138,6 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg .updateBlacklistForFailedTask(anyString(), anyString(), anyInt()) } - test("Schedule tasks based on size of input from ShuffledRDD.") { - sc = new SparkContext("local", "test") - sched = new FakeTaskScheduler(sc) - val tasks = Array.tabulate[Task[_]](4) { i => - new FakeTask(stageId = 0, partitionId = i, prefLocs = Nil) - } - val taskSet = new TaskSet(tasks, stageId = 0, stageAttemptId = 0, priority = 0, null, - Some(tasks.zip(Seq(1L, 100L, 10000L, 1000L)).toMap)) - val clock = new ManualClock() - val manager = new TaskSetManager(sched, taskSet, 1, clock = clock) - assert(manager.resourceOffer("exec", "host", ANY).get.index === 2) - assert(manager.resourceOffer("exec", "host", ANY).get.index === 3) - assert(manager.resourceOffer("exec", "host", ANY).get.index === 1) - assert(manager.resourceOffer("exec", "host", ANY).get.index === 0) - } - private def createTaskResult( id: Int, accumUpdates: Seq[AccumulatorV2[_, _]] = Seq.empty): DirectTaskResult[Int] = {