Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions core/src/main/scala/org/apache/spark/SparkEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.apache.spark

import java.io.File
import java.net.Socket

import scala.collection.JavaConversions._
import scala.collection.mutable
Expand Down Expand Up @@ -102,10 +103,10 @@ class SparkEnv (
}

private[spark]
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String]) {
def destroyPythonWorker(pythonExec: String, envVars: Map[String, String], worker: Socket) {
synchronized {
val key = (pythonExec, envVars)
pythonWorkers(key).stop()
pythonWorkers.get(key).foreach(_.stopWorker(worker))
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,8 @@ private[spark] class PythonRDD(
val env = SparkEnv.get
val localdir = env.blockManager.diskBlockManager.localDirs.map(
f => f.getPath()).mkString(",")
val worker: Socket = env.createPythonWorker(pythonExec,
envVars.toMap + ("SPARK_LOCAL_DIR" -> localdir))
envVars += ("SPARK_LOCAL_DIR" -> localdir) // it's also used in monitor thread
val worker: Socket = env.createPythonWorker(pythonExec, envVars.toMap)

// Start a thread to feed the process input from our parent's iterator
val writerThread = new WriterThread(env, worker, split, context)
Expand Down Expand Up @@ -241,7 +241,7 @@ private[spark] class PythonRDD(
if (!context.completed) {
try {
logWarning("Incomplete task interrupted: Attempting to kill Python Worker")
env.destroyPythonWorker(pythonExec, envVars.toMap)
env.destroyPythonWorker(pythonExec, envVars.toMap, worker)
} catch {
case e: Exception =>
logError("Exception when trying to kill worker", e)
Expand Down Expand Up @@ -685,9 +685,8 @@ private[spark] object PythonRDD extends Logging {

/**
* Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
* This function is outdated, PySpark does not use it anymore
*/
@deprecated
@deprecated("PySpark does not use it anymore", "1.1")
def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
pyRDD.rdd.mapPartitions { iter =>
val unpickle = new Unpickler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

package org.apache.spark.api.python

import java.io.{DataInputStream, InputStream, OutputStreamWriter}
import java.lang.Runtime
import java.io.{DataOutputStream, DataInputStream, InputStream, OutputStreamWriter}
import java.net.{InetAddress, ServerSocket, Socket, SocketException}

import scala.collection.mutable
import scala.collection.JavaConversions._

import org.apache.spark._
Expand All @@ -39,6 +41,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
var daemon: Process = null
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
var daemonPort: Int = 0
var daemonWorkers = new mutable.WeakHashMap[Socket, Int]()

var simpleWorkers = new mutable.WeakHashMap[Socket, Process]()

val pythonPath = PythonUtils.mergePythonPaths(
PythonUtils.sparkPythonPath,
Expand All @@ -58,25 +63,31 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
* to avoid the high cost of forking from Java. This currently only works on UNIX-based systems.
*/
private def createThroughDaemon(): Socket = {

def createSocket(): Socket = {
val socket = new Socket(daemonHost, daemonPort)
val pid = new DataInputStream(socket.getInputStream).readInt()
if (pid < 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
daemonWorkers.put(socket, pid)
socket
}

synchronized {
// Start the daemon if it hasn't been started
startDaemon()

// Attempt to connect, restart and retry once if it fails
try {
val socket = new Socket(daemonHost, daemonPort)
val launchStatus = new DataInputStream(socket.getInputStream).readInt()
if (launchStatus != 0) {
throw new IllegalStateException("Python daemon failed to launch worker")
}
socket
createSocket()
} catch {
case exc: SocketException =>
logWarning("Failed to open socket to Python daemon:", exc)
logWarning("Assuming that daemon unexpectedly quit, attempting to restart")
stopDaemon()
startDaemon()
new Socket(daemonHost, daemonPort)
createSocket()
}
}
}
Expand Down Expand Up @@ -107,7 +118,9 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
// Wait for it to connect to our socket
serverSocket.setSoTimeout(10000)
try {
return serverSocket.accept()
val socket = serverSocket.accept()
simpleWorkers.put(socket, worker)
return socket
} catch {
case e: Exception =>
throw new SparkException("Python worker did not connect back in time", e)
Expand Down Expand Up @@ -189,19 +202,40 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String

private def stopDaemon() {
synchronized {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}
if (useDaemon) {
// Request shutdown of existing daemon by sending SIGTERM
if (daemon != null) {
daemon.destroy()
}

daemon = null
daemonPort = 0
daemon = null
daemonPort = 0
} else {
simpleWorkers.mapValues(_.destroy())
}
}
}

def stop() {
stopDaemon()
}

def stopWorker(worker: Socket) {
if (useDaemon) {
if (daemon != null) {
daemonWorkers.get(worker).foreach { pid =>
// tell daemon to kill worker by pid
val output = new DataOutputStream(daemon.getOutputStream)
output.writeInt(pid)
output.flush()
daemon.getOutputStream.flush()
}
}
} else {
simpleWorkers.get(worker).foreach(_.destroy())
}
worker.close()
}
}

private object PythonWorkerFactory {
Expand Down
24 changes: 18 additions & 6 deletions python/pyspark/daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from socket import AF_INET, SOCK_STREAM, SOMAXCONN
from signal import SIGHUP, SIGTERM, SIGCHLD, SIG_DFL, SIG_IGN
from pyspark.worker import main as worker_main
from pyspark.serializers import write_int
from pyspark.serializers import read_int, write_int


def compute_real_exit_code(exit_code):
Expand Down Expand Up @@ -67,7 +67,8 @@ def waitSocketClose(sock):
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
exit_code = 0
try:
write_int(0, outfile) # Acknowledge that the fork was successful
# Acknowledge that the fork was successful
write_int(os.getpid(), outfile)
outfile.flush()
worker_main(infile, outfile)
except SystemExit as exc:
Expand Down Expand Up @@ -125,14 +126,23 @@ def handle_sigchld(*args):
else:
raise
if 0 in ready_fds:
# Spark told us to exit by closing stdin
shutdown(0)
try:
worker_pid = read_int(sys.stdin)
except EOFError:
# Spark told us to exit by closing stdin
shutdown(0)
try:
os.kill(worker_pid, signal.SIGKILL)
except OSError:
pass # process already died


if listen_sock in ready_fds:
sock, addr = listen_sock.accept()
# Launch a worker process
try:
fork_return_code = os.fork()
if fork_return_code == 0:
pid = os.fork()
if pid == 0:
listen_sock.close()
try:
worker(sock)
Expand All @@ -143,11 +153,13 @@ def handle_sigchld(*args):
os._exit(0)
else:
sock.close()

except OSError as e:
print >> sys.stderr, "Daemon failed to fork PySpark worker: %s" % e
outfile = os.fdopen(os.dup(sock.fileno()), "a+", 65536)
write_int(-1, outfile) # Signal that the fork failed
outfile.flush()
outfile.close()
sock.close()
finally:
shutdown(1)
Expand Down
51 changes: 51 additions & 0 deletions python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -783,6 +783,57 @@ def test_termination_sigterm(self):
self.do_termination_test(lambda daemon: os.kill(daemon.pid, SIGTERM))


class TestWorker(PySparkTestCase):
def test_cancel_task(self):
temp = tempfile.NamedTemporaryFile(delete=True)
temp.close()
path = temp.name
def sleep(x):
import os, time
with open(path, 'w') as f:
f.write("%d %d" % (os.getppid(), os.getpid()))
time.sleep(100)

# start job in background thread
def run():
self.sc.parallelize(range(1)).foreach(sleep)
import threading
t = threading.Thread(target=run)
t.daemon = True
t.start()

daemon_pid, worker_pid = 0, 0
while True:
if os.path.exists(path):
data = open(path).read().split(' ')
daemon_pid, worker_pid = map(int, data)
break
time.sleep(0.1)

# cancel jobs
self.sc.cancelAllJobs()
t.join()

for i in range(50):
try:
os.kill(worker_pid, 0)
time.sleep(0.1)
except OSError:
break # worker was killed
else:
self.fail("worker has not been killed after 5 seconds")

try:
os.kill(daemon_pid, 0)
except OSError:
self.fail("daemon had been killed")

def test_fd_leak(self):
N = 1100 # fd limit is 1024 by default
rdd = self.sc.parallelize(range(N), N)
self.assertEquals(N, rdd.count())


class TestSparkSubmit(unittest.TestCase):
def setUp(self):
self.programDir = tempfile.mkdtemp()
Expand Down