diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 2beca6fddb27..69a74146fad1 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -77,7 +77,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String @GuardedBy("self") private var daemon: Process = null - val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1)) + val daemonHost = InetAddress.getLoopbackAddress() @GuardedBy("self") private var daemonPort: Int = 0 @GuardedBy("self") @@ -153,7 +153,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String private def createSimpleWorker(): (Socket, Option[Int]) = { var serverSocket: ServerSocket = null try { - serverSocket = new ServerSocket(0, 1, InetAddress.getByAddress(Array(127, 0, 0, 1))) + serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) // Create and start the worker val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) @@ -164,6 +164,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.put("PYTHONUNBUFFERED", "YES") workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) + if (Utils.preferIPv6) { + workerEnv.put("SPARK_PREFER_IPV6", "True") + } val worker = pb.start() // Redirect worker stdout and stderr @@ -211,6 +214,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String workerEnv.putAll(envVars.asJava) workerEnv.put("PYTHONPATH", pythonPath) workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) + if (Utils.preferIPv6) { + workerEnv.put("SPARK_PREFER_IPV6", "True") + } // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: workerEnv.put("PYTHONUNBUFFERED", "YES") daemon = pb.start() diff --git a/python/pyspark/daemon.py b/python/pyspark/daemon.py index 6676bf911937..81b6481f70ea 100644 --- a/python/pyspark/daemon.py +++ b/python/pyspark/daemon.py @@ -25,7 +25,7 @@ import time import gc from errno import EINTR, EAGAIN -from socket import AF_INET, SOCK_STREAM, SOMAXCONN +from socket import AF_INET, AF_INET6, SOCK_STREAM, SOMAXCONN from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN, SIGINT from pyspark.worker import main as worker_main @@ -86,11 +86,17 @@ def manager(): # Create a new process group to corral our children os.setpgid(0, 0) - # Create a listening socket on the AF_INET loopback interface - listen_sock = socket.socket(AF_INET, SOCK_STREAM) - listen_sock.bind(("127.0.0.1", 0)) - listen_sock.listen(max(1024, SOMAXCONN)) - listen_host, listen_port = listen_sock.getsockname() + # Create a listening socket on the loopback interface + if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true": + listen_sock = socket.socket(AF_INET6, SOCK_STREAM) + listen_sock.bind(("::1", 0, 0, 0)) + listen_sock.listen(max(1024, SOMAXCONN)) + listen_host, listen_port, _, _ = listen_sock.getsockname() + else: + listen_sock = socket.socket(AF_INET, SOCK_STREAM) + listen_sock.bind(("127.0.0.1", 0)) + listen_sock.listen(max(1024, SOMAXCONN)) + listen_host, listen_port = listen_sock.getsockname() # re-open stdin/stdout in 'wb' mode stdin_bin = os.fdopen(sys.stdin.fileno(), "rb", 4) diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index a41ccfafde4e..aee206dd6b3e 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -193,8 +193,10 @@ def local_connect_and_auth(port, auth_secret): sock = None errors = [] # Support for both IPv4 and IPv6. - # On most of IPv6-ready systems, IPv6 will take precedence. - for res in socket.getaddrinfo("127.0.0.1", port, socket.AF_UNSPEC, socket.SOCK_STREAM): + addr = "127.0.0.1" + if os.environ.get("SPARK_PREFER_IPV6", "false").lower() == "true": + addr = "::1" + for res in socket.getaddrinfo(addr, port, socket.AF_UNSPEC, socket.SOCK_STREAM): af, socktype, proto, _, sa = res try: sock = socket.socket(af, socktype, proto)