Skip to content

Commit 2aea0da

Browse files
daviesJoshRosen
authored andcommitted
[SPARK-3030] [PySpark] Reuse Python worker
Reuse Python worker to avoid the overhead of fork() Python process for each tasks. It also tracks the broadcasts for each worker, avoid sending repeated broadcasts. This can reduce the time for dummy task from 22ms to 13ms (-40%). It can help to reduce the latency for Spark Streaming. For a job with broadcast (43M after compress): ``` b = sc.broadcast(set(range(30000000))) print sc.parallelize(range(24000), 100).filter(lambda x: x in b.value).count() ``` It will finish in 281s without reused worker, and it will finish in 65s with reused worker(4 CPUs). After reusing the worker, it can save about 9 seconds for transfer and deserialize the broadcast for each tasks. It's enabled by default, could be disabled by `spark.python.worker.reuse = false`. Author: Davies Liu <[email protected]> Closes #2259 from davies/reuse-worker and squashes the following commits: f11f617 [Davies Liu] Merge branch 'master' into reuse-worker 3939f20 [Davies Liu] fix bug in serializer in mllib cf1c55e [Davies Liu] address comments 3133a60 [Davies Liu] fix accumulator with reused worker 760ab1f [Davies Liu] do not reuse worker if there are any exceptions 7abb224 [Davies Liu] refactor: sychronized with itself ac3206e [Davies Liu] renaming 8911f44 [Davies Liu] synchronized getWorkerBroadcasts() 6325fc1 [Davies Liu] bugfix: bid >= 0 e0131a2 [Davies Liu] fix name of config 583716e [Davies Liu] only reuse completed and not interrupted worker ace2917 [Davies Liu] kill python worker after timeout 6123d0f [Davies Liu] track broadcasts for each worker 8d2f08c [Davies Liu] reuse python worker
1 parent 0f8c4ed commit 2aea0da

File tree

9 files changed

+208
-51
lines changed

9 files changed

+208
-51
lines changed

core/src/main/scala/org/apache/spark/SparkEnv.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,14 @@ class SparkEnv (
108108
pythonWorkers.get(key).foreach(_.stopWorker(worker))
109109
}
110110
}
111+
112+
private[spark]
113+
def releasePythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
114+
synchronized {
115+
val key = (pythonExec, envVars)
116+
pythonWorkers.get(key).foreach(_.releaseWorker(worker))
117+
}
118+
}
111119
}
112120

