diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index fe3a48440991a..b5a2231b696f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1136,11 +1136,11 @@ private[spark] class DAGScheduler( val taskIdToLocations: Map[Int, Seq[TaskLocation]] = try { stage match { case s: ShuffleMapStage => - partitionsToCompute.map { id => (id, getPreferredLocs(stage.rdd, id))}.toMap + partitionsToCompute.map { id => (id, getPreferredLocsInternal(stage.rdd, id))}.toMap case s: ResultStage => partitionsToCompute.map { id => val p = s.partitions(id) - (id, getPreferredLocs(stage.rdd, p)) + (id, getPreferredLocsInternal(stage.rdd, p)) }.toMap } } catch { @@ -1152,7 +1152,8 @@ private[spark] class DAGScheduler( return } - stage.makeNewStageAttempt(partitionsToCompute.size, taskIdToLocations.values.toSeq) + val taskLocalityPrefs = taskIdToLocations.values.map(_.filter(_ != WildcardLocation)).toSeq + stage.makeNewStageAttempt(partitionsToCompute.size, taskLocalityPrefs) // If there are tasks to execute, record the submission time of the stage. Otherwise, // post the even without the submission time, which indicates that this stage was @@ -2054,7 +2055,7 @@ private[spark] class DAGScheduler( /** * Gets the locality information associated with a partition of a particular RDD. * - * This method is thread-safe and is called from both DAGScheduler and SparkContext. + * This method is thread-safe and is called from SparkContext. * * @param rdd whose partitions are to be looked at * @param partition to lookup locality information for @@ -2062,6 +2063,20 @@ private[spark] class DAGScheduler( */ private[spark] def getPreferredLocs(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { + getPreferredLocsInternal(rdd, partition).filter(_ != WildcardLocation) + } + + /** + * Gets the locality information associated with a partition of a particular RDD, which may + * include a [[WildcardLocation]]. + * + * This method is thread-safe and is called from DAGScheduler only. + * + * @param rdd whose partitions are to be looked at + * @param partition to lookup locality information for + * @return list of machines that are preferred by the partition + */ + private def getPreferredLocsInternal(rdd: RDD[_], partition: Int): Seq[TaskLocation] = { getPreferredLocsInternal(rdd, partition, new HashSet) } @@ -2090,7 +2105,10 @@ private[spark] class DAGScheduler( // If the RDD has some placement preferences (as is the case for input RDDs), get those val rddPrefs = rdd.preferredLocations(rdd.partitions(partition)).toList if (rddPrefs.nonEmpty) { - return rddPrefs.map(TaskLocation(_)) + return rddPrefs.map { + case WildcardLocation.host => WildcardLocation + case host => TaskLocation(host) + } } // If the RDD has narrow dependencies, pick the first partition of the first narrow dependency diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala index 06b52935c696c..9f97daf2876aa 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocation.scala @@ -49,6 +49,22 @@ private [spark] case class HDFSCacheTaskLocation(override val host: String) exte override def toString: String = TaskLocation.inMemoryLocationTag + host } +/** + * A location that can match any host. This can be used as one of the locations in + * `RDD.getPreferredLocations` to indicate that the task can be assigned to any host if none of + * the other desired locations can be satisfied immediately. + * + * This location is only used internally by DAGScheduler to skip delayed scheduling for individual + * RDDs. `DAGScheduler.getPreferredLocs` does not contain this location. + * + * @note This class is experimental and may be replaced by a more complete solution for delayed + * scheduling. + */ +private [spark] case object WildcardLocation extends TaskLocation { + override val host: String = "*" + override def toString: String = host +} + private[spark] object TaskLocation { // We identify hosts on which the block is cached with this prefix. Because this prefix contains // underscores, which are not legal characters in hostnames, there should be no potential for diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 5c0bc497dd1b3..b3e0ecf8f1bf6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -217,7 +217,8 @@ private[spark] class TaskSetManager( resolveRacks: Boolean = true, speculatable: Boolean = false): Unit = { val pendingTaskSetToAddTo = if (speculatable) pendingSpeculatableTasks else pendingTasks - for (loc <- tasks(index).preferredLocations) { + val preferredLocations = tasks(index).preferredLocations + for (loc <- preferredLocations.filter(_ != WildcardLocation)) { loc match { case e: ExecutorCacheTaskLocation => pendingTaskSetToAddTo.forExecutor.getOrElseUpdate(e.executorId, new ArrayBuffer) += index @@ -244,7 +245,7 @@ private[spark] class TaskSetManager( } } - if (tasks(index).preferredLocations == Nil) { + if (preferredLocations == Nil || preferredLocations.contains(WildcardLocation)) { pendingTaskSetToAddTo.noPrefs += index } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala index 34bcae8abd512..5a4c6afb417bb 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala @@ -1796,4 +1796,22 @@ class TaskSetManagerSuite extends SparkFunSuite with LocalSparkContext with Logg manager.handleFailedTask(offerResult.get.taskId, TaskState.FAILED, reason) assert(sched.taskSetsFailed.contains(taskSet.id)) } + + test("Tasks with wildcard location can run immediately if preferred location not available") { + sc = new SparkContext("local", "test") + sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3")) + val taskSet = FakeTask.createTaskSet(3, + Seq(TaskLocation("host1"), WildcardLocation), + Seq(TaskLocation("host1"), WildcardLocation), + Seq(TaskLocation("host2"), WildcardLocation) + ) + val clock = new ManualClock + val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock) + assert(manager.resourceOffer("exec1", "host1", NODE_LOCAL).get.index === 0) + // Second task is not scheduled as it does not satisfy locality level. + assert(manager.resourceOffer("exec2", "host2", NODE_LOCAL).get.index === 2) + assert(manager.resourceOffer("exec3", "host3", NODE_LOCAL).isEmpty) + // Second task is scheduled immediately on a non-preferred host with NO_PREF locality level. + assert(manager.resourceOffer("exec3", "host3", NO_PREF).get.index === 1) + } }