Skip to content

Commit b928cd4

Browse files
author
Andrew Or
committed
Fix potential leak + write tests for it
Previously there was an opportunity to leak in `completedStageIds`. This commit fixes that and includes tests that ensures that it is fixed.
1 parent 7c4c364 commit b928cd4

File tree

2 files changed

+216
-57
lines changed

2 files changed

+216
-57
lines changed

core/src/main/scala/org/apache/spark/ui/scope/RDDOperationGraphListener.scala

Lines changed: 30 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,46 +82,60 @@ private[ui] class RDDOperationGraphListener(conf: SparkConf) extends SparkListen
8282
stageIds += stageId
8383
stageIdToJobId(stageId) = jobId
8484
stageIdToGraph(stageId) = RDDOperationGraph.makeOperationGraph(stageInfo)
85-
86-
// Remove state for old stages if necessary
87-
if (stageIds.size >= retainedStages) {
88-
val toRemove = math.max(retainedStages / 10, 1)
89-
stageIds.take(toRemove).foreach { id => cleanStage(id) }
90-
stageIds.trimStart(toRemove)
91-
}
85+
trimStagesIfNecessary()
9286
}
9387

94-
// Remove state for old jobs if necessary
95-
if (jobIds.size >= retainedJobs) {
96-
val toRemove = math.max(retainedJobs / 10, 1)
97-
jobIds.take(toRemove).foreach { id => cleanJob(id) }
98-
jobIds.trimStart(toRemove)
99-
}
88+
trimJobsIfNecessary()
10089
}
10190

10291
/** Keep track of stages that have completed. */
10392
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = synchronized {
104-
completedStageIds += stageCompleted.stageInfo.stageId
93+
val stageId = stageCompleted.stageInfo.stageId
94+
if (stageIdToJobId.contains(stageId)) {
95+
// Note: Only do this if the stage has not already been cleaned up
96+
// Otherwise, we may never clean this stage from `completedStageIds`
97+
completedStageIds += stageCompleted.stageInfo.stageId
98+
}
10599
}
106100

107101
/** On job end, find all stages in this job that are skipped and mark them as such. */
108102
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = synchronized {
109103
val jobId = jobEnd.jobId
110104
jobIdToStageIds.get(jobId).foreach { stageIds =>
111105
val skippedStageIds = stageIds.filter { sid => !completedStageIds.contains(sid) }
106+
// Note: Only do this if the job has not already been cleaned up
107+
// Otherwise, we may never clean this job from `jobIdToSkippedStageIds`
112108
jobIdToSkippedStageIds(jobId) = skippedStageIds
113109
}
114110
}
115111

