@@ -677,7 +677,7 @@ class DAGScheduler(
677677 }
678678
679679 private [scheduler] def handleBeginEvent (task : Task [_], taskInfo : TaskInfo ) {
680- val stageInfo = stageIdToStage(task.stageId).info
680+ val stageInfo = stageIdToStage(task.stageId).latestInfo
681681 listenerBus.post(SparkListenerTaskStart (task.stageId, stageInfo.attemptId, taskInfo))
682682 submitWaitingStages()
683683 }
@@ -696,8 +696,8 @@ class DAGScheduler(
696696 // is in the process of getting stopped.
697697 val stageFailedMessage = " Stage cancelled because SparkContext was shut down"
698698 runningStages.foreach { stage =>
699- stage.info .stageFailed(stageFailedMessage)
700- listenerBus.post(SparkListenerStageCompleted (stage.info ))
699+ stage.latestInfo .stageFailed(stageFailedMessage)
700+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo ))
701701 }
702702 listenerBus.post(SparkListenerJobEnd (job.jobId, JobFailed (error)))
703703 }
@@ -782,7 +782,16 @@ class DAGScheduler(
782782 logDebug(" submitMissingTasks(" + stage + " )" )
783783 // Get our pending tasks and remember them in our pendingTasks entry
784784 stage.pendingTasks.clear()
785- var tasks = ArrayBuffer [Task [_]]()
785+
786+ // First figure out the indexes of partition ids to compute.
787+ val partitionsToCompute : Seq [Int ] = {
788+ if (stage.isShuffleMap) {
789+ (0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil )
790+ } else {
791+ val job = stage.resultOfJob.get
792+ (0 until job.numPartitions).filter(id => ! job.finished(id))
793+ }
794+ }
786795
787796 val properties = if (jobIdToActiveJob.contains(jobId)) {
788797 jobIdToActiveJob(stage.jobId).properties
@@ -796,7 +805,8 @@ class DAGScheduler(
796805 // serializable. If tasks are not serializable, a SparkListenerStageCompleted event
797806 // will be posted, which should always come after a corresponding SparkListenerStageSubmitted
798807 // event.
799- listenerBus.post(SparkListenerStageSubmitted (stage.info, properties))
808+ stage.latestInfo = StageInfo .fromStage(stage, Some (partitionsToCompute.size))
809+ listenerBus.post(SparkListenerStageSubmitted (stage.latestInfo, properties))
800810
801811 // TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
802812 // Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
@@ -827,25 +837,22 @@ class DAGScheduler(
827837 return
828838 }
829839
830- if (stage.isShuffleMap) {
831- for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil ) {
832- val locs = getPreferredLocs(stage.rdd, p )
833- val part = stage.rdd.partitions(p )
834- tasks += new ShuffleMapTask (stage.id, taskBinary, part, locs)
840+ val tasks : Seq [ Task [_]] = if (stage.isShuffleMap) {
841+ partitionsToCompute.map { id =>
842+ val locs = getPreferredLocs(stage.rdd, id )
843+ val part = stage.rdd.partitions(id )
844+ new ShuffleMapTask (stage.id, taskBinary, part, locs)
835845 }
836846 } else {
837- // This is a final stage; figure out its job's missing partitions
838847 val job = stage.resultOfJob.get
839- for (id <- 0 until job.numPartitions if ! job.finished(id)) {
848+ partitionsToCompute.map { id =>
840849 val p : Int = job.partitions(id)
841850 val part = stage.rdd.partitions(p)
842851 val locs = getPreferredLocs(stage.rdd, p)
843- tasks += new ResultTask (stage.id, taskBinary, part, locs, id)
852+ new ResultTask (stage.id, taskBinary, part, locs, id)
844853 }
845854 }
846855
847- stage.info = StageInfo .fromStage(stage, Some (tasks.size))
848-
849856 if (tasks.size > 0 ) {
850857 // Preemptively serialize a task to make sure it can be serialized. We are catching this
851858 // exception here because it would be fairly hard to catch the non-serializable exception
@@ -872,11 +879,11 @@ class DAGScheduler(
872879 logDebug(" New pending tasks: " + stage.pendingTasks)
873880 taskScheduler.submitTasks(
874881 new TaskSet (tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
875- stage.info .submissionTime = Some (clock.getTime())
882+ stage.latestInfo .submissionTime = Some (clock.getTime())
876883 } else {
877884 // Because we posted SparkListenerStageSubmitted earlier, we should post
878885 // SparkListenerStageCompleted here in case there are no tasks to run.
879- listenerBus.post(SparkListenerStageCompleted (stage.info ))
886+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo ))
880887 logDebug(" Stage " + stage + " is actually done; %b %d %d" .format(
881888 stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
882889 runningStages -= stage
@@ -890,7 +897,7 @@ class DAGScheduler(
890897 private [scheduler] def handleTaskCompletion (event : CompletionEvent ) {
891898 val task = event.task
892899 val stageId = task.stageId
893- val stageInfo = stageIdToStage(task.stageId).info
900+ val stageInfo = stageIdToStage(task.stageId).latestInfo
894901 val taskType = Utils .getFormattedClassName(task)
895902
896903 // The success case is dealt with separately below, since we need to compute accumulator
@@ -906,14 +913,19 @@ class DAGScheduler(
906913 }
907914 val stage = stageIdToStage(task.stageId)
908915
909- def markStageAsFinished (stage : Stage ) = {
910- val serviceTime = stage.info .submissionTime match {
916+ def markStageAsFinished (stage : Stage , isSuccessful : Boolean ) = {
917+ val serviceTime = stage.latestInfo .submissionTime match {
911918 case Some (t) => " %.03f" .format((clock.getTime() - t) / 1000.0 )
912919 case _ => " Unknown"
913920 }
914- logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
915- stage.info.completionTime = Some (clock.getTime())
916- listenerBus.post(SparkListenerStageCompleted (stage.info))
921+ if (isSuccessful) {
922+ logInfo(" %s (%s) finished in %s s" .format(stage, stage.name, serviceTime))
923+ } else {
924+
925+ logInfo(" %s (%s) failed in %s s" .format(stage, stage.name, serviceTime))
926+ }
927+ stage.latestInfo.completionTime = Some (clock.getTime())
928+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo))
917929 runningStages -= stage
918930 }
919931 event.reason match {
@@ -928,7 +940,7 @@ class DAGScheduler(
928940 val name = acc.name.get
929941 val stringPartialValue = Accumulators .stringifyPartialValue(partialValue)
930942 val stringValue = Accumulators .stringifyValue(acc.value)
931- stage.info .accumulables(id) = AccumulableInfo (id, name, stringValue)
943+ stage.latestInfo .accumulables(id) = AccumulableInfo (id, name, stringValue)
932944 event.taskInfo.accumulables +=
933945 AccumulableInfo (id, name, Some (stringPartialValue), stringValue)
934946 }
@@ -951,7 +963,7 @@ class DAGScheduler(
951963 job.numFinished += 1
952964 // If the whole job has finished, remove it
953965 if (job.numFinished == job.numPartitions) {
954- markStageAsFinished(stage)
966+ markStageAsFinished(stage, isSuccessful = true )
955967 cleanupStateForJobAndIndependentStages(job)
956968 listenerBus.post(SparkListenerJobEnd (job.jobId, JobSucceeded ))
957969 }
@@ -980,7 +992,7 @@ class DAGScheduler(
980992 stage.addOutputLoc(smt.partitionId, status)
981993 }
982994 if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
983- markStageAsFinished(stage)
995+ markStageAsFinished(stage, isSuccessful = true )
984996 logInfo(" looking for newly runnable stages" )
985997 logInfo(" running: " + runningStages)
986998 logInfo(" waiting: " + waitingStages)
@@ -1033,7 +1045,7 @@ class DAGScheduler(
10331045 case FetchFailed (bmAddress, shuffleId, mapId, reduceId) =>
10341046 // Mark the stage that the reducer was in as unrunnable
10351047 val failedStage = stageIdToStage(task.stageId)
1036- listenerBus.post( SparkListenerStageCompleted ( failedStage.info) )
1048+ markStageAsFinished( failedStage, isSuccessful = false )
10371049 runningStages -= failedStage
10381050 // TODO: Cancel running tasks in the stage
10391051 logInfo(" Marking " + failedStage + " (" + failedStage.name +
@@ -1147,7 +1159,7 @@ class DAGScheduler(
11471159 }
11481160 val dependentJobs : Seq [ActiveJob ] =
11491161 activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
1150- failedStage.info .completionTime = Some (clock.getTime())
1162+ failedStage.latestInfo .completionTime = Some (clock.getTime())
11511163 for (job <- dependentJobs) {
11521164 failJobAndIndependentStages(job, s " Job aborted due to stage failure: $reason" )
11531165 }
@@ -1187,8 +1199,8 @@ class DAGScheduler(
11871199 if (runningStages.contains(stage)) {
11881200 try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
11891201 taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1190- stage.info .stageFailed(failureReason)
1191- listenerBus.post(SparkListenerStageCompleted (stage.info ))
1202+ stage.latestInfo .stageFailed(failureReason)
1203+ listenerBus.post(SparkListenerStageCompleted (stage.latestInfo ))
11921204 } catch {
11931205 case e : UnsupportedOperationException =>
11941206 logInfo(s " Could not cancel tasks for stage $stageId" , e)
0 commit comments