@@ -2,10 +2,9 @@ package spark.api.python
22
33import java .io ._
44import java .net ._
5- import java .util .{List => JList , ArrayList => JArrayList , Collections }
5+ import java .util .{List => JList , ArrayList => JArrayList , Map => JMap , Collections }
66
77import scala .collection .JavaConversions ._
8- import scala .io .Source
98
109import spark .api .java .{JavaSparkContext , JavaPairRDD , JavaRDD }
1110import spark .broadcast .Broadcast
@@ -16,7 +15,7 @@ import spark.rdd.PipedRDD
1615private [spark] class PythonRDD [T : ClassManifest ](
1716 parent : RDD [T ],
1817 command : Seq [String ],
19- envVars : java.util. Map [String , String ],
18+ envVars : JMap [String , String ],
2019 preservePartitoning : Boolean ,
2120 pythonExec : String ,
2221 broadcastVars : JList [Broadcast [Array [Byte ]]],
@@ -25,7 +24,7 @@ private[spark] class PythonRDD[T: ClassManifest](
2524
2625 // Similar to Runtime.exec(), if we are given a single string, split it into words
2726 // using a standard StringTokenizer (i.e. by spaces)
28- def this (parent : RDD [T ], command : String , envVars : java.util. Map [String , String ],
27+ def this (parent : RDD [T ], command : String , envVars : JMap [String , String ],
2928 preservePartitoning : Boolean , pythonExec : String ,
3029 broadcastVars : JList [Broadcast [Array [Byte ]]],
3130 accumulator : Accumulator [JList [Array [Byte ]]]) =
@@ -36,35 +35,18 @@ private[spark] class PythonRDD[T: ClassManifest](
3635
3736 override val partitioner = if (preservePartitoning) parent.partitioner else None
3837
39- override def compute (split : Partition , context : TaskContext ): Iterator [Array [Byte ]] = {
40- val SPARK_HOME = new ProcessBuilder ().environment().get(" SPARK_HOME" )
41-
42- val pb = new ProcessBuilder (Seq (pythonExec, SPARK_HOME + " /python/pyspark/worker.py" ))
43- // Add the environmental variables to the process.
44- val currentEnvVars = pb.environment()
45-
46- for ((variable, value) <- envVars) {
47- currentEnvVars.put(variable, value)
48- }
4938
50- val proc = pb.start()
39+ override def compute (split : Partition , context : TaskContext ): Iterator [Array [Byte ]] = {
40+ val startTime = System .currentTimeMillis
5141 val env = SparkEnv .get
52-
53- // Start a thread to print the process's stderr to ours
54- new Thread (" stderr reader for " + pythonExec) {
55- override def run () {
56- for (line <- Source .fromInputStream(proc.getErrorStream).getLines) {
57- System .err.println(line)
58- }
59- }
60- }.start()
42+ val worker = env.createPythonWorker(pythonExec, envVars.toMap)
6143
6244 // Start a thread to feed the process input from our parent's iterator
6345 new Thread (" stdin writer for " + pythonExec) {
6446 override def run () {
6547 SparkEnv .set(env)
66- val out = new PrintWriter (proc .getOutputStream)
67- val dOut = new DataOutputStream (proc .getOutputStream)
48+ val out = new PrintWriter (worker .getOutputStream)
49+ val dOut = new DataOutputStream (worker .getOutputStream)
6850 // Partition index
6951 dOut.writeInt(split.index)
7052 // sparkFilesDir
@@ -88,16 +70,21 @@ private[spark] class PythonRDD[T: ClassManifest](
8870 }
8971 dOut.flush()
9072 out.flush()
91- proc.getOutputStream.close ()
73+ worker.shutdownOutput ()
9274 }
9375 }.start()
9476
9577 // Return an iterator that read lines from the process's stdout
96- val stream = new DataInputStream (proc .getInputStream)
78+ val stream = new DataInputStream (worker .getInputStream)
9779 return new Iterator [Array [Byte ]] {
9880 def next (): Array [Byte ] = {
9981 val obj = _nextObj
100- _nextObj = read()
82+ if (hasNext) {
83+ // FIXME: can deadlock if worker is waiting for us to
84+ // respond to current message (currently irrelevant because
85+ // output is shutdown before we read any input)
86+ _nextObj = read()
87+ }
10188 obj
10289 }
10390
@@ -108,30 +95,39 @@ private[spark] class PythonRDD[T: ClassManifest](
10895 val obj = new Array [Byte ](length)
10996 stream.readFully(obj)
11097 obj
98+ case - 3 =>
99+ // Timing data from worker
100+ val bootTime = stream.readLong()
101+ val initTime = stream.readLong()
102+ val finishTime = stream.readLong()
103+ val boot = bootTime - startTime
104+ val init = initTime - bootTime
105+ val finish = finishTime - initTime
106+ val total = finishTime - startTime
107+ logInfo(" Times: total = %s, boot = %s, init = %s, finish = %s" .format(total, boot, init, finish))
108+ read
111109 case - 2 =>
112110 // Signals that an exception has been thrown in python
113111 val exLength = stream.readInt()
114112 val obj = new Array [Byte ](exLength)
115113 stream.readFully(obj)
116114 throw new PythonException (new String (obj))
117115 case - 1 =>
118- // We've finished the data section of the output, but we can still read some
119- // accumulator updates; let's do that, breaking when we get EOFException
120- while (true ) {
121- val len2 = stream.readInt()
116+ // We've finished the data section of the output, but we can still
117+ // read some accumulator updates; let's do that, breaking when we
118+ // get a negative length record.
119+ var len2 = stream.readInt()
120+ while (len2 >= 0 ) {
122121 val update = new Array [Byte ](len2)
123122 stream.readFully(update)
124123 accumulator += Collections .singletonList(update)
124+ len2 = stream.readInt()
125125 }
126126 new Array [Byte ](0 )
127127 }
128128 } catch {
129129 case eof : EOFException => {
130- val exitStatus = proc.waitFor()
131- if (exitStatus != 0 ) {
132- throw new Exception (" Subprocess exited with status " + exitStatus)
133- }
134- new Array [Byte ](0 )
130+ throw new SparkException (" Python worker exited unexpectedly (crashed)" , eof)
135131 }
136132 case e => throw e
137133 }
@@ -159,7 +155,7 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
159155 override def compute (split : Partition , context : TaskContext ) =
160156 prev.iterator(split, context).grouped(2 ).map {
161157 case Seq (a, b) => (a, b)
162- case x => throw new Exception (" PairwiseRDD: unexpected value: " + x)
158+ case x => throw new SparkException (" PairwiseRDD: unexpected value: " + x)
163159 }
164160 val asJavaPairRDD : JavaPairRDD [Array [Byte ], Array [Byte ]] = JavaPairRDD .fromRDD(this )
165161}
@@ -215,7 +211,7 @@ private[spark] object PythonRDD {
215211 dOut.write(s)
216212 dOut.writeByte(Pickle .STOP )
217213 } else {
218- throw new Exception (" Unexpected RDD type" )
214+ throw new SparkException (" Unexpected RDD type" )
219215 }
220216 }
221217
0 commit comments