Skip to content

Commit f9a2ad8

Browse files
committed
simplify the logic in TaskSchedulerImpl
1 parent c8c1de4 commit f9a2ad8

File tree

4 files changed

+61
-59
lines changed

4 files changed

+61
-59
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import org.apache.spark.annotation.DeveloperApi
2222
@DeveloperApi
2323
object TaskLocality extends Enumeration {
2424
// Process local is expected to be used ONLY within TaskSetManager for now.
25-
val PROCESS_LOCAL, NODE_LOCAL, NOPREF, RACK_LOCAL, ANY = Value
25+
val PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY = Value
2626

2727
type TaskLocality = Value
2828

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -249,36 +249,29 @@ private[spark] class TaskSchedulerImpl(
249249

250250
// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
251251
// of locality levels so that it gets a chance to launch local tasks on all of them.
252-
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NOPREF, RACK_local, ANY
252+
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NOPREF, RACK_LOCAL, ANY
253253
var launchedTask = false
254254
for (taskSet <- sortedTaskSets; preferredLocality <- taskSet.myLocalityLevels) {
255-
def launchTaskOnLocalityLevel(locality: TaskLocality.Value) {
256-
do {
257-
launchedTask = false
258-
for (i <- 0 until shuffledOffers.size) {
259-
val execId = shuffledOffers(i).executorId
260-
val host = shuffledOffers(i).host
261-
if (availableCpus(i) >= CPUS_PER_TASK) {
262-
for (task <- taskSet.resourceOffer(execId, host, locality)) {
263-
tasks(i) += task
264-
val tid = task.taskId
265-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
266-
taskIdToExecutorId(tid) = execId
267-
activeExecutorIds += execId
268-
executorsByHost(host) += execId
269-
availableCpus(i) -= CPUS_PER_TASK
270-
assert(availableCpus(i) >= 0)
271-
launchedTask = true
272-
}
255+
do {
256+
launchedTask = false
257+
for (i <- 0 until shuffledOffers.size) {
258+
val execId = shuffledOffers(i).executorId
259+
val host = shuffledOffers(i).host
260+
if (availableCpus(i) >= CPUS_PER_TASK) {
261+
for (task <- taskSet.resourceOffer(execId, host, preferredLocality)) {
262+
tasks(i) += task
263+
val tid = task.taskId
264+
taskIdToTaskSetId(tid) = taskSet.taskSet.id
265+
taskIdToExecutorId(tid) = execId
266+
activeExecutorIds += execId
267+
executorsByHost(host) += execId
268+
availableCpus(i) -= CPUS_PER_TASK
269+
assert(availableCpus(i) >= 0)
270+
launchedTask = true
273271
}
274272
}
275-
} while (launchedTask)
276-
}
277-
launchTaskOnLocalityLevel(preferredLocality)
278-
// search noPref task after we have launched all node_local and nearer tasks
279-
if (preferredLocality == TaskLocality.NODE_LOCAL) {
280-
launchTaskOnLocalityLevel(TaskLocality.NOPREF)
281-
}
273+
}
274+
} while (launchedTask)
282275
}
283276

284277
if (tasks.size > 0) {

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 27 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ private[spark] class TaskSetManager(
298298
for (index <- speculatableTasks if canRunOnHost(index)) {
299299
val prefs = tasks(index).preferredLocations
300300
val executors = prefs.flatMap(_.executorId)
301-
if (prefs.size == 0 || executors.contains(execId)) {
301+
if (executors.contains(execId)) {
302302
speculatableTasks -= index
303303
return Some((index, TaskLocality.PROCESS_LOCAL))
304304
}
@@ -315,6 +315,16 @@ private[spark] class TaskSetManager(
315315
}
316316
}
317317

318+
if (TaskLocality.isAllowed(locality, TaskLocality.NO_PREF)) {
319+
for (index <- speculatableTasks if canRunOnHost(index)) {
320+
val locations = tasks(index).preferredLocations
321+
if (locations.size == 0) {
322+
speculatableTasks -= index
323+
return Some((index, TaskLocality.PROCESS_LOCAL))
324+
}
325+
}
326+
}
327+
318328
// Check for rack-local tasks
319329
if (TaskLocality.isAllowed(locality, TaskLocality.RACK_LOCAL)) {
320330
for (rack <- sched.getRackForHost(host)) {
@@ -359,15 +369,11 @@ private[spark] class TaskSetManager(
359369
}
360370
}
361371

362-
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NOPREF)) {
372+
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NO_PREF)) {
363373
// Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
364374
for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
365375
return Some((index, TaskLocality.PROCESS_LOCAL, false))
366376
}
367-
// find a speculative task if all noPref tasks have been scheduled
368-
val specTask = findSpeculativeTask(execId, host, maxLocality).map {
369-
case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
370-
if (specTask != None) return specTask
371377
}
372378

373379
if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
@@ -385,7 +391,9 @@ private[spark] class TaskSetManager(
385391
}
386392
}
387393

388-
None
394+
// find a speculative task if all others tasks have been scheduled
395+
findSpeculativeTask(execId, host, maxLocality).map {
396+
case (taskIndex, allowedLocality) => (taskIndex, allowedLocality, true)}
389397
}
390398

391399
/**
@@ -397,25 +405,25 @@ private[spark] class TaskSetManager(
397405
*
398406
* @param execId the executor Id of the offered resource
399407
* @param host the host Id of the offered resource
400-
* @param preferredLocality the maximum locality we want to schedule the tasks at
408+
* @param maxLocality the maximum locality we want to schedule the tasks at
401409
*/
402410
def resourceOffer(
403411
execId: String,
404412
host: String,
405-
preferredLocality: TaskLocality.TaskLocality)
413+
maxLocality: TaskLocality.TaskLocality)
406414
: Option[TaskDescription] =
407415
{
408416
if (!isZombie) {
409417
val curTime = clock.getTime()
410418

411-
var allowedLocality = preferredLocality
419+
var allowedLocality = maxLocality
412420

413-
if (preferredLocality != TaskLocality.NOPREF ||
421+
if (maxLocality != TaskLocality.NO_PREF ||
414422
(nodeLocalTasks.contains(host) && nodeLocalTasks(host).size > 0)) {
415423
allowedLocality = getAllowedLocalityLevel(curTime)
416-
if (allowedLocality > preferredLocality) {
424+
if (allowedLocality > maxLocality) {
417425
// We're not allowed to search for farther-away tasks
418-
allowedLocality = preferredLocality
426+
allowedLocality = maxLocality
419427
}
420428
}
421429

@@ -433,7 +441,7 @@ private[spark] class TaskSetManager(
433441
taskAttempts(index) = info :: taskAttempts(index)
434442
// Update our locality level for delay scheduling
435443
// NOPREF will not affect the variables related to delay scheduling
436-
if (preferredLocality != TaskLocality.NOPREF) {
444+
if (maxLocality != TaskLocality.NO_PREF) {
437445
currentLocalityIndex = getLocalityIndex(taskLocality)
438446
lastLaunchTime = curTime
439447
}
@@ -756,19 +764,17 @@ private[spark] class TaskSetManager(
756764
conf.get("spark.locality.wait.node", defaultWait).toLong
757765
case TaskLocality.RACK_LOCAL =>
758766
conf.get("spark.locality.wait.rack", defaultWait).toLong
759-
case TaskLocality.ANY =>
760-
0L
767+
case _ => 0L
761768
}
762769
}
763770

764771
/**
765772
* Compute the locality levels used in this TaskSet. Assumes that all tasks have already been
766773
* added to queues using addPendingTask.
767774
*
768-
* NOTE: don't need to handle NOPREF here, because NOPREF is scheduled as PROCESS_LOCAL
769775
*/
770776
private def computeValidLocalityLevels(): Array[TaskLocality.TaskLocality] = {
771-
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY}
777+
import TaskLocality.{PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY}
772778
val levels = new ArrayBuffer[TaskLocality.TaskLocality]
773779
if (!pendingTasksForExecutor.isEmpty && getLocalityWait(PROCESS_LOCAL) != 0 &&
774780
pendingTasksForExecutor.keySet.exists(sched.isExecutorAlive(_))) {
@@ -778,6 +784,9 @@ private[spark] class TaskSetManager(
778784
pendingTasksForHost.keySet.exists(sched.hasExecutorsAliveOnHost(_))) {
779785
levels += NODE_LOCAL
780786
}
787+
if (!pendingTasksWithNoPrefs.isEmpty) {
788+
levels += NO_PREF
789+
}
781790
if (!pendingTasksForRack.isEmpty && getLocalityWait(RACK_LOCAL) != 0 &&
782791
pendingTasksForRack.keySet.exists(sched.hasHostAliveOnRack(_))) {
783792
levels += RACK_LOCAL

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class LargeTask(stageId: Int) extends Task[Array[Byte]](stageId, 0) {
147147
}
148148

149149
class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
150-
import TaskLocality.{ANY, PROCESS_LOCAL, NOPREF, NODE_LOCAL, RACK_LOCAL}
150+
import TaskLocality.{ANY, PROCESS_LOCAL, NO_PREF, NODE_LOCAL, RACK_LOCAL}
151151

152152
private val conf = new SparkConf
153153

@@ -163,7 +163,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
163163

164164
// Offer a host with NOPREF as the constraint,
165165
// we should get a nopref task immediately since that's what we only have
166-
var taskOption = manager.resourceOffer("exec1", "host1", NOPREF)
166+
var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
167167
assert(taskOption.isDefined)
168168

169169
// Tell it the task has finished
@@ -180,15 +180,15 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
180180

181181
// First three offers should all find tasks
182182
for (i <- 0 until 3) {
183-
var taskOption = manager.resourceOffer("exec1", "host1", NOPREF)
183+
var taskOption = manager.resourceOffer("exec1", "host1", NO_PREF)
184184
assert(taskOption.isDefined)
185185
val task = taskOption.get
186186
assert(task.executorId === "exec1")
187187
}
188188
assert(sched.startedTasks.toSet === Set(0, 1, 2))
189189

190190
// Re-offer the host -- now we should get no more tasks
191-
assert(manager.resourceOffer("exec1", "host1", NOPREF) === None)
191+
assert(manager.resourceOffer("exec1", "host1", NO_PREF) === None)
192192

193193
// Finish the first two tasks
194194
manager.handleSuccessfulTask(0, createTaskResult(0))
@@ -245,7 +245,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
245245
// Offer host2, exec3 again, at NODE_LOCAL level: we should get noPref task
246246
// after failing to find a node_Local task
247247
assert(manager.resourceOffer("exec3", "host2", NODE_LOCAL) == None)
248-
assert(manager.resourceOffer("exec3", "host2", NOPREF).get.index == 3)
248+
assert(manager.resourceOffer("exec3", "host2", NO_PREF).get.index == 3)
249249
}
250250

251251
test("we do not need to delay scheduling when we only have noPref tasks in the queue") {
@@ -262,7 +262,7 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
262262
assert(manager.resourceOffer("exec1", "host1", PROCESS_LOCAL).get.index === 0)
263263
assert(manager.resourceOffer("exec3", "host2", PROCESS_LOCAL).get.index === 1)
264264
assert(manager.resourceOffer("exec3", "host2", NODE_LOCAL) == None)
265-
assert(manager.resourceOffer("exec3", "host2", NOPREF).get.index === 2)
265+
assert(manager.resourceOffer("exec3", "host2", NO_PREF).get.index === 2)
266266
}
267267

268268
test("delay scheduling with fallback") {
@@ -482,24 +482,24 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
482482
val clock = new FakeClock
483483
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock)
484484
// Only ANY is valid
485-
assert(manager.myLocalityLevels.sameElements(Array(ANY)))
485+
assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY)))
486486
// Add a new executor
487487
sched.addExecutor("execD", "host1")
488488
manager.executorAdded()
489489
// Valid locality should contain NODE_LOCAL and ANY
490-
assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY)))
490+
assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY)))
491491
// Add another executor
492492
sched.addExecutor("execC", "host2")
493493
manager.executorAdded()
494494
// Valid locality should contain PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL and ANY
495-
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY)))
495+
assert(manager.myLocalityLevels.sameElements(Array(PROCESS_LOCAL, NODE_LOCAL, NO_PREF, RACK_LOCAL, ANY)))
496496
// test if the valid locality is recomputed when the executor is lost
497497
sched.removeExecutor("execC")
498498
manager.executorLost("execC", "host2")
499-
assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, ANY)))
499+
assert(manager.myLocalityLevels.sameElements(Array(NODE_LOCAL, NO_PREF, ANY)))
500500
sched.removeExecutor("execD")
501501
manager.executorLost("execD", "host1")
502-
assert(manager.myLocalityLevels.sameElements(Array(ANY)))
502+
assert(manager.myLocalityLevels.sameElements(Array(NO_PREF, ANY)))
503503
}
504504

