Skip to content

Commit c8c1de4

Browse files
committed
simplify the patch
1 parent be652ed commit c8c1de4

File tree

4 files changed

+145
-126
lines changed

4 files changed

+145
-126
lines changed

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 26 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -249,29 +249,36 @@ private[spark] class TaskSchedulerImpl(
249249

250250
// Take each TaskSet in our scheduling order, and then offer it each node in increasing order
251251
// of locality levels so that it gets a chance to launch local tasks on all of them.
252+
// NOTE: the preferredLocality order: PROCESS_LOCAL, NODE_LOCAL, NOPREF, RACK_local, ANY
252253
var launchedTask = false
253254
for (taskSet <- sortedTaskSets; preferredLocality <- taskSet.myLocalityLevels) {
254-
do {
255-
launchedTask = false
256-
for (i <- 0 until shuffledOffers.size) {
257-
val execId = shuffledOffers(i).executorId
258-
val host = shuffledOffers(i).host
259-
if (availableCpus(i) >= CPUS_PER_TASK) {
260-
for (task <- taskSet.resourceOffer(execId, host, preferredLocality,
261-
TaskLocality.PROCESS_LOCAL, true)) {
262-
tasks(i) += task
263-
val tid = task.taskId
264-
taskIdToTaskSetId(tid) = taskSet.taskSet.id
265-
taskIdToExecutorId(tid) = execId
266-
activeExecutorIds += execId
267-
executorsByHost(host) += execId
268-
availableCpus(i) -= CPUS_PER_TASK
269-
assert(availableCpus(i) >= 0)
270-
launchedTask = true
255+
def launchTaskOnLocalityLevel(locality: TaskLocality.Value) {
256+
do {
257+
launchedTask = false
258+
for (i <- 0 until shuffledOffers.size) {
259+
val execId = shuffledOffers(i).executorId
260+
val host = shuffledOffers(i).host
261+
if (availableCpus(i) >= CPUS_PER_TASK) {
262+
for (task <- taskSet.resourceOffer(execId, host, locality)) {
263+
tasks(i) += task
264+
val tid = task.taskId
265+
taskIdToTaskSetId(tid) = taskSet.taskSet.id
266+
taskIdToExecutorId(tid) = execId
267+
activeExecutorIds += execId
268+
executorsByHost(host) += execId
269+
availableCpus(i) -= CPUS_PER_TASK
270+
assert(availableCpus(i) >= 0)
271+
launchedTask = true
272+
}
271273
}
272274
}
273-
}
274-
} while (launchedTask)
275+
} while (launchedTask)
276+
}
277+
launchTaskOnLocalityLevel(preferredLocality)
278+
// search noPref task after we have launched all node_local and nearer tasks
279+
if (preferredLocality == TaskLocality.NODE_LOCAL) {
280+
launchTaskOnLocalityLevel(TaskLocality.NOPREF)
281+
}
275282
}
276283

277284
if (tasks.size > 0) {

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ private[spark] class TaskSetManager(
7979
private val numFailures = new Array[Int](numTasks)
8080
// key is taskId, value is a Map of executor id to when it failed
8181
private val failedExecutors = new HashMap[Int, HashMap[String, Long]]()
82+
8283
val taskAttempts = Array.fill[List[TaskInfo]](numTasks)(Nil)
8384
var tasksSuccessful = 0
8485

@@ -113,6 +114,10 @@ private[spark] class TaskSetManager(
113114
// but at host level.
114115
private val pendingTasksForHost = new HashMap[String, ArrayBuffer[Int]]
115116

117+
// this collection is mainly for ensuring that the NODE_LOCAL task is always scheduled
118+
// before NOPREF and it contain all NODE_LOCAL and "not-launched" tasks
119+
private[scheduler] val nodeLocalTasks = new HashMap[String, HashSet[Int]]
120+
116121
// Set of pending tasks for each rack -- similar to the above.
117122
private val pendingTasksForRack = new HashMap[String, ArrayBuffer[Int]]
118123

@@ -188,6 +193,9 @@ private[spark] class TaskSetManager(
188193
hadAliveLocations = true
189194
}
190195
addTo(pendingTasksForHost.getOrElseUpdate(loc.host, new ArrayBuffer))
196+
if (loc.executorId == None) {
197+
nodeLocalTasks.getOrElseUpdate(loc.host, new HashSet[Int]) += index
198+
}
191199
for (rack <- sched.getRackForHost(loc.host)) {
192200
addTo(pendingTasksForRack.getOrElseUpdate(rack, new ArrayBuffer))
193201
if(sched.hasHostAliveOnRack(rack)){
@@ -336,38 +344,22 @@ private[spark] class TaskSetManager(
336344
* Dequeue a pending task for a given node and return its index and locality level.
337345
* Only search for tasks matching the given locality constraint.
338346
*
339-
* NOTE: minLocality is for avoiding duplicate traverse of the list (especially when we
340-
* pass NOPREF as maxLocality after the others
341-
*
342347
* @return An option containing (task index within the task set, locality, is speculative?)
343348
*/
344-
private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value,
345-
minLocality: TaskLocality.Value)
349+
private def findTask(execId: String, host: String, maxLocality: TaskLocality.Value)
346350
: Option[(Int, TaskLocality.Value, Boolean)] =
347351
{
348-
def withinAllowedLocality(locality: TaskLocality.TaskLocality): Boolean = {
349-
TaskLocality.isAllowed(maxLocality, locality) && {
350-
if (maxLocality != minLocality) {
351-
minLocality < locality
352-
} else {
353-
true
354-
}
355-
}
356-
}
357-
358-
if (withinAllowedLocality(TaskLocality.PROCESS_LOCAL)) {
359-
for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
360-
return Some((index, TaskLocality.PROCESS_LOCAL, false))
361-
}
352+
for (index <- findTaskFromList(execId, getPendingTasksForExecutor(execId))) {
353+
return Some((index, TaskLocality.PROCESS_LOCAL, false))
362354
}
363355

364-
if (withinAllowedLocality(TaskLocality.NODE_LOCAL)) {
356+
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NODE_LOCAL)) {
365357
for (index <- findTaskFromList(execId, getPendingTasksForHost(host))) {
366358
return Some((index, TaskLocality.NODE_LOCAL, false))
367359
}
368360
}
369361

370-
if (withinAllowedLocality(TaskLocality.NOPREF)) {
362+
if (TaskLocality.isAllowed(maxLocality, TaskLocality.NOPREF)) {
371363
// Look for noPref tasks after NODE_LOCAL for minimize cross-rack traffic
372364
for (index <- findTaskFromList(execId, pendingTasksWithNoPrefs)) {
373365
return Some((index, TaskLocality.PROCESS_LOCAL, false))
@@ -378,7 +370,7 @@ private[spark] class TaskSetManager(
378370
if (specTask != None) return specTask
379371
}
380372

381-
if (withinAllowedLocality(TaskLocality.RACK_LOCAL)) {
373+
if (TaskLocality.isAllowed(maxLocality, TaskLocality.RACK_LOCAL)) {
382374
for {
383375
rack <- sched.getRackForHost(host)
384376
index <- findTaskFromList(execId, getPendingTasksForRack(rack))
@@ -387,7 +379,7 @@ private[spark] class TaskSetManager(
387379
}
388380
}
389381

390-
if (withinAllowedLocality(TaskLocality.ANY)) {
382+
if (TaskLocality.isAllowed(maxLocality, TaskLocality.ANY)) {
391383
for (index <- findTaskFromList(execId, allPendingTasks)) {
392384
return Some((index, TaskLocality.ANY, false))
393385
}
@@ -398,42 +390,36 @@ private[spark] class TaskSetManager(
398390

399391
/**
400392
* Respond to an offer of a single executor from the scheduler by finding a task
393+
*
394+
* NOTE: this function is either called with a real preferredLocality level which
395+
* would be adjusted by delay scheduling algorithm or it will be with a special
396+
* NOPREF locality which will be not modified
397+
*
401398
* @param execId the executor Id of the offered resource
402399
* @param host the host Id of the offered resource
403400
* @param preferredLocality the maximum locality we want to schedule the tasks at
404-
* @param bottomLocality the minimum locality we want to schedule the tasks at, this
405-
* parameter is mainly used to avoid some duplicate traversing of
406-
* the task lists, after we have determined that we have no candidate
407-
* tasks on certain levels
408-
* @param allowAdjustPrefLocality this parameter is mainly for scheduling noPref tasks, where
409-
* we do not want to apply delay scheduling on this kind of tasks
410401
*/
411402
def resourceOffer(
412403
execId: String,
413404
host: String,
414-
preferredLocality: TaskLocality.TaskLocality,
415-
bottomLocality: TaskLocality.TaskLocality,
416-
allowAdjustPrefLocality: Boolean = true)
405+
preferredLocality: TaskLocality.TaskLocality)
417406
: Option[TaskDescription] =
418407
{
419408
if (!isZombie) {
420409
val curTime = clock.getTime()
421410

422-
var allowedLocality = getAllowedLocalityLevel(curTime)
423-
424-
if (allowedLocality > preferredLocality) {
425-
// We're not allowed to search for farther-away tasks
426-
allowedLocality = preferredLocality
427-
}
411+
var allowedLocality = preferredLocality
428412

429-
val foundTask = {
430-
if (allowAdjustPrefLocality) {
431-
findTask(execId, host, allowedLocality, bottomLocality)
432-
} else {
433-
findTask(execId, host, preferredLocality, bottomLocality)
413+
if (preferredLocality != TaskLocality.NOPREF ||
414+
(nodeLocalTasks.contains(host) && nodeLocalTasks(host).size > 0)) {
415+
allowedLocality = getAllowedLocalityLevel(curTime)
416+
if (allowedLocality > preferredLocality) {
417+
// We're not allowed to search for farther-away tasks
418+
allowedLocality = preferredLocality
434419
}
435420
}
436-
foundTask match {
421+
422+
findTask(execId, host, allowedLocality) match {
437423
case Some((index, taskLocality, speculative)) => {
438424
// Found a task; do some bookkeeping and return a task description
439425
val task = tasks(index)
@@ -446,8 +432,11 @@ private[spark] class TaskSetManager(
446432
taskInfos(taskId) = info
447433
taskAttempts(index) = info :: taskAttempts(index)
448434
// Update our locality level for delay scheduling
449-
currentLocalityIndex = getLocalityIndex(taskLocality)
450-
lastLaunchTime = curTime
435+
// NOPREF will not affect the variables related to delay scheduling
436+
if (preferredLocality != TaskLocality.NOPREF) {
437+
currentLocalityIndex = getLocalityIndex(taskLocality)
438+
lastLaunchTime = curTime
439+
}
451440
// Serialize and return the task
452441
val startTime = clock.getTime()
453442
// We rely on the DAGScheduler to catch non-serializable closures and RDDs, so in here
@@ -471,13 +460,16 @@ private[spark] class TaskSetManager(
471460
taskName, taskId, host, taskLocality, serializedTask.limit))
472461

473462
sched.dagScheduler.taskStarted(task, info)
463+
if (taskLocality <= TaskLocality.NODE_LOCAL) {
464+
for (preferedLocality <- tasks(index).preferredLocations) {
465+
if (nodeLocalTasks.contains(preferedLocality.host)) {
466+
nodeLocalTasks(preferedLocality.host) -= index
467+
}
468+
}
469+
}
474470
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
475471
}
476472
case _ =>
477-
if (preferredLocality != TaskLocality.NOPREF) {
478-
return resourceOffer(execId, host, TaskLocality.NOPREF, preferredLocality,
479-
allowAdjustPrefLocality = false)
480-
}
481473
}
482474
}
483475
None

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,9 +80,7 @@ class FakeTaskSetManager(
8080
override def resourceOffer(
8181
execId: String,
8282
host: String,
83-
preferredLocality: TaskLocality.TaskLocality,
84-
bottomLocality: TaskLocality.TaskLocality,
85-
allowedAdjustPrefLocality: Boolean)
83+
preferredLocality: TaskLocality.TaskLocality)
8684
: Option[TaskDescription] =
8785
{
8886
if (tasksSuccessful + numRunningTasks < numTasks) {
@@ -126,8 +124,7 @@ class TaskSchedulerImplSuite extends FunSuite with LocalSparkContext with Loggin
126124
manager.parent.name, manager.parent.runningTasks, manager.name, manager.runningTasks))
127125
}
128126
for (taskSet <- taskSetQueue) {
129-
taskSet.resourceOffer("execId_1", "hostname_1", TaskLocality.ANY,
130-
TaskLocality.PROCESS_LOCAL) match {
127+
taskSet.resourceOffer("execId_1", "hostname_1", TaskLocality.ANY) match {
131128
case Some(task) =>
132129
return taskSet.stageId
133130
case None => {}

0 commit comments

Comments
 (0)