From 45b23178d14e3d94df020220b8b018077c0e5355 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 13 Feb 2015 17:16:47 +0800 Subject: [PATCH 01/31] A standard RPC interface and An Akka implementation --- .../scala/org/apache/spark/SparkEnv.scala | 40 +- .../spark/deploy/worker/DriverWrapper.scala | 7 +- .../spark/deploy/worker/WorkerWatcher.scala | 58 +-- .../CoarseGrainedExecutorBackend.scala | 2 +- .../apache/spark/rpc/ActionScheduler.scala | 207 +++++++++ .../scala/org/apache/spark/rpc/RpcEnv.scala | 345 ++++++++++++++ .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 287 ++++++++++++ .../scheduler/OutputCommitCoordinator.scala | 33 +- .../org/apache/spark/util/AkkaUtils.scala | 2 +- .../deploy/worker/WorkerWatcherSuite.scala | 39 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 430 ++++++++++++++++++ .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 49 ++ 12 files changed, 1420 insertions(+), 79 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala create mode 100644 core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 2a0c7e756dd3..5203521e6d47 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,8 +34,10 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorActor +import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ @@ -54,7 +56,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -71,6 +73,9 @@ class SparkEnv ( val outputCommitCoordinator: OutputCommitCoordinator, val conf: SparkConf) extends Logging { + // TODO Remove actorSystem + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -91,7 +96,8 @@ class SparkEnv ( blockManager.master.stop() metricsSystem.stop() outputCommitCoordinator.stop() - actorSystem.shutdown() + rpcEnv.shutdown() + // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -236,16 +242,15 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.address.port.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.address.port.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -290,6 +295,15 @@ object SparkEnv extends Logging { } } + def registerOrLookupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + if (isDriver) { + logInfo("Registering " + name) + rpcEnv.setupEndpoint(name, endpointCreator) + } else { + rpcEnv.setupDriverEndpointRef(name) + } + } + val mapOutputTracker = if (isDriver) { new MapOutputTrackerMaster(conf) } else { @@ -377,13 +391,13 @@ object SparkEnv extends Logging { val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { new OutputCommitCoordinator(conf) } - val outputCommitCoordinatorActor = registerOrLookup("OutputCommitCoordinator", - new OutputCommitCoordinatorActor(outputCommitCoordinator)) - outputCommitCoordinator.coordinatorActor = Some(outputCommitCoordinatorActor) + val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) + outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) val envInstance = new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index ab467a5ee8c6..00fdcc0922a7 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -21,6 +21,7 @@ import java.io.File import akka.actor._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf} import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} @@ -32,9 +33,9 @@ object DriverWrapper { args.toList match { case workerUrl :: userJar :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) val currentLoader = Thread.currentThread.getContextClassLoader val userJarUrl = new File(userJar).toURI().toURL() @@ -51,7 +52,7 @@ object DriverWrapper { val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.shutdown() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b61..ce8c095dfcbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,26 +17,23 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends NetworkRpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + val worker = rpcEnv.setupEndpointRefByUrl(workerUrl) + worker.send(SendHeartbeat) // need to send a message here to initiate connection + } } // Used to avoid shutting down JVM during tests @@ -45,30 +42,37 @@ private[spark] class WorkerWatcher(workerUrl: String) private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedHostPort = new java.net.URI(workerUrl) + private def isWorker(address: RpcAddress) = { + expectedHostPort.getHost == address.host && expectedHostPort.getPort == address.port + } def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive(sender: RpcEndpointRef) = { + case e => logWarning(s"Received unexpected actor system event: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index dd19e4947db1..1138b03e279e 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -169,7 +169,7 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { driverUrl, executorId, sparkHostPort, cores, userClassPath, env), name = "Executor") workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } env.actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala new file mode 100644 index 000000000000..f16af0634abe --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.{SynchronousQueue, TimeUnit, ThreadPoolExecutor} + +import scala.concurrent.duration.FiniteDuration +import scala.concurrent.ExecutionContext +import scala.util.control.NonFatal + +import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} + +import org.apache.spark.util.Utils +import org.apache.spark.{Logging, SparkConf} + +/** + * It's very common that executing some actions in other threads to avoid blocking the event loop + * in a RpcEndpoint. [[ActionScheduler]] is designed for such use cases. + */ +private[spark] trait ActionScheduler { + + /** + * Run the action in the IO thread pool. The thread name will be `name` when running this action. + */ + def executeIOAction(name: String)(action: => Unit): Unit + + /** + * Run the action in the CPU thread pool. The thread name will be `name` when running this action. + */ + def executeCPUAction(name: String)(action: => Unit): Unit + + /** + * Run the action after `delay`. The thread name will be `name` when running this action. + */ + def schedule(name: String, delay: FiniteDuration)(action: => Unit): Cancellable + + + /** + * Run the action every `interval`. The thread name will be `name` when running this action. + */ + def schedulePeriodically(name: String, interval: FiniteDuration)(action: => Unit): Cancellable = { + schedulePeriodically(name, interval, interval)(action) + } + + /** + * Run the action every `interval`. The thread name will be `name` when running this action. + */ + def schedulePeriodically( + name: String, delay: FiniteDuration, interval: FiniteDuration)(action: => Unit): Cancellable +} + +private[spark] trait Cancellable { + // Should be reentrant + def cancel(): Unit +} + +private[rpc] object NopCancellable extends Cancellable { + override def cancel(): Unit = {} +} + +private[rpc] class SettableCancellable extends Cancellable { + + @volatile private var underlying: Cancellable = NopCancellable + + @volatile private var isCancelled = false + + // Should be called only once + def set(c: Cancellable): Unit = { + underlying = c + if (isCancelled) { + underlying.cancel() + } + } + + override def cancel(): Unit = { + isCancelled = true + underlying.cancel() + } +} + +private[spark] class ActionSchedulerImpl(conf: SparkConf) extends ActionScheduler with Logging { + + val maxIOThreadNumber = conf.getInt("spark.rpc.io.maxThreads", 1000) + + private val ioExecutor = new ThreadPoolExecutor( + 0, + maxIOThreadNumber, + 60L, + TimeUnit.SECONDS, + new SynchronousQueue[Runnable](), Utils.namedThreadFactory("rpc-io")) + + private val cpuExecutor = ExecutionContext.fromExecutorService(null, e => { + e match { + case NonFatal(e) => logError(e.getMessage, e) + case e => + Thread.getDefaultUncaughtExceptionHandler.uncaughtException(Thread.currentThread, e) + } + }) + + private val timer = new HashedWheelTimer(Utils.namedThreadFactory("rpc-timer")) + + // Need a name to distinguish between different actions because they use the same thread pool + override def executeIOAction(name: String)(action: => Unit): Unit = { + ioExecutor.execute(new Runnable { + + override def run(): Unit = { + val previousThreadName = Thread.currentThread().getName + Thread.currentThread().setName(name) + try { + action + } finally { + Thread.currentThread().setName(previousThreadName) + } + } + + }) + } + + // Need a name to distinguish between different actions because they use the same thread pool + override def executeCPUAction(name: String)(action: => Unit): Unit = { + cpuExecutor.execute(new Runnable { + + override def run(): Unit = { + val previousThreadName = Thread.currentThread().getName + Thread.currentThread().setName(name) + try { + action + } finally { + Thread.currentThread().setName(previousThreadName) + } + } + + }) + } + + def schedule(name: String, delay: FiniteDuration)(action: => Unit): Cancellable = { + val timeout = timer.newTimeout(new TimerTask { + + override def run(timeout: Timeout): Unit = { + val previousThreadName = Thread.currentThread().getName + Thread.currentThread().setName(name) + try { + action + } finally { + Thread.currentThread().setName(previousThreadName) + } + } + + }, delay.toNanos, TimeUnit.NANOSECONDS) + new Cancellable { + override def cancel(): Unit = timeout.cancel() + } + } + + override def schedulePeriodically( + name: String, delay: FiniteDuration, interval: FiniteDuration)(action: => Unit): + Cancellable = { + val initial = new SettableCancellable + val cancellable = new AtomicReference[SettableCancellable](initial) + def actionOnce: Unit = { + if (cancellable.get != null) { + action + val c = cancellable.get + if (c != null) { + val s = new SettableCancellable + if (cancellable.compareAndSet(c, s)) { + s.set(schedule(name, interval)(actionOnce)) + } else { + // has been cancelled + assert(cancellable.get == null) + } + } + } + } + initial.set(schedule(name, delay)(actionOnce)) + new Cancellable { + override def cancel(): Unit = { + var c = cancellable.get + while (c != null) { + if (cancellable.compareAndSet(c, null)) { + c.cancel() + return + } else { + // Already schedule another action, retry to cancel it + c = cancellable.get + } + } + } + } + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 000000000000..5f05b9d25846 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,345 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.net.URI + +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils + +/** + * An RPC environment. + */ +private[spark] trait RpcEnv { + + private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + + def scheduler: ActionScheduler + + def systemName: String + + /** + * Return the address that [[RpcEnv]] is listening to. + */ + def address: RpcAddress + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. + */ + def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def setupDriverEndpointRef(name: String): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `url`. + */ + def setupEndpointRefByUrl(url: String): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[shutdown()]]. + */ + def shutdown(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit + + /** + * Create a URI used to create a [[RpcEndpointRef]] + */ + def newURI(systemName: String, address: RpcAddress, endpointName: String): String +} + +private[spark] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a companion object with an + * `apply(config: RpcEnvConfig): RpcEnv` method so that it can be created via Reflection. + * + * {{{ + * object MyCustomRpcEnv { + * def apply(config: RpcEnvConfig): RpcEnv = { + * ... + * } + * } + * }}} + */ +private[spark] object RpcEnv { + + private def getRpcEnvCompanion(conf: SparkConf): AnyRef = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + val companion = Class.forName( + rpcEnvClassName + "$", true, Utils.getContextOrSparkClassLoader).getField("MODULE$").get(null) + companion + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + val companion = getRpcEnvCompanion(conf) + companion.getClass.getMethod("apply", classOf[RpcEnvConfig]). + invoke(companion, config).asInstanceOf[RpcEnv] + } + +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * RpcEndpoint will be guaranteed that `onStart`, `receive` and `onStop` will + * be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * If any error is thrown from one of RpcEndpoint methods except `onError`, [[RpcEndpoint.onError]] + * will be invoked with the cause. If onError throws an error, [[RpcEnv]] will ignore it. + */ +private[spark] trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * Provide the implicit sender. `self` will become valid when `onStart` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. In the other + * words, don't call [[RpcEndpointRef.send]] in the constructor of [[RpcEndpoint]]. + */ + implicit final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and + * messages from the same RpcEndpoint will be delivered in order. + * + * @param sender + * @return + */ + def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * An convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + val _self = self + if (_self != null) { + rpcEnv.stop(self) + } + } +} + +/** + * A RpcEndoint interested in network events. + * + * [[NetworkRpcEndpoint]] will be guaranteed that `onStart`, `receive` , `onConnected`, + * `onDisconnected`, `onNetworkError` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart (receive|onConnected|onDisconnected|onNetworkError)* onStop + * + * If any error is thrown from `onConnected`, `onDisconnected` or `onNetworkError`, + * [[RpcEndpoint.onError)]] will be invoked with the cause. If onError throws an error, + * [[RpcEnv]] will ignore it. + */ +private[spark] trait NetworkRpcEndpoint extends RpcEndpoint { + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } +} + +private[spark] object RpcEndpoint { + final val noSender: RpcEndpointRef = null +} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +private[spark] trait RpcEndpointRef { + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + def name: String + + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within a default timeout. + */ + def ask[T: ClassTag](message: Any): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within the specified timeout. + */ + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any): T + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a specified + * timeout, throw a SparkException if this fails even after the specified number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + * + * If invoked from within an [[RpcEndpoint]] then `self` is implicitly passed on as the implicit + * 'sender' argument. If not then no sender is available. + * + * This `sender` reference is then available in the receiving [[RpcEndpoint]] as the `sender` + * parameter of [[RpcEndpoint.receive]] + */ + def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit + + def toURI: URI +} + +/** + * Represent a host with a port + */ +private[spark] case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +private[spark] object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + val u = new java.net.URI(uri) + RpcAddress(u.getHost, u.getPort) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 000000000000..ba4c0ef4cbd6 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,287 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.net.URI +import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.Future +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import _root_.akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.pattern.{ask => akkaAsk} +import akka.remote._ + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{SparkUncaughtExceptionHandler, ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private ( + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) extends RpcEnv with Logging { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + // In some test case, ActorSystem doesn't bind to any address. + // So just use some default value since they are only some unit tests + RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) + } + + override val systemName: String = actorSystem.name + + override val address: RpcAddress = defaultAddress + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + override val scheduler = new ActionSchedulerImpl(conf) + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + private def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + private def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + override def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + val endpointRef = endpointToRef.get(endpoint) + require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") + endpointRef + } + + override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + val latch = new CountDownLatch(1) + try { + @volatile var endpointRef: AkkaRpcEndpointRef = null + val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + // Wait until `endpointRef` is set. TODO better solution? + latch.await() + require(endpointRef != null) + registerEndpoint(endpoint, endpointRef) + + var isNetworkRpcEndpoint = false + + override def preStart(): Unit = { + if (endpoint.isInstanceOf[NetworkRpcEndpoint]) { + isNetworkRpcEndpoint = true + // Listen for remote client network events only when it's `NetworkRpcEndpoint` + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + safelyCall(endpoint) { + endpoint.onStart() + } + } + + override def receiveWithLogging: Receive = if (isNetworkRpcEndpoint) { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onConnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case DisassociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } + case e: RemotingLifecycleEvent => + // TODO ignore? + + case message: Any => + logDebug("Received RPC message: " + message) + safelyCall(endpoint) { + val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } + } else { + case message: Any => + logDebug("Received RPC message: " + message) + safelyCall(endpoint) { + val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } + } + + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() + } + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf) + endpointRef + } finally { + latch.countDown() + } + } + + private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { + try { + action + } catch { + case NonFatal(e) => { + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } + } + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def setupDriverEndpointRef(name: String): RpcEndpointRef = { + new AkkaRpcEndpointRef(defaultAddress, AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) + } + + override def setupEndpointRefByUrl(url: String): RpcEndpointRef = { + val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + // TODO defaultAddress is wrong + new AkkaRpcEndpointRef(defaultAddress, ref, conf) + } + + override def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByUrl(newURI(systemName, address, endpointName)) + } + + override def newURI(systemName: String, address: RpcAddress, endpointName: String): String = { + AkkaUtils.address( + AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) + } + + + override def shutdown(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString = s"${getClass.getSimpleName}($actorSystem)" +} + +private[spark] object AkkaRpcEnv { + + def apply(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } + +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + val actorRef: ActorRef, + @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { + // `defaultAddress` and `conf` won't be used after initialization. So it's safe to be transient. + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + + override val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override val name: String = actorRef.path.name + + override def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultTimeout) + + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + actorRef.ask(message)(timeout).mapTo[T] + } + + override def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) + + override def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + AkkaUtils.askWithReply(message, actorRef, maxRetries, retryWaitMs, timeout) + } + + override def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit = { + implicit val actorSender: ActorRef = + if (sender == null) { + Actor.noSender + } else { + require(sender.isInstanceOf[AkkaRpcEndpointRef]) + sender.asInstanceOf[AkkaRpcEndpointRef].actorRef + } + actorRef ! message + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" + + override def toURI: URI = new URI(actorRef.path.toString) +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 759df023a6dc..e60cff3d0692 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -19,10 +19,8 @@ package org.apache.spark.scheduler import scala.collection.mutable -import akka.actor.{ActorRef, Actor} - import org.apache.spark._ -import org.apache.spark.util.{AkkaUtils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -43,10 +41,7 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { // Initialized by SparkEnv - var coordinatorActor: Option[ActorRef] = None - private val timeout = AkkaUtils.askTimeout(conf) - private val maxAttempts = AkkaUtils.numRetries(conf) - private val retryInterval = AkkaUtils.retryWaitMs(conf) + var coordinatorRef: Option[RpcEndpointRef] = None private type StageId = Int private type PartitionId = Long @@ -81,9 +76,9 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { partition: PartitionId, attempt: TaskAttemptId): Boolean = { val msg = AskPermissionToCommitOutput(stage, partition, attempt) - coordinatorActor match { - case Some(actor) => - AkkaUtils.askWithReply[Boolean](msg, actor, maxAttempts, retryInterval, timeout) + coordinatorRef match { + case Some(endpointRef) => + endpointRef.askWithReply[Boolean](msg) case None => logError( "canCommit called after coordinator was stopped (is SparkEnv shutdown in progress)?") @@ -125,8 +120,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { } def stop(): Unit = synchronized { - coordinatorActor.foreach(_ ! StopCoordinator) - coordinatorActor = None + coordinatorRef.foreach(_ send StopCoordinator) + coordinatorRef = None authorizedCommittersByStage.clear() } @@ -157,16 +152,18 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private[spark] object OutputCommitCoordinator { // This actor is used only for RPC - class OutputCommitCoordinatorActor(outputCommitCoordinator: OutputCommitCoordinator) - extends Actor with ActorLogReceive with Logging { + class OutputCommitCoordinatorEndpoint( + override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) + extends RpcEndpoint with Logging { - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - sender ! outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt) + sender.send( + outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - context.stop(self) - sender ! true + sender.send(true) + stop() } } } diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 3d9c6192ff7f..5e216a53cb21 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -182,7 +182,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a..864753b63948 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,39 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.SecurityManager +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + assert(workerWatcher.isShutDown) + rpcEnv.shutdown() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val conf = new SparkConf() + val rpcEnv = RpcEnv.create("test", "localhost", 12345, conf, new SecurityManager(conf)) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected( + RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + assert(!workerWatcher.isShutDown) + rpcEnv.shutdown() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 000000000000..c9455f37cbd9 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,430 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} + +import scala.collection.mutable +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +import org.apache.spark.SparkConf + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + val conf = new SparkConf() + env = createRpcEnv(conf, 12345) + } + + override def afterAll(): Unit = { + if(env != null) { + env.shutdown() + } + } + + def createRpcEnv(conf: SparkConf, port: Int): RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } + + test("send a message remotely") { + @volatile var message: String = null + // Set up a RpcEndpoint using env + env.setupEndpoint("send-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => message = msg + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "send-remotely") + try { + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(10 millis)) { + assert("hello" === message) + } + } finally { + anotherEnv.shutdown() + } + } + + test("send a RpcEndpointRef") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case "Hello" => sender.send(self) + case "Echo" => sender.send("Echo") + } + } + val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) + + val newRpcEndpointRef = rpcEndpointRef.askWithReply[RpcEndpointRef]("Hello") + val reply = newRpcEndpointRef.askWithReply[String]("Echo") + assert("Echo" === reply) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => { + sender.send(msg) + } + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ask a message remotely") { + env.setupEndpoint("ask-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => { + sender.send(msg) + } + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "ask-remotely") + try { + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } finally { + anotherEnv.shutdown() + } + } + + test("ask a message timeout") { + env.setupEndpoint("ask-timeout", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => { + Thread.sleep(100) + sender.send(msg) + } + } + }) + + val conf = new SparkConf() + conf.set("spark.akka.retry.wait", "0") + conf.set("spark.akka.num.retries", "1") + val anotherEnv = createRpcEnv(conf, 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "ask-timeout") + try { + val e = intercept[Exception] { + rpcEndpointRef.askWithReply[String]("hello", 1 millis) + } + assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) + } finally { + anotherEnv.shutdown() + } + } + + test("ping pong") { + val pongRef = env.setupEndpoint("pong", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case Ping(id) => sender.send(Pong(id)) + } + }) + + val pingRef = env.setupEndpoint("ping", new RpcEndpoint { + override val rpcEnv = env + + var requester: RpcEndpointRef = _ + + override def receive(sender: RpcEndpointRef) = { + case Start => { + requester = sender + pongRef.send(Ping(1)) + } + case p @ Pong(id) => { + if (id < 10) { + sender.send(Ping(id + 1)) + } else { + requester.send(p) + } + } + } + }) + + val reply = pingRef.askWithReply[Pong](Start) + assert(Pong(10) === reply) + } + + test("onStart and onStop") { + val stopLatch = new CountDownLatch(1) + val calledMethods = mutable.ArrayBuffer[String]() + + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + calledMethods += "start" + } + + override def receive(sender: RpcEndpointRef) = { + case msg: String => + } + + override def onStop(): Unit = { + calledMethods += "stop" + stopLatch.countDown() + } + } + val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) + rpcEndpointRef.send("message") + env.stop(rpcEndpointRef) + stopLatch.await(10, TimeUnit.SECONDS) + assert(List("start", "stop") === calledMethods) + } + + test("onError: error in onStart") { + @volatile var e: Throwable = null + env.setupEndpoint("onError-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + throw new RuntimeException("Oops!") + } + + override def receive(sender: RpcEndpointRef) = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in onStop") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + + override def onStop(): Unit = { + throw new RuntimeException("Oops!") + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("onError: error in receive") { + @volatile var e: Throwable = null + val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => throw new RuntimeException("Oops!") + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + assert(e.getMessage === "Oops!") + } + } + + test("self: call in onStart") { + @volatile var callSelfSuccessfully = false + + env.setupEndpoint("self-onStart", new RpcEndpoint { + override val rpcEnv = env + + override def onStart(): Unit = { + self + callSelfSuccessfully = true + } + + override def receive(sender: RpcEndpointRef) = { + case m => + } + }) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStart` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in receive") { + @volatile var callSelfSuccessfully = false + + val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => { + self + callSelfSuccessfully = true + } + } + }) + + endpointRef.send("Foo") + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `receive` is fine + assert(callSelfSuccessfully === true) + } + } + + test("self: call in onStop") { + @volatile var e: Throwable = null + + val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => + } + + override def onStop(): Unit = { + self + } + + override def onError(cause: Throwable): Unit = { + e = cause + } + }) + + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(10 millis)) { + // Calling `self` in `onStop` is invalid + assert(e != null) + assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + } + } + + test("call receive in sequence") { + // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it + for(i <- 0 until 100) { + @volatile var result = 0 + val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => result += 1 + } + + }) + + (0 until 10) foreach { _ => + new Thread { + override def run() { + (0 until 100) foreach { _ => + endpointRef.send("Hello") + } + } + }.start() + } + + eventually(timeout(5 seconds), interval(5 millis)) { + assert(result == 1000) + } + + env.stop(endpointRef) + } + } + + test("stop(RpcEndpointRef) reentrant") { + @volatile var onStopCount = 0 + val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case m => + } + + override def onStop(): Unit = { + onStopCount += 1 + } + }) + + env.stop(endpointRef) + env.stop(endpointRef) + + eventually(timeout(5 seconds), interval(5 millis)) { + // Calling stop twice should only trigger onStop once. + assert(onStopCount == 1) + } + } +} + +case object Start + +case class Ping(id: Int) + +case class Pong(id: Int) diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 000000000000..a4c49a237a9f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv(conf: SparkConf, port: Int): RpcEnv = { + AkkaRpcEnv(RpcEnvConfig(conf, s"test-$port", "localhost", port, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = + AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef(s"test-${env.address.port}", ref.address, "test_endpoint") + assert(s"akka.tcp://test-${env.address.port}@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.shutdown() + } + } + +} From 155b98726c85bc2eb3b5d664c44729f5a6db8303 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 26 Feb 2015 11:53:59 +0800 Subject: [PATCH 02/31] Change newURI to uriOf and add some comments --- .../org/apache/spark/deploy/worker/WorkerWatcher.scala | 4 ++++ .../main/scala/org/apache/spark/rpc/ActionScheduler.scala | 5 ++++- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 2 +- .../main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 8 ++++++-- 4 files changed, 15 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index ce8c095dfcbe..86f87b0b6015 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -37,6 +37,10 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin } // Used to avoid shutting down JVM during tests + // In the normal case, exitNonZero will call `System.exit(-1)` to shutdown the JVM. In the unit + // test, the user should call `setTesting(true)` so that `exitNonZero` will set `isShutDown` to + // true rather than calling `System.exit`. The user can check `isShutDown` to know if + // `exitNonZero` is called. private[deploy] var isShutDown = false private[deploy] def setTesting(testing: Boolean) = isTesting = testing private var isTesting = false diff --git a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala index f16af0634abe..8997848155a1 100644 --- a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala +++ b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala @@ -66,7 +66,10 @@ private[spark] trait ActionScheduler { } private[spark] trait Cancellable { - // Should be reentrant + /** + * Cancel the corresponding work. The caller may call this method multiple times and call it in + * any thread. + */ def cancel(): Unit } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 5f05b9d25846..0083fb1e74c7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -84,7 +84,7 @@ private[spark] trait RpcEnv { /** * Create a URI used to create a [[RpcEndpointRef]] */ - def newURI(systemName: String, address: RpcAddress, endpointName: String): String + def uriOf(systemName: String, address: RpcAddress, endpointName: String): String } private[spark] case class RpcEnvConfig( diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index ba4c0ef4cbd6..19e0e3018fb0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -172,6 +172,10 @@ private[spark] class AkkaRpcEnv private ( } } + /** + * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will + * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. + */ private def safelyCall(endpoint: RpcEndpoint)(action: => Unit): Unit = { try { action @@ -204,10 +208,10 @@ private[spark] class AkkaRpcEnv private ( override def setupEndpointRef( systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByUrl(newURI(systemName, address, endpointName)) + setupEndpointRefByUrl(uriOf(systemName, address, endpointName)) } - override def newURI(systemName: String, address: RpcAddress, endpointName: String): String = { + override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { AkkaUtils.address( AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) } From 2a579f498b0b2aaeb3b995d2fe92d48787a876bd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 26 Feb 2015 20:01:19 +0800 Subject: [PATCH 03/31] Remove RpcEnv.systemName --- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 2 -- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 -- .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 16 ++++++++-------- .../apache/spark/rpc/akka/AkkaRpcEnvSuite.scala | 8 ++++---- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 0083fb1e74c7..8f1954ed7c0e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -35,8 +35,6 @@ private[spark] trait RpcEnv { def scheduler: ActionScheduler - def systemName: String - /** * Return the address that [[RpcEnv]] is listening to. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 19e0e3018fb0..d78d03398c8e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -55,8 +55,6 @@ private[spark] class AkkaRpcEnv private ( RpcAddress(address.host.getOrElse("localhost"), address.port.getOrElse(boundPort)) } - override val systemName: String = actorSystem.name - override val address: RpcAddress = defaultAddress /** diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index c9455f37cbd9..52cb65d5e86a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -37,7 +37,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { override def beforeAll(): Unit = { val conf = new SparkConf() - env = createRpcEnv(conf, 12345) + env = createRpcEnv(conf, "local", 12345) } override def afterAll(): Unit = { @@ -46,7 +46,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } - def createRpcEnv(conf: SparkConf, port: Int): RpcEnv + def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv test("send a message locally") { @volatile var message: String = null @@ -74,9 +74,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote" ,13345) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "send-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "send-remotely") try { rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(10 millis)) { @@ -128,9 +128,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } }) - val anotherEnv = createRpcEnv(new SparkConf(), 13345) + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "ask-remotely") + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-remotely") try { val reply = rpcEndpointRef.askWithReply[String]("hello") assert("hello" === reply) @@ -154,9 +154,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val conf = new SparkConf() conf.set("spark.akka.retry.wait", "0") conf.set("spark.akka.num.retries", "1") - val anotherEnv = createRpcEnv(conf, 13345) + val anotherEnv = createRpcEnv(conf, "remote", 13345) // Use anotherEnv to find out the RpcEndpointRef - val rpcEndpointRef = anotherEnv.setupEndpointRef(env.systemName, env.address, "ask-timeout") + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "ask-timeout") try { val e = intercept[Exception] { rpcEndpointRef.askWithReply[String]("hello", 1 millis) diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index a4c49a237a9f..c2989f6b2ca3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -22,8 +22,8 @@ import org.apache.spark.{SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { - override def createRpcEnv(conf: SparkConf, port: Int): RpcEnv = { - AkkaRpcEnv(RpcEnvConfig(conf, s"test-$port", "localhost", port, new SecurityManager(conf))) + override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { + AkkaRpcEnv(RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) } test("setupEndpointRef: systemName, address, endpointName") { @@ -38,8 +38,8 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val newRpcEnv = AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) try { - val newRef = newRpcEnv.setupEndpointRef(s"test-${env.address.port}", ref.address, "test_endpoint") - assert(s"akka.tcp://test-${env.address.port}@localhost:12345/user/test_endpoint" === + val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") + assert(s"akka.tcp://local@localhost:12345/user/test_endpoint" === newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) } finally { newRpcEnv.shutdown() From 04a106ed368aa3cd0eedefa5e04188ef47d0340c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 26 Feb 2015 20:07:52 +0800 Subject: [PATCH 04/31] Remove NopCancellable and add a const NOP in object SettableCancellable --- .../scala/org/apache/spark/rpc/ActionScheduler.scala | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala index 8997848155a1..c5755e922ce8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala +++ b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala @@ -73,13 +73,9 @@ private[spark] trait Cancellable { def cancel(): Unit } -private[rpc] object NopCancellable extends Cancellable { - override def cancel(): Unit = {} -} - private[rpc] class SettableCancellable extends Cancellable { - @volatile private var underlying: Cancellable = NopCancellable + @volatile private var underlying: Cancellable = SettableCancellable.NOP @volatile private var isCancelled = false @@ -97,6 +93,12 @@ private[rpc] class SettableCancellable extends Cancellable { } } +private[rpc] object SettableCancellable { + val NOP = new Cancellable { + override def cancel(): Unit = {} + } +} + private[spark] class ActionSchedulerImpl(conf: SparkConf) extends ActionScheduler with Logging { val maxIOThreadNumber = conf.getInt("spark.rpc.io.maxThreads", 1000) From 7b9e0c9ca34c04ec023538a79a9e2ee182e3afa8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 27 Feb 2015 20:07:14 +0800 Subject: [PATCH 05/31] Fix the indentation --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 8f1954ed7c0e..68405a927870 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -59,7 +59,7 @@ private[spark] trait RpcEnv { * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` */ def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef /** * Stop [[RpcEndpoint]] specified by `endpoint`. From fe7d1fffcca4f497d9ed49ae7469ff0cc30836c3 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 4 Mar 2015 17:13:17 +0800 Subject: [PATCH 06/31] Add explicit reply in rpc --- .../spark/deploy/worker/WorkerWatcher.scala | 43 ++-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 201 ++++++++++-------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 170 +++++++++++---- .../scheduler/OutputCommitCoordinator.scala | 8 +- .../deploy/worker/WorkerWatcherSuite.scala | 10 +- .../org/apache/spark/rpc/RpcEnvSuite.scala | 176 ++++++++++++--- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 2 +- 7 files changed, 419 insertions(+), 191 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 86f87b0b6015..2808a953641f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -26,7 +26,7 @@ import org.apache.spark.rpc._ * Provides fate sharing between a worker and its associated child processes. */ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) - extends NetworkRpcEndpoint with Logging { + extends RpcEndpoint with Logging { override def onStart() { logInfo(s"Connecting to worker $workerUrl") @@ -53,30 +53,25 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receive(sender: RpcEndpointRef) = { + override def receive = { + case AssociatedEvent(remoteAddress) => + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + case DisassociatedEvent(remoteAddress) => + if (isWorker(remoteAddress)) { + // This log message will never be seen + logError(s"Lost connection to worker actor $workerUrl. Exiting.") + exitNonZero() + } + case NetworkErrorEvent(remoteAddress, cause) => + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } case e => logWarning(s"Received unexpected actor system event: $e") } - override def onConnected(remoteAddress: RpcAddress): Unit = { - if (isWorker(remoteAddress)) { - logInfo(s"Successfully connected to $workerUrl") - } - } - - override def onDisconnected(remoteAddress: RpcAddress): Unit = { - if (isWorker(remoteAddress)) { - // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") - exitNonZero() - } - } - - override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - if (isWorker(remoteAddress)) { - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() - } - } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 68405a927870..e9569831a038 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -23,7 +23,7 @@ import scala.concurrent.Future import scala.concurrent.duration.FiniteDuration import scala.reflect.ClassTag -import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.{SparkException, SecurityManager, SparkConf} import org.apache.spark.util.Utils /** @@ -31,8 +31,15 @@ import org.apache.spark.util.Utils */ private[spark] trait RpcEnv { + /** + * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement + * [[RpcEndpoint.self]]. + */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef + /** + * Return an ActionScheduler for the caller to run long-time actions out of the current thread. + */ def scheduler: ActionScheduler /** @@ -41,10 +48,17 @@ private[spark] trait RpcEnv { def address: RpcAddress /** - * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] does not + * guarantee thread-safety. */ def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should + * make sure thread-safely sending messages to [[RpcEndpoint]]. + */ + def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef + /** * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. */ @@ -80,7 +94,8 @@ private[spark] trait RpcEnv { def awaitTermination(): Unit /** - * Create a URI used to create a [[RpcEndpointRef]] + * Create a URI used to create a [[RpcEndpointRef]]. Use this one to create the URI instead of + * creating it manually because different [[RpcEnv]] may have different formats. */ def uriOf(systemName: String, address: RpcAddress, endpointName: String): String } @@ -117,11 +132,11 @@ private[spark] object RpcEnv { } def create( - name: String, - host: String, - port: Int, - conf: SparkConf, - securityManager: SecurityManager): RpcEnv = { + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager) val companion = getRpcEnvCompanion(conf) @@ -141,8 +156,11 @@ private[spark] object RpcEnv { * * constructor onStart receive* onStop * - * If any error is thrown from one of RpcEndpoint methods except `onError`, [[RpcEndpoint.onError]] - * will be invoked with the cause. If onError throws an error, [[RpcEnv]] will ignore it. + * Note: `receive` can be called concurrently. If you want `receive` is thread-safe, please use + * [[RpcEnv.setupThreadSafeEndpoint]] + * + * If any error is thrown from one of [[RpcEndpoint]] methods except `onError`, `onError` will be + * invoked with the cause. If `onError` throws an error, [[RpcEnv]] will ignore it. */ private[spark] trait RpcEndpoint { @@ -152,25 +170,31 @@ private[spark] trait RpcEndpoint { val rpcEnv: RpcEnv /** - * Provide the implicit sender. `self` will become valid when `onStart` is called. + * The [[RpcEndpointRef]] of this [[RpcEndpoint]]. `self` will become valid when `onStart` is + * called. * * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not - * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. In the other - * words, don't call [[RpcEndpointRef.send]] in the constructor of [[RpcEndpoint]]. + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. */ - implicit final def self: RpcEndpointRef = { + final def self: RpcEndpointRef = { require(rpcEnv != null, "rpcEnv has not been initialized") rpcEnv.endpointRef(this) } /** - * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and - * messages from the same RpcEndpoint will be delivered in order. - * - * @param sender - * @return + * Process messages from [[RpcEndpointRef.send]] or [[RpcResponse.reply)]] */ - def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] + def receive: PartialFunction[Any, Unit] = { + case _ => + // network events will be passed here by default, so do nothing by default to avoid noise. + } + + /** + * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcResponse.replyWithSender)]] + */ + def receiveAndReply(response: RpcResponse): PartialFunction[Any, Unit] = { + case _ => response.fail(new SparkException(self + " won't reply anything")) + } /** * Call onError when any exception is thrown during handling messages. @@ -207,49 +231,6 @@ private[spark] trait RpcEndpoint { } } -/** - * A RpcEndoint interested in network events. - * - * [[NetworkRpcEndpoint]] will be guaranteed that `onStart`, `receive` , `onConnected`, - * `onDisconnected`, `onNetworkError` and `onStop` will be called in sequence. - * - * The lift-cycle will be: - * - * constructor onStart (receive|onConnected|onDisconnected|onNetworkError)* onStop - * - * If any error is thrown from `onConnected`, `onDisconnected` or `onNetworkError`, - * [[RpcEndpoint.onError)]] will be invoked with the cause. If onError throws an error, - * [[RpcEnv]] will ignore it. - */ -private[spark] trait NetworkRpcEndpoint extends RpcEndpoint { - - /** - * Invoked when `remoteAddress` is connected to the current node. - */ - def onConnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when `remoteAddress` is lost. - */ - def onDisconnected(remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } - - /** - * Invoked when some network error happens in the connection between the current node and - * `remoteAddress`. - */ - def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { - // By default, do nothing. - } -} - -private[spark] object RpcEndpoint { - final val noSender: RpcEndpointRef = null -} - /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ @@ -262,18 +243,6 @@ private[spark] trait RpcEndpointRef { def name: String - /** - * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply - * within a default timeout. - */ - def ask[T: ClassTag](message: Any): Future[T] - - /** - * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply - * within the specified timeout. - */ - def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] - /** * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default * timeout, or throw a SparkException if this fails even after the default number of retries. @@ -288,8 +257,9 @@ private[spark] trait RpcEndpointRef { def askWithReply[T: ClassTag](message: Any): T /** - * Send a message to the corresponding [[RpcEndpoint]] and get its result within a specified - * timeout, throw a SparkException if this fails even after the specified number of retries. + * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a + * specified timeout, throw a SparkException if this fails even after the specified number of + * retries. * * Note: this is a blocking action which may cost a lot of time, so don't call it in an message * loop of [[RpcEndpoint]]. @@ -303,14 +273,29 @@ private[spark] trait RpcEndpointRef { /** * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] asynchronously. + * Fire-and-forget semantics. * - * If invoked from within an [[RpcEndpoint]] then `self` is implicitly passed on as the implicit - * 'sender' argument. If not then no sender is available. - * - * This `sender` reference is then available in the receiving [[RpcEndpoint]] as the `sender` - * parameter of [[RpcEndpoint.receive]] + * The receiver will reply to sender's [[RpcEndpoint.receive]] or [[RpcEndpoint.receiveAndReply]] + * depending on which one of [[RpcResponse.reply]]s is called. */ - def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit + def sendWithReply(message: Any, sender: RpcEndpointRef): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] def toURI: URI } @@ -341,3 +326,53 @@ private[spark] object RpcAddress { RpcAddress(host, port) } } + +/** + * Indicate that a new connection is established. + * + * @param address the remote address of the connection + */ +private[spark] case class AssociatedEvent(address: RpcAddress) + +/** + * Indicate a disconnection from a remote address. + * + * @param address the remote address of the connection + */ +private[spark] case class DisassociatedEvent(address: RpcAddress) + +/** + * Indicate a network error. + * @param address the remote address of the connection which this error happens on. + * @param cause the cause of the network error. + */ +private[spark] case class NetworkErrorEvent(address: RpcAddress, cause: Throwable) + +/** + * A callback that [[RpcEndpoint]] can use it to send back a message or failure. + */ +private[spark] trait RpcResponse { + + /** + * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] + * will be called. + */ + def reply(response: Any): Unit + + /** + * Reply a message to the corresponding [[RpcEndpoint.receiveAndReply]]. If you use this one to + * reply, it means you expect the target [[RpcEndpoint]] should reply you something. + * + * TODO better method name? + * + * @param response the response message + * @param sender who replies this message. The target [[RpcEndpoint]] will use `sender` to send + * back something. + */ + def replyWithSender(response: Any, sender: RpcEndpointRef): Unit + + /** + * Report a failure to the sender. + */ + def fail(e: Throwable): Unit +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index d78d03398c8e..2ab065ed8a85 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -27,13 +27,14 @@ import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal -import _root_.akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} +import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} import akka.pattern.{ask => akkaAsk} -import akka.remote._ - -import org.apache.spark.{Logging, SparkConf} +import akka.remote.{AssociatedEvent => AkkaAssociatedEvent} +import akka.remote.{DisassociatedEvent => AkkaDisassociatedEvent} +import akka.remote.{AssociationErrorEvent, RemotingLifecycleEvent} +import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ -import org.apache.spark.util.{SparkUncaughtExceptionHandler, ActorLogReceive, AkkaUtils} +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} /** * A RpcEnv implementation based on Akka. @@ -92,6 +93,10 @@ private[spark] class AkkaRpcEnv private ( } override def setupEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { + setupThreadSafeEndpoint(name, endpoint) + } + + override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { val latch = new CountDownLatch(1) try { @volatile var endpointRef: AkkaRpcEndpointRef = null @@ -102,57 +107,84 @@ private[spark] class AkkaRpcEnv private ( require(endpointRef != null) registerEndpoint(endpoint, endpointRef) - var isNetworkRpcEndpoint = false - override def preStart(): Unit = { - if (endpoint.isInstanceOf[NetworkRpcEndpoint]) { - isNetworkRpcEndpoint = true - // Listen for remote client network events only when it's `NetworkRpcEndpoint` - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) safelyCall(endpoint) { endpoint.onStart() } } - override def receiveWithLogging: Receive = if (isNetworkRpcEndpoint) { - case AssociatedEvent(_, remoteAddress, _) => + override def receiveWithLogging: Receive = { + case AkkaAssociatedEvent(_, remoteAddress, _) => safelyCall(endpoint) { - endpoint.asInstanceOf[NetworkRpcEndpoint]. - onConnected(akkaAddressToRpcAddress(remoteAddress)) + val event = AssociatedEvent(akkaAddressToRpcAddress(remoteAddress)) + val pf = endpoint.receive + if (pf.isDefinedAt(event)) { + pf.apply(event) + } } - case DisassociatedEvent(_, remoteAddress, _) => + case AkkaDisassociatedEvent(_, remoteAddress, _) => safelyCall(endpoint) { - endpoint.asInstanceOf[NetworkRpcEndpoint]. - onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + val event = DisassociatedEvent(akkaAddressToRpcAddress(remoteAddress)) + val pf = endpoint.receive + if (pf.isDefinedAt(event)) { + pf.apply(event) + } } case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => safelyCall(endpoint) { - endpoint.asInstanceOf[NetworkRpcEndpoint]. - onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + val event = NetworkErrorEvent(akkaAddressToRpcAddress(remoteAddress), cause) + val pf = endpoint.receive + if (pf.isDefinedAt(event)) { + pf.apply(event) + } } case e: RemotingLifecycleEvent => // TODO ignore? - case message: Any => - logDebug("Received RPC message: " + message) - safelyCall(endpoint) { - val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) - } - } - } else { - case message: Any => - logDebug("Received RPC message: " + message) + case AkkaMessage(message: Any, reply: Boolean)=> + logDebug("Received RPC message: " + AkkaMessage(message, reply)) safelyCall(endpoint) { - val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) - if (pf.isDefinedAt(message)) { - pf.apply(message) + val s = sender() + val pf = + if (reply) { + endpoint.receiveAndReply(new RpcResponse { + override def fail(e: Throwable): Unit = { + s ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + s ! AkkaMessage(response, false) + } + + override def replyWithSender(response: Any, sender: RpcEndpointRef): Unit = { + s.!(AkkaMessage(response, true))( + sender.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + }) + } else { + endpoint.receive + } + try { + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => + if (reply) { + // If the sender asks a reply, we should send the error back to the sender + s ! AkkaFailure(e) + } else { + throw e + } } } + case message: Any => { + logWarning(s"Unknown message: $message") + } } override def postStop(): Unit = { @@ -259,20 +291,39 @@ private[akka] class AkkaRpcEndpointRef( override val name: String = actorRef.path.name - override def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultTimeout) - - override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { - actorRef.ask(message)(timeout).mapTo[T] - } - override def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) override def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts - AkkaUtils.askWithReply(message, actorRef, maxRetries, retryWaitMs, timeout) + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) } - override def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit = { + override def send(message: Any): Unit = { + actorRef ! AkkaMessage(message, false) + } + + override def sendWithReply(message: Any, sender: RpcEndpointRef): Unit = { implicit val actorSender: ActorRef = if (sender == null) { Actor.noSender @@ -280,10 +331,41 @@ private[akka] class AkkaRpcEndpointRef( require(sender.isInstanceOf[AkkaRpcEndpointRef]) sender.asInstanceOf[AkkaRpcEndpointRef].actorRef } - actorRef ! message + actorRef ! AkkaMessage(message, true) + } + + + override def sendWithReply[T: ClassTag](message: Any): Future[T] = { + sendWithReply(message, defaultTimeout) + } + + override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { + case AkkaMessage(message, reply) => + if (reply) { + Future.failed(new SparkException("The sender cannot reply")) + } else { + Future.successful(message) + } + case AkkaFailure(e) => + Future.failed(e) + }.mapTo[T] } override def toString: String = s"${getClass.getSimpleName}($actorRef)" override def toURI: URI = new URI(actorRef.path.toString) } + +/** + * A wrapper to `message` so that the receiver knows if the sender expects a reply. + * @param message + * @param reply if the sender expects a reply message + */ +private[akka] case class AkkaMessage(message: Any, reply: Boolean) + +/** + * A reply with the failure error from the receiver to the sender + */ +private[akka] case class AkkaFailure(e: Throwable) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index e60cff3d0692..4f3bc490233b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcResponse, RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -156,13 +156,13 @@ private[spark] object OutputCommitCoordinator { override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { - override def receive(sender: RpcEndpointRef) = { + override def receiveAndReply(response: RpcResponse) = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - sender.send( + response.reply( outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - sender.send(true) + response.reply(true) stop() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 864753b63948..fa7d7132ce8f 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import akka.actor.AddressFromURIString import org.apache.spark.SparkConf import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, DisassociatedEvent} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { @@ -32,8 +32,8 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) + workerWatcher.receive( + DisassociatedEvent(RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get))) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -47,8 +47,8 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.onDisconnected( - RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) + workerWatcher.receive( + DisassociatedEvent(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get))) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 52cb65d5e86a..d49f9bf54653 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -20,13 +20,14 @@ package org.apache.spark.rpc import java.util.concurrent.{TimeUnit, CountDownLatch, TimeoutException} import scala.collection.mutable +import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.postfixOps import org.scalatest.{BeforeAndAfterAll, FunSuite} import org.scalatest.concurrent.Eventually._ -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} /** * Common tests for an RpcEnv implementation. @@ -53,7 +54,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("send-locally", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case msg: String => message = msg } }) @@ -69,7 +70,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("send-remotely", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case msg: String => message = msg } }) @@ -84,6 +85,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } finally { anotherEnv.shutdown() + anotherEnv.awaitTermination() } } @@ -91,9 +93,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpoint = new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { - case "Hello" => sender.send(self) - case "Echo" => sender.send("Echo") + override def receiveAndReply(response: RpcResponse) = { + case "Hello" => response.reply(self) + case "Echo" => response.reply("Echo") } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) @@ -107,9 +109,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receiveAndReply(response: RpcResponse) = { case msg: String => { - sender.send(msg) + response.reply(msg) } } }) @@ -121,9 +123,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-remotely", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receiveAndReply(response: RpcResponse) = { case msg: String => { - sender.send(msg) + response.reply(msg) } } }) @@ -136,6 +138,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { assert("hello" === reply) } finally { anotherEnv.shutdown() + anotherEnv.awaitTermination() } } @@ -143,10 +146,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-timeout", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receiveAndReply(response: RpcResponse) = { case msg: String => { Thread.sleep(100) - sender.send(msg) + response.reply(msg) } } }) @@ -164,6 +167,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { assert(e.isInstanceOf[TimeoutException] || e.getCause.isInstanceOf[TimeoutException]) } finally { anotherEnv.shutdown() + anotherEnv.awaitTermination() } } @@ -171,26 +175,26 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val pongRef = env.setupEndpoint("pong", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { - case Ping(id) => sender.send(Pong(id)) + override def receiveAndReply(response: RpcResponse) = { + case Ping(id) => response.replyWithSender(Pong(id), self) } }) val pingRef = env.setupEndpoint("ping", new RpcEndpoint { override val rpcEnv = env - var requester: RpcEndpointRef = _ + var requester: RpcResponse = _ - override def receive(sender: RpcEndpointRef) = { + override def receiveAndReply(response: RpcResponse) = { case Start => { - requester = sender - pongRef.send(Ping(1)) + requester = response + pongRef.sendWithReply(Ping(1), self) } case p @ Pong(id) => { if (id < 10) { - sender.send(Ping(id + 1)) + response.replyWithSender(Ping(id + 1), self) } else { - requester.send(p) + requester.reply(p) } } } @@ -211,7 +215,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { calledMethods += "start" } - override def receive(sender: RpcEndpointRef) = { + override def receive = { case msg: String => } @@ -236,7 +240,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { throw new RuntimeException("Oops!") } - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => } @@ -255,7 +259,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => } @@ -280,7 +284,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("onError-receive", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => throw new RuntimeException("Oops!") } @@ -307,7 +311,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { callSelfSuccessfully = true } - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => } }) @@ -324,7 +328,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("self-receive", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => { self callSelfSuccessfully = true @@ -346,7 +350,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("self-onStop", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => } @@ -372,10 +376,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { // If a RpcEnv implementation breaks the `receive` contract, hope this test can expose it for(i <- 0 until 100) { @volatile var result = 0 - val endpointRef = env.setupEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { + val endpointRef = env.setupThreadSafeEndpoint(s"receive-in-sequence-$i", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => result += 1 } @@ -404,7 +408,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("stop-reentrant", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case m => } @@ -421,6 +425,118 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { assert(onStopCount == 1) } } + + test("sendWithReply") { + val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(response: RpcResponse) = { + case m => response.reply("ack") + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely") { + env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(response: RpcResponse) = { + case m => response.reply("ack") + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef("local", env.address, "sendWithReply-remotely") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val ack = Await.result(f, 5 seconds) + assert("ack" === ack) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("sendWithReply: error") { + val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(response: RpcResponse) = { + case m => response.fail(new SparkException("Oops")) + } + }) + + val f = endpointRef.sendWithReply[String]("Hi") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + + env.stop(endpointRef) + } + + test("sendWithReply: remotely error") { + env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { + override val rpcEnv = env + + override def receiveAndReply(response: RpcResponse) = { + case msg: String => response.fail(new SparkException("Oops")) + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "sendWithReply-remotely-error") + try { + val f = rpcEndpointRef.sendWithReply[String]("hello") + val e = intercept[SparkException] { + Await.result(f, 5 seconds) + } + assert("Oops" === e.getMessage) + } finally { + anotherEnv.shutdown() + anotherEnv.awaitTermination() + } + } + + test("network events") { + val events = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { + override val rpcEnv = env + + override def receive = { + case "hello" => + case m => events += m + } + }) + + val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) + // Use anotherEnv to find out the RpcEndpointRef + val rpcEndpointRef = anotherEnv.setupEndpointRef( + "local", env.address, "network-events") + val remoteAddress = anotherEnv.address + rpcEndpointRef.send("hello") + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events === List(AssociatedEvent(remoteAddress))) + } + + anotherEnv.shutdown() + anotherEnv.awaitTermination() + eventually(timeout(5 seconds), interval(5 millis)) { + assert(events.size == 3) + assert(events(0) === AssociatedEvent(remoteAddress)) + assert(events(1).asInstanceOf[NetworkErrorEvent].address === remoteAddress) + assert(events(2) === DisassociatedEvent(remoteAddress)) + } + } } case object Start diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index c2989f6b2ca3..99df2bd4e174 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -30,7 +30,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { override val rpcEnv = env - override def receive(sender: RpcEndpointRef) = { + override def receive = { case _ => } }) From 3751c973b7c71553e50032a5bbb4ce5d837d1cea Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 10:25:37 +0800 Subject: [PATCH 07/31] Rename RpcResponse to RpcCallContext --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 10 ++++----- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 +- .../scheduler/OutputCommitCoordinator.scala | 4 ++-- .../org/apache/spark/rpc/RpcEnvSuite.scala | 22 +++++++++---------- 4 files changed, 19 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e9569831a038..2bcb892b991b 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -182,7 +182,7 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcResponse.reply)]] + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]] */ def receive: PartialFunction[Any, Unit] = { case _ => @@ -190,9 +190,9 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcResponse.replyWithSender)]] + * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcCallContext.replyWithSender)]] */ - def receiveAndReply(response: RpcResponse): PartialFunction[Any, Unit] = { + def receiveAndReply(response: RpcCallContext): PartialFunction[Any, Unit] = { case _ => response.fail(new SparkException(self + " won't reply anything")) } @@ -281,7 +281,7 @@ private[spark] trait RpcEndpointRef { * Fire-and-forget semantics. * * The receiver will reply to sender's [[RpcEndpoint.receive]] or [[RpcEndpoint.receiveAndReply]] - * depending on which one of [[RpcResponse.reply]]s is called. + * depending on which one of [[RpcCallContext.reply]]s is called. */ def sendWithReply(message: Any, sender: RpcEndpointRef): Unit @@ -351,7 +351,7 @@ private[spark] case class NetworkErrorEvent(address: RpcAddress, cause: Throwabl /** * A callback that [[RpcEndpoint]] can use it to send back a message or failure. */ -private[spark] trait RpcResponse { +private[spark] trait RpcCallContext { /** * Reply a message to the sender. If the sender is [[RpcEndpoint]], its [[RpcEndpoint.receive]] diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 2ab065ed8a85..2f26167ceefc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -151,7 +151,7 @@ private[spark] class AkkaRpcEnv private ( val s = sender() val pf = if (reply) { - endpoint.receiveAndReply(new RpcResponse { + endpoint.receiveAndReply(new RpcCallContext { override def fail(e: Throwable): Unit = { s ! AkkaFailure(e) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 4f3bc490233b..5237649e9377 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -20,7 +20,7 @@ package org.apache.spark.scheduler import scala.collection.mutable import org.apache.spark._ -import org.apache.spark.rpc.{RpcResponse, RpcEndpointRef, RpcEnv, RpcEndpoint} +import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, RpcEndpoint} private sealed trait OutputCommitCoordinationMessage extends Serializable @@ -156,7 +156,7 @@ private[spark] object OutputCommitCoordinator { override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => response.reply( outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index d49f9bf54653..7fdf280f9f83 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -93,7 +93,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpoint = new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case "Hello" => response.reply(self) case "Echo" => response.reply("Echo") } @@ -109,7 +109,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case msg: String => { response.reply(msg) } @@ -123,7 +123,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case msg: String => { response.reply(msg) } @@ -146,7 +146,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-timeout", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case msg: String => { Thread.sleep(100) response.reply(msg) @@ -175,7 +175,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val pongRef = env.setupEndpoint("pong", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case Ping(id) => response.replyWithSender(Pong(id), self) } }) @@ -183,9 +183,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val pingRef = env.setupEndpoint("ping", new RpcEndpoint { override val rpcEnv = env - var requester: RpcResponse = _ + var requester: RpcCallContext = _ - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case Start => { requester = response pongRef.sendWithReply(Ping(1), self) @@ -430,7 +430,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case m => response.reply("ack") } }) @@ -446,7 +446,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case m => response.reply("ack") } }) @@ -468,7 +468,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case m => response.fail(new SparkException("Oops")) } }) @@ -486,7 +486,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcResponse) = { + override def receiveAndReply(response: RpcCallContext) = { case msg: String => response.fail(new SparkException("Oops")) } }) From 28e6d0f10a68094c99bb99f5dd4888ee4faf069b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 10:55:48 +0800 Subject: [PATCH 08/31] Add onXXX for network events and remove the companion objects of network events --- .../spark/deploy/worker/WorkerWatcher.scala | 39 +++++++++------- .../scala/org/apache/spark/rpc/RpcEnv.scala | 46 +++++++++---------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 33 ++++--------- .../deploy/worker/WorkerWatcherSuite.scala | 9 ++-- .../org/apache/spark/rpc/RpcEnvSuite.scala | 27 ++++++++--- 5 files changed, 79 insertions(+), 75 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 2808a953641f..5aee74aa63ed 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -54,24 +54,29 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) override def receive = { - case AssociatedEvent(remoteAddress) => - if (isWorker(remoteAddress)) { - logInfo(s"Successfully connected to $workerUrl") - } - case DisassociatedEvent(remoteAddress) => - if (isWorker(remoteAddress)) { - // This log message will never be seen - logError(s"Lost connection to worker actor $workerUrl. Exiting.") - exitNonZero() - } - case NetworkErrorEvent(remoteAddress, cause) => - if (isWorker(remoteAddress)) { - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() - } case e => logWarning(s"Received unexpected actor system event: $e") } + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // This log message will never be seen + logError(s"Lost connection to worker actor $workerUrl. Exiting.") + exitNonZero() + } + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } + } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 2bcb892b991b..e79c639c0ce6 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -185,8 +185,7 @@ private[spark] trait RpcEndpoint { * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]] */ def receive: PartialFunction[Any, Unit] = { - case _ => - // network events will be passed here by default, so do nothing by default to avoid noise. + case _ => throw new SparkException(self + " does not implement 'receive'") } /** @@ -220,6 +219,28 @@ private[spark] trait RpcEndpoint { // By default, do nothing. } + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + /** * An convenient method to stop [[RpcEndpoint]]. */ @@ -327,27 +348,6 @@ private[spark] object RpcAddress { } } -/** - * Indicate that a new connection is established. - * - * @param address the remote address of the connection - */ -private[spark] case class AssociatedEvent(address: RpcAddress) - -/** - * Indicate a disconnection from a remote address. - * - * @param address the remote address of the connection - */ -private[spark] case class DisassociatedEvent(address: RpcAddress) - -/** - * Indicate a network error. - * @param address the remote address of the connection which this error happens on. - * @param cause the cause of the network error. - */ -private[spark] case class NetworkErrorEvent(address: RpcAddress, cause: Throwable) - /** * A callback that [[RpcEndpoint]] can use it to send back a message or failure. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 2f26167ceefc..19ec1bf60cb8 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -29,9 +29,7 @@ import scala.util.control.NonFatal import akka.actor.{ActorSystem, ExtendedActorSystem, Actor, ActorRef, Props, Address} import akka.pattern.{ask => akkaAsk} -import akka.remote.{AssociatedEvent => AkkaAssociatedEvent} -import akka.remote.{DisassociatedEvent => AkkaDisassociatedEvent} -import akka.remote.{AssociationErrorEvent, RemotingLifecycleEvent} +import akka.remote.{AssociationEvent, AssociatedEvent, DisassociatedEvent, AssociationErrorEvent} import org.apache.spark.{SparkException, Logging, SparkConf} import org.apache.spark.rpc._ import org.apache.spark.util.{ActorLogReceive, AkkaUtils} @@ -109,41 +107,30 @@ private[spark] class AkkaRpcEnv private ( override def preStart(): Unit = { // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) safelyCall(endpoint) { endpoint.onStart() } } override def receiveWithLogging: Receive = { - case AkkaAssociatedEvent(_, remoteAddress, _) => + case AssociatedEvent(_, remoteAddress, _) => safelyCall(endpoint) { - val event = AssociatedEvent(akkaAddressToRpcAddress(remoteAddress)) - val pf = endpoint.receive - if (pf.isDefinedAt(event)) { - pf.apply(event) - } + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) } - case AkkaDisassociatedEvent(_, remoteAddress, _) => + case DisassociatedEvent(_, remoteAddress, _) => safelyCall(endpoint) { - val event = DisassociatedEvent(akkaAddressToRpcAddress(remoteAddress)) - val pf = endpoint.receive - if (pf.isDefinedAt(event)) { - pf.apply(event) - } + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) } case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => safelyCall(endpoint) { - val event = NetworkErrorEvent(akkaAddressToRpcAddress(remoteAddress), cause) - val pf = endpoint.receive - if (pf.isDefinedAt(event)) { - pf.apply(event) - } + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) } - case e: RemotingLifecycleEvent => - // TODO ignore? + + case e: AssociationEvent => + // TODO ignore? case AkkaMessage(message: Any, reply: Boolean)=> logDebug("Received RPC message: " + AkkaMessage(message, reply)) diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index fa7d7132ce8f..6a6f29dd613c 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.deploy.worker import akka.actor.AddressFromURIString import org.apache.spark.SparkConf import org.apache.spark.SecurityManager -import org.apache.spark.rpc.{RpcAddress, RpcEnv, DisassociatedEvent} +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { @@ -32,8 +32,8 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.receive( - DisassociatedEvent(RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get))) + workerWatcher.onDisconnected( + RpcAddress(targetWorkerAddress.host.get, targetWorkerAddress.port.get)) assert(workerWatcher.isShutDown) rpcEnv.shutdown() } @@ -47,8 +47,7 @@ class WorkerWatcherSuite extends FunSuite { val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) rpcEnv.setupEndpoint("worker-watcher", workerWatcher) - workerWatcher.receive( - DisassociatedEvent(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get))) + workerWatcher.onDisconnected(RpcAddress(otherAkkaAddress.host.get, otherAkkaAddress.port.get)) assert(!workerWatcher.isShutDown) rpcEnv.shutdown() } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 7fdf280f9f83..e3c59b5ff17b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -508,14 +508,27 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } test("network events") { - val events = new mutable.ArrayBuffer[Any] with mutable.SynchronizedBuffer[Any] + val events = new mutable.ArrayBuffer[(Any, Any)] with mutable.SynchronizedBuffer[(Any, Any)] env.setupThreadSafeEndpoint("network-events", new RpcEndpoint { override val rpcEnv = env override def receive = { case "hello" => - case m => events += m + case m => events += "receive" -> m } + + override def onConnected(remoteAddress: RpcAddress): Unit = { + events += "onConnected" -> remoteAddress + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + events += "onDisconnected" -> remoteAddress + } + + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + events += "onNetworkError" -> remoteAddress + } + }) val anotherEnv = createRpcEnv(new SparkConf(), "remote", 13345) @@ -525,16 +538,16 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val remoteAddress = anotherEnv.address rpcEndpointRef.send("hello") eventually(timeout(5 seconds), interval(5 millis)) { - assert(events === List(AssociatedEvent(remoteAddress))) + assert(events === List(("onConnected", remoteAddress))) } anotherEnv.shutdown() anotherEnv.awaitTermination() eventually(timeout(5 seconds), interval(5 millis)) { - assert(events.size == 3) - assert(events(0) === AssociatedEvent(remoteAddress)) - assert(events(1).asInstanceOf[NetworkErrorEvent].address === remoteAddress) - assert(events(2) === DisassociatedEvent(remoteAddress)) + assert(events === List( + ("onConnected", remoteAddress), + ("onNetworkError", remoteAddress), + ("onDisconnected", remoteAddress))) } } } From ffc1280c57708a05002133319c2e02c622a89068 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 11:48:19 +0800 Subject: [PATCH 09/31] Rename 'fail' to 'sendFailure' and other minor code style changes --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 4 ++-- .../scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 10 +++++----- .../test/scala/org/apache/spark/rpc/RpcEnvSuite.scala | 4 ++-- .../org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala | 2 +- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e79c639c0ce6..d41fa6ec3c19 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -192,7 +192,7 @@ private[spark] trait RpcEndpoint { * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcCallContext.replyWithSender)]] */ def receiveAndReply(response: RpcCallContext): PartialFunction[Any, Unit] = { - case _ => response.fail(new SparkException(self + " won't reply anything")) + case _ => response.sendFailure(new SparkException(self + " won't reply anything")) } /** @@ -374,5 +374,5 @@ private[spark] trait RpcCallContext { /** * Report a failure to the sender. */ - def fail(e: Throwable): Unit + def sendFailure(e: Throwable): Unit } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 19ec1bf60cb8..f90ac375e07e 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -20,9 +20,8 @@ package org.apache.spark.rpc.akka import java.net.URI import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} -import scala.concurrent.Await +import scala.concurrent.{Await, Future} import scala.concurrent.duration._ -import scala.concurrent.Future import scala.language.postfixOps import scala.reflect.ClassTag import scala.util.control.NonFatal @@ -139,7 +138,7 @@ private[spark] class AkkaRpcEnv private ( val pf = if (reply) { endpoint.receiveAndReply(new RpcCallContext { - override def fail(e: Throwable): Unit = { + override def sendFailure(e: Throwable): Unit = { s ! AkkaFailure(e) } @@ -329,9 +328,10 @@ private[akka] class AkkaRpcEndpointRef( override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { - case AkkaMessage(message, reply) => + case msg @ AkkaMessage(message, reply) => if (reply) { - Future.failed(new SparkException("The sender cannot reply")) + logError(s"Receive $msg but the sender cannot reply") + Future.failed(new SparkException(s"Receive $msg but the sender cannot reply")) } else { Future.successful(message) } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e3c59b5ff17b..e592cbeb7dd5 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -469,7 +469,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(response: RpcCallContext) = { - case m => response.fail(new SparkException("Oops")) + case m => response.sendFailure(new SparkException("Oops")) } }) @@ -487,7 +487,7 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { override val rpcEnv = env override def receiveAndReply(response: RpcCallContext) = { - case msg: String => response.fail(new SparkException("Oops")) + case msg: String => response.sendFailure(new SparkException("Oops")) } }) diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index 99df2bd4e174..c6139a1ad89a 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -39,7 +39,7 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") - assert(s"akka.tcp://local@localhost:12345/user/test_endpoint" === + assert("akka.tcp://local@localhost:12345/user/test_endpoint" === newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) } finally { newRpcEnv.shutdown() From 51e6667bc7d96705b5c3a093201cb9981d182f98 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 14:34:39 +0800 Subject: [PATCH 10/31] Add 'sender' to RpcCallContext and rename the parameter of receiveAndReply to 'context' --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 19 +++----- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 7 ++- .../scheduler/OutputCommitCoordinator.scala | 6 +-- .../org/apache/spark/rpc/RpcEnvSuite.scala | 44 +++++++++---------- 4 files changed, 34 insertions(+), 42 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index d41fa6ec3c19..2aa0f645ba57 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -191,8 +191,8 @@ private[spark] trait RpcEndpoint { /** * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcCallContext.replyWithSender)]] */ - def receiveAndReply(response: RpcCallContext): PartialFunction[Any, Unit] = { - case _ => response.sendFailure(new SparkException(self + " won't reply anything")) + def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case _ => context.sendFailure(new SparkException(self + " won't reply anything")) } /** @@ -360,19 +360,12 @@ private[spark] trait RpcCallContext { def reply(response: Any): Unit /** - * Reply a message to the corresponding [[RpcEndpoint.receiveAndReply]]. If you use this one to - * reply, it means you expect the target [[RpcEndpoint]] should reply you something. - * - * TODO better method name? - * - * @param response the response message - * @param sender who replies this message. The target [[RpcEndpoint]] will use `sender` to send - * back something. + * Report a failure to the sender. */ - def replyWithSender(response: Any, sender: RpcEndpointRef): Unit + def sendFailure(e: Throwable): Unit /** - * Report a failure to the sender. + * The sender of this message. */ - def sendFailure(e: Throwable): Unit + def sender: RpcEndpointRef } diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f90ac375e07e..842f93f7a5dc 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -146,10 +146,9 @@ private[spark] class AkkaRpcEnv private ( s ! AkkaMessage(response, false) } - override def replyWithSender(response: Any, sender: RpcEndpointRef): Unit = { - s.!(AkkaMessage(response, true))( - sender.asInstanceOf[AkkaRpcEndpointRef].actorRef) - } + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, s, conf) }) } else { endpoint.receive diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 5237649e9377..ce9903de0b3d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -156,13 +156,13 @@ private[spark] object OutputCommitCoordinator { override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { - override def receiveAndReply(response: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext) = { case AskPermissionToCommitOutput(stage, partition, taskAttempt) => - response.reply( + context.reply( outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskAttempt)) case StopCoordinator => logInfo("OutputCommitCoordinator stopped!") - response.reply(true) + context.reply(true) stop() } } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index e592cbeb7dd5..4ce409c1a47b 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -93,9 +93,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpoint = new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case "Hello" => response.reply(self) - case "Echo" => response.reply("Echo") + override def receiveAndReply(context: RpcCallContext) = { + case "Hello" => context.reply(self) + case "Echo" => context.reply("Echo") } } val rpcEndpointRef = env.setupEndpoint("send-ref", endpoint) @@ -109,9 +109,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val rpcEndpointRef = env.setupEndpoint("ask-locally", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext) = { case msg: String => { - response.reply(msg) + context.reply(msg) } } }) @@ -123,9 +123,9 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext) = { case msg: String => { - response.reply(msg) + context.reply(msg) } } }) @@ -146,10 +146,10 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("ask-timeout", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext) = { case msg: String => { Thread.sleep(100) - response.reply(msg) + context.reply(msg) } } }) @@ -175,8 +175,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val pongRef = env.setupEndpoint("pong", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case Ping(id) => response.replyWithSender(Pong(id), self) + override def receiveAndReply(context: RpcCallContext) = { + case Ping(id) => context.sender.sendWithReply(Pong(id), self) } }) @@ -185,14 +185,14 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { var requester: RpcCallContext = _ - override def receiveAndReply(response: RpcCallContext) = { + override def receiveAndReply(context: RpcCallContext) = { case Start => { - requester = response + requester = context pongRef.sendWithReply(Ping(1), self) } case p @ Pong(id) => { if (id < 10) { - response.replyWithSender(Ping(id + 1), self) + context.sender.sendWithReply(Ping(id + 1), self) } else { requester.reply(p) } @@ -430,8 +430,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case m => response.reply("ack") + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") } }) @@ -446,8 +446,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case m => response.reply("ack") + override def receiveAndReply(context: RpcCallContext) = { + case m => context.reply("ack") } }) @@ -468,8 +468,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { val endpointRef = env.setupEndpoint("sendWithReply-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case m => response.sendFailure(new SparkException("Oops")) + override def receiveAndReply(context: RpcCallContext) = { + case m => context.sendFailure(new SparkException("Oops")) } }) @@ -486,8 +486,8 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { env.setupEndpoint("sendWithReply-remotely-error", new RpcEndpoint { override val rpcEnv = env - override def receiveAndReply(response: RpcCallContext) = { - case msg: String => response.sendFailure(new SparkException("Oops")) + override def receiveAndReply(context: RpcCallContext) = { + case msg: String => context.sendFailure(new SparkException("Oops")) } }) From 7cdd95e7dd7856023e1779519e2b747c63100bed Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 14:50:14 +0800 Subject: [PATCH 11/31] Add docs for RpcEnv --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 2aa0f645ba57..13e912ad18ac 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -27,7 +27,11 @@ import org.apache.spark.{SparkException, SecurityManager, SparkConf} import org.apache.spark.util.Utils /** - * An RPC environment. + * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to + * receives messages. Then [[RpcEnv]] will process messages sent from [[RpcEndpointRef]] or remote + * nodes, and deliver them to corresponding [[RpcEndpoint]]s. + * + * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. */ private[spark] trait RpcEnv { @@ -189,7 +193,7 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.sendWithReply]] or [[RpcCallContext.replyWithSender)]] + * Process messages from [[RpcEndpointRef.sendWithReply]] */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) From 4d3419126b11ae2770a565d1867df86e86468333 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 16:24:58 +0800 Subject: [PATCH 12/31] Remove scheduler from RpcEnv --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 5 ----- .../main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 2 -- 2 files changed, 7 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 13e912ad18ac..62a65191e321 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -41,11 +41,6 @@ private[spark] trait RpcEnv { */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef - /** - * Return an ActionScheduler for the caller to run long-time actions out of the current thread. - */ - def scheduler: ActionScheduler - /** * Return the address that [[RpcEnv]] is listening to. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 842f93f7a5dc..120a70ad54ad 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -61,8 +61,6 @@ private[spark] class AkkaRpcEnv private ( */ private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() - override val scheduler = new ActionSchedulerImpl(conf) - /** * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` */ From 07f128fc14d522ff2313d7dde4e698bc916bdbc7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Fri, 6 Mar 2015 16:45:41 +0800 Subject: [PATCH 13/31] Remove ActionScheduler.scala --- .../apache/spark/rpc/ActionScheduler.scala | 212 ------------------ 1 file changed, 212 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala diff --git a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala b/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala deleted file mode 100644 index c5755e922ce8..000000000000 --- a/core/src/main/scala/org/apache/spark/rpc/ActionScheduler.scala +++ /dev/null @@ -1,212 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rpc - -import java.util.concurrent.atomic.AtomicReference -import java.util.concurrent.{SynchronousQueue, TimeUnit, ThreadPoolExecutor} - -import scala.concurrent.duration.FiniteDuration -import scala.concurrent.ExecutionContext -import scala.util.control.NonFatal - -import io.netty.util.{Timeout, TimerTask, HashedWheelTimer} - -import org.apache.spark.util.Utils -import org.apache.spark.{Logging, SparkConf} - -/** - * It's very common that executing some actions in other threads to avoid blocking the event loop - * in a RpcEndpoint. [[ActionScheduler]] is designed for such use cases. - */ -private[spark] trait ActionScheduler { - - /** - * Run the action in the IO thread pool. The thread name will be `name` when running this action. - */ - def executeIOAction(name: String)(action: => Unit): Unit - - /** - * Run the action in the CPU thread pool. The thread name will be `name` when running this action. - */ - def executeCPUAction(name: String)(action: => Unit): Unit - - /** - * Run the action after `delay`. The thread name will be `name` when running this action. - */ - def schedule(name: String, delay: FiniteDuration)(action: => Unit): Cancellable - - - /** - * Run the action every `interval`. The thread name will be `name` when running this action. - */ - def schedulePeriodically(name: String, interval: FiniteDuration)(action: => Unit): Cancellable = { - schedulePeriodically(name, interval, interval)(action) - } - - /** - * Run the action every `interval`. The thread name will be `name` when running this action. - */ - def schedulePeriodically( - name: String, delay: FiniteDuration, interval: FiniteDuration)(action: => Unit): Cancellable -} - -private[spark] trait Cancellable { - /** - * Cancel the corresponding work. The caller may call this method multiple times and call it in - * any thread. - */ - def cancel(): Unit -} - -private[rpc] class SettableCancellable extends Cancellable { - - @volatile private var underlying: Cancellable = SettableCancellable.NOP - - @volatile private var isCancelled = false - - // Should be called only once - def set(c: Cancellable): Unit = { - underlying = c - if (isCancelled) { - underlying.cancel() - } - } - - override def cancel(): Unit = { - isCancelled = true - underlying.cancel() - } -} - -private[rpc] object SettableCancellable { - val NOP = new Cancellable { - override def cancel(): Unit = {} - } -} - -private[spark] class ActionSchedulerImpl(conf: SparkConf) extends ActionScheduler with Logging { - - val maxIOThreadNumber = conf.getInt("spark.rpc.io.maxThreads", 1000) - - private val ioExecutor = new ThreadPoolExecutor( - 0, - maxIOThreadNumber, - 60L, - TimeUnit.SECONDS, - new SynchronousQueue[Runnable](), Utils.namedThreadFactory("rpc-io")) - - private val cpuExecutor = ExecutionContext.fromExecutorService(null, e => { - e match { - case NonFatal(e) => logError(e.getMessage, e) - case e => - Thread.getDefaultUncaughtExceptionHandler.uncaughtException(Thread.currentThread, e) - } - }) - - private val timer = new HashedWheelTimer(Utils.namedThreadFactory("rpc-timer")) - - // Need a name to distinguish between different actions because they use the same thread pool - override def executeIOAction(name: String)(action: => Unit): Unit = { - ioExecutor.execute(new Runnable { - - override def run(): Unit = { - val previousThreadName = Thread.currentThread().getName - Thread.currentThread().setName(name) - try { - action - } finally { - Thread.currentThread().setName(previousThreadName) - } - } - - }) - } - - // Need a name to distinguish between different actions because they use the same thread pool - override def executeCPUAction(name: String)(action: => Unit): Unit = { - cpuExecutor.execute(new Runnable { - - override def run(): Unit = { - val previousThreadName = Thread.currentThread().getName - Thread.currentThread().setName(name) - try { - action - } finally { - Thread.currentThread().setName(previousThreadName) - } - } - - }) - } - - def schedule(name: String, delay: FiniteDuration)(action: => Unit): Cancellable = { - val timeout = timer.newTimeout(new TimerTask { - - override def run(timeout: Timeout): Unit = { - val previousThreadName = Thread.currentThread().getName - Thread.currentThread().setName(name) - try { - action - } finally { - Thread.currentThread().setName(previousThreadName) - } - } - - }, delay.toNanos, TimeUnit.NANOSECONDS) - new Cancellable { - override def cancel(): Unit = timeout.cancel() - } - } - - override def schedulePeriodically( - name: String, delay: FiniteDuration, interval: FiniteDuration)(action: => Unit): - Cancellable = { - val initial = new SettableCancellable - val cancellable = new AtomicReference[SettableCancellable](initial) - def actionOnce: Unit = { - if (cancellable.get != null) { - action - val c = cancellable.get - if (c != null) { - val s = new SettableCancellable - if (cancellable.compareAndSet(c, s)) { - s.set(schedule(name, interval)(actionOnce)) - } else { - // has been cancelled - assert(cancellable.get == null) - } - } - } - } - initial.set(schedule(name, delay)(actionOnce)) - new Cancellable { - override def cancel(): Unit = { - var c = cancellable.get - while (c != null) { - if (cancellable.compareAndSet(c, null)) { - c.cancel() - return - } else { - // Already schedule another action, retry to cancel it - c = cancellable.get - } - } - } - } - } -} From 3e56123010d48d8e2a917589b53926da1bc0cad8 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 7 Mar 2015 01:32:14 +0800 Subject: [PATCH 14/31] Use lazy to eliminate CountDownLatch --- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 168 ++++++++++-------- 1 file changed, 90 insertions(+), 78 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 120a70ad54ad..9639d63b22d0 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -18,7 +18,7 @@ package org.apache.spark.rpc.akka import java.net.URI -import java.util.concurrent.{ConcurrentHashMap, CountDownLatch} +import java.util.concurrent.ConcurrentHashMap import scala.concurrent.{Await, Future} import scala.concurrent.duration._ @@ -92,97 +92,94 @@ private[spark] class AkkaRpcEnv private ( } override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = { - val latch = new CountDownLatch(1) - try { - @volatile var endpointRef: AkkaRpcEndpointRef = null - val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + @volatile var endpointRef: AkkaRpcEndpointRef = null + // Use lazy because the Actor needs to use `endpointRef`. + // So `actorRef` should be created after assigning `endpointRef`. + lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + require(endpointRef != null) + registerEndpoint(endpoint, endpointRef) + + override def preStart(): Unit = { + // Listen for remote client network events + context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + safelyCall(endpoint) { + endpoint.onStart() + } + } - // Wait until `endpointRef` is set. TODO better solution? - latch.await() - require(endpointRef != null) - registerEndpoint(endpoint, endpointRef) + override def receiveWithLogging: Receive = { + case AssociatedEvent(_, remoteAddress, _) => + safelyCall(endpoint) { + endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) + } - override def preStart(): Unit = { - // Listen for remote client network events - context.system.eventStream.subscribe(self, classOf[AssociationEvent]) + case DisassociatedEvent(_, remoteAddress, _) => safelyCall(endpoint) { - endpoint.onStart() + endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) } - } - override def receiveWithLogging: Receive = { - case AssociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress)) - } + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + safelyCall(endpoint) { + endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } - case DisassociatedEvent(_, remoteAddress, _) => - safelyCall(endpoint) { - endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress)) - } + case e: AssociationEvent => + // TODO ignore? - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => - safelyCall(endpoint) { - endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) - } + case AkkaMessage(message: Any, reply: Boolean)=> + logDebug("Received RPC message: " + AkkaMessage(message, reply)) + safelyCall(endpoint) { + val s = sender() + val pf = + if (reply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + s ! AkkaFailure(e) + } - case e: AssociationEvent => - // TODO ignore? + override def reply(response: Any): Unit = { + s ! AkkaMessage(response, false) + } - case AkkaMessage(message: Any, reply: Boolean)=> - logDebug("Received RPC message: " + AkkaMessage(message, reply)) - safelyCall(endpoint) { - val s = sender() - val pf = + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, s, conf) + }) + } else { + endpoint.receive + } + try { + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => if (reply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - s ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - s ! AkkaMessage(response, false) - } - - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, s, conf) - }) + // If the sender asks a reply, we should send the error back to the sender + s ! AkkaFailure(e) } else { - endpoint.receive - } - try { - if (pf.isDefinedAt(message)) { - pf.apply(message) + throw e } - } catch { - case NonFatal(e) => - if (reply) { - // If the sender asks a reply, we should send the error back to the sender - s ! AkkaFailure(e) - } else { - throw e - } - } } - case message: Any => { - logWarning(s"Unknown message: $message") } + case message: Any => { + logWarning(s"Unknown message: $message") } + } - override def postStop(): Unit = { - unregisterEndpoint(endpoint.self) - safelyCall(endpoint) { - endpoint.onStop() - } + override def postStop(): Unit = { + unregisterEndpoint(endpoint.self) + safelyCall(endpoint) { + endpoint.onStop() } + } - }), name = name) - endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf) - endpointRef - } finally { - latch.countDown() - } + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + // Now actorRef can be created safely + endpointRef.init() + endpointRef } /** @@ -258,21 +255,36 @@ private[spark] object AkkaRpcEnv { private[akka] class AkkaRpcEndpointRef( @transient defaultAddress: RpcAddress, - val actorRef: ActorRef, - @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { + @transient _actorRef: => ActorRef, + @transient conf: SparkConf, + @transient initInConstructor: Boolean = true) + extends RpcEndpointRef with Serializable with Logging { // `defaultAddress` and `conf` won't be used after initialization. So it's safe to be transient. private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds - override val address: RpcAddress = { + lazy val actorRef = _actorRef + + override lazy val address: RpcAddress = { val akkaAddress = actorRef.path.address RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), akkaAddress.port.getOrElse(defaultAddress.port)) } - override val name: String = actorRef.path.name + override lazy val name: String = actorRef.path.name + + private[akka] def init(): Unit = { + // Initialize the lazy vals + actorRef + address + name + } + + if (initInConstructor) { + init() + } override def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) From 5f87700f87c3c34f543a1691ebb12b4e1160bfcd Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 7 Mar 2015 01:41:26 +0800 Subject: [PATCH 15/31] Move the logical of processing message to a private function --- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 73 ++++++++++--------- 1 file changed, 39 insertions(+), 34 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 9639d63b22d0..f9b66b0a1dd4 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -127,41 +127,10 @@ private[spark] class AkkaRpcEnv private ( case e: AssociationEvent => // TODO ignore? - case AkkaMessage(message: Any, reply: Boolean)=> - logDebug("Received RPC message: " + AkkaMessage(message, reply)) + case m: AkkaMessage => + logDebug(s"Received RPC message: $m") safelyCall(endpoint) { - val s = sender() - val pf = - if (reply) { - endpoint.receiveAndReply(new RpcCallContext { - override def sendFailure(e: Throwable): Unit = { - s ! AkkaFailure(e) - } - - override def reply(response: Any): Unit = { - s ! AkkaMessage(response, false) - } - - // Some RpcEndpoints need to know the sender's address - override val sender: RpcEndpointRef = - new AkkaRpcEndpointRef(defaultAddress, s, conf) - }) - } else { - endpoint.receive - } - try { - if (pf.isDefinedAt(message)) { - pf.apply(message) - } - } catch { - case NonFatal(e) => - if (reply) { - // If the sender asks a reply, we should send the error back to the sender - s ! AkkaFailure(e) - } else { - throw e - } - } + processMessage(endpoint, m, sender) } case message: Any => { logWarning(s"Unknown message: $message") @@ -182,6 +151,42 @@ private[spark] class AkkaRpcEnv private ( endpointRef } + private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { + val message = m.message + val reply = m.reply + val pf = + if (reply) { + endpoint.receiveAndReply(new RpcCallContext { + override def sendFailure(e: Throwable): Unit = { + _sender ! AkkaFailure(e) + } + + override def reply(response: Any): Unit = { + _sender ! AkkaMessage(response, false) + } + + // Some RpcEndpoints need to know the sender's address + override val sender: RpcEndpointRef = + new AkkaRpcEndpointRef(defaultAddress, _sender, conf) + }) + } else { + endpoint.receive + } + try { + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => + if (reply) { + // If the sender asks a reply, we should send the error back to the sender + _sender ! AkkaFailure(e) + } else { + throw e + } + } + } + /** * Run `action` safely to avoid to crash the thread. If any non-fatal exception happens, it will * call `endpoint.onError`. If `endpoint.onError` throws any non-fatal exception, just log it. From c425022dd2709ffbb98d8ff5f026ad50e36278e0 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 10 Mar 2015 17:25:41 +0800 Subject: [PATCH 16/31] Fix the code style --- core/src/main/scala/org/apache/spark/SparkEnv.scala | 4 ++-- .../org/apache/spark/deploy/worker/DriverWrapper.scala | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 5203521e6d47..6f7144177135 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -34,8 +34,8 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService -import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.scheduler.{OutputCommitCoordinator, LiveListenerBus} import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint import org.apache.spark.serializer.Serializer @@ -56,7 +56,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val rpcEnv: RpcEnv, + private[spark] val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 00fdcc0922a7..ac40731fe95d 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -19,11 +19,9 @@ package org.apache.spark.deploy.worker import java.io.File -import akka.actor._ - -import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, ChildFirstURLClassLoader, MutableURLClassLoader, Utils} +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.{ChildFirstURLClassLoader, MutableURLClassLoader, Utils} /** * Utility object for launching driver programs such that they share fate with the Worker process. From 3007c093b41fba8cc457490f2e98ebb1271e5c63 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 10 Mar 2015 17:35:10 +0800 Subject: [PATCH 17/31] Move setupDriverEndpointRef to RpcUtils and rename to makeDriverRef --- .../scala/org/apache/spark/SparkEnv.scala | 4 +-- .../scala/org/apache/spark/rpc/RpcEnv.scala | 5 --- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 +--- .../org/apache/spark/util/RpcUtils.scala | 35 +++++++++++++++++++ 4 files changed, 38 insertions(+), 12 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/RpcUtils.scala diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 6f7144177135..4a2ed82a40de 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -41,7 +41,7 @@ import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinato import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{AkkaUtils, RpcUtils, Utils} /** * :: DeveloperApi :: @@ -300,7 +300,7 @@ object SparkEnv extends Logging { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) } else { - rpcEnv.setupDriverEndpointRef(name) + RpcUtils.makeDriverRef(name, conf, rpcEnv) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 62a65191e321..6a6df48e864c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -58,11 +58,6 @@ private[spark] trait RpcEnv { */ def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef - /** - * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. - */ - def setupDriverEndpointRef(name: String): RpcEndpointRef - /** * Retrieve the [[RpcEndpointRef]] represented by `url`. */ diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f9b66b0a1dd4..4d7039342ec3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -210,12 +210,8 @@ private[spark] class AkkaRpcEnv private ( address.port.getOrElse(defaultAddress.port)) } - override def setupDriverEndpointRef(name: String): RpcEndpointRef = { - new AkkaRpcEndpointRef(defaultAddress, AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) - } - override def setupEndpointRefByUrl(url: String): RpcEndpointRef = { - val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + val timeout = AkkaUtils.lookupTimeout(conf) val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) // TODO defaultAddress is wrong new AkkaRpcEndpointRef(defaultAddress, ref, conf) diff --git a/core/src/main/scala/org/apache/spark/util/RpcUtils.scala b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala new file mode 100644 index 000000000000..6665b17c3d5d --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/RpcUtils.scala @@ -0,0 +1,35 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import org.apache.spark.{SparkEnv, SparkConf} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv} + +object RpcUtils { + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def makeDriverRef(name: String, conf: SparkConf, rpcEnv: RpcEnv): RpcEndpointRef = { + val driverActorSystemName = SparkEnv.driverActorSystemName + val driverHost: String = conf.get("spark.driver.host", "localhost") + val driverPort: Int = conf.getInt("spark.driver.port", 7077) + Utils.checkHost(driverHost, "Expected hostname") + rpcEnv.setupEndpointRef(driverActorSystemName, RpcAddress(driverHost, driverPort), name) + } +} From 92884061d4bd7aee672cc1a02094a187f0b24b2d Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 10 Mar 2015 17:51:29 +0800 Subject: [PATCH 18/31] Document thread-safety for setupThreadSafeEndpoint --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 6a6df48e864c..285fee6899b5 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -55,6 +55,14 @@ private[spark] trait RpcEnv { /** * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. [[RpcEnv]] should * make sure thread-safely sending messages to [[RpcEndpoint]]. + * + * Thread-safety means processing of one message happens before processing of the next message by + * the same [[RpcEndpoint]]. In the other words, changes to internal fields of a [[RpcEndpoint]] + * are visible when processing the next message, and fields in the [[RpcEndpoint]] need not be + * volatile or equivalent. + * + * However, there is no guarantee that the same thread will be executing the same [[RpcEndpoint]] + * for different messages. */ def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef From 7fc95e17aef1ff20bea15377c06a0378c4094fc1 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 11 Mar 2015 14:40:46 +0800 Subject: [PATCH 19/31] Implement askWithReply in RpcEndpointRef --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 44 ++++++++++++++++--- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 40 +---------------- 2 files changed, 38 insertions(+), 46 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 285fee6899b5..680ca0ffb019 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -19,11 +19,12 @@ package org.apache.spark.rpc import java.net.URI -import scala.concurrent.Future -import scala.concurrent.duration.FiniteDuration +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.language.postfixOps import scala.reflect.ClassTag -import org.apache.spark.{SparkException, SecurityManager, SparkConf} +import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} import org.apache.spark.util.Utils /** @@ -257,7 +258,12 @@ private[spark] trait RpcEndpoint { /** * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. */ -private[spark] trait RpcEndpointRef { +private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) + extends Serializable with Logging { + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds /** * return the address for the [[RpcEndpointRef]] @@ -277,7 +283,7 @@ private[spark] trait RpcEndpointRef { * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithReply[T: ClassTag](message: Any): T + def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a @@ -292,7 +298,31 @@ private[spark] trait RpcEndpointRef { * @tparam T type of the reply message * @return the reply message from the corresponding [[RpcEndpoint]] */ - def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T + def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + var attempts = 0 + var lastException: Exception = null + while (attempts < maxRetries) { + attempts += 1 + try { + val future = sendWithReply[T](message, timeout) + val result = Await.result(future, timeout) + if (result == null) { + throw new SparkException("Actor returned null") + } + return result + } catch { + case ie: InterruptedException => throw ie + case e: Exception => + lastException = e + logWarning(s"Error sending message [message = $message] in $attempts attempts", e) + } + Thread.sleep(retryWaitMs) + } + + throw new SparkException( + s"Error sending message [message = $message]", lastException) + } /** * Sends a one-way asynchronous message. Fire-and-forget semantics. @@ -312,7 +342,7 @@ private[spark] trait RpcEndpointRef { * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to * receive the reply within a default timeout. */ - def sendWithReply[T: ClassTag](message: Any): Future[T] + def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 4d7039342ec3..c6bc9d1a91d2 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -259,12 +259,7 @@ private[akka] class AkkaRpcEndpointRef( @transient _actorRef: => ActorRef, @transient conf: SparkConf, @transient initInConstructor: Boolean = true) - extends RpcEndpointRef with Serializable with Logging { - // `defaultAddress` and `conf` won't be used after initialization. So it's safe to be transient. - - private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) - private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) - private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + extends RpcEndpointRef(conf) with Logging { lazy val actorRef = _actorRef @@ -287,34 +282,6 @@ private[akka] class AkkaRpcEndpointRef( init() } - override def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout) - - override def askWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): T = { - // TODO: Consider removing multiple attempts - var attempts = 0 - var lastException: Exception = null - while (attempts < maxRetries) { - attempts += 1 - try { - val future = sendWithReply[T](message, timeout) - val result = Await.result(future, timeout) - if (result == null) { - throw new SparkException("Actor returned null") - } - return result - } catch { - case ie: InterruptedException => throw ie - case e: Exception => - lastException = e - logWarning(s"Error sending message [message = $message] in $attempts attempts", e) - } - Thread.sleep(retryWaitMs) - } - - throw new SparkException( - s"Error sending message [message = $message]", lastException) - } - override def send(message: Any): Unit = { actorRef ! AkkaMessage(message, false) } @@ -330,11 +297,6 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, true) } - - override def sendWithReply[T: ClassTag](message: Any): Future[T] = { - sendWithReply(message, defaultTimeout) - } - override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { From ec7c5b06fb08ca6ddf8e561fbf2994242696f2dc Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 11 Mar 2015 15:22:22 +0800 Subject: [PATCH 20/31] Fix docs --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 680ca0ffb019..168c17b77e3f 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -152,8 +152,7 @@ private[spark] object RpcEnv { /** * An end point for the RPC that defines what functions to trigger given a message. * - * RpcEndpoint will be guaranteed that `onStart`, `receive` and `onStop` will - * be called in sequence. + * It is guaranteed that `onStart`, `receive` and `onStop` will be called in sequence. * * The lift-cycle will be: * @@ -245,7 +244,7 @@ private[spark] trait RpcEndpoint { } /** - * An convenient method to stop [[RpcEndpoint]]. + * A convenient method to stop [[RpcEndpoint]]. */ final def stop(): Unit = { val _self = self From e5df4ca8357c0c5f5f1dd46f26433d08d43fc252 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 11 Mar 2015 15:22:44 +0800 Subject: [PATCH 21/31] Handle AkkaFailure(e) in Actor --- .../main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index c6bc9d1a91d2..658056cb29de 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -132,6 +132,12 @@ private[spark] class AkkaRpcEnv private ( safelyCall(endpoint) { processMessage(endpoint, m, sender) } + case AkkaFailure(e) => + try { + endpoint.onError(e) + } catch { + case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + } case message: Any => { logWarning(s"Unknown message: $message") } From 08564ae75fbdf02897dd5bd452fc4dbcd44532e4 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Wed, 11 Mar 2015 15:48:11 +0800 Subject: [PATCH 22/31] Add RpcEnvFactory to create RpcEnv --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 36 +++++++++---------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 9 +++-- .../spark/rpc/akka/AkkaRpcEnvSuite.scala | 7 ++-- 3 files changed, 25 insertions(+), 27 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 168c17b77e3f..a1ce7f82f58a 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -111,27 +111,18 @@ private[spark] case class RpcEnvConfig( securityManager: SecurityManager) /** - * A RpcEnv implementation must have a companion object with an - * `apply(config: RpcEnvConfig): RpcEnv` method so that it can be created via Reflection. - * - * {{{ - * object MyCustomRpcEnv { - * def apply(config: RpcEnvConfig): RpcEnv = { - * ... - * } - * } - * }}} + * A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor + * so that it can be created via Reflection. */ private[spark] object RpcEnv { - private def getRpcEnvCompanion(conf: SparkConf): AnyRef = { + private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = { // Add more RpcEnv implementations here - val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv") + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory") val rpcEnvName = conf.get("spark.rpc", "akka") - val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) - val companion = Class.forName( - rpcEnvClassName + "$", true, Utils.getContextOrSparkClassLoader).getField("MODULE$").get(null) - companion + val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader). + newInstance().asInstanceOf[RpcEnvFactory] } def create( @@ -142,13 +133,20 @@ private[spark] object RpcEnv { securityManager: SecurityManager): RpcEnv = { // Using Reflection to create the RpcEnv to avoid to depend on Akka directly val config = RpcEnvConfig(conf, name, host, port, securityManager) - val companion = getRpcEnvCompanion(conf) - companion.getClass.getMethod("apply", classOf[RpcEnvConfig]). - invoke(companion, config).asInstanceOf[RpcEnv] + getRpcEnvFactory(conf).create(config) } } +/** + * A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be + * created using Reflection. + */ +private[spark] trait RpcEnvFactory { + + def create(config: RpcEnvConfig): RpcEnv +} + /** * An end point for the RPC that defines what functions to trigger given a message. * diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 658056cb29de..97fc21c1535c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils} * @param conf * @param boundPort */ -private[spark] class AkkaRpcEnv private ( +private[spark] class AkkaRpcEnv private[akka] ( val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) extends RpcEnv with Logging { private val defaultAddress: RpcAddress = { @@ -250,14 +250,13 @@ private[spark] class AkkaRpcEnv private ( override def toString = s"${getClass.getSimpleName}($actorSystem)" } -private[spark] object AkkaRpcEnv { +private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { - def apply(config: RpcEnvConfig): RpcEnv = { + def create(config: RpcEnvConfig): RpcEnv = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( - config.name, config.host, config.port, config.conf, config.securityManager) + config.name, config.host, config.port, config.conf, config.securityManager) new AkkaRpcEnv(actorSystem, config.conf, boundPort) } - } private[akka] class AkkaRpcEndpointRef( diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala index c6139a1ad89a..58214c063723 100644 --- a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -23,7 +23,8 @@ import org.apache.spark.{SecurityManager, SparkConf} class AkkaRpcEnvSuite extends RpcEnvSuite { override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = { - AkkaRpcEnv(RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) + new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf))) } test("setupEndpointRef: systemName, address, endpointName") { @@ -35,8 +36,8 @@ class AkkaRpcEnvSuite extends RpcEnvSuite { } }) val conf = new SparkConf() - val newRpcEnv = - AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + val newRpcEnv = new AkkaRpcEnvFactory().create( + RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) try { val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint") assert("akka.tcp://local@localhost:12345/user/test_endpoint" === From e8dfec329050d5db778e4698dde724e99aba0136 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Thu, 12 Mar 2015 09:56:11 +0800 Subject: [PATCH 23/31] Remove 'sendWithReply(message: Any, sender: RpcEndpointRef): Unit' --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 9 ----- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 11 ------- .../org/apache/spark/rpc/RpcEnvSuite.scala | 33 ------------------- 3 files changed, 53 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index a1ce7f82f58a..88215b513eaa 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -326,15 +326,6 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) */ def send(message: Any): Unit - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply]] asynchronously. - * Fire-and-forget semantics. - * - * The receiver will reply to sender's [[RpcEndpoint.receive]] or [[RpcEndpoint.receiveAndReply]] - * depending on which one of [[RpcCallContext.reply]]s is called. - */ - def sendWithReply(message: Any, sender: RpcEndpointRef): Unit - /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to * receive the reply within a default timeout. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 97fc21c1535c..901114085529 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -291,17 +291,6 @@ private[akka] class AkkaRpcEndpointRef( actorRef ! AkkaMessage(message, false) } - override def sendWithReply(message: Any, sender: RpcEndpointRef): Unit = { - implicit val actorSender: ActorRef = - if (sender == null) { - Actor.noSender - } else { - require(sender.isInstanceOf[AkkaRpcEndpointRef]) - sender.asInstanceOf[AkkaRpcEndpointRef].actorRef - } - actorRef ! AkkaMessage(message, true) - } - override def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { import scala.concurrent.ExecutionContext.Implicits.global actorRef.ask(AkkaMessage(message, true))(timeout).flatMap { diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 4ce409c1a47b..61ab2e43f830 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -171,39 +171,6 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } - test("ping pong") { - val pongRef = env.setupEndpoint("pong", new RpcEndpoint { - override val rpcEnv = env - - override def receiveAndReply(context: RpcCallContext) = { - case Ping(id) => context.sender.sendWithReply(Pong(id), self) - } - }) - - val pingRef = env.setupEndpoint("ping", new RpcEndpoint { - override val rpcEnv = env - - var requester: RpcCallContext = _ - - override def receiveAndReply(context: RpcCallContext) = { - case Start => { - requester = context - pongRef.sendWithReply(Ping(1), self) - } - case p @ Pong(id) => { - if (id < 10) { - context.sender.sendWithReply(Ping(id + 1), self) - } else { - requester.reply(p) - } - } - } - }) - - val reply = pingRef.askWithReply[Pong](Start) - assert(Pong(10) === reply) - } - test("onStart and onStop") { val stopLatch = new CountDownLatch(1) val calledMethods = mutable.ArrayBuffer[String]() From 2cc3f788b23d2ce3ba468bd603c70613481b1219 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 17 Mar 2015 10:45:01 +0800 Subject: [PATCH 24/31] Add an asynchronous version of setupEndpointRefByUrl --- .../spark/deploy/worker/WorkerWatcher.scala | 3 +- .../scala/org/apache/spark/rpc/RpcEnv.scala | 31 ++++++++++++++++--- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 18 ++++------- 3 files changed, 33 insertions(+), 19 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 5aee74aa63ed..ee3b19f902e5 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -31,8 +31,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin override def onStart() { logInfo(s"Connecting to worker $workerUrl") if (!isTesting) { - val worker = rpcEnv.setupEndpointRefByUrl(workerUrl) - worker.send(SendHeartbeat) // need to send a message here to initiate connection + rpcEnv.asyncSetupEndpointRefByUrl(workerUrl) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 88215b513eaa..114221919229 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -25,7 +25,7 @@ import scala.language.postfixOps import scala.reflect.ClassTag import org.apache.spark.{Logging, SparkException, SecurityManager, SparkConf} -import org.apache.spark.util.Utils +import org.apache.spark.util.{AkkaUtils, Utils} /** * An RPC environment. [[RpcEndpoint]]s need to register itself with a name to [[RpcEnv]] to @@ -34,7 +34,9 @@ import org.apache.spark.util.Utils * * [[RpcEnv]] also provides some methods to retrieve [[RpcEndpointRef]]s given name or uri. */ -private[spark] trait RpcEnv { +private[spark] abstract class RpcEnv(conf: SparkConf) { + + private[spark] val defaultLookupTimeout = AkkaUtils.lookupTimeout(conf) /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement @@ -68,15 +70,34 @@ private[spark] trait RpcEnv { def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef /** - * Retrieve the [[RpcEndpointRef]] represented by `url`. + * Retrieve the [[RpcEndpointRef]] represented by `url` asynchronously. */ - def setupEndpointRefByUrl(url: String): RpcEndpointRef + def asyncSetupEndpointRefByUrl(url: String): Future[RpcEndpointRef] + + /** + * Retrieve the [[RpcEndpointRef]] represented by `url`. This is a blocking action. + */ + def setupEndpointRefByUrl(url: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByUrl(url), defaultLookupTimeout) + } /** * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + * asynchronously. + */ + def asyncSetupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { + asyncSetupEndpointRefByUrl(uriOf(systemName, address, endpointName)) + } + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName`. + * This is a blocking action. */ def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByUrl(uriOf(systemName, address, endpointName)) + } /** * Stop [[RpcEndpoint]] specified by `endpoint`. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 901114085529..dc83271c4753 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -44,7 +44,8 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils} * @param boundPort */ private[spark] class AkkaRpcEnv private[akka] ( - val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) extends RpcEnv with Logging { + val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) + extends RpcEnv(conf) with Logging { private val defaultAddress: RpcAddress = { val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress @@ -216,16 +217,10 @@ private[spark] class AkkaRpcEnv private[akka] ( address.port.getOrElse(defaultAddress.port)) } - override def setupEndpointRefByUrl(url: String): RpcEndpointRef = { - val timeout = AkkaUtils.lookupTimeout(conf) - val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) - // TODO defaultAddress is wrong - new AkkaRpcEndpointRef(defaultAddress, ref, conf) - } - - override def setupEndpointRef( - systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByUrl(uriOf(systemName, address, endpointName)) + override def asyncSetupEndpointRefByUrl(url: String): Future[RpcEndpointRef] = { + import scala.concurrent.ExecutionContext.Implicits.global + actorSystem.actorSelection(url).resolveOne(defaultLookupTimeout). + map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) } override def uriOf(systemName: String, address: RpcAddress, endpointName: String): String = { @@ -233,7 +228,6 @@ private[spark] class AkkaRpcEnv private[akka] ( AkkaUtils.protocol(actorSystem), systemName, address.host, address.port, endpointName) } - override def shutdown(): Unit = { actorSystem.shutdown() } From 385b9c3e34d39adcddccfbeaeea29cbb0419270a Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 24 Mar 2015 14:09:41 +0800 Subject: [PATCH 25/31] Fix the code style and add docs --- .../spark/deploy/worker/WorkerWatcher.scala | 2 +- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 8 +++++++- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 15 +++++++++------ .../scala/org/apache/spark/rpc/RpcEnvSuite.scala | 1 - 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index ee3b19f902e5..3bf1dfd9b6c1 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -53,7 +53,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) override def receive = { - case e => logWarning(s"Received unexpected actor system event: $e") + case e => logWarning(s"Received unexpected message: $e") } override def onConnected(remoteAddress: RpcAddress): Unit = { diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 114221919229..e1b4e6dce41d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -293,6 +293,7 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) /** * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default * timeout, or throw a SparkException if this fails even after the default number of retries. + * Because this method retries, the message handling in the receiver side should be idempotent. * * Note: this is a blocking action which may cost a lot of time, so don't call it in an message * loop of [[RpcEndpoint]]. @@ -306,7 +307,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. + * retries. Because this method retries, the message handling in the receiver side should be + * idempotent. * * Note: this is a blocking action which may cost a lot of time, so don't call it in an message * loop of [[RpcEndpoint]]. @@ -350,12 +352,16 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. */ def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) /** * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. */ def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index dc83271c4753..09dd065e0dab 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -133,15 +133,18 @@ private[spark] class AkkaRpcEnv private[akka] ( safelyCall(endpoint) { processMessage(endpoint, m, sender) } + case AkkaFailure(e) => try { endpoint.onError(e) } catch { case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) } + case message: Any => { logWarning(s"Unknown message: $message") } + } override def postStop(): Unit = { @@ -160,9 +163,9 @@ private[spark] class AkkaRpcEnv private[akka] ( private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { val message = m.message - val reply = m.reply + val needReply = m.needReply val pf = - if (reply) { + if (needReply) { endpoint.receiveAndReply(new RpcCallContext { override def sendFailure(e: Throwable): Unit = { _sender ! AkkaFailure(e) @@ -185,7 +188,7 @@ private[spark] class AkkaRpcEnv private[akka] ( } } catch { case NonFatal(e) => - if (reply) { + if (needReply) { // If the sender asks a reply, we should send the error back to the sender _sender ! AkkaFailure(e) } else { @@ -241,7 +244,7 @@ private[spark] class AkkaRpcEnv private[akka] ( actorSystem.awaitTermination() } - override def toString = s"${getClass.getSimpleName}($actorSystem)" + override def toString: String = s"${getClass.getSimpleName}($actorSystem)" } private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory { @@ -308,9 +311,9 @@ private[akka] class AkkaRpcEndpointRef( /** * A wrapper to `message` so that the receiver knows if the sender expects a reply. * @param message - * @param reply if the sender expects a reply message + * @param needReply if the sender expects a reply message */ -private[akka] case class AkkaMessage(message: Any, reply: Boolean) +private[akka] case class AkkaMessage(message: Any, needReply: Boolean) /** * A reply with the failure error from the receiver to the sender diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 61ab2e43f830..e07bdb963757 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -192,7 +192,6 @@ abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { } } val rpcEndpointRef = env.setupEndpoint("start-stop-test", endpoint) - rpcEndpointRef.send("message") env.stop(rpcEndpointRef) stopLatch.await(10, TimeUnit.SECONDS) assert(List("start", "stop") === calledMethods) From 9ffa9979cc08b0338411465f736f030a67122304 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Tue, 24 Mar 2015 16:02:35 +0800 Subject: [PATCH 26/31] Fix MiMa tests --- .../apache/spark/scheduler/OutputCommitCoordinator.scala | 6 +++--- project/MimaExcludes.scala | 4 +++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala index 477e1b0b033f..f748f394d134 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala @@ -32,8 +32,8 @@ private case class AskPermissionToCommitOutput(stage: Int, task: Long, taskAttem * policy. * * OutputCommitCoordinator is instantiated in both the drivers and executors. On executors, it is - * configured with a reference to the driver's OutputCommitCoordinatorActor, so requests to commit - * output will be forwarded to the driver's OutputCommitCoordinator. + * configured with a reference to the driver's OutputCommitCoordinatorEndpoint, so requests to + * commit output will be forwarded to the driver's OutputCommitCoordinator. * * This class was introduced in SPARK-4879; see that JIRA issue (and the associated pull requests) * for an extensive design discussion. @@ -152,7 +152,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf) extends Logging { private[spark] object OutputCommitCoordinator { // This actor is used only for RPC - class OutputCommitCoordinatorEndpoint( + private[spark] class OutputCommitCoordinatorEndpoint( override val rpcEnv: RpcEnv, outputCommitCoordinator: OutputCommitCoordinator) extends RpcEndpoint with Logging { diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala index 56f5dbe53fad..36b224daa72a 100644 --- a/project/MimaExcludes.scala +++ b/project/MimaExcludes.scala @@ -50,7 +50,9 @@ object MimaExcludes { ProblemFilters.exclude[IncompatibleResultTypeProblem]( "org.apache.spark.broadcast.HttpBroadcastFactory.newBroadcast"), ProblemFilters.exclude[IncompatibleResultTypeProblem]( - "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast") + "org.apache.spark.broadcast.TorrentBroadcastFactory.newBroadcast"), + ProblemFilters.exclude[MissingClassProblem]( + "org.apache.spark.scheduler.OutputCommitCoordinator$OutputCommitCoordinatorActor") ) case v if v.startsWith("1.3") => From b221398cacee911e61f60b2117c22a83b7e7fac3 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 09:11:42 +0800 Subject: [PATCH 27/31] Move send methods above ask methods --- .../scala/org/apache/spark/rpc/RpcEnv.scala | 42 +++++++++---------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index e1b4e6dce41d..870d527a8643 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -290,6 +290,27 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) def name: String + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + */ + def send(message: Any): Unit + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within a default timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) + + /** + * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to + * receive the reply within the specified timeout. + * + * This method only sends the message once and never retries. + */ + def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + /** * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default * timeout, or throw a SparkException if this fails even after the default number of retries. @@ -344,27 +365,6 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) s"Error sending message [message = $message]", lastException) } - /** - * Sends a one-way asynchronous message. Fire-and-forget semantics. - */ - def send(message: Any): Unit - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within a default timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any): Future[T] = sendWithReply(message, defaultTimeout) - - /** - * Send a message to the corresponding [[RpcEndpoint.receiveAndReply)]] and return a `Future` to - * receive the reply within the specified timeout. - * - * This method only sends the message once and never retries. - */ - def sendWithReply[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] - def toURI: URI } From f459380683d3e69d4f0967529227aab4782d4269 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 09:15:09 +0800 Subject: [PATCH 28/31] Add RpcAddress.fromURI and rename urls to uris --- .../spark/deploy/worker/WorkerWatcher.scala | 8 +++---- .../scala/org/apache/spark/rpc/RpcEnv.scala | 24 ++++++++++++------- .../apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 ++-- 3 files changed, 20 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 9584571d25b0..83fb991891a4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -31,7 +31,7 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin override def onStart() { logInfo(s"Connecting to worker $workerUrl") if (!isTesting) { - rpcEnv.asyncSetupEndpointRefByUrl(workerUrl) + rpcEnv.asyncSetupEndpointRefByURI(workerUrl) } } @@ -45,10 +45,8 @@ private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: Strin private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = new java.net.URI(workerUrl) - private def isWorker(address: RpcAddress) = { - expectedHostPort.getHost == address.host && expectedHostPort.getPort == address.port - } + private val expectedAddress = RpcAddress.fromURIString(workerUrl) + private def isWorker(address: RpcAddress) = expectedAddress == address private def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 870d527a8643..ad9a2888d4c3 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -70,15 +70,15 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef /** - * Retrieve the [[RpcEndpointRef]] represented by `url` asynchronously. + * Retrieve the [[RpcEndpointRef]] represented by `uri` asynchronously. */ - def asyncSetupEndpointRefByUrl(url: String): Future[RpcEndpointRef] + def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] /** - * Retrieve the [[RpcEndpointRef]] represented by `url`. This is a blocking action. + * Retrieve the [[RpcEndpointRef]] represented by `uri`. This is a blocking action. */ - def setupEndpointRefByUrl(url: String): RpcEndpointRef = { - Await.result(asyncSetupEndpointRefByUrl(url), defaultLookupTimeout) + def setupEndpointRefByURI(uri: String): RpcEndpointRef = { + Await.result(asyncSetupEndpointRefByURI(uri), defaultLookupTimeout) } /** @@ -87,7 +87,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def asyncSetupEndpointRef( systemName: String, address: RpcAddress, endpointName: String): Future[RpcEndpointRef] = { - asyncSetupEndpointRefByUrl(uriOf(systemName, address, endpointName)) + asyncSetupEndpointRefByURI(uriOf(systemName, address, endpointName)) } /** @@ -96,7 +96,7 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { */ def setupEndpointRef( systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { - setupEndpointRefByUrl(uriOf(systemName, address, endpointName)) + setupEndpointRefByURI(uriOf(systemName, address, endpointName)) } /** @@ -381,12 +381,18 @@ private[spark] case class RpcAddress(host: String, port: Int) { private[spark] object RpcAddress { + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURI(uri: URI): RpcAddress = { + RpcAddress(uri.getHost, uri.getPort) + } + /** * Return the [[RpcAddress]] represented by `uri`. */ def fromURIString(uri: String): RpcAddress = { - val u = new java.net.URI(uri) - RpcAddress(u.getHost, u.getPort) + fromURI(new java.net.URI(uri)) } def fromSparkURL(sparkUrl: String): RpcAddress = { diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 09dd065e0dab..2684f023c933 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -220,9 +220,9 @@ private[spark] class AkkaRpcEnv private[akka] ( address.port.getOrElse(defaultAddress.port)) } - override def asyncSetupEndpointRefByUrl(url: String): Future[RpcEndpointRef] = { + override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { import scala.concurrent.ExecutionContext.Implicits.global - actorSystem.actorSelection(url).resolveOne(defaultLookupTimeout). + actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) } From 8bd10973798f80a30d7bf37ff3fdf40a953e5da7 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 10:00:28 +0800 Subject: [PATCH 29/31] Fix docs and the code style --- .../main/scala/org/apache/spark/rpc/RpcEnv.scala | 16 +++++++++++----- .../org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 16 +++++++--------- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index ad9a2888d4c3..3511a1086207 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -41,6 +41,9 @@ private[spark] abstract class RpcEnv(conf: SparkConf) { /** * Return RpcEndpointRef of the registered [[RpcEndpoint]]. Will be used to implement * [[RpcEndpoint.self]]. + * + * Note: This method won't return null. `IllegalArgumentException` will be thrown if calling this + * on a non-existent endpoint. */ private[rpc] def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef @@ -203,14 +206,16 @@ private[spark] trait RpcEndpoint { } /** - * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]] + * Process messages from [[RpcEndpointRef.send]] or [[RpcCallContext.reply)]]. If receiving a + * unmatched message, [[SparkException]] will be thrown and sent to `onError`. */ def receive: PartialFunction[Any, Unit] = { case _ => throw new SparkException(self + " does not implement 'receive'") } /** - * Process messages from [[RpcEndpointRef.sendWithReply]] + * Process messages from [[RpcEndpointRef.sendWithReply]]. If receiving a unmatched message, + * [[SparkException]] will be thrown and sent to `onError`. */ def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { case _ => context.sendFailure(new SparkException(self + " won't reply anything")) @@ -314,7 +319,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) /** * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default * timeout, or throw a SparkException if this fails even after the default number of retries. - * Because this method retries, the message handling in the receiver side should be idempotent. + * The default `timeout` will be used in every trial of calling `sendWithReply`. Because this + * method retries, the message handling in the receiver side should be idempotent. * * Note: this is a blocking action which may cost a lot of time, so don't call it in an message * loop of [[RpcEndpoint]]. @@ -328,8 +334,8 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) /** * Send a message to the corresponding [[RpcEndpoint.receive]] and get its result within a * specified timeout, throw a SparkException if this fails even after the specified number of - * retries. Because this method retries, the message handling in the receiver side should be - * idempotent. + * retries. `timeout` will be used in every trial of calling `sendWithReply`. Because this method + * retries, the message handling in the receiver side should be idempotent. * * Note: this is a blocking action which may cost a lot of time, so don't call it in an message * loop of [[RpcEndpoint]]. diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index 2684f023c933..f1a01df3f797 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -98,7 +98,7 @@ private[spark] class AkkaRpcEnv private[akka] ( // So `actorRef` should be created after assigning `endpointRef`. lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { - require(endpointRef != null) + assert(endpointRef != null) registerEndpoint(endpoint, endpointRef) override def preStart(): Unit = { @@ -135,10 +135,8 @@ private[spark] class AkkaRpcEnv private[akka] ( } case AkkaFailure(e) => - try { - endpoint.onError(e) - } catch { - case NonFatal(e) => logError(s"Ignore error: ${e.getMessage}", e) + safelyCall(endpoint) { + throw e } case message: Any => { @@ -164,7 +162,7 @@ private[spark] class AkkaRpcEnv private[akka] ( private def processMessage(endpoint: RpcEndpoint, m: AkkaMessage, _sender: ActorRef): Unit = { val message = m.message val needReply = m.needReply - val pf = + val pf: PartialFunction[Any, Unit] = if (needReply) { endpoint.receiveAndReply(new RpcCallContext { override def sendFailure(e: Throwable): Unit = { @@ -183,9 +181,9 @@ private[spark] class AkkaRpcEnv private[akka] ( endpoint.receive } try { - if (pf.isDefinedAt(message)) { - pf.apply(message) - } + pf.applyOrElse[Any, Unit](message, { message => + throw new SparkException(s"Unmatched message $message from ${_sender}") + }) } catch { case NonFatal(e) => if (needReply) { From f6f3287092097f40d8f159321a55b5164cb10968 Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 10:14:47 +0800 Subject: [PATCH 30/31] Remove RpcEndpointRef.toURI --- core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala | 1 - core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 1 - 2 files changed, 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala index 3511a1086207..7985941d949c 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -371,7 +371,6 @@ private[spark] abstract class RpcEndpointRef(@transient conf: SparkConf) s"Error sending message [message = $message]", lastException) } - def toURI: URI } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index f1a01df3f797..a5b2597f72bb 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -303,7 +303,6 @@ private[akka] class AkkaRpcEndpointRef( override def toString: String = s"${getClass.getSimpleName}($actorRef)" - override def toURI: URI = new URI(actorRef.path.toString) } /** From fe3df4cbd9efa052803f0c3d12544874b649728b Mon Sep 17 00:00:00 2001 From: zsxwing Date: Mon, 30 Mar 2015 10:39:20 +0800 Subject: [PATCH 31/31] Move registerEndpoint and use actorSystem.dispatcher in asyncSetupEndpointRefByURI --- .../src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala index a5b2597f72bb..769d59b7b334 100644 --- a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -99,7 +99,6 @@ private[spark] class AkkaRpcEnv private[akka] ( lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { assert(endpointRef != null) - registerEndpoint(endpoint, endpointRef) override def preStart(): Unit = { // Listen for remote client network events @@ -154,6 +153,7 @@ private[spark] class AkkaRpcEnv private[akka] ( }), name = name) endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false) + registerEndpoint(endpoint, endpointRef) // Now actorRef can be created safely endpointRef.init() endpointRef @@ -219,7 +219,7 @@ private[spark] class AkkaRpcEnv private[akka] ( } override def asyncSetupEndpointRefByURI(uri: String): Future[RpcEndpointRef] = { - import scala.concurrent.ExecutionContext.Implicits.global + import actorSystem.dispatcher actorSystem.actorSelection(uri).resolveOne(defaultLookupTimeout). map(new AkkaRpcEndpointRef(defaultAddress, _, conf)) }