diff --git a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala index 2e2851eb9070b..7333b31524f2a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Pool.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Pool.scala @@ -59,6 +59,8 @@ private[spark] class Pool( } } + override def isSchedulable: Boolean = true + override def addSchedulable(schedulable: Schedulable): Unit = { require(schedulable != null) schedulableQueue.add(schedulable) @@ -105,7 +107,7 @@ private[spark] class Pool( val sortedSchedulableQueue = schedulableQueue.asScala.toSeq.sortWith(taskSetSchedulingAlgorithm.comparator) for (schedulable <- sortedSchedulableQueue) { - sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue + sortedTaskSetQueue ++= schedulable.getSortedTaskSetQueue.filter(_.isSchedulable) } sortedTaskSetQueue } diff --git a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala index 8cc239c81d11a..0626f8fb8150a 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/Schedulable.scala @@ -39,6 +39,7 @@ private[spark] trait Schedulable { def stageId: Int def name: String + def isSchedulable: Boolean def addSchedulable(schedulable: Schedulable): Unit def removeSchedulable(schedulable: Schedulable): Unit def getSchedulableByName(name: String): Schedulable diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index 107c517ca06bc..2fcf13d5268f8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -535,7 +535,7 @@ private[spark] class TaskSchedulerImpl( val availableResources = shuffledOffers.map(_.resources).toArray val availableCpus = shuffledOffers.map(o => o.cores).toArray val resourceProfileIds = shuffledOffers.map(o => o.resourceProfileId).toArray - val sortedTaskSets = rootPool.getSortedTaskSetQueue.filterNot(_.isZombie) + val sortedTaskSets = rootPool.getSortedTaskSetQueue for (taskSet <- sortedTaskSets) { logDebug("parentName: %s, name: %s, runningTasks: %s".format( taskSet.parent.name, taskSet.name, taskSet.runningTasks)) diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index 673fe4fe27519..78fd412ef154c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -951,6 +951,9 @@ private[spark] class TaskSetManager( null } + override def isSchedulable: Boolean = !isZombie && + (pendingTasks.all.nonEmpty || pendingSpeculatableTasks.all.nonEmpty) + override def addSchedulable(schedulable: Schedulable): Unit = {} override def removeSchedulable(schedulable: Schedulable): Unit = {}