diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala index 151c910bf1ae..a99e7f09455e 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala @@ -206,6 +206,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( sock = serverSocket.get.accept() // Wait for function call from python side. sock.setSoTimeout(10000) + authHelper.authClient(sock) val input = new DataInputStream(sock.getInputStream()) input.readInt() match { case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => @@ -324,8 +325,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT]( def barrierAndServe(sock: Socket): Unit = { require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.") - authHelper.authClient(sock) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { context.asInstanceOf[BarrierTaskContext].barrier() diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index fa2d5e8db716..b06503b53be9 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -134,7 +134,7 @@ def killChild(): return gateway -def do_server_auth(conn, auth_secret): +def _do_server_auth(conn, auth_secret): """ Performs the authentication protocol defined by the SocketAuthHelper class on the given file-like object 'conn'. @@ -147,6 +147,36 @@ def do_server_auth(conn, auth_secret): raise Exception("Unexpected reply from iterator server.") +def local_connect_and_auth(port, auth_secret): + """ + Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection. + Handles IPV4 & IPV6, does some error handling. + :param port + :param auth_secret + :return: a tuple with (sockfile, sock) + """ + sock = None + errors = [] + # Support for both IPv4 and IPv6. + # On most of IPv6-ready systems, IPv6 will take precedence. + for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + af, socktype, proto, _, sa = res + try: + sock = socket.socket(af, socktype, proto) + sock.settimeout(15) + sock.connect(sa) + sockfile = sock.makefile("rwb", 65536) + _do_server_auth(sockfile, auth_secret) + return (sockfile, sock) + except socket.error as e: + emsg = _exception_message(e) + errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) + sock.close() + sock = None + else: + raise Exception("could not open socket: %s" % errors) + + def ensure_callback_server_started(gw): """ Start callback server if not already started. The callback server is needed if the Java diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index b061074a28ab..380475e706fb 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -39,7 +39,7 @@ else: from itertools import imap as map, ifilter as filter -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \ @@ -141,33 +141,10 @@ def _parse_memory(s): def _load_from_socket(sock_info, serializer): - port, auth_secret = sock_info - sock = None - errors = [] - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - sock.settimeout(15) - sock.connect(sa) - except socket.error as e: - emsg = _exception_message(e) - errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg)) - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket: %s" % errors) + (sockfile, sock) = local_connect_and_auth(*sock_info) # The RDD materialization time is unpredicable, if we set a timeout for socket reading # operation, it will very possibly fail. See SPARK-18281. sock.settimeout(None) - - sockfile = sock.makefile("rwb", 65536) - do_server_auth(sockfile, auth_secret) - # The socket will be automatically closed when garbage-collected. return serializer.load_stream(sockfile) diff --git a/python/pyspark/taskcontext.py b/python/pyspark/taskcontext.py index c0312e5265c6..53fc2b29e066 100644 --- a/python/pyspark/taskcontext.py +++ b/python/pyspark/taskcontext.py @@ -18,7 +18,7 @@ from __future__ import print_function import socket -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.serializers import write_int, UTF8Deserializer @@ -108,38 +108,14 @@ def _load_from_socket(port, auth_secret): """ Load data from a given socket, this is a blocking method thus only return when the socket connection has been closed. - - This is copied from context.py, while modified the message protocol. """ - sock = None - # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - sock = socket.socket(af, socktype, proto) - try: - # Do not allow timeout for socket reading operation. - sock.settimeout(None) - sock.connect(sa) - except socket.error: - sock.close() - sock = None - continue - break - if not sock: - raise Exception("could not open socket") - - # We don't really need a socket file here, it's just for convenience that we can reuse the - # do_server_auth() function and data serialization methods. - sockfile = sock.makefile("rwb", 65536) - + (sockfile, sock) = local_connect_and_auth(port, auth_secret) + # The barrier() call may block forever, so no timeout + sock.settimeout(None) # Make a barrier() function call. write_int(BARRIER_FUNCTION, sockfile) sockfile.flush() - # Do server auth. - do_server_auth(sockfile, auth_secret) - # Collect result. res = UTF8Deserializer().loads(sockfile) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d54a5b8e396e..fcca8708a232 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -27,7 +27,7 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import do_server_auth +from pyspark.java_gateway import local_connect_and_auth from pyspark.taskcontext import BarrierTaskContext, TaskContext from pyspark.files import SparkFiles from pyspark.rdd import PythonEvalType @@ -364,8 +364,5 @@ def process(): # Read information about how to connect back to the JVM from the environment. java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.connect(("127.0.0.1", java_port)) - sock_file = sock.makefile("rwb", 65536) - do_server_auth(sock_file, auth_secret) + (sock_file, _) = local_connect_and_auth(java_port, auth_secret) main(sock_file, sock_file)