Skip to content

Commit ef4ff00

Browse files
aarondavJoshRosen
authored andcommitted
SPARK-2282: Reuse Socket for sending accumulator updates to Pyspark
Prior to this change, every PySpark task completion opened a new socket to the accumulator server, passed its updates through, and then quit. I'm not entirely sure why PySpark always sends accumulator updates, but regardless this causes a very rapid buildup of ephemeral TCP connections that remain in the TCP_WAIT state for around a minute before being cleaned up. Rather than trying to allow these sockets to be cleaned up faster, this patch simply reuses the connection between tasks completions (since they're fed updates in a single-threaded manner by the DAGScheduler anyway). The only tricky part here was making sure that the AccumulatorServer was able to shutdown in a timely manner (i.e., stop polling for new data), and this was accomplished via minor feats of magic. I have confirmed that this patch eliminates the buildup of ephemeral sockets due to the accumulator updates. However, I did note that there were still significant sockets being created against the PySpark daemon port, but my machine was not able to create enough sockets fast enough to fail. This may not be the last time we've seen this issue, though. Author: Aaron Davidson <[email protected]> Closes #1503 from aarondav/accum and squashes the following commits: b3e12f7 [Aaron Davidson] SPARK-2282: Reuse Socket for sending accumulator updates to Pyspark
1 parent 492a195 commit ef4ff00

File tree

2 files changed

+42
-12
lines changed

2 files changed

+42
-12
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -731,19 +731,30 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
731731

732732
val bufferSize = SparkEnv.get.conf.getInt("spark.buffer.size", 65536)
733733

734+
/**
735+
* We try to reuse a single Socket to transfer accumulator updates, as they are all added
736+
* by the DAGScheduler's single-threaded actor anyway.
737+
*/
738+
@transient var socket: Socket = _
739+
740+
def openSocket(): Socket = synchronized {
741+
if (socket == null || socket.isClosed) {
742+
socket = new Socket(serverHost, serverPort)
743+
}
744+
socket
745+
}
746+
734747
override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList
735748

736749
override def addInPlace(val1: JList[Array[Byte]], val2: JList[Array[Byte]])
737-
: JList[Array[Byte]] = {
750+
: JList[Array[Byte]] = synchronized {
738751
if (serverHost == null) {
739752
// This happens on the worker node, where we just want to remember all the updates
740753
val1.addAll(val2)
741754
val1
742755
} else {
743756
// This happens on the master, where we pass the updates to Python through a socket
744-
val socket = new Socket(serverHost, serverPort)
745-
// SPARK-2282: Immediately reuse closed sockets because we create one per task.
746-
socket.setReuseAddress(true)
757+
val socket = openSocket()
747758
val in = socket.getInputStream
748759
val out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream, bufferSize))
749760
out.writeInt(val2.size)
@@ -757,7 +768,6 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
757768
if (byteRead == -1) {
758769
throw new SparkException("EOF reached before Python server acknowledged")
759770
}
760-
socket.close()
761771
null
762772
}
763773
}

python/pyspark/accumulators.py

Lines changed: 27 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@
8686
Exception:...
8787
"""
8888

89+
import select
8990
import struct
9091
import SocketServer
9192
import threading
@@ -209,19 +210,38 @@ def addInPlace(self, value1, value2):
209210

210211

211212
class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
213+
"""
214+
This handler will keep polling updates from the same socket until the
215+
server is shutdown.
216+
"""
217+
212218
def handle(self):
213219
from pyspark.accumulators import _accumulatorRegistry
214-
num_updates = read_int(self.rfile)
215-
for _ in range(num_updates):
216-
(aid, update) = pickleSer._read_with_length(self.rfile)
217-
_accumulatorRegistry[aid] += update
218-
# Write a byte in acknowledgement
219-
self.wfile.write(struct.pack("!b", 1))
220+
while not self.server.server_shutdown:
221+
# Poll every 1 second for new data -- don't block in case of shutdown.
222+
r, _, _ = select.select([self.rfile], [], [], 1)
223+
if self.rfile in r:
224+
num_updates = read_int(self.rfile)
225+
for _ in range(num_updates):
226+
(aid, update) = pickleSer._read_with_length(self.rfile)
227+
_accumulatorRegistry[aid] += update
228+
# Write a byte in acknowledgement
229+
self.wfile.write(struct.pack("!b", 1))
230+
231+
class AccumulatorServer(SocketServer.TCPServer):
232+
"""
233+
A simple TCP server that intercepts shutdown() in order to interrupt
234+
our continuous polling on the handler.
235+
"""
236+
server_shutdown = False
220237

238+
def shutdown(self):
239+
self.server_shutdown = True
240+
SocketServer.TCPServer.shutdown(self)
221241

222242
def _start_update_server():
223243
"""Start a TCP server to receive accumulator updates in a daemon thread, and returns it"""
224-
server = SocketServer.TCPServer(("localhost", 0), _UpdateRequestHandler)
244+
server = AccumulatorServer(("localhost", 0), _UpdateRequestHandler)
225245
thread = threading.Thread(target=server.serve_forever)
226246
thread.daemon = True
227247
thread.start()

0 commit comments

Comments
 (0)