Skip to content

Commit fc1c4e7

Browse files
committed
[SPARK-25253][PYSPARK] Refactor local connection & auth code
This eliminates some duplication in the code to connect to a server on localhost to talk directly to the jvm. Also it gives consistent ipv6 and error handling. Two other incidental changes, that shouldn't matter: 1) python barrier tasks perform authentication immediately (rather than waiting for the BARRIER_FUNCTION indicator) 2) for `rdd._load_from_socket`, the timeout is only increased after authentication. Closes #22247 from squito/py_connection_refactor. Authored-by: Imran Rashid <[email protected]> Signed-off-by: hyukjinkwon <[email protected]> (cherry picked from commit 38391c9) (cherry picked from commit a2a54a5)
1 parent bd12eb7 commit fc1c4e7

File tree

3 files changed

+35
-28
lines changed

3 files changed

+35
-28
lines changed

python/pyspark/java_gateway.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def killChild():
133133
return gateway
134134

135135

136-
def do_server_auth(conn, auth_secret):
136+
def _do_server_auth(conn, auth_secret):
137137
"""
138138
Performs the authentication protocol defined by the SocketAuthHelper class on the given
139139
file-like object 'conn'.
@@ -144,3 +144,33 @@ def do_server_auth(conn, auth_secret):
144144
if reply != "ok":
145145
conn.close()
146146
raise Exception("Unexpected reply from iterator server.")
147+
148+
149+
def local_connect_and_auth(port, auth_secret):
150+
"""
151+
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
152+
Handles IPV4 & IPV6, does some error handling.
153+
:param port
154+
:param auth_secret
155+
:return: a tuple with (sockfile, sock)
156+
"""
157+
sock = None
158+
errors = []
159+
# Support for both IPv4 and IPv6.
160+
# On most of IPv6-ready systems, IPv6 will take precedence.
161+
for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
162+
af, socktype, proto, _, sa = res
163+
try:
164+
sock = socket.socket(af, socktype, proto)
165+
sock.settimeout(15)
166+
sock.connect(sa)
167+
sockfile = sock.makefile("rwb", 65536)
168+
_do_server_auth(sockfile, auth_secret)
169+
return (sockfile, sock)
170+
except socket.error as e:
171+
emsg = _exception_message(e)
172+
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
173+
sock.close()
174+
sock = None
175+
else:
176+
raise Exception("could not open socket: %s" % errors)

python/pyspark/rdd.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
else:
4040
from itertools import imap as map, ifilter as filter
4141

42-
from pyspark.java_gateway import do_server_auth
42+
from pyspark.java_gateway import local_connect_and_auth
4343
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
4444
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
4545
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
@@ -122,30 +122,10 @@ def _parse_memory(s):
122122

123123

124124
def _load_from_socket(sock_info, serializer):
125-
port, auth_secret = sock_info
126-
sock = None
127-
# Support for both IPv4 and IPv6.
128-
# On most of IPv6-ready systems, IPv6 will take precedence.
129-
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
130-
af, socktype, proto, canonname, sa = res
131-
sock = socket.socket(af, socktype, proto)
132-
try:
133-
sock.settimeout(15)
134-
sock.connect(sa)
135-
except socket.error:
136-
sock.close()
137-
sock = None
138-
continue
139-
break
140-
if not sock:
141-
raise Exception("could not open socket")
125+
(sockfile, sock) = local_connect_and_auth(*sock_info)
142126
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
143127
# operation, it will very possibly fail. See SPARK-18281.
144128
sock.settimeout(None)
145-
146-
sockfile = sock.makefile("rwb", 65536)
147-
do_server_auth(sockfile, auth_secret)
148-
149129
# The socket will be automatically closed when garbage-collected.
150130
return serializer.load_stream(sockfile)
151131

python/pyspark/worker.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
from pyspark.accumulators import _accumulatorRegistry
2929
from pyspark.broadcast import Broadcast, _broadcastRegistry
30-
from pyspark.java_gateway import do_server_auth
30+
from pyspark.java_gateway import local_connect_and_auth
3131
from pyspark.taskcontext import TaskContext
3232
from pyspark.files import SparkFiles
3333
from pyspark.serializers import write_with_length, write_int, read_long, \
@@ -212,8 +212,5 @@ def process():
212212
# Read information about how to connect back to the JVM from the environment.
213213
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
214214
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
215-
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
216-
sock.connect(("127.0.0.1", java_port))
217-
sock_file = sock.makefile("rwb", 65536)
218-
do_server_auth(sock_file, auth_secret)
215+
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
219216
main(sock_file, sock_file)

0 commit comments

Comments
 (0)