Skip to content

Commit a4cdb77

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-1740] [PySpark] kill the python worker
Kill only the python worker related to cancelled tasks. The daemon will start a background thread to monitor all the opened sockets for all workers. If the socket is closed by JVM, this thread will kill the worker. When an task is cancelled, the socket to worker will be closed, then the worker will be killed by deamon. Author: Davies Liu <[email protected]> Closes #1643 from davies/kill and squashes the following commits: 8ffe9f3 [Davies Liu] kill worker by deamon, because runtime.exec() is too heavy 46ca150 [Davies Liu] address comment acd751c [Davies Liu] kill the worker when task is canceled (cherry picked from commit 55349f9) Signed-off-by: Josh Rosen <[email protected]>
1 parent 7c6afda commit a4cdb77

File tree

5 files changed

+125
-28
lines changed

5 files changed

+125
-28
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
package org.apache.spark
1919

2020
import java.io.File
21+
import java.net.Socket
2122

2223
import scala.collection.JavaConversions._
2324
import scala.collection.mutable
@@ -102,10 +103,10 @@ class SparkEnv (
102103
}
103104

104105
private[spark]
105-
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
106+
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
106107
synchronized {
107108
val key = (pythonExec, envVars)
108-
pythonWorkers(key).stop()
109+
pythonWorkers.get(key).foreach(_.stopWorker(worker))
109110
}
110111
}
111112
}

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ private[spark] class PythonRDD(
6262
val env = SparkEnv.get
6363
val localdir = env.blockManager.diskBlockManager.localDirs.map(
6464
f => f.getPath()).mkString(",")
65-
val worker: Socket = env.createPythonWorker(pythonExec,
66-
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
65+
envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
66+
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6767

6868
// Start a thread to feed the process input from our parent's iterator
6969
val writerThread = new WriterThread(env, worker, split, context)
@@ -241,7 +241,7 @@ private[spark] class PythonRDD(
241241
if (!context.completed) {
242242
try {
243243
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
244-
env.destroyPythonWorker(pythonExec, envVars.toMap)
244+
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
245245
} catch {
246246
case e: Exception =>
247247
logError("Exception when trying to kill worker", e)
@@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {
685685

686686
/**
687687
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
688-
* This function is outdated, PySpark does not use it anymore
689688
*/
690-
@deprecated
689+
@deprecated("PySpark does not use it anymore", "1.1")
691690
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
692691
pyRDD.rdd.mapPartitions { iter =>
693692
val unpickle = new Unpickler

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 49 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,11 @@
1717

1818
package org.apache.spark.api.python
1919

20-
import java.io.{DataInputStream, InputStream, OutputStreamWriter}
20+
import java.lang.Runtime
21+
import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
2122
import java.net.{InetAddress, ServerSocket, Socket, SocketException}
2223

24+
import scala.collection.mutable
2325
import scala.collection.JavaConversions._
2426

2527
import org.apache.spark._
@@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
3941
var daemon: Process = null
4042
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
4143
var daemonPort: Int = 0
44+
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
45+
46+
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
4247

4348
val pythonPath = PythonUtils.mergePythonPaths(
4449
PythonUtils.sparkPythonPath,
@@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
5863
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
5964
*/
6065
private def createThroughDaemon(): Socket = {
66+
67+
def createSocket(): Socket = {
68+
val socket = new Socket(daemonHost, daemonPort)
69+
val pid = new DataInputStream(socket.getInputStream).readInt()
70+
if (pid < 0) {
71+
throw new IllegalStateException("Python daemon failed to launch worker")
72+
}
73+
daemonWorkers.put(socket, pid)
74+
socket
75+
}
76+
6177
synchronized {
6278
// Start the daemon if it hasn't been started
6379
startDaemon()
6480

6581
// Attempt to connect, restart and retry once if it fails
6682
try {
67-
val socket = new Socket(daemonHost, daemonPort)
68-
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
69-
if (launchStatus != 0) {
70-
throw new IllegalStateException("Python daemon failed to launch worker")
71-
}
72-
socket
83+
createSocket()
7384
} catch {
7485
case exc: SocketException =>
7586
logWarning("Failed to open socket to Python daemon:", exc)
7687
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
7788
stopDaemon()
7889
startDaemon()
79-
new Socket(daemonHost, daemonPort)
90+
createSocket()
8091
}
8192
}
8293
}
@@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
107118
// Wait for it to connect to our socket
108119
serverSocket.setSoTimeout(10000)
109120
try {
110-
return serverSocket.accept()
121+
val socket = serverSocket.accept()
122+
simpleWorkers.put(socket, worker)
123+
return socket
111124
} catch {
112125
case e: Exception =>
113126
throw new SparkException("Python worker did not connect back in time", e)
@@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
189202

190203
private def stopDaemon() {
191204
synchronized {
192-
// Request shutdown of existing daemon by sending SIGTERM
193-
if (daemon != null) {
194-
daemon.destroy()
195-
}
205+
if (useDaemon) {
206+
// Request shutdown of existing daemon by sending SIGTERM
207+
if (daemon != null) {
208+
daemon.destroy()
209+
}
196210

197-
daemon = null
198-
daemonPort = 0
211+
daemon = null
212+
daemonPort = 0
213+
} else {
214+
simpleWorkers.mapValues(_.destroy())
215+
}
199216
}
200217
}
201218

202219
def stop() {
203220
stopDaemon()
204221
}
222+
223+
def stopWorker(worker: Socket) {
224+
if (useDaemon) {
225+
if (daemon != null) {
226+
daemonWorkers.get(worker).foreach { pid =>
227+
// tell daemon to kill worker by pid
228+
val output = new DataOutputStream(daemon.getOutputStream)
229+
output.writeInt(pid)
230+
output.flush()
231+
daemon.getOutputStream.flush()
232+
}
233+
}
234+
} else {
235+
simpleWorkers.get(worker).foreach(_.destroy())
236+
}
237+
worker.close()
238+
}
205239
}
206240

207241
private object PythonWorkerFactory {

python/pyspark/daemon.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
2727
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
2828
from pyspark.worker import main as worker_main
29-
from pyspark.serializers import write_int
29+
from pyspark.serializers import read_int, write_int
3030

3131

3232
def compute_real_exit_code(exit_code):
@@ -67,7 +67,8 @@ def waitSocketClose(sock):
6767
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
6868
exit_code = 0
6969
try:
70-
write_int(0, outfile) # Acknowledge that the fork was successful
70+
# Acknowledge that the fork was successful
71+
write_int(os.getpid(), outfile)
7172
outfile.flush()
7273
worker_main(infile, outfile)
7374
except SystemExit as exc:
@@ -125,14 +126,23 @@ def handle_sigchld(*args):
125126
else:
126127
raise
127128
if 0 in ready_fds:
128-
# Spark told us to exit by closing stdin
129-
shutdown(0)
129+
try:
130+
worker_pid = read_int(sys.stdin)
131+
except EOFError:
132+
# Spark told us to exit by closing stdin
133+
shutdown(0)
134+
try:
135+
os.kill(worker_pid, signal.SIGKILL)
136+
except OSError:
137+
pass # process already died
138+
139+
130140
if listen_sock in ready_fds:
131141
sock, addr = listen_sock.accept()
132142
# Launch a worker process
133143
try:
134-
fork_return_code = os.fork()
135-
if fork_return_code == 0:
144+
pid = os.fork()
145+
if pid == 0:
136146
listen_sock.close()
137147
try:
138148
worker(sock)
@@ -143,11 +153,13 @@ def handle_sigchld(*args):
143153
os._exit(0)
144154
else:
145155
sock.close()
156+
146157
except OSError as e:
147158
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
148159
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
149160
write_int(-1, outfile) # Signal that the fork failed
150161
outfile.flush()
162+
outfile.close()
151163
sock.close()
152164
finally:
153165
shutdown(1)

python/pyspark/tests.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,57 @@ def test_termination_sigterm(self):
790790
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))
791791

792792

793+
class TestWorker(PySparkTestCase):
794+
def test_cancel_task(self):
795+
temp = tempfile.NamedTemporaryFile(delete=True)
796+
temp.close()
797+
path = temp.name
798+
def sleep(x):
799+
import os, time
800+
with open(path, 'w') as f:
801+
f.write("%d %d" % (os.getppid(), os.getpid()))
802+
time.sleep(100)
803+
804+
# start job in background thread
805+
def run():
806+
self.sc.parallelize(range(1)).foreach(sleep)
807+
import threading
808+
t = threading.Thread(target=run)
809+
t.daemon = True
810+
t.start()
811+
812+
daemon_pid, worker_pid = 0, 0
813+
while True:
814+
if os.path.exists(path):
815+
data = open(path).read().split(' ')
816+
daemon_pid, worker_pid = map(int, data)
817+
break
818+
time.sleep(0.1)
819+
820+
# cancel jobs
821+
self.sc.cancelAllJobs()
822+
t.join()
823+
824+
for i in range(50):
825+
try:
826+
os.kill(worker_pid, 0)
827+
time.sleep(0.1)
828+
except OSError:
829+
break # worker was killed
830+
else:
831+
self.fail("worker has not been killed after 5 seconds")
832+
833+
try:
834+
os.kill(daemon_pid, 0)
835+
except OSError:
836+
self.fail("daemon had been killed")
837+
838+
def test_fd_leak(self):
839+
N = 1100 # fd limit is 1024 by default
840+
rdd = self.sc.parallelize(range(N), N)
841+
self.assertEquals(N, rdd.count())
842+
843+
793844
class TestSparkSubmit(unittest.TestCase):
794845
def setUp(self):
795846
self.programDir = tempfile.mkdtemp()

0 commit comments

Comments
 (0)