Skip to content

Commit 7e4b266

Browse files
committed
Merge pull request #563 from jey/python-optimization
Optimize PySpark worker invocation
2 parents b350f34 + c75bed0 commit 7e4b266

File tree

7 files changed

+382
-61
lines changed

7 files changed

+382
-61
lines changed

core/src/main/scala/spark/SparkEnv.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package spark
22

3+
import collection.mutable
4+
import serializer.Serializer
5+
36
import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
47
import akka.remote.RemoteActorRefProvider
58

@@ -9,6 +12,7 @@ import spark.storage.BlockManagerMaster
912
import spark.network.ConnectionManager
1013
import spark.serializer.{Serializer, SerializerManager}
1114
import spark.util.AkkaUtils
15+
import spark.api.python.PythonWorkerFactory
1216

1317

1418
/**
@@ -37,7 +41,10 @@ class SparkEnv (
3741
// If executorId is NOT found, return defaultHostPort
3842
var executorIdToHostPort: Option[(String, String) => String]) {
3943

44+
private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]()
45+
4046
def stop() {
47+
pythonWorkers.foreach { case(key, worker) => worker.stop() }
4148
httpFileServer.stop()
4249
mapOutputTracker.stop()
4350
shuffleFetcher.stop()
@@ -50,6 +57,11 @@ class SparkEnv (
5057
actorSystem.awaitTermination()
5158
}
5259

60+
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
61+
synchronized {
62+
pythonWorkers.getOrElseUpdate((pythonExec, envVars), new PythonWorkerFactory(pythonExec, envVars)).create()
63+
}
64+
}
5365

5466
def resolveExecutorIdToHostPort(executorId: String, defaultHostPort: String): String = {
5567
val env = SparkEnv.get

core/src/main/scala/spark/api/python/PythonRDD.scala

Lines changed: 36 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,9 @@ package spark.api.python
22

33
import java.io._
44
import java.net._
5-
import java.util.{List => JList, ArrayList => JArrayList, Collections}
5+
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
66

77
import scala.collection.JavaConversions._
8-
import scala.io.Source
98

109
import spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
1110
import spark.broadcast.Broadcast
@@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
1615
private[spark] class PythonRDD[T: ClassManifest](
1716
parent: RDD[T],
1817
command: Seq[String],
19-
envVars: java.util.Map[String, String],
18+
envVars: JMap[String, String],
2019
preservePartitoning: Boolean,
2120
pythonExec: String,
2221
broadcastVars: JList[Broadcast[Array[Byte]]],
@@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
2524

2625
// Similar to Runtime.exec(), if we are given a single string, split it into words
2726
// using a standard StringTokenizer (i.e. by spaces)
28-
def this(parent: RDD[T], command: String, envVars: java.util.Map[String, String],
27+
def this(parent: RDD[T], command: String, envVars: JMap[String, String],
2928
preservePartitoning: Boolean, pythonExec: String,
3029
broadcastVars: JList[Broadcast[Array[Byte]]],
3130
accumulator: Accumulator[JList[Array[Byte]]]) =
@@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
3635

3736
override val partitioner = if (preservePartitoning) parent.partitioner else None
3837

39-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
40-
val SPARK_HOME = new ProcessBuilder().environment().get("SPARK_HOME")
41-
42-
val pb = new ProcessBuilder(Seq(pythonExec, SPARK_HOME + "/python/pyspark/worker.py"))
43-
// Add the environmental variables to the process.
44-
val currentEnvVars = pb.environment()
45-
46-
for ((variable, value) <- envVars) {
47-
currentEnvVars.put(variable, value)
48-
}
4938

50-
val proc = pb.start()
39+
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
40+
val startTime = System.currentTimeMillis
5141
val env = SparkEnv.get
52-
53-
// Start a thread to print the process's stderr to ours
54-
new Thread("stderr reader for " + pythonExec) {
55-
override def run() {
56-
for (line <- Source.fromInputStream(proc.getErrorStream).getLines) {
57-
System.err.println(line)
58-
}
59-
}
60-
}.start()
42+
val worker = env.createPythonWorker(pythonExec, envVars.toMap)
6143

6244
// Start a thread to feed the process input from our parent's iterator
6345
new Thread("stdin writer for " + pythonExec) {
6446
override def run() {
6547
SparkEnv.set(env)
66-
val out = new PrintWriter(proc.getOutputStream)
67-
val dOut = new DataOutputStream(proc.getOutputStream)
48+
val out = new PrintWriter(worker.getOutputStream)
49+
val dOut = new DataOutputStream(worker.getOutputStream)
6850
// Partition index
6951
dOut.writeInt(split.index)
7052
// sparkFilesDir
@@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
8870
}
8971
dOut.flush()
9072
out.flush()
91-
proc.getOutputStream.close()
73+
worker.shutdownOutput()
9274
}
9375
}.start()
9476

9577
// Return an iterator that read lines from the process's stdout
96-
val stream = new DataInputStream(proc.getInputStream)
78+
val stream = new DataInputStream(worker.getInputStream)
9779
return new Iterator[Array[Byte]] {
9880
def next(): Array[Byte] = {
9981
val obj = _nextObj
100-
_nextObj = read()
82+
if (hasNext) {
83+
// FIXME: can deadlock if worker is waiting for us to
84+
// respond to current message (currently irrelevant because
85+
// output is shutdown before we read any input)
86+
_nextObj = read()
87+
}
10188
obj
10289
}
10390

@@ -108,30 +95,39 @@ private[spark] class PythonRDD[T: ClassManifest](
10895
val obj = new Array[Byte](length)
10996
stream.readFully(obj)
11097
obj
98+
case -3 =>
99+
// Timing data from worker
100+
val bootTime = stream.readLong()
101+
val initTime = stream.readLong()
102+
val finishTime = stream.readLong()
103+
val boot = bootTime - startTime
104+
val init = initTime - bootTime
105+
val finish = finishTime - initTime
106+
val total = finishTime - startTime
107+
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
108+
read
111109
case -2 =>
112110
// Signals that an exception has been thrown in python
113111
val exLength = stream.readInt()
114112
val obj = new Array[Byte](exLength)
115113
stream.readFully(obj)
116114
throw new PythonException(new String(obj))
117115
case -1 =>
118-
// We've finished the data section of the output, but we can still read some
119-
// accumulator updates; let's do that, breaking when we get EOFException
120-
while (true) {
121-
val len2 = stream.readInt()
116+
// We've finished the data section of the output, but we can still
117+
// read some accumulator updates; let's do that, breaking when we
118+
// get a negative length record.
119+
var len2 = stream.readInt()
120+
while (len2 >= 0) {
122121
val update = new Array[Byte](len2)
123122
stream.readFully(update)
124123
accumulator += Collections.singletonList(update)
124+
len2 = stream.readInt()
125125
}
126126
new Array[Byte](0)
127127
}
128128
} catch {
129129
case eof: EOFException => {
130-
val exitStatus = proc.waitFor()
131-
if (exitStatus != 0) {
132-
throw new Exception("Subprocess exited with status " + exitStatus)
133-
}
134-
new Array[Byte](0)
130+
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
135131
}
136132
case e => throw e
137133
}
@@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
159155
override def compute(split: Partition, context: TaskContext) =
160156
prev.iterator(split, context).grouped(2).map {
161157
case Seq(a, b) => (a, b)
162-
case x => throw new Exception("PairwiseRDD: unexpected value: " + x)
158+
case x => throw new SparkException("PairwiseRDD: unexpected value: " + x)
163159
}
164160
val asJavaPairRDD : JavaPairRDD[Array[Byte], Array[Byte]] = JavaPairRDD.fromRDD(this)
165161
}
@@ -215,7 +211,7 @@ private[spark] object PythonRDD {
215211
dOut.write(s)
216212
dOut.writeByte(Pickle.STOP)
217213
} else {
218-
throw new Exception("Unexpected RDD type")
214+
throw new SparkException("Unexpected RDD type")
219215
}
220216
}
221217

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
package spark.api.python
2+
3+
import java.io.{DataInputStream, IOException}
4+
import java.net.{Socket, SocketException, InetAddress}
5+
6+
import scala.collection.JavaConversions._
7+
8+
import spark._
9+
10+
private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String, String])
11+
extends Logging {
12+
var daemon: Process = null
13+
val daemonHost = InetAddress.getByAddress(Array(127, 0, 0, 1))
14+
var daemonPort: Int = 0
15+
16+
def create(): Socket = {
17+
synchronized {
18+
// Start the daemon if it hasn't been started
19+
startDaemon()
20+
21+
// Attempt to connect, restart and retry once if it fails
22+
try {
23+
new Socket(daemonHost, daemonPort)
24+
} catch {
25+
case exc: SocketException => {
26+
logWarning("Python daemon unexpectedly quit, attempting to restart")
27+
stopDaemon()
28+
startDaemon()
29+
new Socket(daemonHost, daemonPort)
30+
}
31+
case e => throw e
32+
}
33+
}
34+
}
35+
36+
def stop() {
37+
stopDaemon()
38+
}
39+
40+
private def startDaemon() {
41+
synchronized {
42+
// Is it already running?
43+
if (daemon != null) {
44+
return
45+
}
46+
47+
try {
48+
// Create and start the daemon
49+
val sparkHome = new ProcessBuilder().environment().get("SPARK_HOME")
50+
val pb = new ProcessBuilder(Seq(pythonExec, sparkHome + "/python/pyspark/daemon.py"))
51+
val workerEnv = pb.environment()
52+
workerEnv.putAll(envVars)
53+
daemon = pb.start()
54+
daemonPort = new DataInputStream(daemon.getInputStream).readInt()
55+
56+
// Redirect the stderr to ours
57+
new Thread("stderr reader for " + pythonExec) {
58+
override def run() {
59+
scala.util.control.Exception.ignoring(classOf[IOException]) {
60+
// FIXME HACK: We copy the stream on the level of bytes to
61+
// attempt to dodge encoding problems.
62+
val in = daemon.getErrorStream
63+
var buf = new Array[Byte](1024)
64+
var len = in.read(buf)
65+
while (len != -1) {
66+
System.err.write(buf, 0, len)
67+
len = in.read(buf)
68+
}
69+
}
70+
}
71+
}.start()
72+
} catch {
73+
case e => {
74+
stopDaemon()
75+
throw e
76+
}
77+
}
78+
79+
// Important: don't close daemon's stdin (daemon.getOutputStream) so it can correctly
80+
// detect our disappearance.
81+
}
82+
}
83+
84+
private def stopDaemon() {
85+
synchronized {
86+
// Request shutdown of existing daemon by sending SIGTERM
87+
if (daemon != null) {
88+
daemon.destroy()
89+
}
90+
91+
daemon = null
92+
daemonPort = 0
93+
}
94+
}
95+
}

0 commit comments

Comments
 (0)