Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 23 additions & 5 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1136,11 +1136,11 @@ private[spark] class DAGScheduler(
val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try {
stage match {
case s: ShuffleMapStage =>
partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap
partitionsToCompute.map { id => (id, getPreferredLocsInternal(stage.rdd, id))}.toMap
case s: ResultStage =>
partitionsToCompute.map { id =>
val p = s.partitions(id)
(id, getPreferredLocs(stage.rdd, p))
(id, getPreferredLocsInternal(stage.rdd, p))
}.toMap
}
} catch {
Expand All @@ -1152,7 +1152,8 @@ private[spark] class DAGScheduler(
return
}

stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq)
val taskLocalityPrefs = taskIdToLocations.values.map(_.filter(_ != WildcardLocation)).toSeq
stage.makeNewStageAttempt(partitionsToCompute.size, taskLocalityPrefs)

// If there are tasks to execute, record the submission time of the stage. Otherwise,
// post the even without the submission time, which indicates that this stage was
Expand Down Expand Up @@ -2054,14 +2055,28 @@ private[spark] class DAGScheduler(
/**
* Gets the locality information associated with a partition of a particular RDD.
*
* This method is thread-safe and is called from both DAGScheduler and SparkContext.
* This method is thread-safe and is called from SparkContext.
*
* @param rdd whose partitions are to be looked at
* @param partition to lookup locality information for
* @return list of machines that are preferred by the partition
*/
private[spark]
def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
getPreferredLocsInternal(rdd, partition).filter(_ != WildcardLocation)
}

/**
* Gets the locality information associated with a partition of a particular RDD, which may
* include a [[WildcardLocation]].
*
* This method is thread-safe and is called from DAGScheduler only.
*
* @param rdd whose partitions are to be looked at
* @param partition to lookup locality information for
* @return list of machines that are preferred by the partition
*/
private def getPreferredLocsInternal(rdd: RDD[_], partition: Int): Seq[TaskLocation] = {
getPreferredLocsInternal(rdd, partition, new HashSet)
}

Expand Down Expand Up @@ -2090,7 +2105,10 @@ private[spark] class DAGScheduler(
// If the RDD has some placement preferences (as is the case for input RDDs), get those
val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList
if (rddPrefs.nonEmpty) {
return rddPrefs.map(TaskLocation(_))
return rddPrefs.map {
case WildcardLocation.host => WildcardLocation
case host => TaskLocation(host)
}
}

// If the RDD has narrow dependencies, pick the first partition of the first narrow dependency
Expand Down
16 changes: 16 additions & 0 deletions core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,22 @@ private [spark] case class HDFSCacheTaskLocation(override val host: String) exte
override def toString: String = TaskLocation.inMemoryLocationTag + host
}

/**
* A location that can match any host. This can be used as one of the locations in
* `RDD.getPreferredLocations` to indicate that the task can be assigned to any host if none of
* the other desired locations can be satisfied immediately.
*
* This location is only used internally by DAGScheduler to skip delayed scheduling for individual
* RDDs. `DAGScheduler.getPreferredLocs` does not contain this location.
*
* @note This class is experimental and may be replaced by a more complete solution for delayed
* scheduling.
*/
private [spark] case object WildcardLocation extends TaskLocation {
override val host: String = "*"
override def toString: String = host
}

private[spark] object TaskLocation {
// We identify hosts on which the block is cached with this prefix. Because this prefix contains
// underscores, which are not legal characters in hostnames, there should be no potential for
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,8 @@ private[spark] class TaskSetManager(
resolveRacks: Boolean = true,
speculatable: Boolean = false): Unit = {
val pendingTaskSetToAddTo = if (speculatable) pendingSpeculatableTasks else pendingTasks
for (loc <- tasks(index).preferredLocations) {
val preferredLocations = tasks(index).preferredLocations
for (loc <- preferredLocations.filter(_ != WildcardLocation)) {
loc match {
case e: ExecutorCacheTaskLocation =>
pendingTaskSetToAddTo.forExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index
Expand All @@ -244,7 +245,7 @@ private[spark] class TaskSetManager(
}
}

if (tasks(index).preferredLocations == Nil) {
if (preferredLocations == Nil || preferredLocations.contains(WildcardLocation)) {
pendingTaskSetToAddTo.noPrefs += index
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1796,4 +1796,22 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg
manager.handleFailedTask(offerResult.get.taskId, TaskState.FAILED, reason)
assert(sched.taskSetsFailed.contains(taskSet.id))
}

test("Tasks with wildcard location can run immediately if preferred location not available") {
sc = new SparkContext("local", "test")
sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
val taskSet = FakeTask.createTaskSet(3,
Seq(TaskLocation("host1"), WildcardLocation),
Seq(TaskLocation("host1"), WildcardLocation),
Seq(TaskLocation("host2"), WildcardLocation)
)
val clock = new ManualClock
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 0)
// Second task is not scheduled as it does not satisfy locality level.
assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index === 2)
assert(manager.resourceOffer("exec3", "host3", NODE_LOCAL).isEmpty)
// Second task is scheduled immediately on a non-preferred host with NO_PREF locality level.
assert(manager.resourceOffer("exec3", "host3", NO_PREF).get.index === 1)
}
}