Skip to content

Commit 554bda0

Browse files
Merge pull request apache#147 from shivaram/sparkr-ec2-fixes
Bunch of fixes for longer running jobs
2 parents c662f29 + f34bb88 commit 554bda0

File tree

2 files changed

+67
-58
lines changed

2 files changed

+67
-58
lines changed

pkg/R/sparkRClient.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
# Creates a SparkR client connection object
44
# if one doesn't already exist
5-
connectBackend <- function(hostname, port, timeout = 60) {
5+
connectBackend <- function(hostname, port, timeout = 6000) {
66
if (exists(".sparkRcon", envir = .sparkREnv)) {
77
cat("SparkRBackend client connection already exists\n")
88
return(get(".sparkRcon", envir = .sparkREnv))

pkg/src/src/main/scala/edu/berkeley/cs/amplab/sparkr/RRDD.scala

Lines changed: 66 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -117,68 +117,77 @@ private abstract class BaseRRDD[T: ClassTag, U: ClassTag](
117117
// Start a thread to feed the process input from our parent's iterator
118118
new Thread("stdin writer for R") {
119119
override def run() {
120-
SparkEnv.set(env)
121-
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
122-
val printOutStd = new PrintStream(streamStd)
123-
printOutStd.println(tempFileName)
124-
printOutStd.println(rLibDir)
125-
printOutStd.println(tempFileIn.getAbsolutePath())
126-
printOutStd.flush()
127-
128-
streamStd.close()
129-
130-
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
131-
val printOut = new PrintStream(stream)
132-
val dataOut = new DataOutputStream(stream)
133-
134-
dataOut.writeInt(splitIndex)
135-
136-
dataOut.writeInt(func.length)
137-
dataOut.write(func, 0, func.length)
138-
139-
// R worker process input serialization flag
140-
dataOut.writeInt(if (parentSerialized) 1 else 0)
141-
// R worker process output serialization flag
142-
dataOut.writeInt(if (dataSerialized) 1 else 0)
143-
144-
dataOut.writeInt(packageNames.length)
145-
dataOut.write(packageNames, 0, packageNames.length)
146-
147-
dataOut.writeInt(functionDependencies.length)
148-
dataOut.write(functionDependencies, 0, functionDependencies.length)
149-
150-
dataOut.writeInt(broadcastVars.length)
151-
broadcastVars.foreach { broadcast =>
152-
// TODO(shivaram): Read a Long in R to avoid this cast
153-
dataOut.writeInt(broadcast.id.toInt)
154-
// TODO: Pass a byte array from R to avoid this cast ?
155-
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
156-
dataOut.writeInt(broadcastByteArr.length)
157-
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
158-
}
159-
160-
dataOut.writeInt(numPartitions)
120+
try {
121+
SparkEnv.set(env)
122+
val stream = new BufferedOutputStream(new FileOutputStream(tempFileIn), bufferSize)
123+
val printOut = new PrintStream(stream)
124+
val dataOut = new DataOutputStream(stream)
125+
126+
dataOut.writeInt(splitIndex)
127+
128+
dataOut.writeInt(func.length)
129+
dataOut.write(func, 0, func.length)
130+
131+
// R worker process input serialization flag
132+
dataOut.writeInt(if (parentSerialized) 1 else 0)
133+
// R worker process output serialization flag
134+
dataOut.writeInt(if (dataSerialized) 1 else 0)
135+
136+
dataOut.writeInt(packageNames.length)
137+
dataOut.write(packageNames, 0, packageNames.length)
138+
139+
dataOut.writeInt(functionDependencies.length)
140+
dataOut.write(functionDependencies, 0, functionDependencies.length)
141+
142+
dataOut.writeInt(broadcastVars.length)
143+
broadcastVars.foreach { broadcast =>
144+
// TODO(shivaram): Read a Long in R to avoid this cast
145+
dataOut.writeInt(broadcast.id.toInt)
146+
// TODO: Pass a byte array from R to avoid this cast ?
147+
val broadcastByteArr = broadcast.value.asInstanceOf[Array[Byte]]
148+
dataOut.writeInt(broadcastByteArr.length)
149+
dataOut.write(broadcastByteArr, 0, broadcastByteArr.length)
150+
}
161151

162-
if (!iter.hasNext) {
163-
dataOut.writeInt(0)
164-
} else {
165-
dataOut.writeInt(1)
166-
}
152+
dataOut.writeInt(numPartitions)
167153

168-
for (elem <- iter) {
169-
if (parentSerialized) {
170-
val elemArr = elem.asInstanceOf[Array[Byte]]
171-
dataOut.writeInt(elemArr.length)
172-
dataOut.write(elemArr, 0, elemArr.length)
154+
if (!iter.hasNext) {
155+
dataOut.writeInt(0)
173156
} else {
174-
printOut.println(elem)
157+
dataOut.writeInt(1)
158+
}
159+
160+
for (elem <- iter) {
161+
if (parentSerialized) {
162+
val elemArr = elem.asInstanceOf[Array[Byte]]
163+
dataOut.writeInt(elemArr.length)
164+
dataOut.write(elemArr, 0, elemArr.length)
165+
} else {
166+
printOut.println(elem)
167+
}
175168
}
176-
}
177169

178-
printOut.flush()
179-
dataOut.flush()
180-
stream.flush()
181-
stream.close()
170+
printOut.flush()
171+
dataOut.flush()
172+
stream.flush()
173+
stream.close()
174+
175+
// NOTE: We need to write out the temp file before writing out the
176+
// file name to stdin. Otherwise the R process could read partial state
177+
val streamStd = new BufferedOutputStream(proc.getOutputStream, bufferSize)
178+
val printOutStd = new PrintStream(streamStd)
179+
printOutStd.println(tempFileName)
180+
printOutStd.println(rLibDir)
181+
printOutStd.println(tempFileIn.getAbsolutePath())
182+
printOutStd.flush()
183+
184+
streamStd.close()
185+
} catch {
186+
// TODO: We should propogate this error to the task thread
187+
case e: Exception =>
188+
System.err.println("R Writer thread got an exception " + e)
189+
e.printStackTrace()
190+
}
182191
}
183192
}.start()
184193

0 commit comments

Comments
 (0)