Skip to content

Commit 8d2f08c

Browse files
committed
reuse python worker
1 parent 6a72a36 commit 8d2f08c

File tree

7 files changed

+67
-34
lines changed

7 files changed

+67
-34
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,14 @@ class SparkEnv (
105105
pythonWorkers.get(key).foreach(_.stopWorker(worker))
106106
}
107107
}
108+
109+
private[spark]
110+
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
111+
synchronized {
112+
val key = (pythonExec, envVars)
113+
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
114+
}
115+
}
108116
}
109117

110118
object SparkEnv extends Logging {

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ private[spark] class PythonRDD(
5252
extends RDD[Array[Byte]](parent) {
5353

5454
val bufferSize = conf.getInt("spark.buffer.size", 65536)
55+
val reuse_worker = conf.getBoolean("spark.python.reuse.worker", true)
5556

5657
override def getPartitions = parent.partitions
5758

@@ -63,20 +64,17 @@ private[spark] class PythonRDD(
6364
val localdir = env.blockManager.diskBlockManager.localDirs.map(
6465
f => f.getPath()).mkString(",")
6566
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
67+
if (reuse_worker) {
68+
envVars += ("SPARK_REUSE_WORKER" -> "1")
69+
}
6670
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6771

6872
// Start a thread to feed the process input from our parent's iterator
6973
val writerThread = new WriterThread(env, worker, split, context)
7074

7175
context.addTaskCompletionListener { context =>
7276
writerThread.shutdownOnTaskCompletion()
73-
74-
// Cleanup the worker socket. This will also cause the Python worker to exit.
75-
try {
76-
worker.close()
77-
} catch {
78-
case e: Exception => logWarning("Failed to close worker socket", e)
79-
}
77+
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
8078
}
8179

8280
writerThread.start()
@@ -207,6 +205,7 @@ private[spark] class PythonRDD(
207205
dataOut.write(command)
208206
// Data values
209207
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
208+
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
210209
dataOut.flush()
211210
} catch {
212211
case e: Exception if context.isCompleted || context.isInterrupted =>
@@ -216,8 +215,6 @@ private[spark] class PythonRDD(
216215
// We must avoid throwing exceptions here, because the thread uncaught exception handler
217216
// will kill the whole executor (see org.apache.spark.executor.Executor).
218217
_exception = e
219-
} finally {
220-
Try(worker.shutdownOutput()) // kill Python worker process
221218
}
222219
}
223220
}

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
4141
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
4242
var daemonPort: Int = 0
4343
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
44+
var idleWorkers = new mutable.Queue[Socket]()
4445

4546
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
4647

@@ -51,6 +52,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
5152

5253
def create(): Socket = {
5354
if (useDaemon) {
55+
if (idleWorkers.length > 0) {
56+
return idleWorkers.dequeue()
57+
}
5458
createThroughDaemon()
5559
} else {
5660
createSimpleWorker()
@@ -235,6 +239,20 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
235239
}
236240
worker.close()
237241
}
242+
243+
def releaseWorker(worker: Socket) {
244+
if (useDaemon && envVars.get("SPARK_REUSE_WORKER").isDefined) {
245+
idleWorkers.enqueue(worker)
246+
} else {
247+
// Cleanup the worker socket. This will also cause the Python worker to exit.
248+
try {
249+
worker.close()
250+
} catch {
251+
case e: Exception =>
252+
logWarning("Failed to close worker socket", e)
253+
}
254+
}
255+
}
238256
}
239257

240258
private object PythonWorkerFactory {

docs/configuration.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
206206
used during aggregation goes above this amount, it will spill the data into disks.
207207
</td>
208208
</tr>
209+
<tr>
210+
<td><code>spark.python.worker.reuse</code></td>
211+
<td>true</td>
212+
<td>
213+
Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
214+
does not need to fork() a Python process for every tasks. It will be very useful
215+
if there is large broadcast, then the broadcast will not be needed to transfered
216+
from JVM to Python worker for every task.
217+
</td>
218+
</tr>
209219
<tr>
210220
<td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
211221
<td>(none)</td>

python/pyspark/daemon.py

Lines changed: 20 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sys
2424
import traceback
2525
import time
26+
import gc
2627
from errno import EINTR, ECHILD, EAGAIN
2728
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
2829
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -42,43 +43,24 @@ def worker(sock):
4243
"""
4344
Called by a worker process after the fork().
4445
"""
45-
# Redirect stdout to stderr
46-
os.dup2(2, 1)
47-
sys.stdout = sys.stderr # The sys.stdout object is different from file descriptor 1
48-
4946
signal.signal(SIGHUP, SIG_DFL)
5047
signal.signal(SIGCHLD, SIG_DFL)
5148
signal.signal(SIGTERM, SIG_DFL)
5249

53-
# Blocks until the socket is closed by draining the input stream
54-
# until it raises an exception or returns EOF.
55-
def waitSocketClose(sock):
56-
try:
57-
while True:
58-
# Empty string is returned upon EOF (and only then).
59-
if sock.recv(4096) == '':
60-
return
61-
except:
62-
pass
63-
6450
# Read the socket using fdopen instead of socket.makefile() because the latter
6551
# seems to be very slow; note that we need to dup() the file descriptor because
6652
# otherwise writes also cause a seek that makes us miss data on the read side.
6753
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
6854
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
6955
exit_code = 0
7056
try:
71-
# Acknowledge that the fork was successful
72-
write_int(os.getpid(), outfile)
73-
outfile.flush()
7457
worker_main(infile, outfile)
7558
except SystemExit as exc:
76-
exit_code = exc.code
59+
exit_code = compute_real_exit_code(exc.code)
7760
finally:
7861
outfile.flush()
79-
# The Scala side will close the socket upon task completion.
80-
waitSocketClose(sock)
81-
os._exit(compute_real_exit_code(exit_code))
62+
if exit_code:
63+
os._exit(exit_code)
8264

8365

8466
# Cleanup zombie children
@@ -102,6 +84,7 @@ def manager():
10284
listen_sock.listen(max(1024, SOMAXCONN))
10385
listen_host, listen_port = listen_sock.getsockname()
10486
write_int(listen_port, sys.stdout)
87+
sys.stdout.flush()
10588

10689
def shutdown(code):
10790
signal.signal(SIGTERM, SIG_DFL)
@@ -114,8 +97,9 @@ def handle_sigterm(*args):
11497
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
11598
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
11699

100+
reuse = os.environ.get("SPARK_REUSE_WORKER")
101+
117102
# Initialization complete
118-
sys.stdout.close()
119103
try:
120104
while True:
121105
try:
@@ -167,7 +151,19 @@ def handle_sigterm(*args):
167151
# in child process
168152
listen_sock.close()
169153
try:
170-
worker(sock)
154+
# Acknowledge that the fork was successful
155+
outfile = sock.makefile("w")
156+
write_int(os.getpid(), outfile)
157+
outfile.flush()
158+
outfile.close()
159+
while True:
160+
worker(sock)
161+
if not reuse:
162+
# wait for closing
163+
while sock.recv(1024):
164+
pass
165+
break
166+
gc.collect()
171167
except:
172168
traceback.print_exc()
173169
os._exit(1)

python/pyspark/serializers.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,8 @@ def _write_with_length(self, obj, stream):
144144

145145
def _read_with_length(self, stream):
146146
length = read_int(stream)
147+
if length == SpecialLengths.END_OF_DATA_SECTION:
148+
raise EOFError
147149
obj = stream.read(length)
148150
if obj == "":
149151
raise EOFError
@@ -431,6 +433,8 @@ class UTF8Deserializer(Serializer):
431433

432434
def loads(self, stream):
433435
length = read_int(stream)
436+
if length == SpecialLengths.END_OF_DATA_SECTION:
437+
raise EOFError
434438
return stream.read(length).decode('utf8')
435439

436440
def load_stream(self, stream):

python/run-tests

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ echo "Running PySpark tests. Output is in python/unit-tests.log."
5050

5151
# Try to test with Python 2.6, since that's the minimum version that we support:
5252
if [ $(which python2.6) ]; then
53-
export PYSPARK_PYTHON="python2.6"
53+
export PYSPARK_PYTHON="pypy"
5454
fi
5555

5656
echo "Testing with Python version:"

0 commit comments

Comments
 (0)