Skip to content

Commit 35cd9f7

Browse files
committed
Fixed a bug that partially executed partitions can be put into cache (in task killing).
1 parent 640f9a0 commit 35cd9f7

File tree

5 files changed

+87
-14
lines changed

5 files changed

+87
-14
lines changed

core/src/main/scala/org/apache/spark/CacheManager.scala

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,12 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
4747
if (loading.contains(key)) {
4848
logInfo("Another thread is loading %s, waiting for it to finish...".format(key))
4949
while (loading.contains(key)) {
50-
try {loading.wait()} catch {case _ : Throwable =>}
50+
try {
51+
loading.wait()
52+
} catch {
53+
case e: Exception =>
54+
logWarning(s"Got an exception while waiting for another thread to load $key", e)
55+
}
5156
}
5257
logInfo("Finished waiting for %s".format(key))
5358
/* See whether someone else has successfully loaded it. The main way this would fail
@@ -72,7 +77,9 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
7277
val computedValues = rdd.computeOrReadCheckpoint(split, context)
7378

7479
// Persist the result, so long as the task is not running locally
75-
if (context.runningLocally) { return computedValues }
80+
if (context.runningLocally) {
81+
return computedValues
82+
}
7683

7784
// Keep track of blocks with updated statuses
7885
var updatedBlocks = Seq[(BlockId, BlockStatus)]()
@@ -88,7 +95,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
8895
updatedBlocks = blockManager.put(key, computedValues, storageLevel, tellMaster = true)
8996
blockManager.get(key) match {
9097
case Some(values) =>
91-
new InterruptibleIterator(context, values.asInstanceOf[Iterator[T]])
98+
values.asInstanceOf[Iterator[T]]
9299
case None =>
93100
logInfo("Failure to store %s".format(key))
94101
throw new Exception("Block manager failed to return persisted valued")
@@ -107,7 +114,7 @@ private[spark] class CacheManager(blockManager: BlockManager) extends Logging {
107114
val metrics = context.taskMetrics
108115
metrics.updatedBlocks = Some(updatedBlocks)
109116

110-
returnValue
117+
new InterruptibleIterator(context, returnValue)
111118

112119
} finally {
113120
loading.synchronized {

core/src/main/scala/org/apache/spark/InterruptibleIterator.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,17 @@ package org.apache.spark
2424
private[spark] class InterruptibleIterator[+T](val context: TaskContext, val delegate: Iterator[T])
2525
extends Iterator[T] {
2626

27-
def hasNext: Boolean = !context.interrupted && delegate.hasNext
27+
def hasNext: Boolean = {
28+
// TODO(aarondav/rxin): Check Thread.interrupted instead of context.interrupted if interrupt
29+
// is allowed. The assumption is that Thread.interrupted does not have a memory fence in read
30+
// (just a volatile field in C), while context.interrupted is a volatile in the JVM, which
31+
// introduces an expensive read fence.
32+
if (context.interrupted) {
33+
throw new TaskKilledException
34+
} else {
35+
delegate.hasNext
36+
}
37+
}
2838

2939
def next(): T = delegate.next()
3040
}
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark
19+
20+
/**
21+
* Exception for a task getting killed.
22+
*/
23+
private[spark] class TaskKilledException extends RuntimeException

core/src/main/scala/org/apache/spark/executor/Executor.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,6 @@ private[spark] class Executor(
161161
class TaskRunner(execBackend: ExecutorBackend, taskId: Long, serializedTask: ByteBuffer)
162162
extends Runnable {
163163

164-
object TaskKilledException extends Exception
165-
166164
@volatile private var killed = false
167165
@volatile private var task: Task[Any] = _
168166

@@ -200,7 +198,7 @@ private[spark] class Executor(
200198
// causes a NonLocalReturnControl exception to be thrown. The NonLocalReturnControl
201199
// exception will be caught by the catch block, leading to an incorrect ExceptionFailure
202200
// for the task.
203-
throw TaskKilledException
201+
throw new TaskKilledException
204202
}
205203

206204
attemptedTask = Some(task)
@@ -214,7 +212,7 @@ private[spark] class Executor(
214212

215213
// If the task has been killed, let's fail it.
216214
if (task.killed) {
217-
throw TaskKilledException
215+
throw new TaskKilledException
218216
}
219217

220218
val resultSer = SparkEnv.get.serializer.newInstance()
@@ -257,7 +255,7 @@ private[spark] class Executor(
257255
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
258256
}
259257

260-
case TaskKilledException | _: InterruptedException if task.killed => {
258+
case _: TaskKilledException | _: InterruptedException if task.killed => {
261259
logInfo("Executor killed task " + taskId)
262260
execBackend.statusUpdate(taskId, TaskState.KILLED, ser.serialize(TaskKilled))
263261
}

core/src/test/scala/org/apache/spark/JobCancellationSuite.scala

Lines changed: 39 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,35 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
8484
assert(sc.parallelize(1 to 10, 2).count === 10)
8585
}
8686

87+
test("do not put partially executed partitions into cache") {
88+
// In this test case, we create a scenario in which a partition is only partially executed,
89+
// and make sure CacheManager does not put that partially executed partition into the
90+
// BlockManager.
91+
import JobCancellationSuite._
92+
sc = new SparkContext("local", "test")
93+
94+
// Run from 1 to 10, and then block and wait for the task to be killed.
95+
val rdd = sc.parallelize(1 to 1000, 2).map { x =>
96+
if (x > 10) {
97+
taskStartedSemaphore.release()
98+
taskCancelledSemaphore.acquire()
99+
}
100+
x
101+
}.cache()
102+
103+
val rdd1 = rdd.map(x => x)
104+
105+
future {
106+
taskStartedSemaphore.acquire()
107+
sc.cancelAllJobs()
108+
taskCancelledSemaphore.release(100000)
109+
}
110+
111+
intercept[SparkException] { rdd1.count() }
112+
// If the partial block is put into cache, rdd.count() would return a number less than 1000.
113+
assert(rdd.count() === 1000)
114+
}
115+
87116
test("job group") {
88117
sc = new SparkContext("local[2]", "test")
89118

@@ -113,15 +142,15 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
113142
// Once A is cancelled, job B should finish fairly quickly.
114143
assert(jobB.get() === 100)
115144
}
116-
/*
117-
test("two jobs sharing the same stage") {
145+
146+
ignore("two jobs sharing the same stage") {
118147
// sem1: make sure cancel is issued after some tasks are launched
119148
// sem2: make sure the first stage is not finished until cancel is issued
120149
val sem1 = new Semaphore(0)
121150
val sem2 = new Semaphore(0)
122151

123152
sc = new SparkContext("local[2]", "test")
124-
sc.dagScheduler.addSparkListener(new SparkListener {
153+
sc.addSparkListener(new SparkListener {
125154
override def onTaskStart(taskStart: SparkListenerTaskStart) {
126155
sem1.release()
127156
}
@@ -147,7 +176,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
147176
intercept[SparkException] { f1.get() }
148177
intercept[SparkException] { f2.get() }
149178
}
150-
*/
179+
151180
def testCount() {
152181
// Cancel before launching any tasks
153182
{
@@ -206,3 +235,9 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
206235
}
207236
}
208237
}
238+
239+
240+
object JobCancellationSuite {
241+
val taskStartedSemaphore = new Semaphore(0)
242+
val taskCancelledSemaphore = new Semaphore(0)
243+
}

0 commit comments

Comments
 (0)