Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
sock = serverSocket.get.accept()
// Wait for function call from python side.
sock.setSoTimeout(10000)
authHelper.authClient(sock)
val input = new DataInputStream(sock.getInputStream())
input.readInt() match {
case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
Expand Down Expand Up @@ -324,8 +325,6 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
def barrierAndServe(sock: Socket): Unit = {
require(serverSocket.isDefined, "No available ServerSocket to redirect the barrier() call.")

authHelper.authClient(sock)

val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream))
try {
context.asInstanceOf[BarrierTaskContext].barrier()
Expand Down
35 changes: 34 additions & 1 deletion python/pyspark/java_gateway.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def killChild():
return gateway


def do_server_auth(conn, auth_secret):
def _do_server_auth(conn, auth_secret):
"""
Performs the authentication protocol defined by the SocketAuthHelper class on the given
file-like object 'conn'.
Expand All @@ -147,6 +147,39 @@ def do_server_auth(conn, auth_secret):
raise Exception("Unexpected reply from iterator server.")


def local_connect_and_auth(port, auth_secret):
"""
Connect to local host, authenticate with it, and return a (sockfile,sock) for that connection.
Handles IPV4 & IPV6, does some error handling.
:param port
:param auth_secret
:return: a tuple with (sockfile, sock)
"""
sock = None
errors = []
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, _, sa = res
sock = socket.socket(af, socktype, proto)
try:
sock.settimeout(15)
sock.connect(sa)
except socket.error as e:
emsg = _exception_message(e)
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
sock.close()
sock = None
continue
break
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slight shorter (and more "python-compliant"?):

  • move the socket initialization (and the return) inside the try
  • get rid of the continue
  • use an else instead of the condition below

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if not sock:
raise Exception("could not open socket: %s" % errors)

sockfile = sock.makefile("rwb", 65536)
_do_server_auth(sockfile, auth_secret)
return (sockfile, sock)


def ensure_callback_server_started(gw):
"""
Start callback server if not already started. The callback server is needed if the Java
Expand Down
27 changes: 2 additions & 25 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
else:
from itertools import imap as map, ifilter as filter

from pyspark.java_gateway import do_server_auth
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
BatchedSerializer, CloudPickleSerializer, PairDeserializer, \
PickleSerializer, pack_long, AutoBatchedSerializer, write_with_length, \
Expand Down Expand Up @@ -141,33 +141,10 @@ def _parse_memory(s):


def _load_from_socket(sock_info, serializer):
port, auth_secret = sock_info
sock = None
errors = []
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
sock.settimeout(15)
sock.connect(sa)
except socket.error as e:
emsg = _exception_message(e)
errors.append("tried to connect to %s, but an error occured: %s" % (sa, emsg))
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket: %s" % errors)
(sockfile, sock) = local_connect_and_auth(*sock_info)
# The RDD materialization time is unpredicable, if we set a timeout for socket reading
# operation, it will very possibly fail. See SPARK-18281.
sock.settimeout(None)

sockfile = sock.makefile("rwb", 65536)
do_server_auth(sockfile, auth_secret)

# The socket will be automatically closed when garbage-collected.
return serializer.load_stream(sockfile)

Expand Down
30 changes: 2 additions & 28 deletions python/pyspark/taskcontext.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from __future__ import print_function
import socket

from pyspark.java_gateway import do_server_auth
from pyspark.java_gateway import local_connect_and_auth
from pyspark.serializers import write_int, UTF8Deserializer


Expand Down Expand Up @@ -108,38 +108,12 @@ def _load_from_socket(port, auth_secret):
"""
Load data from a given socket, this is a blocking method thus only return when the socket
connection has been closed.

This is copied from context.py, while modified the message protocol.
"""
sock = None
# Support for both IPv4 and IPv6.
# On most of IPv6-ready systems, IPv6 will take precedence.
for res in socket.getaddrinfo("localhost", port, socket.AF_UNSPEC, socket.SOCK_STREAM):
af, socktype, proto, canonname, sa = res
sock = socket.socket(af, socktype, proto)
try:
# Do not allow timeout for socket reading operation.
sock.settimeout(None)
sock.connect(sa)
except socket.error:
sock.close()
sock = None
continue
break
if not sock:
raise Exception("could not open socket")

# We don't really need a socket file here, it's just for convenience that we can reuse the
# do_server_auth() function and data serialization methods.
sockfile = sock.makefile("rwb", 65536)

(sockfile, sock) = local_connect_and_auth(port, auth_secret)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We must set sock timeout to None to allow barrier() call blocking forever.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good catch, thanks! updated

# Make a barrier() function call.
write_int(BARRIER_FUNCTION, sockfile)
sockfile.flush()

# Do server auth.
do_server_auth(sockfile, auth_secret)

# Collect result.
res = UTF8Deserializer().loads(sockfile)

Expand Down
7 changes: 2 additions & 5 deletions python/pyspark/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.java_gateway import do_server_auth
from pyspark.java_gateway import local_connect_and_auth
from pyspark.taskcontext import BarrierTaskContext, TaskContext
from pyspark.files import SparkFiles
from pyspark.rdd import PythonEvalType
Expand Down Expand Up @@ -364,8 +364,5 @@ def process():
# Read information about how to connect back to the JVM from the environment.
java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(("127.0.0.1", java_port))
sock_file = sock.makefile("rwb", 65536)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vanzin, BTW, did you test this on Windows too?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I quickly tested and seems working fine. Please ignore this comment.

do_server_auth(sock_file, auth_secret)
(sock_file, _) = local_connect_and_auth(java_port, auth_secret)
main(sock_file, sock_file)