@@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
117117 // Start a thread to feed the process input from our parent's iterator
118118 new Thread (" stdin writer for R" ) {
119119 override def run () {
120- SparkEnv .set(env)
121- val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
122- val printOutStd = new PrintStream (streamStd)
123- printOutStd.println(tempFileName)
124- printOutStd.println(rLibDir)
125- printOutStd.println(tempFileIn.getAbsolutePath())
126- printOutStd.flush()
127-
128- streamStd.close()
129-
130- val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
131- val printOut = new PrintStream (stream)
132- val dataOut = new DataOutputStream (stream)
133-
134- dataOut.writeInt(splitIndex)
135-
136- dataOut.writeInt(func.length)
137- dataOut.write(func, 0 , func.length)
138-
139- // R worker process input serialization flag
140- dataOut.writeInt(if (parentSerialized) 1 else 0 )
141- // R worker process output serialization flag
142- dataOut.writeInt(if (dataSerialized) 1 else 0 )
143-
144- dataOut.writeInt(packageNames.length)
145- dataOut.write(packageNames, 0 , packageNames.length)
146-
147- dataOut.writeInt(functionDependencies.length)
148- dataOut.write(functionDependencies, 0 , functionDependencies.length)
149-
150- dataOut.writeInt(broadcastVars.length)
151- broadcastVars.foreach { broadcast =>
152- // TODO(shivaram): Read a Long in R to avoid this cast
153- dataOut.writeInt(broadcast.id.toInt)
154- // TODO: Pass a byte array from R to avoid this cast ?
155- val broadcastByteArr = broadcast.value.asInstanceOf [Array [Byte ]]
156- dataOut.writeInt(broadcastByteArr.length)
157- dataOut.write(broadcastByteArr, 0 , broadcastByteArr.length)
158- }
159-
160- dataOut.writeInt(numPartitions)
120+ try {
121+ SparkEnv .set(env)
122+ val stream = new BufferedOutputStream (new FileOutputStream (tempFileIn), bufferSize)
123+ val printOut = new PrintStream (stream)
124+ val dataOut = new DataOutputStream (stream)
125+
126+ dataOut.writeInt(splitIndex)
127+
128+ dataOut.writeInt(func.length)
129+ dataOut.write(func, 0 , func.length)
130+
131+ // R worker process input serialization flag
132+ dataOut.writeInt(if (parentSerialized) 1 else 0 )
133+ // R worker process output serialization flag
134+ dataOut.writeInt(if (dataSerialized) 1 else 0 )
135+
136+ dataOut.writeInt(packageNames.length)
137+ dataOut.write(packageNames, 0 , packageNames.length)
138+
139+ dataOut.writeInt(functionDependencies.length)
140+ dataOut.write(functionDependencies, 0 , functionDependencies.length)
141+
142+ dataOut.writeInt(broadcastVars.length)
143+ broadcastVars.foreach { broadcast =>
144+ // TODO(shivaram): Read a Long in R to avoid this cast
145+ dataOut.writeInt(broadcast.id.toInt)
146+ // TODO: Pass a byte array from R to avoid this cast ?
147+ val broadcastByteArr = broadcast.value.asInstanceOf [Array [Byte ]]
148+ dataOut.writeInt(broadcastByteArr.length)
149+ dataOut.write(broadcastByteArr, 0 , broadcastByteArr.length)
150+ }
161151
162- if (! iter.hasNext) {
163- dataOut.writeInt(0 )
164- } else {
165- dataOut.writeInt(1 )
166- }
152+ dataOut.writeInt(numPartitions)
167153
168- for (elem <- iter) {
169- if (parentSerialized) {
170- val elemArr = elem.asInstanceOf [Array [Byte ]]
171- dataOut.writeInt(elemArr.length)
172- dataOut.write(elemArr, 0 , elemArr.length)
154+ if (! iter.hasNext) {
155+ dataOut.writeInt(0 )
173156 } else {
174- printOut.println(elem)
157+ dataOut.writeInt(1 )
158+ }
159+
160+ for (elem <- iter) {
161+ if (parentSerialized) {
162+ val elemArr = elem.asInstanceOf [Array [Byte ]]
163+ dataOut.writeInt(elemArr.length)
164+ dataOut.write(elemArr, 0 , elemArr.length)
165+ } else {
166+ printOut.println(elem)
167+ }
175168 }
176- }
177169
178- printOut.flush()
179- dataOut.flush()
180- stream.flush()
181- stream.close()
170+ printOut.flush()
171+ dataOut.flush()
172+ stream.flush()
173+ stream.close()
174+
175+ // NOTE: We need to write out the temp file before writing out the
176+ // file name to stdin. Otherwise the R process could read partial state
177+ val streamStd = new BufferedOutputStream (proc.getOutputStream, bufferSize)
178+ val printOutStd = new PrintStream (streamStd)
179+ printOutStd.println(tempFileName)
180+ printOutStd.println(rLibDir)
181+ printOutStd.println(tempFileIn.getAbsolutePath())
182+ printOutStd.flush()
183+
184+ streamStd.close()
185+ } catch {
186+ // TODO: We should propogate this error to the task thread
187+ case e : Exception =>
188+ System .err.println(" R Writer thread got an exception " + e)
189+ e.printStackTrace()
190+ }
182191 }
183192 }.start()
184193
0 commit comments