diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 2eedd201ca355..c02b3395266a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -176,6 +176,9 @@ private[spark] class TaskSetManager( var currentLocalityIndex = 0 // Index of our current locality level in validLocalityLevels var lastLaunchTime = clock.getTimeMillis() // Time we last launched a task at this level + // Time we submitted this taskSet + val taskSetSubmittedTime = clock.getTimeMillis() + override def schedulableQueue: ConcurrentLinkedQueue[Schedulable] = null override def schedulingMode: SchedulingMode = SchedulingMode.NONE @@ -555,6 +558,12 @@ private[spark] class TaskSetManager( logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " + s"${localityWaits(currentLocalityIndex)}ms") currentLocalityIndex += 1 + } else if (curTime - taskSetSubmittedTime >= + getTaskSetLocalityWait(myLocalityLevels(currentLocalityIndex))) { + // Jump to the next locality level + logDebug(s"Moving to ${myLocalityLevels(currentLocalityIndex + 1)} after waiting for " + + s"${getTaskSetLocalityWait(myLocalityLevels(currentLocalityIndex))}ms") + currentLocalityIndex += 1 } else { return myLocalityLevels(currentLocalityIndex) } @@ -878,6 +887,22 @@ private[spark] class TaskSetManager( } } + private def getTaskSetLocalityWait(level: TaskLocality.TaskLocality): Long = { + val defaultWait = conf.get("spark.taskset.locality.wait", "5s") + val localityWaitKey = level match { + case TaskLocality.PROCESS_LOCAL => "spark.taskset.locality.wait.process" + case TaskLocality.NODE_LOCAL => "spark.taskset.locality.wait.node" + case TaskLocality.RACK_LOCAL => "spark.taskset.locality.wait.rack" + case _ => null + } + + if (localityWaitKey != null) { + conf.getTimeAsMs(localityWaitKey, defaultWait) + } else { + 0L + } + } + /** * Compute the locality levels used in this TaskSet. Assumes that all tasks have already been * added to queues using addPendingTask. diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 1d7c8f4a61857..7280a8f2034a7 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -251,16 +251,54 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg // get chosen before the noPref task assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index == 2) - // Offer host2, exec3 again, at NODE_LOCAL level: we should choose task 2 + // Offer host2, exec2 again, at NODE_LOCAL level: we should choose task 1 assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index == 1) - // Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task + // Offer host2, exec2 again, at NODE_LOCAL level: we should get noPref task // after failing to find a node_Local task assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) clock.advance(LOCALITY_WAIT_MS) assert(manager.resourceOffer("exec2", "host2", NO_PREF).get.index == 3) } + test("basic delay scheduling for taskset") { + sc = new SparkContext("local", "test") + val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2")) + val taskSet = FakeTask.createTaskSet(8, + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host2")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1", "exec1")), + Seq(TaskLocation("host1")), + Seq() // Last task has no locality prefs + ) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock) + + val TASKSET_LOCALITY_WAIT_MS = conf.getTimeAsMs("spark.taskset.locality.wait", "5s") + // We just want to get a time less than "spark.locality.wait" + // The used key is not existing and doesn't matter + val LESS_LOCALITY_WAIT_MS = conf.getTimeAsMs("fake.spark.taskset.locality.wait", "1s") + + // First offer host1, exec1: first three tasks should be chosen + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 0) + clock.advance(LESS_LOCALITY_WAIT_MS) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 1) + clock.advance(LESS_LOCALITY_WAIT_MS) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 2) + clock.advance(LESS_LOCALITY_WAIT_MS) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 4) + clock.advance(LESS_LOCALITY_WAIT_MS) + + // Only passed 1s from last launched time, so current locality level is still PROCESS_LOCAL + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL) == None) + clock.advance(LESS_LOCALITY_WAIT_MS) + // Passed 5s from the time taskset submitted, so current locality level is NODE_LOCAL + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index === 3) + } + test("we do not need to delay scheduling when we only have noPref tasks in the queue") { sc = new SparkContext("local", "test") val sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec3", "host2"))