113121
object SparkEnv extends Logging {

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import java.nio.charset.Charset
2323
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
2424

2525
import scala.collection.JavaConversions._
26+
import scala.collection.mutable
2627
import scala.language.existentials
2728
import scala.reflect.ClassTag
2829
import scala.util.{Try, Success, Failure}
@@ -52,6 +53,7 @@ private[spark] class PythonRDD(
5253
extends RDD[Array[Byte]](parent) {
5354

5455
val bufferSize = conf.getInt("spark.buffer.size", 65536)
56+
val reuse_worker = conf.getBoolean("spark.python.worker.reuse", true)
5557

5658
override def getPartitions = parent.partitions
5759

@@ -63,19 +65,26 @@ private[spark] class PythonRDD(
6365
val localdir = env.blockManager.diskBlockManager.localDirs.map(
6466
f => f.getPath()).mkString(",")
6567
envVars += ("SPARK_LOCAL_DIRS" -> localdir) // it's also used in monitor thread
68+
if (reuse_worker) {
69+
envVars += ("SPARK_REUSE_WORKER" -> "1")
70+
}
6671
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)
6772

6873
// Start a thread to feed the process input from our parent's iterator
6974
val writerThread = new WriterThread(env, worker, split, context)
7075

76+
var complete_cleanly = false
7177
context.addTaskCompletionListener { context =>
7278
writerThread.shutdownOnTaskCompletion()
73-
74-
// Cleanup the worker socket. This will also cause the Python worker to exit.
75-
try {
76-
worker.close()
77-
} catch {
78-
case e: Exception => logWarning("Failed to close worker socket", e)
79+
if (reuse_worker && complete_cleanly) {
80+
env.releasePythonWorker(pythonExec, envVars.toMap, worker)
81+
} else {
82+
try {
83+
worker.close()
84+
} catch {
85+
case e: Exception =>
86+
logWarning("Failed to close worker socket", e)
87+
}
7988
}
8089
}
8190

@@ -133,6 +142,7 @@ private[spark] class PythonRDD(
133142
stream.readFully(update)
134143
accumulator += Collections.singletonList(update)
135144
}
145+
complete_cleanly = true
136146
null
137147
}
138148
} catch {
@@ -195,29 +205,45 @@ private[spark] class PythonRDD(
195205
PythonRDD.writeUTF(include, dataOut)
196206
}
197207
// Broadcast variables
198-
dataOut.writeInt(broadcastVars.length)
208+
val oldBids = PythonRDD.getWorkerBroadcasts(worker)
209+
val newBids = broadcastVars.map(_.id).toSet
210+
// number of different broadcasts
211+
val cnt = oldBids.diff(newBids).size + newBids.diff(oldBids).size
212+
dataOut.writeInt(cnt)
213+
for (bid <- oldBids) {
214+
if (!newBids.contains(bid)) {
215+
// remove the broadcast from worker
216+
dataOut.writeLong(- bid - 1) // bid >= 0
217+
oldBids.remove(bid)
218+
}
219+
}
199220
for (broadcast <- broadcastVars) {
200-
dataOut.writeLong(broadcast.id)
201-
dataOut.writeInt(broadcast.value.length)
202-
dataOut.write(broadcast.value)
221+
if (!oldBids.contains(broadcast.id)) {
222+
// send new broadcast
223+
dataOut.writeLong(broadcast.id)
224+
dataOut.writeInt(broadcast.value.length)
225+
dataOut.write(broadcast.value)
226+
oldBids.add(broadcast.id)
227+
}
203228
}
204229
dataOut.flush()
205230
// Serialized command:
206231
dataOut.writeInt(command.length)
207232
dataOut.write(command)
208233
// Data values
209234
PythonRDD.writeIteratorToStream(parent.iterator(split, context), dataOut)
235+
dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
210236
dataOut.flush()
211237
} catch {
212238
case e: Exception if context.isCompleted || context.isInterrupted =>
213239
logDebug("Exception thrown after task completion (likely due to cleanup)", e)
240+
worker.shutdownOutput()
214241

215242
case e: Exception =>
216243
// We must avoid throwing exceptions here, because the thread uncaught exception handler
217244
// will kill the whole executor (see org.apache.spark.executor.Executor).
218245
_exception = e
219-
} finally {
220-
Try(worker.shutdownOutput()) // kill Python worker process
246+
worker.shutdownOutput()
221247
}
222248
}
223249
}
@@ -278,6 +304,14 @@ private object SpecialLengths {
278304
private[spark] object PythonRDD extends Logging {
279305
val UTF8 = Charset.forName("UTF-8")
280306

307+
// remember the broadcasts sent to each worker
308+
private val workerBroadcasts = new mutable.WeakHashMap[Socket, mutable.Set[Long]]()
309+
private def getWorkerBroadcasts(worker: Socket) = {
310+
synchronized {
311+
workerBroadcasts.getOrElseUpdate(worker, new mutable.HashSet[Long]())
312+
}
313+
}
314+
281315
/**
282316
* Adapter for calling SparkContext#runJob from Python.
283317
*

core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala

Lines changed: 74 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
4040
var daemon: Process = null
4141
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
4242
var daemonPort: Int = 0
43-
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
43+
val daemonWorkers = new mutable.WeakHashMap[Socket, Int]()
44+
val idleWorkers = new mutable.Queue[Socket]()
45+
var lastActivity = 0L
46+
new MonitorThread().start()
4447

4548
var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()
4649

@@ -51,6 +54,11 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
5154

5255
def create(): Socket = {
5356
if (useDaemon) {
57+
synchronized {
58+
if (idleWorkers.size > 0) {
59+
return idleWorkers.dequeue()
60+
}
61+
}
5462
createThroughDaemon()
5563
} else {
5664
createSimpleWorker()
@@ -199,9 +207,44 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
199207
}
200208
}
201209

210+
/**
211+
* Monitor all the idle workers, kill them after timeout.
212+
*/
213+
private class MonitorThread extends Thread(s"Idle Worker Monitor for $pythonExec") {
214+
215+
setDaemon(true)
216+
217+
override def run() {
218+
while (true) {
219+
synchronized {
220+
if (lastActivity + IDLE_WORKER_TIMEOUT_MS < System.currentTimeMillis()) {
221+
cleanupIdleWorkers()
222+
lastActivity = System.currentTimeMillis()
223+
}
224+
}
225+
Thread.sleep(10000)
226+
}
227+
}
228+
}
229+
230+
private def cleanupIdleWorkers() {
231+
while (idleWorkers.length > 0) {
232+
val worker = idleWorkers.dequeue()
233+
try {
234+
// the worker will exit after closing the socket
235+
worker.close()
236+
} catch {
237+
case e: Exception =>
238+
logWarning("Failed to close worker socket", e)
239+
}
240+
}
241+
}
242+
202243
private def stopDaemon() {
203244
synchronized {
204245
if (useDaemon) {
246+
cleanupIdleWorkers()
247+
205248
// Request shutdown of existing daemon by sending SIGTERM
206249
if (daemon != null) {
207250
daemon.destroy()
@@ -220,23 +263,43 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
220263
}
221264

222265
def stopWorker(worker: Socket) {
223-
if (useDaemon) {
224-
if (daemon != null) {
225-
daemonWorkers.get(worker).foreach { pid =>
226-
// tell daemon to kill worker by pid
227-
val output = new DataOutputStream(daemon.getOutputStream)
228-
output.writeInt(pid)
229-
output.flush()
230-
daemon.getOutputStream.flush()
266+
synchronized {
267+
if (useDaemon) {
268+
if (daemon != null) {
269+
daemonWorkers.get(worker).foreach { pid =>
270+
// tell daemon to kill worker by pid
271+
val output = new DataOutputStream(daemon.getOutputStream)
272+
output.writeInt(pid)
273+
output.flush()
274+
daemon.getOutputStream.flush()
275+
}
231276
}
277+
} else {
278+
simpleWorkers.get(worker).foreach(_.destroy())
232279
}
233-
} else {
234-
simpleWorkers.get(worker).foreach(_.destroy())
235280
}
236281
worker.close()
237282
}
283+
284+
def releaseWorker(worker: Socket) {
285+
if (useDaemon) {
286+
synchronized {
287+
lastActivity = System.currentTimeMillis()
288+
idleWorkers.enqueue(worker)
289+
}
290+
} else {
291+
// Cleanup the worker socket. This will also cause the Python worker to exit.
292+
try {
293+
worker.close()
294+
} catch {
295+
case e: Exception =>
296+
logWarning("Failed to close worker socket", e)
297+
}
298+
}
299+
}
238300
}
239301

240302
private object PythonWorkerFactory {
241303
val PROCESS_WAIT_TIMEOUT_MS = 10000
304+
val IDLE_WORKER_TIMEOUT_MS = 60000 // kill idle workers after 1 minute
242305
}

docs/configuration.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,16 @@ Apart from these, the following properties are also available, and may be useful
206206
used during aggregation goes above this amount, it will spill the data into disks.
207207
</td>
208208
</tr>
209+
<tr>
210+
<td><code>spark.python.worker.reuse</code></td>
211+
<td>true</td>
212+
<td>
213+
Reuse Python worker or not. If yes, it will use a fixed number of Python workers,
214+
does not need to fork() a Python process for every tasks. It will be very useful
215+
if there is large broadcast, then the broadcast will not be needed to transfered
216+
from JVM to Python worker for every task.
217+
</td>
218+
</tr>
209219
<tr>
210220
<td><code>spark.executorEnv.[EnvironmentVariableName]</code></td>
211221
<td>(none)</td>

python/pyspark/daemon.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import sys
2424
import traceback
2525
import time
26+
import gc
2627
from errno import EINTR, ECHILD, EAGAIN
2728
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
2829
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
@@ -46,35 +47,20 @@ def worker(sock):
4647
signal.signal(SIGCHLD, SIG_DFL)
4748
signal.signal(SIGTERM, SIG_DFL)
4849

49-
# Blocks until the socket is closed by draining the input stream
50-
# until it raises an exception or returns EOF.
51-
def waitSocketClose(sock):
52-
try:
53-
while True:
54-
# Empty string is returned upon EOF (and only then).
55-
if sock.recv(4096) == '':
56-
return
57-
except:
58-
pass
59-
6050
# Read the socket using fdopen instead of socket.makefile() because the latter
6151
# seems to be very slow; note that we need to dup() the file descriptor because
6252
# otherwise writes also cause a seek that makes us miss data on the read side.
6353
infile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
6454
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
6555
exit_code = 0
6656
try:
67-
# Acknowledge that the fork was successful
68-
write_int(os.getpid(), outfile)
69-
outfile.flush()
7057
worker_main(infile, outfile)
7158
except SystemExit as exc:
72-
exit_code = exc.code
59+
exit_code = compute_real_exit_code(exc.code)
7360
finally:
7461
outfile.flush()
75-
# The Scala side will close the socket upon task completion.
76-
waitSocketClose(sock)
77-
os._exit(compute_real_exit_code(exit_code))
62+
if exit_code:
63+
os._exit(exit_code)
7864

7965

8066
# Cleanup zombie children
@@ -111,6 +97,8 @@ def handle_sigterm(*args):
11197
signal.signal(SIGTERM, handle_sigterm) # Gracefully exit on SIGTERM
11298
signal.signal(SIGHUP, SIG_IGN) # Don't die on SIGHUP
11399

100+
reuse = os.environ.get("SPARK_REUSE_WORKER")
101+
114102
# Initialization complete
115103
try:
116104
while True:
@@ -163,7 +151,19 @@ def handle_sigterm(*args):
163151
# in child process
164152
listen_sock.close()
165153
try:
166-
worker(sock)
154+
# Acknowledge that the fork was successful
155+
outfile = sock.makefile("w")
156+
write_int(os.getpid(), outfile)
157+
outfile.flush()
158+
outfile.close()
159+
while True:
160+
worker(sock)
161+
if not reuse:
162+
# wait for closing
163+
while sock.recv(1024):
164+
pass
165+
break
166+
gc.collect()
167167
except:
168168
traceback.print_exc()
169169
os._exit(1)

python/pyspark/mllib/_common.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from numpy import ndarray, float64, int64, int32, array_equal, array
2222
from pyspark import SparkContext, RDD
2323
from pyspark.mllib.linalg import SparseVector
24-
from pyspark.serializers import Serializer
24+
from pyspark.serializers import FramedSerializer
2525

2626

2727
"""
@@ -451,18 +451,16 @@ def _serialize_rating(r):
451451
return ba
452452

453453

454-
class RatingDeserializer(Serializer):
454+
class RatingDeserializer(FramedSerializer):
455455

456-
def loads(self, stream):
457-
length = struct.unpack("!i", stream.read(4))[0]
458-
ba = stream.read(length)
459-
res = ndarray(shape=(3, ), buffer=ba, dtype=float64, offset=4)
456+
def loads(self, string):
457+
res = ndarray(shape=(3, ), buffer=string, dtype=float64, offset=4)
460458
return int(res[0]), int(res[1]), res[2]
461459

462460
def load_stream(self, stream):
463461
while True:
464462
try:
465-
yield self.loads(stream)
463+
yield self._read_with_length(stream)
466464
except struct.error:
467465
return
468466
except EOFError:

0 commit comments

Comments
 (0)