505505
test("test RACK_LOCAL tasks") {
@@ -572,15 +572,15 @@ class TaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging {
572572

573573
assert(manager.resourceOffer("execA", "host1", PROCESS_LOCAL).get.index === 0)
574574
assert(manager.resourceOffer("execA", "host1", NODE_LOCAL) == None)
575-
assert(manager.resourceOffer("execA", "host1", NOPREF) == None)
575+
assert(manager.resourceOffer("execA", "host1", NO_PREF) == None)
576576
clock.advance(LOCALITY_WAIT)
577577
// schedule a node local task
578578
assert(manager.resourceOffer("execA", "host1", NODE_LOCAL).get.index === 1)
579579
manager.speculatableTasks += 1
580580
// schedule the nonPref task
581-
assert(manager.resourceOffer("execA", "host1", NOPREF).get.index === 2)
581+
assert(manager.resourceOffer("execA", "host1", NO_PREF).get.index === 2)
582582
// schedule the speculative task
583-
assert(manager.resourceOffer("execB", "host2", NOPREF).get.index === 1)
583+
assert(manager.resourceOffer("execB", "host2", NO_PREF).get.index === 1)
584584
clock.advance(LOCALITY_WAIT * 3)
585585
// schedule non-local tasks
586586
assert(manager.resourceOffer("execB", "host2", ANY).get.index === 3)

0 commit comments

Comments
 (0)