Skip to content

Commit 6c08b07

Browse files
committed
Addressed code review feedback.
1 parent 4e5faa2 commit 6c08b07

File tree

5 files changed

+57
-40
lines changed

5 files changed

+57
-40
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 42 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -677,7 +677,7 @@ class DAGScheduler(
677677
}
678678

679679
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
680-
val stageInfo = stageIdToStage(task.stageId).info
680+
val stageInfo = stageIdToStage(task.stageId).latestInfo
681681
listenerBus.post(SparkListenerTaskStart(task.stageId, stageInfo.attemptId, taskInfo))
682682
submitWaitingStages()
683683
}
@@ -696,8 +696,8 @@ class DAGScheduler(
696696
// is in the process of getting stopped.
697697
val stageFailedMessage = "Stage cancelled because SparkContext was shut down"
698698
runningStages.foreach { stage =>
699-
stage.info.stageFailed(stageFailedMessage)
700-
listenerBus.post(SparkListenerStageCompleted(stage.info))
699+
stage.latestInfo.stageFailed(stageFailedMessage)
700+
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
701701
}
702702
listenerBus.post(SparkListenerJobEnd(job.jobId, JobFailed(error)))
703703
}
@@ -782,7 +782,16 @@ class DAGScheduler(
782782
logDebug("submitMissingTasks(" + stage + ")")
783783
// Get our pending tasks and remember them in our pendingTasks entry
784784
stage.pendingTasks.clear()
785-
var tasks = ArrayBuffer[Task[_]]()
785+
786+
// First figure out the indexes of partition ids to compute.
787+
val partitionsToCompute: Seq[Int] = {
788+
if (stage.isShuffleMap) {
789+
(0 until stage.numPartitions).filter(id => stage.outputLocs(id) == Nil)
790+
} else {
791+
val job = stage.resultOfJob.get
792+
(0 until job.numPartitions).filter(id => !job.finished(id))
793+
}
794+
}
786795

787796
val properties = if (jobIdToActiveJob.contains(jobId)) {
788797
jobIdToActiveJob(stage.jobId).properties
@@ -796,7 +805,8 @@ class DAGScheduler(
796805
// serializable. If tasks are not serializable, a SparkListenerStageCompleted event
797806
// will be posted, which should always come after a corresponding SparkListenerStageSubmitted
798807
// event.
799-
listenerBus.post(SparkListenerStageSubmitted(stage.info, properties))
808+
stage.latestInfo = StageInfo.fromStage(stage, Some(partitionsToCompute.size))
809+
listenerBus.post(SparkListenerStageSubmitted(stage.latestInfo, properties))
800810

801811
// TODO: Maybe we can keep the taskBinary in Stage to avoid serializing it multiple times.
802812
// Broadcasted binary for the task, used to dispatch tasks to executors. Note that we broadcast
@@ -827,25 +837,22 @@ class DAGScheduler(
827837
return
828838
}
829839

830-
if (stage.isShuffleMap) {
831-
for (p <- 0 until stage.numPartitions if stage.outputLocs(p) == Nil) {
832-
val locs = getPreferredLocs(stage.rdd, p)
833-
val part = stage.rdd.partitions(p)
834-
tasks += new ShuffleMapTask(stage.id, taskBinary, part, locs)
840+
val tasks: Seq[Task[_]] = if (stage.isShuffleMap) {
841+
partitionsToCompute.map { id =>
842+
val locs = getPreferredLocs(stage.rdd, id)
843+
val part = stage.rdd.partitions(id)
844+
new ShuffleMapTask(stage.id, taskBinary, part, locs)
835845
}
836846
} else {
837-
// This is a final stage; figure out its job's missing partitions
838847
val job = stage.resultOfJob.get
839-
for (id <- 0 until job.numPartitions if !job.finished(id)) {
848+
partitionsToCompute.map { id =>
840849
val p: Int = job.partitions(id)
841850
val part = stage.rdd.partitions(p)
842851
val locs = getPreferredLocs(stage.rdd, p)
843-
tasks += new ResultTask(stage.id, taskBinary, part, locs, id)
852+
new ResultTask(stage.id, taskBinary, part, locs, id)
844853
}
845854
}
846855

847-
stage.info = StageInfo.fromStage(stage, Some(tasks.size))
848-
849856
if (tasks.size > 0) {
850857
// Preemptively serialize a task to make sure it can be serialized. We are catching this
851858
// exception here because it would be fairly hard to catch the non-serializable exception
@@ -872,11 +879,11 @@ class DAGScheduler(
872879
logDebug("New pending tasks: " + stage.pendingTasks)
873880
taskScheduler.submitTasks(
874881
new TaskSet(tasks.toArray, stage.id, stage.newAttemptId(), stage.jobId, properties))
875-
stage.info.submissionTime = Some(clock.getTime())
882+
stage.latestInfo.submissionTime = Some(clock.getTime())
876883
} else {
877884
// Because we posted SparkListenerStageSubmitted earlier, we should post
878885
// SparkListenerStageCompleted here in case there are no tasks to run.
879-
listenerBus.post(SparkListenerStageCompleted(stage.info))
886+
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
880887
logDebug("Stage " + stage + " is actually done; %b %d %d".format(
881888
stage.isAvailable, stage.numAvailableOutputs, stage.numPartitions))
882889
runningStages -= stage
@@ -890,7 +897,7 @@ class DAGScheduler(
890897
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
891898
val task = event.task
892899
val stageId = task.stageId
893-
val stageInfo = stageIdToStage(task.stageId).info
900+
val stageInfo = stageIdToStage(task.stageId).latestInfo
894901
val taskType = Utils.getFormattedClassName(task)
895902

896903
// The success case is dealt with separately below, since we need to compute accumulator
@@ -906,14 +913,19 @@ class DAGScheduler(
906913
}
907914
val stage = stageIdToStage(task.stageId)
908915

909-
def markStageAsFinished(stage: Stage) = {
910-
val serviceTime = stage.info.submissionTime match {
916+
def markStageAsFinished(stage: Stage, isSuccessful: Boolean) = {
917+
val serviceTime = stage.latestInfo.submissionTime match {
911918
case Some(t) => "%.03f".format((clock.getTime() - t) / 1000.0)
912919
case _ => "Unknown"
913920
}
914-
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
915-
stage.info.completionTime = Some(clock.getTime())
916-
listenerBus.post(SparkListenerStageCompleted(stage.info))
921+
if (isSuccessful) {
922+
logInfo("%s (%s) finished in %s s".format(stage, stage.name, serviceTime))
923+
} else {
924+
925+
logInfo("%s (%s) failed in %s s".format(stage, stage.name, serviceTime))
926+
}
927+
stage.latestInfo.completionTime = Some(clock.getTime())
928+
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
917929
runningStages -= stage
918930
}
919931
event.reason match {
@@ -928,7 +940,7 @@ class DAGScheduler(
928940
val name = acc.name.get
929941
val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
930942
val stringValue = Accumulators.stringifyValue(acc.value)
931-
stage.info.accumulables(id) = AccumulableInfo(id, name, stringValue)
943+
stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
932944
event.taskInfo.accumulables +=
933945
AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
934946
}
@@ -951,7 +963,7 @@ class DAGScheduler(
951963
job.numFinished += 1
952964
// If the whole job has finished, remove it
953965
if (job.numFinished == job.numPartitions) {
954-
markStageAsFinished(stage)
966+
markStageAsFinished(stage, isSuccessful = true)
955967
cleanupStateForJobAndIndependentStages(job)
956968
listenerBus.post(SparkListenerJobEnd(job.jobId, JobSucceeded))
957969
}
@@ -980,7 +992,7 @@ class DAGScheduler(
980992
stage.addOutputLoc(smt.partitionId, status)
981993
}
982994
if (runningStages.contains(stage) && stage.pendingTasks.isEmpty) {
983-
markStageAsFinished(stage)
995+
markStageAsFinished(stage, isSuccessful = true)
984996
logInfo("looking for newly runnable stages")
985997
logInfo("running: " + runningStages)
986998
logInfo("waiting: " + waitingStages)
@@ -1033,7 +1045,7 @@ class DAGScheduler(
10331045
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
10341046
// Mark the stage that the reducer was in as unrunnable
10351047
val failedStage = stageIdToStage(task.stageId)
1036-
listenerBus.post(SparkListenerStageCompleted(failedStage.info))
1048+
markStageAsFinished(failedStage, isSuccessful = false)
10371049
runningStages -= failedStage
10381050
// TODO: Cancel running tasks in the stage
10391051
logInfo("Marking " + failedStage + " (" + failedStage.name +
@@ -1147,7 +1159,7 @@ class DAGScheduler(
11471159
}
11481160
val dependentJobs: Seq[ActiveJob] =
11491161
activeJobs.filter(job => stageDependsOn(job.finalStage, failedStage)).toSeq
1150-
failedStage.info.completionTime = Some(clock.getTime())
1162+
failedStage.latestInfo.completionTime = Some(clock.getTime())
11511163
for (job <- dependentJobs) {
11521164
failJobAndIndependentStages(job, s"Job aborted due to stage failure: $reason")
11531165
}
@@ -1187,8 +1199,8 @@ class DAGScheduler(
11871199
if (runningStages.contains(stage)) {
11881200
try { // cancelTasks will fail if a SchedulerBackend does not implement killTask
11891201
taskScheduler.cancelTasks(stageId, shouldInterruptThread)
1190-
stage.info.stageFailed(failureReason)
1191-
listenerBus.post(SparkListenerStageCompleted(stage.info))
1202+
stage.latestInfo.stageFailed(failureReason)
1203+
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
11921204
} catch {
11931205
case e: UnsupportedOperationException =>
11941206
logInfo(s"Could not cancel tasks for stage $stageId", e)

core/src/main/scala/org/apache/spark/scheduler/Stage.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ import org.apache.spark.util.CallSite
4343
* stage, the callSite gives the user code that created the RDD being shuffled. For a result
4444
* stage, the callSite gives the user code that executes the associated action (e.g. count()).
4545
*
46+
* A single stage can consist of multiple attempts. In that case, the latestInfo field will
47+
* be updated for each attempt.
48+
*
4649
*/
4750
private[spark] class Stage(
4851
val id: Int,
@@ -71,8 +74,8 @@ private[spark] class Stage(
7174
val name = callSite.shortForm
7275
val details = callSite.longForm
7376

74-
/** Pointer to the [StageInfo] object, set by DAGScheduler. */
75-
var info: StageInfo = StageInfo.fromStage(this)
77+
/** Pointer to the latest [StageInfo] object, set by DAGScheduler. */
78+
var latestInfo: StageInfo = StageInfo.fromStage(this)
7679

7780
def isAvailable: Boolean = {
7881
if (!isShuffleMap) {

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,13 +46,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
4646
// Map from stageId to StageInfo
4747
val activeStages = new HashMap[Int, StageInfo]
4848

49-
// Map from (stageId, attemptId) to StageInfo
49+
// Map from (stageId, attemptId) to StageUIData
5050
val stageIdToData = new HashMap[(Int, Int), StageUIData]
5151

5252
val completedStages = ListBuffer[StageInfo]()
5353
val failedStages = ListBuffer[StageInfo]()
5454

55-
val poolToActiveStages = HashMap[String, HashMap[(Int, Int), StageInfo]]()
55+
// Map from pool name to a hash map (map from stage id to StageInfo).
56+
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
5657

5758
val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
5859

@@ -72,7 +73,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
7273
}
7374

7475
poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap =>
75-
hashMap.remove((stage.stageId, stage.attemptId))
76+
hashMap.remove(stage.stageId)
7677
}
7778
activeStages.remove(stage.stageId)
7879
if (stage.failureReason.isEmpty) {
@@ -109,8 +110,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
109110
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
110111
}
111112

112-
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[(Int, Int), StageInfo]())
113-
stages((stage.stageId, stage.attemptId)) = stage
113+
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
114+
stages(stage.stageId) = stage
114115
}
115116

116117
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {

core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
3434
}
3535

3636
private def poolTable(
37-
makeRow: (Schedulable, HashMap[String, HashMap[(Int, Int), StageInfo]]) => Seq[Node],
37+
makeRow: (Schedulable, HashMap[String, HashMap[Int, StageInfo]]) => Seq[Node],
3838
rows: Seq[Schedulable]): Seq[Node] = {
3939
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
4040
<thead>
@@ -53,7 +53,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
5353

5454
private def poolRow(
5555
p: Schedulable,
56-
poolToActiveStages: HashMap[String, HashMap[(Int, Int), StageInfo]]): Seq[Node] = {
56+
poolToActiveStages: HashMap[String, HashMap[Int, StageInfo]]): Seq[Node] = {
5757
val activeStages = poolToActiveStages.get(p.name) match {
5858
case Some(stages) => stages.size
5959
case None => 0

core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
4343
<h4>Summary Metrics</h4> No tasks have started yet
4444
<h4>Tasks</h4> No tasks have started yet
4545
</div>
46-
return UIUtils.headerSparkPage("Details for Stage %s".format(stageId), content, parent)
46+
return UIUtils.headerSparkPage(
47+
s"Details for Stage $stageId (Attempt $stageAttemptId)", content, parent)
4748
}
4849

4950
val stageData = stageDataOption.get

0 commit comments

Comments
 (0)