Skip to content

Commit 760ab1f

Browse files
committed
do not reuse worker if there are any exceptions
1 parent 7abb224 commit 760ab1f

File tree

3 files changed

+30
-2
lines changed

3 files changed

+30
-2
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,10 @@ private[spark] class PythonRDD(
7373
// Start a thread to feed the process input from our parent's iterator
7474
val writerThread = new WriterThread(env, worker, split, context)
7575

76+
var complete_cleanly = false
7677
context.addTaskCompletionListener { context =>
7778
writerThread.shutdownOnTaskCompletion()
78-
if (!context.isInterrupted) {
79+
if (reuse_worker && complete_cleanly) {
7980
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
8081
} else {
8182
try {
@@ -141,6 +142,7 @@ private[spark] class PythonRDD(
141142
stream.readFully(update)
142143
accumulator += Collections.singletonList(update)
143144
}
145+
complete_cleanly = true
144146
null
145147
}
146148
} catch {
@@ -235,11 +237,13 @@ private[spark] class PythonRDD(
235237
} catch {
236238
case e: Exception if context.isCompleted || context.isInterrupted =>
237239
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
240+
worker.shutdownOutput()
238241

239242
case e: Exception =>
240243
// We must avoid throwing exceptions here, because the thread uncaught exception handler
241244
// will kill the whole executor (see org.apache.spark.executor.Executor).
242245
_exception = e
246+
worker.shutdownOutput()
243247
}
244248
}
245249
}

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -282,7 +282,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
282282
}
283283

284284
def releaseWorker(worker: Socket) {
285-
if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) {
285+
if (useDaemon) {
286286
synchronized {
287287
lastActivity = System.currentTimeMillis()
288288
idleWorkers.enqueue(worker)

python/pyspark/tests.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1077,11 +1077,35 @@ def run():
10771077
except OSError:
10781078
self.fail("daemon had been killed")
10791079

1080+
# run a normal job
1081+
rdd = self.sc.parallelize(range(100), 1)
1082+
self.assertEqual(100, rdd.map(str).count())
1083+
10801084
def test_fd_leak(self):
10811085
N = 1100 # fd limit is 1024 by default
10821086
rdd = self.sc.parallelize(range(N), N)
10831087
self.assertEquals(N, rdd.count())
10841088

1089+
def test_after_exception(self):
1090+
def raise_exception(_):
1091+
raise Exception()
1092+
rdd = self.sc.parallelize(range(100), 1)
1093+
self.assertRaises(Exception, lambda: rdd.foreach(raise_exception))
1094+
self.assertEqual(100, rdd.map(str).count())
1095+
1096+
def test_after_jvm_exception(self):
1097+
tempFile = tempfile.NamedTemporaryFile(delete=False)
1098+
tempFile.write("Hello World!")
1099+
tempFile.close()
1100+
data = self.sc.textFile(tempFile.name, 1)
1101+
filtered_data = data.filter(lambda x: True)
1102+
self.assertEqual(1, filtered_data.count())
1103+
os.unlink(tempFile.name)
1104+
self.assertRaises(Exception, lambda: filtered_data.count())
1105+
1106+
rdd = self.sc.parallelize(range(100), 1)
1107+
self.assertEqual(100, rdd.map(str).count())
1108+
10851109

10861110
class TestSparkSubmit(unittest.TestCase):
10871111

0 commit comments

Comments
 (0)