Skip to content
20 changes: 17 additions & 3 deletions R/pkg/R/backend.R
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,27 @@ invokeJava <- function(isStatic, objId, methodName, ...) {
conn <- get(".sparkRCon", .sparkREnv)
writeBin(requestMessage, conn)

# TODO: check the status code to output error information
returnStatus <- readInt(conn)
handleErrors(returnStatus, conn)

# Backend will send +1 as keep alive value to prevent various connection timeouts
# on very long running jobs. See spark.r.heartBeatInterval
while (returnStatus == 1) {
returnStatus <- readInt(conn)
handleErrors(returnStatus, conn)
}

readObject(conn)
}

# Helper function to check for returned errors and print appropriate error message to user
handleErrors <- function(returnStatus, conn) {
if (length(returnStatus) == 0) {
stop("No status is returned. Java SparkR backend might have failed.")
}
if (returnStatus != 0) {

# 0 is success and +1 is reserved for heartbeats. Other negative values indicate errors.
if (returnStatus < 0) {
stop(readString(conn))
}
readObject(conn)
}
2 changes: 1 addition & 1 deletion R/pkg/R/client.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

# Creates a SparkR client connection object
# if one doesn't already exist
connectBackend <- function(hostname, port, timeout = 6000) {
connectBackend <- function(hostname, port, timeout) {
if (exists(".sparkRcon", envir = .sparkREnv)) {
if (isOpen(.sparkREnv[[".sparkRCon"]])) {
cat("SparkRBackend client connection already exists\n")
Expand Down
8 changes: 6 additions & 2 deletions R/pkg/R/sparkR.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ sparkR.sparkContext <- function(
packages <- processSparkPackages(sparkPackages)

existingPort <- Sys.getenv("EXISTING_SPARKR_BACKEND_PORT", "")
connectionTimeout <- as.numeric(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
if (existingPort != "") {
if (length(packages) != 0) {
warning(paste("sparkPackages has no effect when using spark-submit or sparkR shell",
Expand Down Expand Up @@ -187,14 +188,17 @@ sparkR.sparkContext <- function(
backendPort <- readInt(f)
monitorPort <- readInt(f)
rLibPath <- readString(f)
connectionTimeout <- readInt(f)
close(f)
file.remove(path)
if (length(backendPort) == 0 || backendPort == 0 ||
length(monitorPort) == 0 || monitorPort == 0 ||
length(rLibPath) != 1) {
stop("JVM failed to launch")
}
assign(".monitorConn", socketConnection(port = monitorPort), envir = .sparkREnv)
assign(".monitorConn",
socketConnection(port = monitorPort, timeout = connectionTimeout),
envir = .sparkREnv)
assign(".backendLaunched", 1, envir = .sparkREnv)
if (rLibPath != "") {
assign(".libPath", rLibPath, envir = .sparkREnv)
Expand All @@ -204,7 +208,7 @@ sparkR.sparkContext <- function(

.sparkREnv$backendPort <- backendPort
tryCatch({
connectBackend("localhost", backendPort)
connectBackend("localhost", backendPort, timeout = connectionTimeout)
},
error = function(err) {
stop("Failed to connect JVM\n")
Expand Down
4 changes: 3 additions & 1 deletion R/pkg/inst/worker/daemon.R
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
# Worker daemon

rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
dirs <- strsplit(rLibDir, ",")[[1]]
script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")

Expand All @@ -26,7 +27,8 @@ script <- file.path(dirs[[1]], "SparkR", "worker", "worker.R")
suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(port = port, open = "rb", blocking = TRUE, timeout = 3600)
inputCon <- socketConnection(
port = port, open = "rb", blocking = TRUE, timeout = connectionTimeout)

while (TRUE) {
ready <- socketSelect(list(inputCon))
Expand Down
7 changes: 5 additions & 2 deletions R/pkg/inst/worker/worker.R
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ bootTime <- currentTimeSecs()
bootElap <- elapsedSecs()

rLibDir <- Sys.getenv("SPARKR_RLIBDIR")
connectionTimeout <- as.integer(Sys.getenv("SPARKR_BACKEND_CONNECTION_TIMEOUT", "6000"))
dirs <- strsplit(rLibDir, ",")[[1]]
# Set libPaths to include SparkR package as loadNamespace needs this
# TODO: Figure out if we can avoid this by not loading any objects that require
Expand All @@ -98,8 +99,10 @@ dirs <- strsplit(rLibDir, ",")[[1]]
suppressPackageStartupMessages(library(SparkR))

port <- as.integer(Sys.getenv("SPARKR_WORKER_PORT"))
inputCon <- socketConnection(port = port, blocking = TRUE, open = "rb")
outputCon <- socketConnection(port = port, blocking = TRUE, open = "wb")
inputCon <- socketConnection(
port = port, blocking = TRUE, open = "rb", timeout = connectionTimeout)
outputCon <- socketConnection(
port = port, blocking = TRUE, open = "wb", timeout = connectionTimeout)

# read the index of the current partition inside the RDD
partition <- SparkR:::readInt(inputCon)
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/org/apache/spark/api/r/RBackend.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,13 @@ import java.net.{InetAddress, InetSocketAddress, ServerSocket}
import java.util.concurrent.TimeUnit

import io.netty.bootstrap.ServerBootstrap
import io.netty.channel.{ChannelFuture, ChannelInitializer, EventLoopGroup}
import io.netty.channel.{ChannelFuture, ChannelInitializer, ChannelOption, EventLoopGroup}
import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.LengthFieldBasedFrameDecoder
import io.netty.handler.codec.bytes.{ByteArrayDecoder, ByteArrayEncoder}
import io.netty.handler.timeout.ReadTimeoutHandler

import org.apache.spark.SparkConf
import org.apache.spark.internal.Logging
Expand All @@ -43,7 +44,10 @@ private[spark] class RBackend {

def init(): Int = {
val conf = new SparkConf()
bossGroup = new NioEventLoopGroup(conf.getInt("spark.r.numRBackendThreads", 2))
val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
bossGroup = new NioEventLoopGroup(
conf.getInt("spark.r.numRBackendThreads", SparkRDefaults.DEFAULT_NUM_RBACKEND_THREADS))
val workerGroup = bossGroup
val handler = new RBackendHandler(this)

Expand All @@ -63,6 +67,7 @@ private[spark] class RBackend {
// initialBytesToStrip = 4, i.e. strip out the length field itself
new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE, 0, 4, 0, 4))
.addLast("decoder", new ByteArrayDecoder())
.addLast("readTimeoutHandler", new ReadTimeoutHandler(backendConnectionTimeout))
.addLast("handler", handler)
}
})
Expand Down Expand Up @@ -110,6 +115,11 @@ private[spark] object RBackend extends Logging {
val boundPort = sparkRBackend.init()
val serverSocket = new ServerSocket(0, 1, InetAddress.getByName("localhost"))
val listenPort = serverSocket.getLocalPort()
// Connection timeout is set by socket client. To make it configurable we will pass the
// timeout value to client inside the temp file
val conf = new SparkConf()
val backendConnectionTimeout = conf.getInt(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are we sure SparkConf has been initialized successfully at this point ? Or to to put it another way, in which cases does this code path get called from ? Is this in the spark-submit case or the shell etc. ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for spark-submit. Basically the JVM starts before the R process. As a result the only way for R process to get these configuration parameters from the JVM. In this case, RBackend sets the environment variables based on configs.

For the other mode where JVM is started after the R process, we are sending this timeout value through the TCP connection.

At least that is my current understanding of how deploy modes work. In our production environment we launch the R process from the JVM.

"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)

// tell the R process via temporary file
val path = args(0)
Expand All @@ -118,6 +128,7 @@ private[spark] object RBackend extends Logging {
dos.writeInt(boundPort)
dos.writeInt(listenPort)
SerDe.writeString(dos, RUtils.rPackages.getOrElse(""))
dos.writeInt(backendConnectionTimeout)
dos.close()
f.renameTo(new File(path))

Expand Down
39 changes: 35 additions & 4 deletions core/src/main/scala/org/apache/spark/api/r/RBackendHandler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,19 @@
package org.apache.spark.api.r

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
import java.util.concurrent.TimeUnit

import scala.collection.mutable.HashMap
import scala.language.existentials

import io.netty.channel.{ChannelHandlerContext, SimpleChannelInboundHandler}
import io.netty.channel.ChannelHandler.Sharable
import io.netty.handler.timeout.ReadTimeoutException

import org.apache.spark.api.r.SerDe._
import org.apache.spark.internal.Logging
import org.apache.spark.util.Utils
import org.apache.spark.SparkConf
import org.apache.spark.util.{ThreadUtils, Utils}

/**
* Handler for RBackend
Expand Down Expand Up @@ -83,7 +86,29 @@ private[r] class RBackendHandler(server: RBackend)
writeString(dos, s"Error: unknown method $methodName")
}
} else {
// To avoid timeouts when reading results in SparkR driver, we will be regularly sending
// heartbeat responses. We use special code +1 to signal the client that backend is
// alive and it should continue blocking for result.
val execService = ThreadUtils.newDaemonSingleThreadScheduledExecutor("SparkRKeepAliveThread")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure how expensive it is to create and destroy an executor service each time. Can we just schedule at fixed rate when we get the request and then cancel the schedule at the end of the request ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was not sure about this either. I used this method based on advice from @zsxwing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm - my question on whether we can reuse this still stands. @zsxwing do you think thats possible ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shivaram we can reuse it. scheduleAtFixedRate returns ScheduledFuture which can be used to cancel the task. However, there is no awaitTermination for ScheduledFuture after cancelling a future. We need some extra work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I took a closer look at this and (a) it looks like the executor just calls cancel on the tasks during shutdown, so that part of the behavior is the same as calling cancel on the task we have [1]. But you are right that if we wanted to wait for termination we'd need to do some extra work. We could use the get call but its unclear what the semantics of that are. It might be more easier to just setup a semaphore or mutex that is shared by the runnable and the outside thread.

But overall it looks like thread pool creation is only around 100 micro-seconds [2] and I also benchmarked this locally.

[1] http://www.docjar.com/html/api/java/util/concurrent/ScheduledThreadPoolExecutor.java.html line 367
[2] http://stackoverflow.com/a/5483467/4577954

val pingRunner = new Runnable {
override def run(): Unit = {
val pingBaos = new ByteArrayOutputStream()
val pingDaos = new DataOutputStream(pingBaos)
writeInt(pingDaos, +1)
ctx.write(pingBaos.toByteArray)
}
}
val conf = new SparkConf()
val heartBeatInterval = conf.getInt(
"spark.r.heartBeatInterval", SparkRDefaults.DEFAULT_HEARTBEAT_INTERVAL)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this be documented too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

val backendConnectionTimeout = conf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
val interval = Math.min(heartBeatInterval, backendConnectionTimeout - 1)

execService.scheduleAtFixedRate(pingRunner, interval, interval, TimeUnit.SECONDS)
handleMethodCall(isStatic, objId, methodName, numArgs, dis, dos)
execService.shutdown()
execService.awaitTermination(1, TimeUnit.SECONDS)
}

val reply = bos.toByteArray
Expand All @@ -95,9 +120,15 @@ private[r] class RBackendHandler(server: RBackend)
}

override def exceptionCaught(ctx: ChannelHandlerContext, cause: Throwable): Unit = {
// Close the connection when an exception is raised.
cause.printStackTrace()
ctx.close()
cause match {
case timeout: ReadTimeoutException =>
// Do nothing. We don't want to timeout on read
logWarning("Ignoring read timeout in RBackendHandler")
case _ =>
// Close the connection when an exception is raised.
cause.printStackTrace()
ctx.close()
}
}

def handleMethodCall(
Expand Down
3 changes: 3 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/RRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,8 @@ private[r] object RRunner {
var rCommand = sparkConf.get("spark.sparkr.r.command", "Rscript")
rCommand = sparkConf.get("spark.r.command", rCommand)

val rConnectionTimeout = sparkConf.getInt(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT)
val rOptions = "--vanilla"
val rLibDir = RUtils.sparkRPackagePath(isDriver = false)
val rExecScript = rLibDir(0) + "/SparkR/worker/" + script
Expand All @@ -344,6 +346,7 @@ private[r] object RRunner {
pb.environment().put("R_TESTS", "")
pb.environment().put("SPARKR_RLIBDIR", rLibDir.mkString(","))
pb.environment().put("SPARKR_WORKER_PORT", port.toString)
pb.environment().put("SPARKR_BACKEND_CONNECTION_TIMEOUT", rConnectionTimeout.toString)
pb.redirectErrorStream(true) // redirect stderr into stdout
val proc = pb.start()
val errThread = startStdoutThread(proc)
Expand Down
30 changes: 30 additions & 0 deletions core/src/main/scala/org/apache/spark/api/r/SparkRDefaults.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.api.r

private[spark] object SparkRDefaults {

// Default value for spark.r.backendConnectionTimeout config
val DEFAULT_CONNECTION_TIMEOUT: Int = 6000

// Default value for spark.r.heartBeatInterval config
val DEFAULT_HEARTBEAT_INTERVAL: Int = 100

// Default value for spark.r.numRBackendThreads config
val DEFAULT_NUM_RBACKEND_THREADS = 2
}
7 changes: 6 additions & 1 deletion core/src/main/scala/org/apache/spark/deploy/RRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import scala.collection.JavaConverters._
import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkException, SparkUserAppException}
import org.apache.spark.api.r.{RBackend, RUtils}
import org.apache.spark.api.r.{RBackend, RUtils, SparkRDefaults}
import org.apache.spark.util.RedirectThread

/**
Expand All @@ -51,6 +51,10 @@ object RRunner {
cmd
}

// Connection timeout set by R process on its connection to RBackend in seconds.
val backendConnectionTimeout = sys.props.getOrElse(
"spark.r.backendConnectionTimeout", SparkRDefaults.DEFAULT_CONNECTION_TIMEOUT.toString)

// Check if the file path exists.
// If not, change directory to current working directory for YARN cluster mode
val rF = new File(rFile)
Expand Down Expand Up @@ -81,6 +85,7 @@ object RRunner {
val builder = new ProcessBuilder((Seq(rCommand, rFileNormalized) ++ otherArgs).asJava)
val env = builder.environment()
env.put("EXISTING_SPARKR_BACKEND_PORT", sparkRBackendPort.toString)
env.put("SPARKR_BACKEND_CONNECTION_TIMEOUT", backendConnectionTimeout)
val rPackageDir = RUtils.sparkRPackagePath(isDriver = true)
// Put the R package directories into an env variable of comma-separated paths
env.put("SPARKR_PACKAGE_DIR", rPackageDir.mkString(","))
Expand Down
15 changes: 15 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -1874,6 +1874,21 @@ showDF(properties, numRows = 200, truncate = FALSE)
<code>spark.r.shell.command</code> is used for sparkR shell while <code>spark.r.driver.command</code> is used for running R script.
</td>
</tr>
<tr>
<td><code>spark.r.backendConnectionTimeout</code></td>
<td>6000</td>
<td>
Connection timeout set by R process on its connection to RBackend in seconds.
</td>
</tr>
<tr>
<td><code>spark.r.heartBeatInterval</code></td>
<td>100</td>
<td>
Interval for heartbeats sents from SparkR backend to R process to prevent connection timeout.
</td>
</tr>

</table>

#### Deploy
Expand Down