112+
/** Clean metadata for old stages if we have exceeded the number to retain. */
113+
private def trimStagesIfNecessary(): Unit = {
114+
if (stageIds.size >= retainedStages) {
115+
val toRemove = math.max(retainedStages / 10, 1)
116+
stageIds.take(toRemove).foreach { id => cleanStage(id) }
117+
stageIds.trimStart(toRemove)
118+
}
119+
}
120+
121+
/** Clean metadata for old jobs if we have exceeded the number to retain. */
122+
private def trimJobsIfNecessary(): Unit = {
123+
if (jobIds.size >= retainedJobs) {
124+
val toRemove = math.max(retainedJobs / 10, 1)
125+
jobIds.take(toRemove).foreach { id => cleanJob(id) }
126+
jobIds.trimStart(toRemove)
127+
}
128+
}
129+
116130
/** Clean metadata for the given stage, its job, and all other stages that belong to the job. */
117-
private def cleanStage(stageId: Int): Unit = {
131+
private[ui] def cleanStage(stageId: Int): Unit = {
118132
completedStageIds.remove(stageId)
119133
stageIdToGraph.remove(stageId)
120134
stageIdToJobId.remove(stageId).foreach { jobId => cleanJob(jobId) }
121135
}
122136

123137
/** Clean metadata for the given job and all stages that belong to it. */
124-
private def cleanJob(jobId: Int): Unit = {
138+
private[ui] def cleanJob(jobId: Int): Unit = {
125139
jobIdToSkippedStageIds.remove(jobId)
126140
jobIdToStageIds.remove(jobId).foreach { stageIds =>
127141
stageIds.foreach { stageId => cleanStage(stageId) }

core/src/test/scala/org/apache/spark/ui/scope/RDDOperationGraphListenerSuite.scala

Lines changed: 186 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -20,67 +20,212 @@ package org.apache.spark.ui.scope
2020
import org.scalatest.FunSuite
2121

2222
import org.apache.spark.SparkConf
23-
import org.apache.spark.scheduler.{SparkListenerJobStart, SparkListenerStageSubmitted, StageInfo}
23+
import org.apache.spark.scheduler._
24+
import org.apache.spark.scheduler.SparkListenerStageSubmitted
25+
import org.apache.spark.scheduler.SparkListenerStageCompleted
26+
import org.apache.spark.scheduler.SparkListenerJobStart
2427

28+
/**
29+
* Tests that this listener populates and cleans up its data structures properly.
30+
*/
2531
class RDDOperationGraphListenerSuite extends FunSuite {
2632
private var jobIdCounter = 0
2733
private var stageIdCounter = 0
34+
private val maxRetainedJobs = 10
35+
private val maxRetainedStages = 10
36+
private val conf = new SparkConf()
37+
.set("spark.ui.retainedJobs", maxRetainedJobs.toString)
38+
.set("spark.ui.retainedStages", maxRetainedStages.toString)
2839

29-
/** Run a job with the specified number of stages. */
30-
private def runOneJob(numStages: Int, listener: RDDOperationGraphListener): Unit = {
31-
assert(numStages > 0, "I will not run a job with 0 stages for you.")
32-
val stageInfos = (0 until numStages).map { _ =>
33-
val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
34-
stageIdCounter += 1
35-
stageInfo
36-
}
37-
listener.onJobStart(new SparkListenerJobStart(jobIdCounter, 0, stageInfos))
38-
jobIdCounter += 1
39-
}
40-
41-
test("listener cleans up metadata") {
42-
43-
val conf = new SparkConf()
44-
.set("spark.ui.retainedStages", "10")
45-
.set("spark.ui.retainedJobs", "10")
46-
40+
test("run normal jobs") {
41+
val startingJobId = jobIdCounter
42+
val startingStageId = stageIdCounter
4743
val listener = new RDDOperationGraphListener(conf)
4844
assert(listener.jobIdToStageIds.isEmpty)
45+
assert(listener.jobIdToSkippedStageIds.isEmpty)
46+
assert(listener.stageIdToJobId.isEmpty)
4947
assert(listener.stageIdToGraph.isEmpty)
48+
assert(listener.completedStageIds.isEmpty)
5049
assert(listener.jobIds.isEmpty)
5150
assert(listener.stageIds.isEmpty)
5251

5352
// Run a few jobs, but not enough for clean up yet
54-
runOneJob(1, listener)
55-
runOneJob(2, listener)
56-
runOneJob(3, listener)
53+
(1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages
54+
(0 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish all 6 stages
55+
(0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs
56+
5757
assert(listener.jobIdToStageIds.size === 3)
58+
assert(listener.jobIdToStageIds(startingJobId).size === 1)
59+
assert(listener.jobIdToStageIds(startingJobId + 1).size === 2)
60+
assert(listener.jobIdToStageIds(startingJobId + 2).size === 3)
61+
assert(listener.jobIdToSkippedStageIds.size === 3)
62+
assert(listener.jobIdToSkippedStageIds.values.forall(_.isEmpty)) // no skipped stages
63+
assert(listener.stageIdToJobId.size === 6)
64+
assert(listener.stageIdToJobId(startingStageId) === startingJobId)
65+
assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1)
66+
assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1)
67+
assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2)
68+
assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2)
69+
assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2)
5870
assert(listener.stageIdToGraph.size === 6)
71+
assert(listener.completedStageIds.size === 6)
5972
assert(listener.jobIds.size === 3)
6073
assert(listener.stageIds.size === 6)
74+
}
75+
76+
test("run jobs with skipped stages") {
77+
val startingJobId = jobIdCounter
78+
val startingStageId = stageIdCounter
79+
val listener = new RDDOperationGraphListener(conf)
80+
81+
// Run a few jobs, but not enough for clean up yet
82+
// Leave some stages unfinished so that they are marked as skipped
83+
(1 to 3).foreach { numStages => startJob(numStages, listener) } // start 3 jobs and 6 stages
84+
(4 to 5).foreach { i => endStage(startingStageId + i, listener) } // finish only last 2 stages
85+
(0 to 2).foreach { i => endJob(startingJobId + i, listener) } // finish all 3 jobs
86+
87+
assert(listener.jobIdToSkippedStageIds.size === 3)
88+
assert(listener.jobIdToSkippedStageIds(startingJobId).size === 1)
89+
assert(listener.jobIdToSkippedStageIds(startingJobId + 1).size === 2)
90+
assert(listener.jobIdToSkippedStageIds(startingJobId + 2).size === 1) // 2 stages not skipped
91+
assert(listener.completedStageIds.size === 2)
92+
93+
// The rest should be the same as before
94+
assert(listener.jobIdToStageIds.size === 3)
95+
assert(listener.jobIdToStageIds(startingJobId).size === 1)
96+
assert(listener.jobIdToStageIds(startingJobId + 1).size === 2)
97+
assert(listener.jobIdToStageIds(startingJobId + 2).size === 3)
98+
assert(listener.stageIdToJobId.size === 6)
99+
assert(listener.stageIdToJobId(startingStageId) === startingJobId)
100+
assert(listener.stageIdToJobId(startingStageId + 1) === startingJobId + 1)
101+
assert(listener.stageIdToJobId(startingStageId + 2) === startingJobId + 1)
102+
assert(listener.stageIdToJobId(startingStageId + 3) === startingJobId + 2)
103+
assert(listener.stageIdToJobId(startingStageId + 4) === startingJobId + 2)
104+
assert(listener.stageIdToJobId(startingStageId + 5) === startingJobId + 2)
105+
assert(listener.stageIdToGraph.size === 6)
106+
assert(listener.jobIds.size === 3)
107+
assert(listener.stageIds.size === 6)
108+
}
109+
110+
test("clean up metadata") {
111+
val startingJobId = jobIdCounter
112+
val startingStageId = stageIdCounter
113+
val listener = new RDDOperationGraphListener(conf)
61114

62-
// Run a few more, but this time the stages should be cleaned up, but not the jobs
63-
runOneJob(5, listener)
64-
runOneJob(100, listener)
65-
assert(listener.jobIdToStageIds.size === 5)
66-
assert(listener.stageIdToGraph.size === 9)
67-
assert(listener.jobIds.size === 5)
68-
assert(listener.stageIds.size === 9)
69-
70-
// Run a few more, but this time both jobs and stages should be cleaned up
71-
(1 to 100).foreach { _ =>
72-
runOneJob(1, listener)
115+
// Run many jobs and stages to trigger clean up
116+
(1 to 10000).foreach { i =>
117+
// Note: this must be less than `maxRetainedStages`
118+
val numStages = i % (maxRetainedStages - 2) + 1
119+
val startingStageIdForJob = stageIdCounter
120+
val jobId = startJob(numStages, listener)
121+
// End some, but not all, stages that belong to this job
122+
// This is to ensure that we have both completed and skipped stages
123+
(startingStageIdForJob until stageIdCounter)
124+
.filter { i => i % 2 == 0 }
125+
.foreach { i => endStage(i, listener) }
126+
// End all jobs
127+
endJob(jobId, listener)
73128
}
74-
assert(listener.jobIdToStageIds.size === 9)
75-
assert(listener.stageIdToGraph.size === 9)
76-
assert(listener.jobIds.size === 9)
77-
assert(listener.stageIds.size === 9)
129+
130+
// Ensure we never exceed the max retained thresholds
131+
assert(listener.jobIdToStageIds.size <= maxRetainedJobs)
132+
assert(listener.jobIdToSkippedStageIds.size <= maxRetainedJobs)
133+
assert(listener.stageIdToJobId.size <= maxRetainedStages)
134+
assert(listener.stageIdToGraph.size <= maxRetainedStages)
135+
assert(listener.completedStageIds.size <= maxRetainedStages)
136+
assert(listener.jobIds.size <= maxRetainedJobs)
137+
assert(listener.stageIds.size <= maxRetainedStages)
138+
139+
// Also ensure we're actually populating these data structures
140+
// Otherwise the previous group of asserts will be meaningless
141+
assert(listener.jobIdToStageIds.nonEmpty)
142+
assert(listener.jobIdToSkippedStageIds.nonEmpty)
143+
assert(listener.stageIdToJobId.nonEmpty)
144+
assert(listener.stageIdToGraph.nonEmpty)
145+
assert(listener.completedStageIds.nonEmpty)
146+
assert(listener.jobIds.nonEmpty)
147+
assert(listener.stageIds.nonEmpty)
78148

79149
// Ensure we clean up old jobs and stages, not arbitrary ones
80-
assert(!listener.jobIdToStageIds.contains(0))
81-
assert(!listener.stageIdToGraph.contains(0))
82-
assert(!listener.stageIds.contains(0))
83-
assert(!listener.jobIds.contains(0))
150+
assert(!listener.jobIdToStageIds.contains(startingJobId))
151+
assert(!listener.jobIdToSkippedStageIds.contains(startingJobId))
152+
assert(!listener.stageIdToJobId.contains(startingStageId))
153+
assert(!listener.stageIdToGraph.contains(startingStageId))
154+
assert(!listener.completedStageIds.contains(startingStageId))
155+
assert(!listener.stageIds.contains(startingStageId))
156+
assert(!listener.jobIds.contains(startingJobId))
157+
}
158+
159+
test("fate sharing between jobs and stages") {
160+
val startingJobId = jobIdCounter
161+
val startingStageId = stageIdCounter
162+
val listener = new RDDOperationGraphListener(conf)
163+
164+
// Run 3 jobs and 8 stages, finishing all 3 jobs but only 2 stages
165+
startJob(5, listener)
166+
startJob(1, listener)
167+
startJob(2, listener)
168+
(0 until 8).foreach { i => startStage(i + startingStageId, listener) }
169+
endStage(startingStageId + 3, listener)
170+
endStage(startingStageId + 4, listener)
171+
(0 until 3).foreach { i => endJob(i + startingJobId, listener) }
172+
173+
// First, assert the old stuff
174+
assert(listener.jobIdToStageIds.size === 3)
175+
assert(listener.jobIdToSkippedStageIds.size === 3)
176+
assert(listener.stageIdToJobId.size === 8)
177+
assert(listener.stageIdToGraph.size === 8)
178+
assert(listener.completedStageIds.size === 2)
179+
180+
// Cleaning the third job should clean all of its stages
181+
listener.cleanJob(startingJobId + 2)
182+
assert(listener.jobIdToStageIds.size === 2)
183+
assert(listener.jobIdToSkippedStageIds.size === 2)
184+
assert(listener.stageIdToJobId.size === 6)
185+
assert(listener.stageIdToGraph.size === 6)
186+
assert(listener.completedStageIds.size === 2)
187+
188+
// Cleaning one of the stages in the first job should clean that job and all of its stages
189+
// Note that we still keep around the last stage because it belongs to a different job
190+
listener.cleanStage(startingStageId)
191+
assert(listener.jobIdToStageIds.size === 1)
192+
assert(listener.jobIdToSkippedStageIds.size === 1)
193+
assert(listener.stageIdToJobId.size === 1)
194+
assert(listener.stageIdToGraph.size === 1)
195+
assert(listener.completedStageIds.size === 0)
196+
}
197+
198+
/** Start a job with the specified number of stages. */
199+
private def startJob(numStages: Int, listener: RDDOperationGraphListener): Int = {
200+
assert(numStages > 0, "I will not run a job with 0 stages for you.")
201+
val stageInfos = (0 until numStages).map { _ =>
202+
val stageInfo = new StageInfo(stageIdCounter, 0, "s", 0, Seq.empty, Seq.empty, "d")
203+
stageIdCounter += 1
204+
stageInfo
205+
}
206+
val jobId = jobIdCounter
207+
listener.onJobStart(new SparkListenerJobStart(jobId, 0, stageInfos))
208+
// Also start all stages that belong to this job
209+
stageInfos.map(_.stageId).foreach { sid => startStage(sid, listener) }
210+
jobIdCounter += 1
211+
jobId
212+
}
213+
214+
/** Start the stage specified by the given ID. */
215+
private def startStage(stageId: Int, listener: RDDOperationGraphListener): Unit = {
216+
val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d")
217+
listener.onStageSubmitted(new SparkListenerStageSubmitted(stageInfo))
218+
}
219+
220+
/** Finish the stage specified by the given ID. */
221+
private def endStage(stageId: Int, listener: RDDOperationGraphListener): Unit = {
222+
val stageInfo = new StageInfo(stageId, 0, "s", 0, Seq.empty, Seq.empty, "d")
223+
listener.onStageCompleted(new SparkListenerStageCompleted(stageInfo))
224+
}
225+
226+
/** Finish the job specified by the given ID. */
227+
private def endJob(jobId: Int, listener: RDDOperationGraphListener): Unit = {
228+
listener.onJobEnd(new SparkListenerJobEnd(jobId, 0, JobSucceeded))
84229
}
85230

86231
}

0 commit comments

Comments
 (0)