Skip to content

Commit c232ec6

Browse files
committed
[SPARK-25253][PYSPARK] Refactor local connection & auth code
1 parent 6193a20 commit c232ec6

File tree

5 files changed

+41
-61
lines changed

5 files changed

+41
-61
lines changed

core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
206206
sock = serverSocket.get.accept()
207207
// Wait for function call from python side.
208208
sock.setSoTimeout(10000)
209+
authHelper.authClient(sock)
209210
val input = new DataInputStream(sock.getInputStream())
210211
input.readInt() match {
211212
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
@@ -324,8 +325,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
324325
def barrierAndServe(sock: Socket): Unit = {
325326
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")
326327

327-
authHelper.authClient(sock)
328-
329328
val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
330329
try {
331330
context.asInstanceOf[BarrierTaskContext].barrier()

python/pyspark/java_gateway.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def killChild():
134134
return gateway
135135

136136

137-
def do_server_auth(conn, auth_secret):
137+
def _do_server_auth(conn, auth_secret):
138138
"""
139139
Performs the authentication protocol defined by the SocketAuthHelper class on the given
140140
file-like object 'conn'.
@@ -147,6 +147,39 @@ def do_server_auth(conn, auth_secret):
147147
raise Exception("Unexpected reply from iterator server.")
148148

149149

150+
def local_connect_and_auth(sock_info):
151+
"""
152+
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
153+
Handles IPV4 & IPV6, does some error handling.
154+
:param sock_info: a tuple of (port, auth_secret) for connecting
155+
:return: a tuple with (sockfile, sock)
156+
"""
157+
port, auth_secret = sock_info
158+
sock = None
159+
errors = []
160+
# Support for both IPv4 and IPv6.
161+
# On most of IPv6-ready systems, IPv6 will take precedence.
162+
for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
163+
af, socktype, proto, canonname, sa = res
164+
sock = socket.socket(af, socktype, proto)
165+
try:
166+
sock.settimeout(15)
167+
sock.connect(sa)
168+
except socket.error as e:
169+
emsg = _exception_message(e)
170+
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
171+
sock.close()
172+
sock = None
173+
continue
174+
break
175+
if not sock:
176+
raise Exception("could not open socket: %s" % errors)
177+
178+
sockfile = sock.makefile("rwb", 65536)
179+
_do_server_auth(sockfile, auth_secret)
180+
return (sockfile, sock)
181+
182+
150183
def ensure_callback_server_started(gw):
151184
"""
152185
Start callback server if not already started. The callback server is needed if the Java

python/pyspark/rdd.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
else:
4040
from itertools import imap as map, ifilter as filter
4141

42-
from pyspark.java_gateway import do_server_auth
42+
from pyspark.java_gateway import local_connect_and_auth
4343
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
4444
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
4545
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@@ -141,33 +141,10 @@ def _parse_memory(s):
141141

142142

143143
def _load_from_socket(sock_info, serializer):
144-
port, auth_secret = sock_info
145-
sock = None
146-
errors = []
147-
# Support for both IPv4 and IPv6.
148-
# On most of IPv6-ready systems, IPv6 will take precedence.
149-
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
150-
af, socktype, proto, canonname, sa = res
151-
sock = socket.socket(af, socktype, proto)
152-
try:
153-
sock.settimeout(15)
154-
sock.connect(sa)
155-
except socket.error as e:
156-
emsg = _exception_message(e)
157-
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
158-
sock.close()
159-
sock = None
160-
continue
161-
break
162-
if not sock:
163-
raise Exception("could not open socket: %s" % errors)
144+
(sockfile, sock) = local_connect_and_auth(sock_info)
164145
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
165146
# operation, it will very possibly fail. See SPARK-18281.
166147
sock.settimeout(None)
167-
168-
sockfile = sock.makefile("rwb", 65536)
169-
do_server_auth(sockfile, auth_secret)
170-
171148
# The socket will be automatically closed when garbage-collected.
172149
return serializer.load_stream(sockfile)
173150

python/pyspark/taskcontext.py

Lines changed: 2 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from __future__ import print_function
1919
import socket
2020

21-
from pyspark.java_gateway import do_server_auth
21+
from pyspark.java_gateway import local_connect_and_auth
2222
from pyspark.serializers import write_int, UTF8Deserializer
2323

2424

@@ -108,38 +108,12 @@ def _load_from_socket(port, auth_secret):
108108
"""
109109
Load data from a given socket, this is a blocking method thus only return when the socket
110110
connection has been closed.
111-
112-
This is copied from context.py, while modified the message protocol.
113111
"""
114-
sock = None
115-
# Support for both IPv4 and IPv6.
116-
# On most of IPv6-ready systems, IPv6 will take precedence.
117-
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
118-
af, socktype, proto, canonname, sa = res
119-
sock = socket.socket(af, socktype, proto)
120-
try:
121-
# Do not allow timeout for socket reading operation.
122-
sock.settimeout(None)
123-
sock.connect(sa)
124-
except socket.error:
125-
sock.close()
126-
sock = None
127-
continue
128-
break
129-
if not sock:
130-
raise Exception("could not open socket")
131-
132-
# We don't really need a socket file here, it's just for convenience that we can reuse the
133-
# do_server_auth() function and data serialization methods.
134-
sockfile = sock.makefile("rwb", 65536)
135-
112+
(sockfile, sock) = local_connect_and_auth((port, auth_secret))
136113
# Make a barrier() function call.
137114
write_int(BARRIER_FUNCTION, sockfile)
138115
sockfile.flush()
139116

140-
# Do server auth.
141-
do_server_auth(sockfile, auth_secret)
142-
143117
# Collect result.
144118
res = UTF8Deserializer().loads(sockfile)
145119

python/pyspark/worker.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pyspark.accumulators import _accumulatorRegistry
2929
from pyspark.broadcast import Broadcast, _broadcastRegistry
30-
from pyspark.java_gateway import do_server_auth
30+
from pyspark.java_gateway import local_connect_and_auth
3131
from pyspark.taskcontext import BarrierTaskContext, TaskContext
3232
from pyspark.files import SparkFiles
3333
from pyspark.rdd import PythonEvalType
@@ -364,8 +364,5 @@ def process():
364364
# Read information about how to connect back to the JVM from the environment.
365365
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
366366
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
367-
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
368-
sock.connect(("127.0.0.1", java_port))
369-
sock_file = sock.makefile("rwb", 65536)
370-
do_server_auth(sock_file, auth_secret)
367+
(sock_file, _) = local_connect_and_auth((java_port, auth_secret))
371368
main(sock_file, sock_file)

0 commit comments

Comments
 (0)