Skip to content

Commit 4e5faa2

Browse files
committed
[SPARK-2298] Encode stage attempt in SparkListener & UI.
1 parent 3a5962f commit 4e5faa2

File tree

16 files changed

+141
-102
lines changed

16 files changed

+141
-102
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,7 @@ class DAGScheduler(
164164
*/
165165
def executorHeartbeatReceived(
166166
execId: String,
167-
taskMetrics: Array[(Long, Int, TaskMetrics)], // (taskId, stageId, metrics)
167+
taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics)
168168
blockManagerId: BlockManagerId): Boolean = {
169169
listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics))
170170
implicit val timeout = Timeout(600 seconds)
@@ -677,7 +677,8 @@ class DAGScheduler(
677677
}
678678

679679
private[scheduler] def handleBeginEvent(task: Task[_], taskInfo: TaskInfo) {
680-
listenerBus.post(SparkListenerTaskStart(task.stageId, taskInfo))
680+
val stageInfo = stageIdToStage(task.stageId).info
681+
listenerBus.post(SparkListenerTaskStart(task.stageId, stageInfo.attemptId, taskInfo))
681682
submitWaitingStages()
682683
}
683684

@@ -843,6 +844,8 @@ class DAGScheduler(
843844
}
844845
}
845846

847+
stage.info = StageInfo.fromStage(stage, Some(tasks.size))
848+
846849
if (tasks.size > 0) {
847850
// Preemptively serialize a task to make sure it can be serialized. We are catching this
848851
// exception here because it would be fairly hard to catch the non-serializable exception
@@ -887,13 +890,14 @@ class DAGScheduler(
887890
private[scheduler] def handleTaskCompletion(event: CompletionEvent) {
888891
val task = event.task
889892
val stageId = task.stageId
893+
val stageInfo = stageIdToStage(task.stageId).info
890894
val taskType = Utils.getFormattedClassName(task)
891895

892896
// The success case is dealt with separately below, since we need to compute accumulator
893897
// updates before posting.
894898
if (event.reason != Success) {
895-
listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
896-
event.taskMetrics))
899+
listenerBus.post(SparkListenerTaskEnd(stageId, stageInfo.attemptId, taskType, event.reason,
900+
event.taskInfo, event.taskMetrics))
897901
}
898902

899903
if (!stageIdToStage.contains(task.stageId)) {
@@ -935,8 +939,8 @@ class DAGScheduler(
935939
logError(s"Failed to update accumulators for $task", e)
936940
}
937941
}
938-
listenerBus.post(SparkListenerTaskEnd(stageId, taskType, event.reason, event.taskInfo,
939-
event.taskMetrics))
942+
listenerBus.post(SparkListenerTaskEnd(stageId, stageInfo.attemptId, taskType, event.reason,
943+
event.taskInfo, event.taskMetrics))
940944
stage.pendingTasks -= task
941945
task match {
942946
case rt: ResultTask[_, _] =>
@@ -1029,6 +1033,7 @@ class DAGScheduler(
10291033
case FetchFailed(bmAddress, shuffleId, mapId, reduceId) =>
10301034
// Mark the stage that the reducer was in as unrunnable
10311035
val failedStage = stageIdToStage(task.stageId)
1036+
listenerBus.post(SparkListenerStageCompleted(failedStage.info))
10321037
runningStages -= failedStage
10331038
// TODO: Cancel running tasks in the stage
10341039
logInfo("Marking " + failedStage + " (" + failedStage.name +

core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,14 +39,16 @@ case class SparkListenerStageSubmitted(stageInfo: StageInfo, properties: Propert
3939
case class SparkListenerStageCompleted(stageInfo: StageInfo) extends SparkListenerEvent
4040

4141
@DeveloperApi
42-
case class SparkListenerTaskStart(stageId: Int, taskInfo: TaskInfo) extends SparkListenerEvent
42+
case class SparkListenerTaskStart(stageId: Int, stageAttemptId: Int, taskInfo: TaskInfo)
43+
extends SparkListenerEvent
4344

4445
@DeveloperApi
4546
case class SparkListenerTaskGettingResult(taskInfo: TaskInfo) extends SparkListenerEvent
4647

4748
@DeveloperApi
4849
case class SparkListenerTaskEnd(
4950
stageId: Int,
51+
stageAttemptId: Int,
5052
taskType: String,
5153
reason: TaskEndReason,
5254
taskInfo: TaskInfo,
@@ -75,10 +77,15 @@ case class SparkListenerBlockManagerRemoved(blockManagerId: BlockManagerId)
7577
@DeveloperApi
7678
case class SparkListenerUnpersistRDD(rddId: Int) extends SparkListenerEvent
7779

80+
/**
81+
* Periodic updates from executors.
82+
* @param execId executor id
83+
* @param taskMetrics sequence of (task id, stage id, stage attempt, metrics)
84+
*/
7885
@DeveloperApi
7986
case class SparkListenerExecutorMetricsUpdate(
8087
execId: String,
81-
taskMetrics: Seq[(Long, Int, TaskMetrics)])
88+
taskMetrics: Seq[(Long, Int, Int, TaskMetrics)])
8289
extends SparkListenerEvent
8390

8491
@DeveloperApi

core/src/main/scala/org/apache/spark/scheduler/Stage.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ private[spark] class Stage(
116116
}
117117
}
118118

119+
/** Return a new attempt id, starting with 0. */
119120
def newAttemptId(): Int = {
120121
val id = nextAttemptId
121122
nextAttemptId += 1

core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ import org.apache.spark.storage.RDDInfo
2929
@DeveloperApi
3030
class StageInfo(
3131
val stageId: Int,
32+
val attemptId: Int,
3233
val name: String,
3334
val numTasks: Int,
3435
val rddInfos: Seq[RDDInfo],
@@ -56,9 +57,15 @@ private[spark] object StageInfo {
5657
* shuffle dependencies. Therefore, all ancestor RDDs related to this Stage's RDD through a
5758
* sequence of narrow dependencies should also be associated with this Stage.
5859
*/
59-
def fromStage(stage: Stage): StageInfo = {
60+
def fromStage(stage: Stage, numTasks: Option[Int] = None): StageInfo = {
6061
val ancestorRddInfos = stage.rdd.getNarrowAncestors.map(RDDInfo.fromRdd)
6162
val rddInfos = Seq(RDDInfo.fromRdd(stage.rdd)) ++ ancestorRddInfos
62-
new StageInfo(stage.id, stage.name, stage.numTasks, rddInfos, stage.details)
63+
new StageInfo(
64+
stage.id,
65+
stage.attemptId,
66+
stage.name,
67+
numTasks.getOrElse(stage.numTasks),
68+
rddInfos,
69+
stage.details)
6370
}
6471
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -333,12 +333,12 @@ private[spark] class TaskSchedulerImpl(
333333
execId: String,
334334
taskMetrics: Array[(Long, TaskMetrics)], // taskId -> TaskMetrics
335335
blockManagerId: BlockManagerId): Boolean = {
336-
val metricsWithStageIds = taskMetrics.flatMap {
337-
case (id, metrics) => {
336+
337+
val metricsWithStageIds: Array[(Long, Int, Int, TaskMetrics)] = synchronized {
338+
taskMetrics.flatMap { case (id, metrics) =>
338339
taskIdToTaskSetId.get(id)
339340
.flatMap(activeTaskSets.get)
340-
.map(_.stageId)
341-
.map(x => (id, x, metrics))
341+
.map(taskSetMgr => (id, taskSetMgr.stageId, taskSetMgr.taskSet.attempt, metrics))
342342
}
343343
}
344344
dagScheduler.executorHeartbeatReceived(execId, metricsWithStageIds, blockManagerId)

core/src/main/scala/org/apache/spark/scheduler/TaskSet.scala

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,5 @@ private[spark] class TaskSet(
3131
val properties: Properties) {
3232
val id: String = stageId + "." + attempt
3333

34-
def kill(interruptThread: Boolean) {
35-
tasks.foreach(_.kill(interruptThread))
36-
}
37-
3834
override def toString: String = "TaskSet " + id
3935
}

core/src/main/scala/org/apache/spark/ui/jobs/ExecutorTable.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import org.apache.spark.ui.{ToolTips, UIUtils}
2424
import org.apache.spark.ui.jobs.UIData.StageUIData
2525
import org.apache.spark.util.Utils
2626

27-
/** Page showing executor summary */
28-
private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) {
27+
/** Stage summary grouped by executors. */
28+
private[ui] class ExecutorTable(stageId: Int, stageAttemptId: Int, parent: JobProgressTab) {
2929
private val listener = parent.listener
3030

3131
def toNodeSeq: Seq[Node] = {
@@ -65,7 +65,7 @@ private[ui] class ExecutorTable(stageId: Int, parent: JobProgressTab) {
6565
executorIdToAddress.put(executorId, address)
6666
}
6767

68-
listener.stageIdToData.get(stageId) match {
68+
listener.stageIdToData.get((stageId, stageAttemptId)) match {
6969
case Some(stageData: StageUIData) =>
7070
stageData.executorSummary.toSeq.sortBy(_._1).map { case (k, v) =>
7171
<tr>

core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala

Lines changed: 21 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,16 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
4343
// How many stages to remember
4444
val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
4545

46-
val activeStages = HashMap[Int, StageInfo]()
46+
// Map from stageId to StageInfo
47+
val activeStages = new HashMap[Int, StageInfo]
48+
49+
// Map from (stageId, attemptId) to StageInfo
50+
val stageIdToData = new HashMap[(Int, Int), StageUIData]
51+
4752
val completedStages = ListBuffer[StageInfo]()
4853
val failedStages = ListBuffer[StageInfo]()
4954

50-
val stageIdToData = new HashMap[Int, StageUIData]
51-
52-
val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
55+
val poolToActiveStages = HashMap[String, HashMap[(Int, Int), StageInfo]]()
5356

5457
val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
5558

@@ -59,18 +62,19 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
5962

6063
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
6164
val stage = stageCompleted.stageInfo
62-
val stageId = stage.stageId
63-
val stageData = stageIdToData.getOrElseUpdate(stageId, {
64-
logWarning("Stage completed for unknown stage " + stageId)
65+
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), {
66+
logWarning("Stage completed for unknown stage " + stage.stageId)
6567
new StageUIData
6668
})
6769

6870
for ((id, info) <- stageCompleted.stageInfo.accumulables) {
6971
stageData.accumulables(id) = info
7072
}
7173

72-
poolToActiveStages.get(stageData.schedulingPool).foreach(_.remove(stageId))
73-
activeStages.remove(stageId)
74+
poolToActiveStages.get(stageData.schedulingPool).foreach { hashMap =>
75+
hashMap.remove((stage.stageId, stage.attemptId))
76+
}
77+
activeStages.remove(stage.stageId)
7478
if (stage.failureReason.isEmpty) {
7579
completedStages += stage
7680
trimIfNecessary(completedStages)
@@ -84,7 +88,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
8488
private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
8589
if (stages.size > retainedStages) {
8690
val toRemove = math.max(retainedStages / 10, 1)
87-
stages.take(toRemove).foreach { s => stageIdToData.remove(s.stageId) }
91+
stages.take(toRemove).foreach { s => stageIdToData.remove((s.stageId, s.attemptId)) }
8892
stages.trimStart(toRemove)
8993
}
9094
}
@@ -98,21 +102,21 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
98102
p => p.getProperty("spark.scheduler.pool", DEFAULT_POOL_NAME)
99103
}.getOrElse(DEFAULT_POOL_NAME)
100104

101-
val stageData = stageIdToData.getOrElseUpdate(stage.stageId, new StageUIData)
105+
val stageData = stageIdToData.getOrElseUpdate((stage.stageId, stage.attemptId), new StageUIData)
102106
stageData.schedulingPool = poolName
103107

104108
stageData.description = Option(stageSubmitted.properties).flatMap {
105109
p => Option(p.getProperty(SparkContext.SPARK_JOB_DESCRIPTION))
106110
}
107111

108-
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo]())
109-
stages(stage.stageId) = stage
112+
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[(Int, Int), StageInfo]())
113+
stages((stage.stageId, stage.attemptId)) = stage
110114
}
111115

112116
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
113117
val taskInfo = taskStart.taskInfo
114118
if (taskInfo != null) {
115-
val stageData = stageIdToData.getOrElseUpdate(taskStart.stageId, {
119+
val stageData = stageIdToData.getOrElseUpdate((taskStart.stageId, taskStart.stageAttemptId), {
116120
logWarning("Task start for unknown stage " + taskStart.stageId)
117121
new StageUIData
118122
})
@@ -129,7 +133,7 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
129133
override def onTaskEnd(taskEnd: SparkListenerTaskEnd) = synchronized {
130134
val info = taskEnd.taskInfo
131135
if (info != null) {
132-
val stageData = stageIdToData.getOrElseUpdate(taskEnd.stageId, {
136+
val stageData = stageIdToData.getOrElseUpdate((taskEnd.stageId, taskEnd.stageAttemptId), {
133137
logWarning("Task end for unknown stage " + taskEnd.stageId)
134138
new StageUIData
135139
})
@@ -222,8 +226,8 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
222226
}
223227

224228
override def onExecutorMetricsUpdate(executorMetricsUpdate: SparkListenerExecutorMetricsUpdate) {
225-
for ((taskId, sid, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
226-
val stageData = stageIdToData.getOrElseUpdate(sid, {
229+
for ((taskId, sid, sAttempt, taskMetrics) <- executorMetricsUpdate.taskMetrics) {
230+
val stageData = stageIdToData.getOrElseUpdate((sid, sAttempt), {
227231
logWarning("Metrics update for task in unknown stage " + sid)
228232
new StageUIData
229233
})

core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
3434
}
3535

3636
private def poolTable(
37-
makeRow: (Schedulable, HashMap[String, HashMap[Int, StageInfo]]) => Seq[Node],
37+
makeRow: (Schedulable, HashMap[String, HashMap[(Int, Int), StageInfo]]) => Seq[Node],
3838
rows: Seq[Schedulable]): Seq[Node] = {
3939
<table class="table table-bordered table-striped table-condensed sortable table-fixed">
4040
<thead>
@@ -53,7 +53,7 @@ private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
5353

5454
private def poolRow(
5555
p: Schedulable,
56-
poolToActiveStages: HashMap[String, HashMap[Int, StageInfo]]): Seq[Node] = {
56+
poolToActiveStages: HashMap[String, HashMap[(Int, Int), StageInfo]]): Seq[Node] = {
5757
val activeStages = poolToActiveStages.get(p.name) match {
5858
case Some(stages) => stages.size
5959
case None => 0

core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
3434
def render(request: HttpServletRequest): Seq[Node] = {
3535
listener.synchronized {
3636
val stageId = request.getParameter("id").toInt
37-
val stageDataOption = listener.stageIdToData.get(stageId)
37+
val stageAttemptId = request.getParameter("attempt").toInt
38+
val stageDataOption = listener.stageIdToData.get((stageId, stageAttemptId))
3839

3940
if (stageDataOption.isEmpty || stageDataOption.get.taskData.isEmpty) {
4041
val content =
@@ -49,7 +50,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
4950
val tasks = stageData.taskData.values.toSeq.sortBy(_.taskInfo.launchTime)
5051

5152
val numCompleted = tasks.count(_.taskInfo.finished)
52-
val accumulables = listener.stageIdToData(stageId).accumulables
53+
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
5354
val hasInput = stageData.inputBytes > 0
5455
val hasShuffleRead = stageData.shuffleReadBytes > 0
5556
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
@@ -211,7 +212,8 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
211212
def quantileRow(data: Seq[Node]): Seq[Node] = <tr>{data}</tr>
212213
Some(UIUtils.listingTable(quantileHeaders, quantileRow, listings, fixedWidth = true))
213214
}
214-
val executorTable = new ExecutorTable(stageId, parent)
215+
216+
val executorTable = new ExecutorTable(stageId, stageAttemptId, parent)
215217

216218
val maybeAccumulableTable: Seq[Node] =
217219
if (accumulables.size > 0) { <h4>Accumulators</h4> ++ accumulableTable } else Seq()

0 commit comments

Comments
 (0)