diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index dd1b2595461f..3d1ead7eb40e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -1383,6 +1383,7 @@ private[spark] class DAGScheduler( event.reason match { case Success => + taskScheduler.markPartitionCompletedFromEventLoop(task.partitionId, task.stageId) task match { case rt: ResultTask[_, _] => // Cast to ResultStage here because it's part of the ResultTask diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala index 94221eb0d551..98ff9b8632fe 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala @@ -109,4 +109,16 @@ private[spark] trait TaskScheduler { */ def applicationAttemptId(): Option[String] + /** + * SPARK-25250: Marks the task has completed in all TaskSetManagers for the given stage. + * After stage failure and retry, there may be multiple TaskSetManagers for the stage. If an + * earlier attempt of a stage completes a task, we should ensure that the later attempts do not + * also submit those same tasks. That also means that a task completion from an earlier attempt + * can lead to the entire stage getting marked as successful. Whenever any Task gets + * successfully completed, we simply mark the corresponding partition id as completed in all + * attempts for that particular stage. This method must be called from inside the DAGScheduler + * event loop, to ensure a consistent view of all task sets for the given stage. + */ + def markPartitionCompletedFromEventLoop(partitionId: Int, stageId: Int): Unit + } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index d551fb71a103..1d1295443884 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -150,6 +150,8 @@ private[spark] class TaskSchedulerImpl( private[scheduler] var barrierCoordinator: RpcEndpoint = null + private[scheduler] val stageIdToFinishedPartitions = new HashMap[Int, HashSet[Int]] + private def maybeInitBarrierCoordinator(): Unit = { if (barrierCoordinator == null) { barrierCoordinator = new BarrierCoordinator(barrierSyncTimeout, sc.listenerBus, @@ -287,6 +289,10 @@ private[spark] class TaskSchedulerImpl( } } + override def markPartitionCompletedFromEventLoop(partitionId: Int, stageId: Int): Unit = { + stageIdToFinishedPartitions.getOrElseUpdate(stageId, new HashSet[Int]).add(partitionId) + } + /** * Called to indicate that all task attempts (including speculated tasks) associated with the * given TaskSetManager have completed, so state associated with the TaskSetManager should be diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala index b7bf06974fd5..b1f53beebd67 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala @@ -808,7 +808,6 @@ private[spark] class TaskSetManager( if (tasksSuccessful == numTasks) { isZombie = true } - maybeFinishTaskSet() } } } @@ -924,6 +923,9 @@ private[spark] class TaskSetManager( s" be re-executed (either because the task failed with a shuffle data fetch failure," + s" so the previous stage needs to be re-run, or because a different copy of the task" + s" has already succeeded).") + } else if (sched.stageIdToFinishedPartitions.get(stageId).exists( + partitions => partitions.contains(tasks(index).partitionId))) { + sched.markPartitionCompletedInAllTaskSets(stageId, tasks(index).partitionId, info) } else { addPendingTask(index) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index ed6a3d93b312..aecd4ba1605f 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -133,6 +133,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi /** Stages for which the DAGScheduler has called TaskScheduler.cancelTasks(). */ val cancelledStages = new HashSet[Int]() + val completedPartitions = new HashMap[Int, HashSet[Int]]() + val taskScheduler = new TaskScheduler() { override def schedulingMode: SchedulingMode = SchedulingMode.FIFO override def rootPool: Pool = new Pool("", schedulingMode, 0, 0) @@ -160,6 +162,14 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None + // Since, the method completeTasks in TaskSchedulerImpl.scala marks the partition complete + // for all stage attempts in the particular stage id, it does not need any info about + // stageAttemptId. Hence, completed partition id's are stored only for stage id's to mock + // the method implementation here. + override def markPartitionCompletedFromEventLoop(partitionId: Int, stageId: Int): Unit = { + val partitionIds = completedPartitions.getOrElseUpdate(stageId, new HashSet[Int]) + partitionIds.add(partitionId) + } } /** Length of time to wait while draining listener events. */ @@ -248,6 +258,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi cancelledStages.clear() cacheLocations.clear() results.clear() + completedPartitions.clear() securityMgr = new SecurityManager(conf) broadcastManager = new BroadcastManager(true, conf, securityMgr) mapOutputTracker = new MapOutputTrackerMaster(conf, broadcastManager, true) { @@ -667,6 +678,7 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi override def executorLost(executorId: String, reason: ExecutorLossReason): Unit = {} override def workerRemoved(workerId: String, host: String, message: String): Unit = {} override def applicationAttemptId(): Option[String] = None + override def markPartitionCompletedFromEventLoop(partitionId: Int, stageId: Int): Unit = {} } val noKillScheduler = new DAGScheduler( sc, @@ -2849,6 +2861,49 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi } } + // This test is kind of similar and goes alongwith "Completions in zombie tasksets update + // status of non-zombie taskset" in TaskSchedulerImplSuite.scala. + test("SPARK-25250: Late zombie task completions handled correctly even before" + + " new taskset launched") { + val shuffleMapRdd = new MyRDD(sc, 4, Nil) + val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(4)) + val reduceRdd = new MyRDD(sc, 4, List(shuffleDep), tracker = mapOutputTracker) + submit(reduceRdd, Array(0, 1, 2, 3)) + + completeShuffleMapStageSuccessfully(0, 0, numShufflePartitions = 4) + + // Fail Stage 1 Attempt 0 with Fetch Failure + runEvent(makeCompletionEvent( + taskSets(1).tasks(0), + FetchFailed(makeBlockManagerId("hostA"), shuffleDep.shuffleId, 0, 0, "ignored"), + null)) + + // this will trigger a resubmission of stage 0, since we've lost some of its + // map output, for the next iteration through the loop + scheduler.resubmitFailedStages() + completeShuffleMapStageSuccessfully(0, 1, numShufflePartitions = 4) + + // tasksets 1 & 3 should be two different attempts for our reduce stage -- lets + // double-check test setup + val reduceStage = taskSets(1).stageId + assert(taskSets(3).stageId === reduceStage) + + // complete one task from the original taskset, make sure we update the taskSchedulerImpl + // so it can notify all taskSetManagers. Some of that is mocked here, just check there + // is the right event. + val taskToComplete = taskSets(1).tasks(3) + + runEvent(makeCompletionEvent(taskToComplete, Success, Nil, Nil)) + assert(completedPartitions.getOrElse(reduceStage, Set()) === Set(taskToComplete.partitionId)) + + // this will mark partition id 1 of stage 1 attempt 0 as complete. So we expect the status + // of that partition id to be reflected for stage 1 attempt 1 as well. + runEvent(makeCompletionEvent( + taskSets(1).tasks(1), Success, Nil, Nil)) + assert(completedPartitions(reduceStage) === Set( + taskSets(3).tasks(1).partitionId, taskSets(3).tasks(3).partitionId)) + } + /** * Assert that the supplied TaskSet has exactly the given hosts as its preferred locations. * Note that this checks only the host and not the executor ID. diff --git a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala index 30d0966691a3..8436e6928973 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala @@ -94,4 +94,5 @@ private class DummyTaskScheduler extends TaskScheduler { accumUpdates: Array[(Long, Seq[AccumulatorV2[_, _]])], blockManagerId: BlockManagerId, executorMetrics: ExecutorMetrics): Boolean = true + override def markPartitionCompletedFromEventLoop(partitionId: Int, stageId: Int): Unit = {} } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala index 016adb8b70e7..13efdf53a152 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala @@ -31,7 +31,7 @@ import org.scalatest.mockito.MockitoSugar import org.apache.spark._ import org.apache.spark.internal.Logging import org.apache.spark.internal.config -import org.apache.spark.util.ManualClock +import org.apache.spark.util.{AccumulatorV2, ManualClock} class FakeSchedulerBackend extends SchedulerBackend { def start() {} @@ -125,6 +125,20 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B failedTaskSetReason = reason failedTaskSetException = exception } + override def taskEnded( + task: Task[_], + reason: TaskEndReason, + result: Any, + accumUpdates: Seq[AccumulatorV2[_, _]], + taskInfo: TaskInfo): Unit = { + if (reason == Success) { + // For SPARK-23433 / SPARK-25250, need to make DAGScheduler lets all tasksets know + // about complete partitions. Super implementation is not enough, because we've mocked + // out too much of the rest of the DAGScheduler. + taskScheduler.markPartitionCompletedFromEventLoop(task.partitionId, task.stageId) + } + super.taskEnded(task, reason, result, accumUpdates, taskInfo) + } } taskScheduler }