diff --git a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala index 83ae57b7f151..da7bec3e5c7e 100644 --- a/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala +++ b/core/src/main/scala/org/apache/spark/HeartbeatReceiver.scala @@ -17,11 +17,10 @@ package org.apache.spark -import akka.actor.Actor import org.apache.spark.executor.TaskMetrics -import org.apache.spark.storage.BlockManagerId +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.TaskScheduler -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.storage.BlockManagerId /** * A heartbeat from executors to the driver. This is a shared message used by several internal @@ -37,13 +36,13 @@ private[spark] case class HeartbeatResponse(reregisterBlockManager: Boolean) /** * Lives in the driver to receive heartbeats from executors.. */ -private[spark] class HeartbeatReceiver(scheduler: TaskScheduler) - extends Actor with ActorLogReceive with Logging { +private[spark] class HeartbeatReceiver(override val rpcEnv: RpcEnv, scheduler: TaskScheduler) + extends RpcEndpoint { - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case Heartbeat(executorId, taskMetrics, blockManagerId) => val response = HeartbeatResponse( !scheduler.executorHeartbeatReceived(executorId, taskMetrics, blockManagerId)) - sender ! response + sender.send(response) } } diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 6e4edc7c80d7..df1c88802bf8 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,13 +21,10 @@ import java.io._ import java.util.concurrent.ConcurrentHashMap import java.util.zip.{GZIPInputStream, GZIPOutputStream} -import scala.collection.mutable.{HashSet, HashMap, Map} -import scala.concurrent.Await +import scala.collection.mutable.{HashSet, Map} import scala.collection.JavaConversions._ -import akka.actor._ -import akka.pattern.ask - +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.MapStatus import org.apache.spark.shuffle.MetadataFetchFailedException import org.apache.spark.storage.BlockManagerId @@ -39,14 +36,14 @@ private[spark] case class GetMapOutputStatuses(shuffleId: Int) private[spark] case object StopMapOutputTracker extends MapOutputTrackerMessage /** Actor class for MapOutputTrackerMaster */ -private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { +private[spark] class MapOutputTrackerMasterActor(override val rpcEnv: RpcEnv, + tracker: MapOutputTrackerMaster, conf: SparkConf) extends RpcEndpoint with Logging { val maxAkkaFrameSize = AkkaUtils.maxFrameSizeBytes(conf) - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case GetMapOutputStatuses(shuffleId: Int) => - val hostPort = sender.path.address.hostPort - logInfo("Asked to send map output locations for shuffle " + shuffleId + " to " + hostPort) + logInfo( + "Asked to send map output locations for shuffle " + shuffleId + " to " + sender) val mapOutputStatuses = tracker.getSerializedMapOutputStatuses(shuffleId) val serializedSize = mapOutputStatuses.size if (serializedSize > maxAkkaFrameSize) { @@ -60,12 +57,12 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster logError(msg, exception) throw exception } - sender ! mapOutputStatuses + sender.send(mapOutputStatuses) case StopMapOutputTracker => logInfo("MapOutputTrackerActor stopped!") - sender ! true - context.stop(self) + sender.send(true) + stop() } } @@ -75,12 +72,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster * (driver and executor) use different HashMap to store its metadata. */ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging { - private val timeout = AkkaUtils.askTimeout(conf) - private val retryAttempts = AkkaUtils.numRetries(conf) - private val retryIntervalMs = AkkaUtils.retryWaitMs(conf) /** Set to the MapOutputTrackerActor living on the driver. */ - var trackerActor: ActorRef = _ + var trackerActor: RpcEndpointRef = _ /** * This HashMap has different behavior for the driver and the executors. @@ -108,9 +102,9 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging * Send a message to the trackerActor and get its result within a default timeout, or * throw a SparkException if this fails. */ - protected def askTracker(message: Any): Any = { + protected def askTracker[T](message: Any): T = { try { - AkkaUtils.askWithReply(message, trackerActor, retryAttempts, retryIntervalMs, timeout) + trackerActor.askWithReply(message) } catch { case e: Exception => logError("Error communicating with MapOutputTracker", e) @@ -120,7 +114,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging /** Send a one-way message to the trackerActor, to which we expect it to reply with true. */ protected def sendTracker(message: Any) { - val response = askTracker(message) + val response = askTracker[Boolean](message) if (response != true) { throw new SparkException( "Error reply received from MapOutputTracker. Expecting true, got " + response.toString) @@ -160,8 +154,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging logInfo("Doing the fetch; tracker actor = " + trackerActor) // This try-finally prevents hangs due to timeouts: try { - val fetchedBytes = - askTracker(GetMapOutputStatuses(shuffleId)).asInstanceOf[Array[Byte]] + val fetchedBytes = askTracker[Array[Byte]](GetMapOutputStatuses(shuffleId)) fetchedStatuses = MapOutputTracker.deserializeMapStatuses(fetchedBytes) logInfo("Got the output locations") mapStatuses.put(shuffleId, fetchedStatuses) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index ff5d796ee276..06581d553f7d 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -36,7 +36,6 @@ import org.apache.hadoop.mapred.{FileInputFormat, InputFormat, JobConf, Sequence import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat, Job => NewHadoopJob} import org.apache.hadoop.mapreduce.lib.input.{FileInputFormat => NewFileInputFormat} import org.apache.mesos.MesosNativeLibrary -import akka.actor.Props import org.apache.spark.annotation.{DeveloperApi, Experimental} import org.apache.spark.broadcast.Broadcast @@ -323,8 +322,10 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli // Create and start the scheduler private[spark] var (schedulerBackend, taskScheduler) = SparkContext.createTaskScheduler(this, master) - private val heartbeatReceiver = env.actorSystem.actorOf( - Props(new HeartbeatReceiver(taskScheduler)), "HeartbeatReceiver") + + private val heartbeatReceiver = env.rpcEnv.setupEndpoint("HeartbeatReceiver", + new HeartbeatReceiver(env.rpcEnv, taskScheduler)) + @volatile private[spark] var dagScheduler: DAGScheduler = _ try { dagScheduler = new DAGScheduler(this) @@ -413,9 +414,8 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli Some(Utils.getThreadDump()) } else { val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get - val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem) - Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef, - AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf))) + val endpointRef = env.rpcEnv.setupDriverEndpointRef("ExecutorActor") + Some(endpointRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump)) } } catch { case e: Exception => @@ -1214,7 +1214,7 @@ class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationCli if (dagSchedulerCopy != null) { env.metricsSystem.report() metadataCleaner.cancel() - env.actorSystem.stop(heartbeatReceiver) + env.rpcEnv.stop(heartbeatReceiver) cleaner.foreach(_.stop()) dagSchedulerCopy.stop() taskScheduler = null diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 4d418037bd33..e06fb759104b 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -24,7 +24,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable import scala.util.Properties -import akka.actor._ import com.google.common.collect.MapMaker import org.apache.spark.annotation.DeveloperApi @@ -34,11 +33,13 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.network.BlockTransferService import org.apache.spark.network.netty.NettyBlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcEndpointRef, RpcEndpoint, RpcEnv} import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.{ShuffleMemoryManager, ShuffleManager} import org.apache.spark.storage._ -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * :: DeveloperApi :: @@ -53,7 +54,7 @@ import org.apache.spark.util.{AkkaUtils, Utils} @DeveloperApi class SparkEnv ( val executorId: String, - val actorSystem: ActorSystem, + val rpcEnv: RpcEnv, val serializer: Serializer, val closureSerializer: Serializer, val cacheManager: CacheManager, @@ -69,6 +70,9 @@ class SparkEnv ( val shuffleMemoryManager: ShuffleMemoryManager, val conf: SparkConf) extends Logging { + // TODO actorSystem is used by Streaming + val actorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem + private[spark] var isStopped = false private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() @@ -86,7 +90,7 @@ class SparkEnv ( blockManager.stop() blockManager.master.stop() metricsSystem.stop() - actorSystem.shutdown() + rpcEnv.stopAll() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. @@ -212,16 +216,14 @@ object SparkEnv extends Logging { val securityManager = new SecurityManager(conf) // Create the ActorSystem for Akka and get the port it binds to. - val (actorSystem, boundPort) = { - val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName - AkkaUtils.createActorSystem(actorSystemName, hostname, port, conf, securityManager) - } + val actorSystemName = if (isDriver) driverActorSystemName else executorActorSystemName + val rpcEnv = RpcEnv.create(actorSystemName, hostname, port, conf, securityManager) // Figure out which port Akka actually bound to in case the original port is 0 or occupied. if (isDriver) { - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) } else { - conf.set("spark.executor.port", boundPort.toString) + conf.set("spark.executor.port", rpcEnv.boundPort.toString) } // Create an instance of the class with the given name, possibly initializing it with our conf @@ -257,12 +259,12 @@ object SparkEnv extends Logging { val closureSerializer = instantiateClassFromConf[Serializer]( "spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer") - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookup(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + rpcEnv.setupEndpoint(name, endpointCreator) } else { - AkkaUtils.makeDriverRef(name, conf, actorSystem) + rpcEnv.setupDriverEndpointRef(name) } } @@ -274,9 +276,9 @@ object SparkEnv extends Logging { // Have to assign trackerActor after initialization as MapOutputTrackerActor // requires the MapOutputTracker itself - mapOutputTracker.trackerActor = registerOrLookup( - "MapOutputTracker", - new MapOutputTrackerMasterActor(mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + mapOutputTracker.trackerActor = registerOrLookup("MapOutputTracker", + new MapOutputTrackerMasterActor( + rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -298,10 +300,10 @@ object SparkEnv extends Logging { val blockManagerMaster = new BlockManagerMaster(registerOrLookup( "BlockManagerMaster", - new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver) + new BlockManagerMasterActor(rpcEnv, isLocal, conf, listenerBus)), conf, isDriver) // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster, + val blockManager = new BlockManager(executorId, rpcEnv, blockManagerMaster, serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) @@ -348,7 +350,7 @@ object SparkEnv extends Logging { new SparkEnv( executorId, - actorSystem, + rpcEnv, serializer, closureSerializer, cacheManager, diff --git a/core/src/main/scala/org/apache/spark/deploy/Client.scala b/core/src/main/scala/org/apache/spark/deploy/Client.scala index 7c1c831c248f..c803701fe9d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/Client.scala @@ -17,31 +17,24 @@ package org.apache.spark.deploy -import scala.concurrent._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} import org.apache.log4j.{Level, Logger} import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} +import org.apache.spark.util.Utils /** * Proxy that relays messages to the driver. */ -private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) - extends Actor with ActorLogReceive with Logging { - - var masterActor: ActorSelection = _ - val timeout = AkkaUtils.askTimeout(conf) +private class ClientActor(override val rpcEnv: RpcEnv, driverArgs: ClientArguments, conf: SparkConf) + extends NetworkRpcEndpoint with Logging { - override def preStart() = { - masterActor = context.actorSelection(Master.toAkkaUrl(driverArgs.master)) + var masterActor: RpcEndpointRef = _ - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + override def onStart() = { + masterActor = Master.toEndpointRef(rpcEnv, driverArgs.master) println(s"Sending ${driverArgs.cmd} command to ${driverArgs.master}") @@ -77,11 +70,11 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) driverArgs.supervise, command) - masterActor ! RequestSubmitDriver(driverDescription) + masterActor.send(RequestSubmitDriver(driverDescription)) case "kill" => val driverId = driverArgs.driverId - masterActor ! RequestKillDriver(driverId) + masterActor.send(RequestKillDriver(driverId)) } } @@ -90,9 +83,8 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(s"... waiting before polling master for driver state") Thread.sleep(5000) println("... polling master for driver state") - val statusFuture = (masterActor ? RequestDriverStatus(driverId))(timeout) - .mapTo[DriverStatusResponse] - val statusResponse = Await.result(statusFuture, timeout) + val statusResponse = + masterActor.askWithReply[DriverStatusResponse](RequestDriverStatus(driverId)) statusResponse.found match { case false => @@ -116,8 +108,7 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) } } - override def receiveWithLogging = { - + override def receive(sender: RpcEndpointRef) = { case SubmitDriverResponse(success, driverId, message) => println(message) if (success) pollAndReportStatus(driverId.get) else System.exit(-1) @@ -126,14 +117,17 @@ private class ClientActor(driverArgs: ClientArguments, conf: SparkConf) println(message) if (success) pollAndReportStatus(driverId) else System.exit(-1) - case DisassociatedEvent(_, remoteAddress, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - System.exit(-1) + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + System.exit(-1) + } - case AssociationErrorEvent(cause, _, remoteAddress, _, _) => - println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") - println(s"Cause was: $cause") - System.exit(-1) + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + println(s"Error connecting to master ${driverArgs.master} ($remoteAddress), exiting.") + println(s"Cause was: $cause") + System.exit(-1) } } @@ -157,13 +151,13 @@ object Client { conf.set("akka.loglevel", driverArgs.logLevel.toString.replace("WARN", "WARNING")) Logger.getRootLogger.setLevel(driverArgs.logLevel) - val (actorSystem, _) = AkkaUtils.createActorSystem( + val rpcEnv = RpcEnv.create( "driverClient", Utils.localHostName(), 0, conf, new SecurityManager(conf)) // Verify driverArgs.master is a valid url so that we can use it in ClientActor safely - Master.toAkkaUrl(driverArgs.master) - actorSystem.actorOf(Props(classOf[ClientActor], driverArgs, conf)) + Utils.extractHostPortFromSparkUrl(driverArgs.master) + rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv, driverArgs, conf)) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 9a7a113c9571..f23ae9a65a62 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -19,11 +19,10 @@ package org.apache.spark.deploy import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorSystem - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master +import org.apache.spark.rpc.RpcEnv import org.apache.spark.util.Utils /** @@ -37,22 +36,22 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I extends Logging { private val localHostname = Utils.localHostName() - private val masterActorSystems = ArrayBuffer[ActorSystem]() - private val workerActorSystems = ArrayBuffer[ActorSystem]() + private val masterActorSystems = ArrayBuffer[RpcEnv]() + private val workerActorSystems = ArrayBuffer[RpcEnv]() def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ val conf = new SparkConf(false) - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) + val (masterSystem, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) masterActorSystems += masterSystem - val masterUrl = "spark://" + localHostname + ":" + masterPort + val masterUrl = "spark://" + localHostname + ":" + masterSystem.boundPort val masters = Array(masterUrl) /* Start the Workers */ for (workerNum <- 1 to numWorkers) { - val (workerSystem, _) = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, + val workerSystem = Worker.startSystemAndActor(localHostname, 0, 0, coresPerWorker, memoryPerWorker, masters, null, Some(workerNum)) workerActorSystems += workerSystem } @@ -65,9 +64,9 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I // Stop the workers before the master so they don't get upset that it disconnected // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! // This is unfortunate, but for now we just comment it out. - workerActorSystems.foreach(_.shutdown()) + workerActorSystems.foreach(_.stopAll()) // workerActorSystems.foreach(_.awaitTermination()) - masterActorSystems.foreach(_.shutdown()) + masterActorSystems.foreach(_.stopAll()) // masterActorSystems.foreach(_.awaitTermination()) masterActorSystems.clear() workerActorSystems.clear() diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala index 39a7b0319b6a..d9db5bf6b232 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala @@ -17,105 +17,99 @@ package org.apache.spark.deploy.client -import java.util.concurrent.TimeoutException +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors, TimeoutException} -import scala.concurrent.Await import scala.concurrent.duration._ -import akka.actor._ -import akka.pattern.ask -import akka.remote.{AssociationErrorEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{ActorLogReceive, Utils, AkkaUtils} +import org.apache.spark.rpc._ +import org.apache.spark.util.Utils /** * Interface allowing applications to speak with a Spark deploy cluster. Takes a master URL, * an app description, and a listener for cluster events, and calls back the listener when various * events occur. - * - * @param masterUrls Each url should look like spark://host:port. */ private[spark] class AppClient( - actorSystem: ActorSystem, - masterUrls: Array[String], + rpcEnv: RpcEnv, + masterAddresses: Set[RpcAddress], appDescription: ApplicationDescription, listener: AppClientListener, conf: SparkConf) extends Logging { - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) + val REGISTRATION_TIMEOUT = 20.seconds.toMillis - val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 - var masterAddress: Address = null - var actor: ActorRef = null + var masterAddress: RpcAddress = null + var actor: RpcEndpointRef = null var appId: String = null var registered = false - var activeMasterUrl: String = null - class ClientActor extends Actor with ActorLogReceive with Logging { - var master: ActorSelection = null + class ClientActor(override val rpcEnv: RpcEnv) extends NetworkRpcEndpoint with Logging { + var master: RpcEndpointRef = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times - var registrationRetryTimer: Option[Cancellable] = None + var registrationRetryTimer: Option[ScheduledFuture[_]] = None - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val scheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("client-actor")) + + override def onStart() { try { registerWithMaster() } catch { case e: Exception => logWarning("Failed to connect to master", e) markDisconnected() - context.stop(self) + stop() } } def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterApplication(appDescription) + for (masterAddress <- masterAddresses) { + logInfo("Connecting to master " + masterAddress + "...") + val actor = Master.toEndpointRef(rpcEnv, masterAddress) + actor.send(RegisterApplication(appDescription)) } } def registerWithMaster() { tryRegisterAllMasters() - import context.dispatcher var retries = 0 registrationRetryTimer = Some { - context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { - Utils.tryOrExit { - retries += 1 - if (registered) { - registrationRetryTimer.foreach(_.cancel()) - } else if (retries >= REGISTRATION_RETRIES) { - markDead("All masters are unresponsive! Giving up.") - } else { - tryRegisterAllMasters() + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = { + Utils.tryOrExit { + retries += 1 + if (registered) { + registrationRetryTimer.foreach(_.cancel(true)) + } else if (retries >= REGISTRATION_RETRIES) { + markDead("All masters are unresponsive! Giving up.") + } else { + tryRegisterAllMasters() + } } } - } + }, REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT, TimeUnit.MILLISECONDS) } } def changeMaster(url: String) { - // activeMasterUrl is a valid Spark url since we receive it from master. - activeMasterUrl = url - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + // url is a valid Spark url since we receive it from master. + master = Master.toEndpointRef(rpcEnv, url) + masterAddress = master.address } - private def isPossibleMaster(remoteUrl: Address) = { - masterAkkaUrls.map(AddressFromURIString(_).hostPort).contains(remoteUrl.hostPort) + private def isPossibleMaster(remoteUrl: RpcAddress) = { + masterAddresses.contains(remoteUrl) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredApplication(appId_, masterUrl) => appId = appId_ registered = true @@ -124,13 +118,13 @@ private[spark] class AppClient( case ApplicationRemoved(message) => markDead("Master removed our application: %s".format(message)) - context.stop(self) + stop() case ExecutorAdded(id: Int, workerId: String, hostPort: String, cores: Int, memory: Int) => val fullId = appId + "/" + id logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort, cores)) - master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None) + master.send(ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)) listener.executorAdded(fullId, workerId, hostPort, cores, memory) case ExecutorUpdated(id, state, message, exitStatus) => @@ -145,19 +139,25 @@ private[spark] class AppClient( logInfo("Master has changed, new master is at " + masterUrl) changeMaster(masterUrl) alreadyDisconnected = false - sender ! MasterChangeAcknowledged(appId) + sender.send(MasterChangeAcknowledged(appId)) + + case StopAppClient => + markDead("Application has been stopped.") + sender.send(true) + stop() + } - case DisassociatedEvent(_, address, _) if address == masterAddress => + override def onDisconnected(address: RpcAddress): Unit = { + if (address == masterAddress) { logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() + } + } - case AssociationErrorEvent(cause, _, address, _, _) if isPossibleMaster(address) => - logWarning(s"Could not connect to $address: $cause") - - case StopAppClient => - markDead("Application has been stopped.") - sender ! true - context.stop(self) + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isPossibleMaster(remoteAddress)) { + logWarning(s"Could not connect to $remoteAddress: $cause") + } } /** @@ -177,23 +177,21 @@ private[spark] class AppClient( } } - override def postStop() { - registrationRetryTimer.foreach(_.cancel()) + override def onStop() { + registrationRetryTimer.foreach(_.cancel(true)) } } def start() { // Just launch an actor; it will call back into the listener. - actor = actorSystem.actorOf(Props(new ClientActor)) + actor = rpcEnv.setupEndpoint("client-actor", new ClientActor(rpcEnv)) } def stop() { if (actor != null) { try { - val timeout = AkkaUtils.askTimeout(conf) - val future = actor.ask(StopAppClient)(timeout) - Await.result(future, timeout) + actor.askWithReply(StopAppClient) } catch { case e: TimeoutException => logInfo("Stop request to Master timed out; it may already be shut down.") diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 88a0862b96af..165cc1d196bf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -17,9 +17,10 @@ package org.apache.spark.deploy.client +import org.apache.spark.rpc.{RpcAddress, RpcEnv} import org.apache.spark.{SecurityManager, SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils private[spark] object TestClient { @@ -46,13 +47,14 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val conf = new SparkConf - val (actorSystem, _) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, + val rpcEnv = RpcEnv.create("spark", Utils.localIpAddress, 0, conf = conf, securityManager = new SecurityManager(conf)) val desc = new ApplicationDescription("TestClient", Some(1), 512, Command("spark.deploy.client.TestExecutor", Seq(), Map(), Seq(), Seq(), Seq()), "ignored") val listener = new TestListener - val client = new AppClient(actorSystem, Array(url), desc, listener, new SparkConf) + val client = + new AppClient(rpcEnv, Set(RpcAddress.fromSparkURL(url)), desc, listener, new SparkConf) client.start() - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala index ede0a9dbefb8..12c0d94001b9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationInfo.scala @@ -22,10 +22,8 @@ import java.util.Date import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi import org.apache.spark.deploy.ApplicationDescription +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class ApplicationInfo( @@ -33,7 +31,7 @@ private[spark] class ApplicationInfo( val id: String, val desc: ApplicationDescription, val submitDate: Date, - val driver: ActorRef, + val driver: RpcEndpointRef, defaultCores: Int) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index d92d99310a58..04d0c9b0ccf2 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -20,17 +20,12 @@ package org.apache.spark.deploy.master import java.io.FileNotFoundException import java.net.URLEncoder import java.text.SimpleDateFormat +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ -import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.Serialization import akka.serialization.SerializationExtension import org.apache.hadoop.fs.Path @@ -44,18 +39,25 @@ import org.apache.spark.deploy.master.DriverState.DriverState import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem +import org.apache.spark.rpc.akka.AkkaRpcEnv +import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler.{EventLoggingListener, ReplayListenerBus} import org.apache.spark.ui.SparkUI -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class Master( + override val rpcEnv: RpcEnv, host: String, port: Int, webUiPort: Int, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging with LeaderElectable { + extends NetworkRpcEndpoint with Logging with LeaderElectable { - import context.dispatcher // to use Akka's scheduler.schedule() + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("check-worker-timeout")) + + // TODO hide the actor system + private def internalActorSystem = rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem val conf = new SparkConf val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf) @@ -69,12 +71,12 @@ private[spark] class Master( val workers = new HashSet[WorkerInfo] val idToWorker = new HashMap[String, WorkerInfo] - val addressToWorker = new HashMap[Address, WorkerInfo] + val addressToWorker = new HashMap[RpcAddress, WorkerInfo] val apps = new HashSet[ApplicationInfo] val idToApp = new HashMap[String, ApplicationInfo] - val actorToApp = new HashMap[ActorRef, ApplicationInfo] - val addressToApp = new HashMap[Address, ApplicationInfo] + val actorToApp = new HashMap[RpcEndpointRef, ApplicationInfo] + val addressToApp = new HashMap[RpcAddress, ApplicationInfo] val waitingApps = new ArrayBuffer[ApplicationInfo] val completedApps = new ArrayBuffer[ApplicationInfo] var nextAppNumber = 0 @@ -108,7 +110,7 @@ private[spark] class Master( var leaderElectionAgent: LeaderElectionAgent = _ - private var recoveryCompletionTask: Cancellable = _ + private var recoveryCompletionTask: ScheduledFuture[_] = _ // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app @@ -121,14 +123,15 @@ private[spark] class Master( throw new SparkException("spark.deploy.defaultCores must be positive") } - override def preStart() { + override def onStart() { logInfo("Starting Spark master at " + masterUrl) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.bind() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort - context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(CheckForWorkerTimeOut) + }, 0, WORKER_TIMEOUT, TimeUnit.MILLISECONDS) masterMetricsSystem.registerSource(masterSource) masterMetricsSystem.start() @@ -142,16 +145,16 @@ private[spark] class Master( case "ZOOKEEPER" => logInfo("Persisting recovery state to ZooKeeper") val zkFactory = - new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(context.system)) + new ZooKeeperRecoveryModeFactory(conf, SerializationExtension(internalActorSystem)) (zkFactory.createPersistenceEngine(), zkFactory.createLeaderElectionAgent(this)) case "FILESYSTEM" => val fsFactory = - new FileSystemRecoveryModeFactory(conf, SerializationExtension(context.system)) + new FileSystemRecoveryModeFactory(conf, SerializationExtension(internalActorSystem)) (fsFactory.createPersistenceEngine(), fsFactory.createLeaderElectionAgent(this)) case "CUSTOM" => val clazz = Class.forName(conf.get("spark.deploy.recoveryMode.factory")) val factory = clazz.getConstructor(conf.getClass, Serialization.getClass) - .newInstance(conf, SerializationExtension(context.system)) + .newInstance(conf, SerializationExtension(internalActorSystem)) .asInstanceOf[StandaloneRecoveryModeFactory] (factory.createPersistenceEngine(), factory.createLeaderElectionAgent(this)) case _ => @@ -161,17 +164,17 @@ private[spark] class Master( leaderElectionAgent = leaderElectionAgent_ } - override def preRestart(reason: Throwable, message: Option[Any]) { - super.preRestart(reason, message) // calls postStop()! - logError("Master actor restarted due to exception", reason) + override def onError(reason: Throwable) { + logError("Master actor is crashed due to exception", reason) + throw reason // throw it so that the master will be restarted } - override def postStop() { + override def onStop() { masterMetricsSystem.report() applicationMetricsSystem.report() // prevent the CompleteRecovery message sending to restarted master if (recoveryCompletionTask != null) { - recoveryCompletionTask.cancel() + recoveryCompletionTask.cancel(true) } webUi.stop() masterMetricsSystem.stop() @@ -181,14 +184,14 @@ private[spark] class Master( } override def electedLeader() { - self ! ElectedLeader + self.send(ElectedLeader) } override def revokedLeadership() { - self ! RevokedLeadership + self.send(RevokedLeadership) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case ElectedLeader => { val (storedApps, storedDrivers, storedWorkers) = persistenceEngine.readPersistedData() state = if (storedApps.isEmpty && storedDrivers.isEmpty && storedWorkers.isEmpty) { @@ -199,8 +202,9 @@ private[spark] class Master( logInfo("I have been elected leader! New state: " + state) if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedDrivers, storedWorkers) - recoveryCompletionTask = context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis, self, - CompleteRecovery) + recoveryCompletionTask = scheduler.schedule(new Runnable { + override def run(): Unit = self.send(CompleteRecovery) + }, WORKER_TIMEOUT, TimeUnit.MILLISECONDS) } } @@ -218,20 +222,20 @@ private[spark] class Master( if (state == RecoveryState.STANDBY) { // ignore, don't send response } else if (idToWorker.contains(id)) { - sender ! RegisterWorkerFailed("Duplicate worker ID") + sender.send(RegisterWorkerFailed("Duplicate worker ID")) } else { val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, sender, workerUiPort, publicAddress) if (registerWorker(worker)) { persistenceEngine.addWorker(worker) - sender ! RegisteredWorker(masterUrl, masterWebUiUrl) + sender.send(RegisteredWorker(masterUrl, masterWebUiUrl)) schedule() } else { - val workerAddress = worker.actor.path.address + val workerAddress = worker.actor.address logWarning("Worker registration failed. Attempted to re-register worker at same " + "address: " + workerAddress) - sender ! RegisterWorkerFailed("Attempted to re-register worker at same address: " - + workerAddress) + sender.send(RegisterWorkerFailed("Attempted to re-register worker at same address: " + + workerAddress)) } } } @@ -239,7 +243,7 @@ private[spark] class Master( case RequestSubmitDriver(description) => { if (state != RecoveryState.ALIVE) { val msg = s"Can only accept driver submissions in ALIVE state. Current state: $state." - sender ! SubmitDriverResponse(false, None, msg) + sender.send(SubmitDriverResponse(false, None, msg)) } else { logInfo("Driver submitted " + description.command.mainClass) val driver = createDriver(description) @@ -251,15 +255,15 @@ private[spark] class Master( // TODO: It might be good to instead have the submission client poll the master to determine // the current status of the driver. For now it's simply "fire and forget". - sender ! SubmitDriverResponse(true, Some(driver.id), - s"Driver successfully submitted as ${driver.id}") + sender.send(SubmitDriverResponse(true, Some(driver.id), + s"Driver successfully submitted as ${driver.id}")) } } case RequestKillDriver(driverId) => { if (state != RecoveryState.ALIVE) { val msg = s"Can only kill drivers in ALIVE state. Current state: $state." - sender ! KillDriverResponse(driverId, success = false, msg) + sender.send(KillDriverResponse(driverId, success = false, msg)) } else { logInfo("Asked to kill driver " + driverId) val driver = drivers.find(_.id == driverId) @@ -267,23 +271,23 @@ private[spark] class Master( case Some(d) => if (waitingDrivers.contains(d)) { waitingDrivers -= d - self ! DriverStateChanged(driverId, DriverState.KILLED, None) + self.send(DriverStateChanged(driverId, DriverState.KILLED, None)) } else { // We just notify the worker to kill the driver here. The final bookkeeping occurs // on the return path when the worker submits a state change back to the master // to notify it that the driver was successfully killed. d.worker.foreach { w => - w.actor ! KillDriver(driverId) + w.actor.send(KillDriver(driverId)) } } // TODO: It would be nice for this to be a synchronous response val msg = s"Kill request for $driverId submitted" logInfo(msg) - sender ! KillDriverResponse(driverId, success = true, msg) + sender.send(KillDriverResponse(driverId, success = true, msg)) case None => val msg = s"Driver $driverId has already finished or does not exist" logWarning(msg) - sender ! KillDriverResponse(driverId, success = false, msg) + sender.send(KillDriverResponse(driverId, success = false, msg)) } } } @@ -291,10 +295,10 @@ private[spark] class Master( case RequestDriverStatus(driverId) => { (drivers ++ completedDrivers).find(_.id == driverId) match { case Some(driver) => - sender ! DriverStatusResponse(found = true, Some(driver.state), - driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception) + sender.send(DriverStatusResponse(found = true, Some(driver.state), + driver.worker.map(_.id), driver.worker.map(_.hostPort), driver.exception)) case None => - sender ! DriverStatusResponse(found = false, None, None, None, None) + sender.send(DriverStatusResponse(found = false, None, None, None, None)) } } @@ -307,7 +311,7 @@ private[spark] class Master( registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) persistenceEngine.addApplication(app) - sender ! RegisteredApplication(app.id, masterUrl) + sender.send(RegisteredApplication(app.id, masterUrl)) schedule() } } @@ -319,7 +323,7 @@ private[spark] class Master( val appInfo = idToApp(appId) exec.state = state if (state == ExecutorState.RUNNING) { appInfo.resetRetryCount() } - exec.application.driver ! ExecutorUpdated(execId, state, message, exitStatus) + exec.application.driver.send(ExecutorUpdated(execId, state, message, exitStatus)) if (ExecutorState.isFinished(state)) { // Remove this executor from the worker and app logInfo(s"Removing executor ${exec.fullId} because it is $state") @@ -364,7 +368,7 @@ private[spark] class Master( if (workers.map(_.id).contains(workerId)) { logWarning(s"Got heartbeat from unregistered worker $workerId." + " Asking it to re-register.") - sender ! ReconnectWorker(masterUrl) + sender.send(ReconnectWorker(masterUrl)) } else { logWarning(s"Got heartbeat from unregistered worker $workerId." + " This worker was never registered, so ignoring the heartbeat.") @@ -412,17 +416,9 @@ private[spark] class Master( if (canCompleteRecovery) { completeRecovery() } } - case DisassociatedEvent(_, address, _) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - logInfo(s"$address got disassociated, removing it.") - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - case RequestMasterState => { - sender ! MasterStateResponse(host, port, workers.toArray, apps.toArray, completedApps.toArray, - drivers.toArray, completedDrivers.toArray, state) + sender.send(MasterStateResponse(host, port, workers.toArray, apps.toArray, + completedApps.toArray, drivers.toArray, completedDrivers.toArray, state)) } case CheckForWorkerTimeOut => { @@ -430,10 +426,18 @@ private[spark] class Master( } case RequestWebUIPort => { - sender ! WebUIPortResponse(webUi.boundPort) + sender.send(WebUIPortResponse(webUi.boundPort)) } } + override def onDisconnected(address: RpcAddress): Unit = { + // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") + addressToWorker.get(address).foreach(removeWorker) + addressToApp.get(address).foreach(finishApplication) + if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } + } + def canCompleteRecovery = workers.count(_.state == WorkerState.UNKNOWN) == 0 && apps.count(_.state == ApplicationState.UNKNOWN) == 0 @@ -445,7 +449,7 @@ private[spark] class Master( try { registerApplication(app) app.state = ApplicationState.UNKNOWN - app.driver ! MasterChanged(masterUrl, masterWebUiUrl) + app.driver.send(MasterChanged(masterUrl, masterWebUiUrl)) } catch { case e: Exception => logInfo("App " + app.id + " had exception on reconnect") } @@ -462,7 +466,7 @@ private[spark] class Master( try { registerWorker(worker) worker.state = WorkerState.UNKNOWN - worker.actor ! MasterChanged(masterUrl, masterWebUiUrl) + worker.actor.send(MasterChanged(masterUrl, masterWebUiUrl)) } catch { case e: Exception => logInfo("Worker " + worker.id + " had exception on reconnect") } @@ -584,10 +588,10 @@ private[spark] class Master( def launchExecutor(worker: WorkerInfo, exec: ExecutorDesc) { logInfo("Launching executor " + exec.fullId + " on worker " + worker.id) worker.addExecutor(exec) - worker.actor ! LaunchExecutor(masterUrl, - exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory) - exec.application.driver ! ExecutorAdded( - exec.id, worker.id, worker.hostPort, exec.cores, exec.memory) + worker.actor.send(LaunchExecutor(masterUrl, + exec.application.id, exec.id, exec.application.desc, exec.cores, exec.memory)) + exec.application.driver.send(ExecutorAdded( + exec.id, worker.id, worker.hostPort, exec.cores, exec.memory)) } def registerWorker(worker: WorkerInfo): Boolean = { @@ -599,7 +603,7 @@ private[spark] class Master( workers -= w } - val workerAddress = worker.actor.path.address + val workerAddress = worker.actor.address if (addressToWorker.contains(workerAddress)) { val oldWorker = addressToWorker(workerAddress) if (oldWorker.state == WorkerState.UNKNOWN) { @@ -622,11 +626,11 @@ private[spark] class Master( logInfo("Removing worker " + worker.id + " on " + worker.host + ":" + worker.port) worker.setState(WorkerState.DEAD) idToWorker -= worker.id - addressToWorker -= worker.actor.path.address + addressToWorker -= worker.actor.address for (exec <- worker.executors.values) { logInfo("Telling app of lost executor: " + exec.id) - exec.application.driver ! ExecutorUpdated( - exec.id, ExecutorState.LOST, Some("worker lost"), None) + exec.application.driver.send(ExecutorUpdated( + exec.id, ExecutorState.LOST, Some("worker lost"), None)) exec.application.removeExecutor(exec) } for (driver <- worker.drivers.values) { @@ -648,14 +652,14 @@ private[spark] class Master( schedule() } - def createApplication(desc: ApplicationDescription, driver: ActorRef): ApplicationInfo = { + def createApplication(desc: ApplicationDescription, driver: RpcEndpointRef): ApplicationInfo = { val now = System.currentTimeMillis() val date = new Date(now) new ApplicationInfo(now, newApplicationId(date), desc, date, driver, defaultCores) } def registerApplication(app: ApplicationInfo): Unit = { - val appAddress = app.driver.path.address + val appAddress = app.driver.address if (addressToWorker.contains(appAddress)) { logInfo("Attempted to re-register application at same address: " + appAddress) return @@ -679,7 +683,7 @@ private[spark] class Master( apps -= app idToApp -= app.id actorToApp -= app.driver - addressToApp -= app.driver.path.address + addressToApp -= app.driver.address if (completedApps.size >= RETAINED_APPLICATIONS) { val toRemove = math.max(RETAINED_APPLICATIONS / 10, 1) completedApps.take(toRemove).foreach( a => { @@ -696,19 +700,19 @@ private[spark] class Master( for (exec <- app.executors.values) { exec.worker.removeExecutor(exec) - exec.worker.actor ! KillExecutor(masterUrl, exec.application.id, exec.id) + exec.worker.actor.send(KillExecutor(masterUrl, exec.application.id, exec.id)) exec.state = ExecutorState.KILLED } app.markFinished(state) if (state != ApplicationState.FINISHED) { - app.driver ! ApplicationRemoved(state.toString) + app.driver.send(ApplicationRemoved(state.toString)) } persistenceEngine.removeApplication(app) schedule() // Tell all workers that the application has finished, so they can clean up any app state. workers.foreach { w => - w.actor ! ApplicationFinished(app.id) + w.actor.send(ApplicationFinished(app.id)) } } } @@ -818,7 +822,7 @@ private[spark] class Master( logInfo("Launching driver " + driver.id + " on worker " + worker.id) worker.addDriver(driver) driver.worker = Some(worker) - worker.actor ! LaunchDriver(driver.id, driver.desc) + worker.actor.send(LaunchDriver(driver.id, driver.desc)) driver.state = DriverState.RUNNING } @@ -851,8 +855,8 @@ private[spark] object Master extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new MasterArguments(argStrings, conf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) - actorSystem.awaitTermination() + val (rpcEnv, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) + rpcEnv.awaitTermination() } /** @@ -860,34 +864,29 @@ private[spark] object Master extends Logging { * * @throws SparkException if the url is invalid */ - def toAkkaUrl(sparkUrl: String): String = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + def toEndpointRef(rpcEnv: RpcEnv, sparkUrl: String): RpcEndpointRef = { + val address = RpcAddress.fromSparkURL(sparkUrl) + toEndpointRef(rpcEnv, address) } /** - * Returns an akka `Address` for the Master actor given a sparkUrl `spark://host:port`. + * Returns an `akka.tcp://...` URL for the Master actor given a RpcAddress. * - * @throws SparkException if the url is invalid */ - def toAkkaAddress(sparkUrl: String): Address = { - val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) - Address("akka.tcp", systemName, host, port) + def toEndpointRef(rpcEnv: RpcEnv, address: RpcAddress): RpcEndpointRef = { + rpcEnv.setupEndpointRef(systemName, address, actorName) } def startSystemAndActor( host: String, port: Int, webUiPort: Int, - conf: SparkConf): (ActorSystem, Int, Int) = { + conf: SparkConf): (RpcEnv, Int) = { val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf, - securityManager = securityMgr) - val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort, - securityMgr), actorName) - val timeout = AkkaUtils.askTimeout(conf) - val respFuture = actor.ask(RequestWebUIPort)(timeout) - val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] - (actorSystem, boundPort, resp.webUIBoundPort) + val rpcEnv = RpcEnv.create(systemName, host, port, conf = conf, securityManager = securityMgr) + val actor = rpcEnv.setupEndpoint(actorName, + new Master(rpcEnv, host, rpcEnv.boundPort, webUiPort, securityMgr)) + val resp = actor.askWithReply[WebUIPortResponse](RequestWebUIPort) + (rpcEnv, resp.webUIBoundPort) } } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala index e94aae93e449..c76460fcb1c9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerInfo.scala @@ -19,9 +19,7 @@ package org.apache.spark.deploy.master import scala.collection.mutable -import akka.actor.ActorRef - -import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] class WorkerInfo( @@ -30,7 +28,7 @@ private[spark] class WorkerInfo( val port: Int, val cores: Int, val memory: Int, - val actor: ActorRef, + val actor: RpcEndpointRef, val webUiPort: Int, val publicAddress: String) extends Serializable { diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index 8eaa0ad94851..017844dd57de 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -17,8 +17,6 @@ package org.apache.spark.deploy.master -import akka.actor.ActorRef - import org.apache.spark.{Logging, SparkConf} import org.apache.spark.deploy.master.MasterMessages._ import org.apache.curator.framework.CuratorFramework diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index 3aae2b95d739..fb6f28146a2e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,10 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.{ExecutorState, JsonProtocol} @@ -33,14 +31,12 @@ import org.apache.spark.util.Utils private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private def master = parent.masterEndpointRef /** Executor details for a particular application */ override def renderJson(request: HttpServletRequest): JValue = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) @@ -50,8 +46,7 @@ private[spark] class ApplicationPage(parent: MasterWebUI) extends WebUIPage("app /** Executor details for a particular application */ def render(request: HttpServletRequest): Seq[Node] = { val appId = request.getParameter("appId") - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val app = state.activeApps.find(_.id == appId).getOrElse({ state.completedApps.find(_.id == appId).getOrElse(null) }) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala index 7ca3b08a2872..67bad341532e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterPage.scala @@ -19,10 +19,8 @@ package org.apache.spark.deploy.master.ui import javax.servlet.http.HttpServletRequest -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import org.json4s.JValue import org.apache.spark.deploy.JsonProtocol @@ -32,19 +30,16 @@ import org.apache.spark.ui.{WebUIPage, UIUtils} import org.apache.spark.util.Utils private[spark] class MasterPage(parent: MasterWebUI) extends WebUIPage("") { - private val master = parent.masterActorRef - private val timeout = parent.timeout + private def master = parent.masterEndpointRef override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) JsonProtocol.writeMasterState(state) } /** Index view listing applications and executors */ def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (master ? RequestMasterState)(timeout).mapTo[MasterStateResponse] - val state = Await.result(stateFuture, timeout) + val state = master.askWithReply[MasterStateResponse](RequestMasterState) val workerHeaders = Seq("Id", "Address", "State", "Cores", "Memory") val workers = state.workers.sortBy(_.id) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 73400c5affb5..2c9d2c6b1e52 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -30,7 +30,7 @@ private[spark] class MasterWebUI(val master: Master, requestedPort: Int) extends WebUI(master.securityMgr, requestedPort, master.conf, name = "MasterUI") with Logging { - val masterActorRef = master.self + def masterEndpointRef = master.self val timeout = AkkaUtils.askTimeout(master.conf) initialize() diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala index 28cab36c7b9e..9b063e3c41f9 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverRunner.scala @@ -19,10 +19,11 @@ package org.apache.spark.deploy.worker import java.io._ +import org.apache.spark.rpc.RpcEndpointRef + import scala.collection.JavaConversions._ import scala.collection.Map -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.hadoop.conf.Configuration @@ -44,7 +45,7 @@ private[spark] class DriverRunner( val workDir: File, val sparkHome: File, val driverDesc: DriverDescription, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerUrl: String) extends Logging { @@ -98,7 +99,7 @@ private[spark] class DriverRunner( finalState = Some(state) - worker ! DriverStateChanged(driverId, state, finalException) + worker.send(DriverStateChanged(driverId, state, finalException)) } }.start() } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala index 05e242e6df70..dfcd34e51a5e 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/DriverWrapper.scala @@ -17,10 +17,10 @@ package org.apache.spark.deploy.worker -import akka.actor._ +import org.apache.spark.rpc.RpcEnv import org.apache.spark.{SecurityManager, SparkConf} -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * Utility object for launching driver programs such that they share fate with the Worker process. @@ -30,16 +30,16 @@ object DriverWrapper { args.toList match { case workerUrl :: mainClass :: extraArgs => val conf = new SparkConf() - val (actorSystem, _) = AkkaUtils.createActorSystem("Driver", + val rpcEnv = RpcEnv.create("Driver", Utils.localHostName(), 0, conf, new SecurityManager(conf)) - actorSystem.actorOf(Props(classOf[WorkerWatcher], workerUrl), name = "workerWatcher") + rpcEnv.setupEndpoint("workerWatcher", new WorkerWatcher(rpcEnv, workerUrl)) // Delegate to supplied main class val clazz = Class.forName(args(1)) val mainMethod = clazz.getMethod("main", classOf[Array[String]]) mainMethod.invoke(null, extraArgs.toArray[String]) - actorSystem.shutdown() + rpcEnv.stopAll() case _ => System.err.println("Usage: DriverWrapper [options]") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala index acbdf0d8bd7b..7e77c9028e9c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala @@ -21,13 +21,13 @@ import java.io._ import scala.collection.JavaConversions._ -import akka.actor.ActorRef import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.apache.spark.{SparkConf, Logging} import org.apache.spark.deploy.{ApplicationDescription, Command, ExecutorState} import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.logging.FileAppender /** @@ -40,7 +40,7 @@ private[spark] class ExecutorRunner( val appDesc: ApplicationDescription, val cores: Int, val memory: Int, - val worker: ActorRef, + val worker: RpcEndpointRef, val workerId: String, val host: String, val sparkHome: File, @@ -94,7 +94,7 @@ private[spark] class ExecutorRunner( } exitCode = Some(process.waitFor()) } - worker ! ExecutorStateChanged(appId, execId, state, message, exitCode) + worker.send(ExecutorStateChanged(appId, execId, state, message, exitCode)) } /** Stop this executor runner, including killing the process it launched */ @@ -151,7 +151,7 @@ private[spark] class ExecutorRunner( val exitCode = process.waitFor() state = ExecutorState.EXITED val message = "Command exited with code " + exitCode - worker ! ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode)) + worker.send(ExecutorStateChanged(appId, execId, state, Some(message), Some(exitCode))) } catch { case interrupted: InterruptedException => { logInfo("Runner thread for executor " + fullId + " interrupted") diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 13599830123d..80cb0542bd37 100755 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -20,42 +20,45 @@ package org.apache.spark.deploy.worker import java.io.File import java.io.IOException import java.text.SimpleDateFormat +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.{UUID, Date} import scala.collection.JavaConversions._ import scala.collection.mutable.{HashMap, HashSet} +import scala.concurrent.ExecutionContext import scala.concurrent.duration._ import scala.language.postfixOps import scala.util.Random -import akka.actor._ -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} - -import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkException} +import org.apache.spark.{Logging, SecurityManager, SparkConf} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.{DriverState, Master} import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpointRef, NetworkRpcEndpoint} +import org.apache.spark.util.{SignalLogger, Utils} -/** - * @param masterAkkaUrls Each url should be a valid akka url. - */ private[spark] class Worker( + override val rpcEnv: RpcEnv, host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterAkkaUrls: Array[String], + masterAddresses: Set[RpcAddress], actorSystemName: String, actorName: String, workDirPath: String = null, val conf: SparkConf, val securityMgr: SecurityManager) - extends Actor with ActorLogReceive with Logging { - import context.dispatcher + extends NetworkRpcEndpoint with Logging { + + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("worker-scheduler")) + + implicit val cleanupExecutor = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(1, Utils.namedThreadFactory("cleanup-thread"))) Utils.checkHost(host, "Expected hostname") assert (port > 0) @@ -89,8 +92,8 @@ private[spark] class Worker( val APP_DATA_RETENTION_SECS = conf.getLong("spark.worker.cleanup.appDataTtl", 7 * 24 * 3600) val testing: Boolean = sys.props.contains("spark.testing") - var master: ActorSelection = null - var masterAddress: Address = null + var master: RpcEndpointRef = null + var masterAddress: RpcAddress = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" val akkaUrl = "akka.tcp://%s@%s:%s/user/%s".format(actorSystemName, host, port, actorName) @@ -128,7 +131,7 @@ private[spark] class Worker( val metricsSystem = MetricsSystem.createMetricsSystem("worker", conf, securityMgr) val workerSource = new WorkerSource(this) - var registrationRetryTimer: Option[Cancellable] = None + var registrationRetryTimer: Option[ScheduledFuture[_]] = None def coresFree: Int = cores - coresUsed def memoryFree: Int = memory - memoryUsed @@ -151,14 +154,13 @@ private[spark] class Worker( } } - override def preStart() { + override def onStart() { assert(!registered) logInfo("Starting Spark worker %s:%d with %d cores, %s RAM".format( host, port, cores, Utils.megabytesToString(memory))) logInfo(s"Running Spark version ${org.apache.spark.SPARK_VERSION}") logInfo("Spark home: " + sparkHome) createWorkDir() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) shuffleService.startIfEnabled() webUi = new WorkerWebUI(this, workDir, webUiPort) webUi.bind() @@ -174,19 +176,20 @@ private[spark] class Worker( // activeMasterUrl it's a valid Spark url since we receive it from master. activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) - masterAddress = Master.toAkkaAddress(activeMasterUrl) + master = Master.toEndpointRef(rpcEnv, activeMasterUrl) + masterAddress = master.address connected = true // Cancel any outstanding re-registration attempts because we found a new master - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = None } private def tryRegisterAllMasters() { - for (masterAkkaUrl <- masterAkkaUrls) { - logInfo("Connecting to master " + masterAkkaUrl + "...") - val actor = context.actorSelection(masterAkkaUrl) - actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + for (masterAddress <- masterAddresses) { + logInfo("Connecting to master " + masterAddress + "...") + val actor = Master.toEndpointRef(rpcEnv, masterAddress) + actor.send( + RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort, publicAddress)) } } @@ -199,7 +202,7 @@ private[spark] class Worker( Utils.tryOrExit { connectionAttemptCount += 1 if (registered) { - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = None } else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) { logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)") @@ -224,8 +227,8 @@ private[spark] class Worker( * less likely scenario. */ if (master != null) { - master ! RegisterWorker( - workerId, host, port, cores, memory, webUi.boundPort, publicAddress) + master.send(RegisterWorker( + workerId, host, port, cores, memory, webUi.boundPort, publicAddress)) } else { // We are retrying the initial registration tryRegisterAllMasters() @@ -233,10 +236,12 @@ private[spark] class Worker( // We have exceeded the initial registration retry threshold // All retries from now on should use a higher interval if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) { - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) registrationRetryTimer = Some { - context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL, - PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReregisterWithMaster) + }, PROLONGED_REGISTRATION_RETRY_INTERVAL.toMillis, + PROLONGED_REGISTRATION_RETRY_INTERVAL.toMillis, TimeUnit.MILLISECONDS) } } } else { @@ -255,8 +260,10 @@ private[spark] class Worker( tryRegisterAllMasters() connectionAttemptCount = 0 registrationRetryTimer = Some { - context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL, - INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReregisterWithMaster) + }, INITIAL_REGISTRATION_RETRY_INTERVAL.toMillis, + INITIAL_REGISTRATION_RETRY_INTERVAL.toMillis, TimeUnit.MILLISECONDS) } case Some(_) => logInfo("Not spawning another attempt to register with the master, since there is an" + @@ -264,20 +271,23 @@ private[spark] class Worker( } } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredWorker(masterUrl, masterWebUiUrl) => logInfo("Successfully registered with master " + masterUrl) registered = true changeMaster(masterUrl, masterWebUiUrl) - context.system.scheduler.schedule(0 millis, HEARTBEAT_MILLIS millis, self, SendHeartbeat) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(SendHeartbeat) + }, 0, HEARTBEAT_MILLIS, TimeUnit.MILLISECONDS) if (CLEANUP_ENABLED) { logInfo(s"Worker cleanup enabled; old application directories will be deleted in: $workDir") - context.system.scheduler.schedule(CLEANUP_INTERVAL_MILLIS millis, - CLEANUP_INTERVAL_MILLIS millis, self, WorkDirCleanup) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(WorkDirCleanup) + }, CLEANUP_INTERVAL_MILLIS, CLEANUP_INTERVAL_MILLIS, TimeUnit.MILLISECONDS) } case SendHeartbeat => - if (connected) { master ! Heartbeat(workerId) } + if (connected) { master.send(Heartbeat(workerId)) } case WorkDirCleanup => // Spin up a separate thread (in a future) to do the dir cleanup; don't tie up worker actor @@ -310,10 +320,10 @@ private[spark] class Worker( val execs = executors.values. map(e => new ExecutorDescription(e.appId, e.execId, e.cores, e.state)) - sender ! WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq) + sender.send(WorkerSchedulerStateResponse(workerId, execs.toList, drivers.keys.toSeq)) case Heartbeat => - logInfo(s"Received heartbeat from driver ${sender.path}") + logInfo(s"Received heartbeat from driver ${sender.address}") case RegisterWorkerFailed(message) => if (!registered) { @@ -355,7 +365,7 @@ private[spark] class Worker( manager.start() coresUsed += cores_ memoryUsed += memory_ - master ! ExecutorStateChanged(appId, execId, manager.state, None, None) + master.send(ExecutorStateChanged(appId, execId, manager.state, None, None)) } catch { case e: Exception => { logError(s"Failed to launch executor $appId/$execId for ${appDesc.name}.", e) @@ -363,14 +373,14 @@ private[spark] class Worker( executors(appId + "/" + execId).kill() executors -= appId + "/" + execId } - master ! ExecutorStateChanged(appId, execId, ExecutorState.FAILED, - Some(e.toString), None) + master.send(ExecutorStateChanged(appId, execId, ExecutorState.FAILED, + Some(e.toString), None)) } } } case ExecutorStateChanged(appId, execId, state, message, exitStatus) => - master ! ExecutorStateChanged(appId, execId, state, message, exitStatus) + master.send(ExecutorStateChanged(appId, execId, state, message, exitStatus)) val fullId = appId + "/" + execId if (ExecutorState.isFinished(state)) { executors.get(fullId) match { @@ -437,22 +447,18 @@ private[spark] class Worker( case _ => logDebug(s"Driver $driverId changed state to $state") } - master ! DriverStateChanged(driverId, state, exception) + master.send(DriverStateChanged(driverId, state, exception)) val driver = drivers.remove(driverId).get finishedDrivers(driverId) = driver memoryUsed -= driver.driverDesc.mem coresUsed -= driver.driverDesc.cores } - case x: DisassociatedEvent if x.remoteAddress == masterAddress => - logInfo(s"$x Disassociated !") - masterDisconnected() - case RequestWorkerState => - sender ! WorkerStateResponse(host, port, workerId, executors.values.toList, + sender.send(WorkerStateResponse(host, port, workerId, executors.values.toList, finishedExecutors.values.toList, drivers.values.toList, finishedDrivers.values.toList, activeMasterUrl, cores, memory, - coresUsed, memoryUsed, activeMasterWebUiUrl) + coresUsed, memoryUsed, activeMasterWebUiUrl)) case ReregisterWithMaster => reregisterWithMaster() @@ -462,6 +468,13 @@ private[spark] class Worker( maybeCleanupApplication(id) } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (remoteAddress == masterAddress ) { + logInfo(s"$remoteAddress Disassociated !") + masterDisconnected() + } + } + private def masterDisconnected() { logError("Connection to master failed! Waiting for master to reconnect...") connected = false @@ -485,9 +498,9 @@ private[spark] class Worker( "worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port) } - override def postStop() { + override def onStop() { metricsSystem.report() - registrationRetryTimer.foreach(_.cancel()) + registrationRetryTimer.foreach(_.cancel(true)) executors.values.foreach(_.kill()) drivers.values.foreach(_.kill()) shuffleService.stop() @@ -501,9 +514,9 @@ private[spark] object Worker extends Logging { SignalLogger.register(log) val conf = new SparkConf val args = new WorkerArguments(argStrings, conf) - val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, + val rpcEnv = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, args.memory, args.masters, args.workDir) - actorSystem.awaitTermination() + rpcEnv.awaitTermination() } def startSystemAndActor( @@ -514,19 +527,19 @@ private[spark] object Worker extends Logging { memory: Int, masterUrls: Array[String], workDir: String, - workerNumber: Option[Int] = None): (ActorSystem, Int) = { + workerNumber: Option[Int] = None): RpcEnv = { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val actorName = "Worker" val securityMgr = new SecurityManager(conf) - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = conf, securityManager = securityMgr) - val masterAkkaUrls = masterUrls.map(Master.toAkkaUrl) - actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterAkkaUrls, systemName, actorName, workDir, conf, securityMgr), name = actorName) - (actorSystem, boundPort) + + val masterAddresses = masterUrls.map(RpcAddress.fromSparkURL).toSet + val rpcEnv = RpcEnv.create(systemName, host, port, conf = conf, securityManager = securityMgr) + rpcEnv.setupEndpoint(actorName, new Worker(rpcEnv, host, rpcEnv.boundPort, webUiPort, cores, + memory, masterAddresses, systemName, actorName, workDir, conf, securityMgr)) + rpcEnv } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala index 63a8ac817b61..ce8c095dfcbe 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/WorkerWatcher.scala @@ -17,26 +17,23 @@ package org.apache.spark.deploy.worker -import akka.actor.{Actor, Address, AddressFromURIString} -import akka.remote.{AssociatedEvent, AssociationErrorEvent, AssociationEvent, DisassociatedEvent, RemotingLifecycleEvent} - import org.apache.spark.Logging import org.apache.spark.deploy.DeployMessages.SendHeartbeat -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.rpc._ /** * Actor which connects to a worker process and terminates the JVM if the connection is severed. * Provides fate sharing between a worker and its associated child processes. */ -private[spark] class WorkerWatcher(workerUrl: String) - extends Actor with ActorLogReceive with Logging { - - override def preStart() { - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) +private[spark] class WorkerWatcher(override val rpcEnv: RpcEnv, workerUrl: String) + extends NetworkRpcEndpoint with Logging { + override def onStart() { logInfo(s"Connecting to worker $workerUrl") - val worker = context.actorSelection(workerUrl) - worker ! SendHeartbeat // need to send a message here to initiate connection + if (!isTesting) { + val worker = rpcEnv.setupEndpointRefByUrl(workerUrl) + worker.send(SendHeartbeat) // need to send a message here to initiate connection + } } // Used to avoid shutting down JVM during tests @@ -45,30 +42,37 @@ private[spark] class WorkerWatcher(workerUrl: String) private var isTesting = false // Lets us filter events only from the worker's actor system - private val expectedHostPort = AddressFromURIString(workerUrl).hostPort - private def isWorker(address: Address) = address.hostPort == expectedHostPort + private val expectedHostPort = new java.net.URI(workerUrl) + private def isWorker(address: RpcAddress) = { + expectedHostPort.getHost == address.host && expectedHostPort.getPort == address.port + } def exitNonZero() = if (isTesting) isShutDown = true else System.exit(-1) - override def receiveWithLogging = { - case AssociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => - logInfo(s"Successfully connected to $workerUrl") + override def receive(sender: RpcEndpointRef) = { + case e => logWarning(s"Received unexpected actor system event: $e") + } - case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) - if isWorker(remoteAddress) => - // These logs may not be seen if the worker (and associated pipe) has died - logError(s"Could not initialize connection to worker $workerUrl. Exiting.") - logError(s"Error was: $cause") - exitNonZero() + override def onConnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + logInfo(s"Successfully connected to $workerUrl") + } + } - case DisassociatedEvent(localAddress, remoteAddress, inbound) if isWorker(remoteAddress) => + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { // This log message will never be seen logError(s"Lost connection to worker actor $workerUrl. Exiting.") exitNonZero() + } + } - case e: AssociationEvent => - // pass through association events relating to other remote actor systems - - case e => logWarning(s"Received unexpected actor system event: $e") + override def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + if (isWorker(remoteAddress)) { + // These logs may not be seen if the worker (and associated pipe) has died + logError(s"Could not initialize connection to worker $workerUrl. Exiting.") + logError(s"Error was: $cause") + exitNonZero() + } } } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala index 327b90503280..c3f9b91e7056 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerPage.scala @@ -17,10 +17,8 @@ package org.apache.spark.deploy.worker.ui -import scala.concurrent.Await import scala.xml.Node -import akka.pattern.ask import javax.servlet.http.HttpServletRequest import org.json4s.JValue @@ -34,17 +32,15 @@ import org.apache.spark.util.Utils private[spark] class WorkerPage(parent: WorkerWebUI) extends WebUIPage("") { val workerActor = parent.worker.self val worker = parent.worker - val timeout = parent.timeout override def renderJson(request: HttpServletRequest): JValue = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerActor.askWithReply[WorkerStateResponse](RequestWorkerState) JsonProtocol.writeWorkerState(workerState) } def render(request: HttpServletRequest): Seq[Node] = { - val stateFuture = (workerActor ? RequestWorkerState)(timeout).mapTo[WorkerStateResponse] - val workerState = Await.result(stateFuture, timeout) + val workerState = workerActor.askWithReply[WorkerStateResponse](RequestWorkerState) + JsonProtocol.writeWorkerState(workerState) val executorHeaders = Seq("ExecutorID", "Cores", "State", "Memory", "Job Details", "Logs") val runningExecutors = workerState.executors diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 7ac81a2d87ef..9c2f51e27a56 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -38,8 +38,6 @@ class WorkerWebUI( extends WebUI(worker.securityMgr, requestedPort, worker.conf, name = "WorkerUI") with Logging { - val timeout = AkkaUtils.askTimeout(worker.conf) - initialize() /** Initialize all components of the server. */ diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 9a4adfbbb3d7..ea2ea7f43e63 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,19 +19,14 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import scala.concurrent.Await - -import akka.actor.{Actor, ActorSelection, Props} -import akka.pattern.Patterns -import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent} - import org.apache.spark.{Logging, SecurityManager, SparkConf, SparkEnv} import org.apache.spark.TaskState.TaskState import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.deploy.worker.WorkerWatcher +import org.apache.spark.rpc.{RpcEnv, RpcAddress, NetworkRpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.TaskDescription import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, SignalLogger, Utils} +import org.apache.spark.util.{SignalLogger, Utils} private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, @@ -39,21 +34,23 @@ private[spark] class CoarseGrainedExecutorBackend( hostPort: String, cores: Int, env: SparkEnv) - extends Actor with ActorLogReceive with ExecutorBackend with Logging { + extends NetworkRpcEndpoint with ExecutorBackend with Logging { + + override val rpcEnv = env.rpcEnv Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorSelection = null + var driver: RpcEndpointRef = _ - override def preStart() { + override def onStart(): Unit = { + // self is valid now. So now we can use `send` logInfo("Connecting to driver: " + driverUrl) - driver = context.actorSelection(driverUrl) - driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + driver = rpcEnv.setupEndpointRefByUrl(driverUrl) + driver.send(RegisterExecutor(executorId, hostPort, cores, self)) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisteredExecutor => logInfo("Successfully registered with driver") val (hostname, _) = Utils.parseHostPort(hostPort) @@ -83,19 +80,20 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId, interruptThread) } - case x: DisassociatedEvent => - logError(s"Driver $x disassociated! Shutting down.") - System.exit(1) - case StopExecutor => logInfo("Driver commanded a shutdown") executor.stop() - context.stop(self) - context.system.shutdown() + stop() + rpcEnv.stopAll() + } + + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + logError(s"Driver $remoteAddress disassociated! Shutting down.") + System.exit(1) } override def statusUpdate(taskId: Long, state: TaskState, data: ByteBuffer) { - driver ! StatusUpdate(executorId, taskId, state, data) + driver.send(StatusUpdate(executorId, taskId, state, data)) } } @@ -118,14 +116,14 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Bootstrap to fetch the driver's Spark properties. val executorConf = new SparkConf val port = executorConf.getInt("spark.executor.port", 0) - val (fetcher, _) = AkkaUtils.createActorSystem( + val rpcEnv = RpcEnv.create( "driverPropsFetcher", hostname, port, executorConf, new SecurityManager(executorConf)) - val driver = fetcher.actorSelection(driverUrl) - val timeout = AkkaUtils.askTimeout(executorConf) - val fut = Patterns.ask(driver, RetrieveSparkProps, timeout) - val props = Await.result(fut, timeout).asInstanceOf[Seq[(String, String)]] ++ + val driver = rpcEnv.setupEndpointRefByUrl(driverUrl) + val props = driver.askWithReply[Seq[(String, String)]](RetrieveSparkProps) ++ Seq[(String, String)](("spark.app.id", appId)) - fetcher.shutdown() + + rpcEnv.stopAll() + rpcEnv.awaitTermination() // Create SparkEnv using properties we fetched from the driver. val driverConf = new SparkConf().setAll(props) @@ -138,14 +136,12 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging { // Start the CoarseGrainedExecutorBackend actor. val sparkHostPort = hostname + ":" + boundPort - env.actorSystem.actorOf( - Props(classOf[CoarseGrainedExecutorBackend], - driverUrl, executorId, sparkHostPort, cores, env), - name = "Executor") + env.rpcEnv.setupEndpoint("Executor", new CoarseGrainedExecutorBackend( + driverUrl, executorId, sparkHostPort, cores, env)) workerUrl.foreach { url => - env.actorSystem.actorOf(Props(classOf[WorkerWatcher], url), name = "WorkerWatcher") + env.rpcEnv.setupEndpoint("WorkerWatcher", new WorkerWatcher(env.rpcEnv, url)) } - env.actorSystem.awaitTermination() + env.rpcEnv.awaitTermination() } } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 6660b98eb8ce..bbff1dfe2b6b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -26,8 +26,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.{ArrayBuffer, HashMap} import scala.util.control.NonFatal -import akka.actor.Props - import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.scheduler._ @@ -81,8 +79,8 @@ private[spark] class Executor( } // Create an actor for receiving RPCs from the driver - private val executorActor = env.actorSystem.actorOf( - Props(new ExecutorActor(executorId)), "ExecutorActor") + private val executorActor = env.rpcEnv.setupEndpoint( + "ExecutorActor", new ExecutorActor(env.rpcEnv, executorId)) // Create our ClassLoader // do this after SparkEnv creation so can access the SecurityManager @@ -128,7 +126,7 @@ private[spark] class Executor( def stop() { env.metricsSystem.report() - env.actorSystem.stop(executorActor) + env.rpcEnv.stop(executorActor) isStopped = true threadPool.shutdown() if (!isLocal) { @@ -361,10 +359,7 @@ private[spark] class Executor( def startDriverHeartbeater() { val interval = conf.getInt("spark.executor.heartbeatInterval", 10000) - val timeout = AkkaUtils.lookupTimeout(conf) - val retryAttempts = AkkaUtils.numRetries(conf) - val retryIntervalMs = AkkaUtils.retryWaitMs(conf) - val heartbeatReceiverRef = AkkaUtils.makeDriverRef("HeartbeatReceiver", conf, env.actorSystem) + val heartbeatReceiverRef = env.rpcEnv.setupDriverEndpointRef("HeartbeatReceiver") val t = new Thread() { override def run() { @@ -397,8 +392,7 @@ private[spark] class Executor( val message = Heartbeat(executorId, tasksMetrics.toArray, env.blockManager.blockManagerId) try { - val response = AkkaUtils.askWithReply[HeartbeatResponse](message, heartbeatReceiverRef, - retryAttempts, retryIntervalMs, timeout) + val response = heartbeatReceiverRef.askWithReply[HeartbeatResponse](message) if (response.reregisterBlockManager) { logWarning("Told to re-register on heartbeat") env.blockManager.reregister() diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala index 41925f7e97e8..5c7d07195808 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala @@ -17,10 +17,8 @@ package org.apache.spark.executor -import akka.actor.Actor -import org.apache.spark.Logging - -import org.apache.spark.util.{Utils, ActorLogReceive} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} +import org.apache.spark.util.Utils /** * Driver -> Executor message to trigger a thread dump. @@ -31,11 +29,11 @@ private[spark] case object TriggerThreadDump * Actor that runs inside of executors to enable driver -> executor RPC. */ private[spark] -class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging { +class ExecutorActor(override val rpcEnv: RpcEnv, executorId: String) extends RpcEndpoint { - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case TriggerThreadDump => - sender ! Utils.getThreadDump() + sender.send(Utils.getThreadDump()) } } diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala new file mode 100644 index 000000000000..c911484f45b7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala @@ -0,0 +1,370 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import java.util.concurrent.ConcurrentHashMap + +import scala.concurrent.Future +import scala.concurrent.duration.FiniteDuration +import scala.reflect.ClassTag + +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{SecurityManager, SparkConf} +import org.apache.spark.util.Utils + +/** + * An RPC environment. + */ +trait RpcEnv { + + /** + * A lookup table to search a [[RpcEndpointRef]] for a [[RpcEndpoint]]. We need it to make + * [[RpcEndpoint.self]] work. + */ + private val endpointToRef = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef]() + + /** + * Need this map to remove `RpcEndpoint` from `endpointToRef` via a `RpcEndpointRef` + */ + private val refToEndpoint = new ConcurrentHashMap[RpcEndpointRef, RpcEndpoint]() + + protected def registerEndpoint(endpoint: RpcEndpoint, endpointRef: RpcEndpointRef): Unit = { + endpointToRef.put(endpoint, endpointRef) + refToEndpoint.put(endpointRef, endpoint) + } + + protected def unregisterEndpoint(endpointRef: RpcEndpointRef): Unit = { + val endpoint = refToEndpoint.remove(endpointRef) + if (endpoint != null) { + endpointToRef.remove(endpoint) + } + } + + /** + * Retrieve the [[RpcEndpointRef]] of `endpoint`. + */ + def endpointRef(endpoint: RpcEndpoint): RpcEndpointRef = { + val endpointRef = endpointToRef.get(endpoint) + require(endpointRef != null, s"Cannot find RpcEndpointRef of ${endpoint} in ${this}") + endpointRef + } + + /** + * Return the port that [[RpcEnv]] is listening to. + */ + def boundPort: Int + + /** + * Register a [[RpcEndpoint]] with a name and return its [[RpcEndpointRef]]. + */ + def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef + + /** + * Retrieve a [[RpcEndpointRef]] which is located in the driver via its name. + */ + def setupDriverEndpointRef(name: String): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `url`. + */ + def setupEndpointRefByUrl(url: String): RpcEndpointRef + + /** + * Retrieve the [[RpcEndpointRef]] represented by `systemName`, `address` and `endpointName` + */ + def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef + + /** + * Stop [[RpcEndpoint]] specified by `endpoint`. + */ + def stop(endpoint: RpcEndpointRef): Unit + + /** + * Shutdown this [[RpcEnv]] asynchronously. If need to make sure [[RpcEnv]] exits successfully, + * call [[awaitTermination()]] straight after [[stopAll()]]. + */ + def stopAll(): Unit + + /** + * Wait until [[RpcEnv]] exits. + * + * TODO do we need a timeout parameter? + */ + def awaitTermination(): Unit +} + +private[rpc] case class RpcEnvConfig( + conf: SparkConf, + name: String, + host: String, + port: Int, + securityManager: SecurityManager) + +/** + * A RpcEnv implementation must have a companion object with an + * `apply(config: RpcEnvConfig): RpcEnv` method so that it can be created via Reflection. + * + * {{{ + * object MyCustomRpcEnv { + * def apply(config: RpcEnvConfig): RpcEnv = { + * ... + * } + * } + * }}} + */ +object RpcEnv { + + private def getRpcEnvCompanion(conf: SparkConf): AnyRef = { + // Add more RpcEnv implementations here + val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv") + val rpcEnvName = conf.get("spark.rpc", "akka") + val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName) + val companion = Class.forName( + rpcEnvClassName + "$", true, Utils.getContextOrSparkClassLoader).getField("MODULE$").get(null) + companion + } + + def create( + name: String, + host: String, + port: Int, + conf: SparkConf, + securityManager: SecurityManager): RpcEnv = { + // Using Reflection to create the RpcEnv to avoid to depend on Akka directly + val config = RpcEnvConfig(conf, name, host, port, securityManager) + val companion = getRpcEnvCompanion(conf) + companion.getClass.getMethod("apply", classOf[RpcEnvConfig]). + invoke(companion, config).asInstanceOf[RpcEnv] + } + + // TODO Remove it + @VisibleForTesting + def create(name: String, conf: SparkConf): RpcEnv = { + val companion = getRpcEnvCompanion(conf) + companion.getClass.getMethod("apply", classOf[String], classOf[SparkConf]). + invoke(companion, name, conf).asInstanceOf[RpcEnv] + } + +} + +/** + * An end point for the RPC that defines what functions to trigger given a message. + * + * RpcEndpoint will be guaranteed that `onStart`, `receive` and `onStop` will + * be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart receive* onStop + * + * If any error is thrown from one of RpcEndpoint methods except `onError`, [[RpcEndpoint.onError)]] + * will be invoked with the cause. If onError throws an error, it will force [[RpcEndpoint]] to + * restart by creating a new one. + */ +trait RpcEndpoint { + + /** + * The [[RpcEnv]] that this [[RpcEndpoint]] is registered to. + */ + val rpcEnv: RpcEnv + + /** + * Provide the implicit sender. `self` will become valid when `onStart` is called. + * + * Note: Because before `onStart`, [[RpcEndpoint]] has not yet been registered and there is not + * valid [[RpcEndpointRef]] for it. So don't call `self` before `onStart` is called. In the other + * words, don't call [[RpcEndpointRef.send]] in the constructor of [[RpcEndpoint]]. + */ + implicit final def self: RpcEndpointRef = { + require(rpcEnv != null, "rpcEnv has not been initialized") + rpcEnv.endpointRef(this) + } + + /** + * Same assumption like Actor: messages sent to a RpcEndpoint will be delivered in sequence, and + * messages from the same RpcEndpoint will be delivered in order. + * + * @param sender + * @return + */ + def receive(sender: RpcEndpointRef): PartialFunction[Any, Unit] + + /** + * Call onError when any exception is thrown during handling messages. + * + * @param cause + */ + def onError(cause: Throwable): Unit = { + // By default, throw e and let RpcEnv handle it + throw cause + } + + /** + * Invoked before [[RpcEndpoint]] starts to handle any message. + */ + def onStart(): Unit = { + // By default, do nothing. + } + + /** + * Invoked when [[RpcEndpoint]] is stopping. + */ + def onStop(): Unit = { + // By default, do nothing. + } + + /** + * An convenient method to stop [[RpcEndpoint]]. + */ + final def stop(): Unit = { + rpcEnv.stop(self) + } +} + +/** + * A RpcEndoint interested in network events. + * + * [[NetworkRpcEndpoint]] will be guaranteed that `onStart`, `receive` , `onConnected`, + * `onDisconnected`, `onNetworkError` and `onStop` will be called in sequence. + * + * The lift-cycle will be: + * + * constructor onStart (receive|onConnected|onDisconnected|onNetworkError)* onStop + * + * If any error is thrown from `onConnected`, `onDisconnected` or `onNetworkError`, + * [[RpcEndpoint.onError)]] will be invoked with the cause. If onError throws an error, it will + * force [[RpcEndpoint]] to restart by creating a new one. + */ +trait NetworkRpcEndpoint extends RpcEndpoint { + + /** + * Invoked when `remoteAddress` is connected to the current node. + */ + def onConnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when `remoteAddress` is lost. + */ + def onDisconnected(remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } + + /** + * Invoked when some network error happens in the connection between the current node and + * `remoteAddress`. + */ + def onNetworkError(cause: Throwable, remoteAddress: RpcAddress): Unit = { + // By default, do nothing. + } +} + +object RpcEndpoint { + final val noSender: RpcEndpointRef = null +} + +/** + * A reference for a remote [[RpcEndpoint]]. [[RpcEndpointRef]] is thread-safe. + */ +trait RpcEndpointRef { + + /** + * return the address for the [[RpcEndpointRef]] + */ + def address: RpcAddress + + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within a default timeout. + */ + def ask[T: ClassTag](message: Any): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and return a `Future` to receive the reply + * within the specified timeout. + */ + def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a default + * timeout, or throw a SparkException if this fails even after the default number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T](message: Any): T + + /** + * Send a message to the corresponding [[RpcEndpoint]] and get its result within a specified + * timeout, throw a SparkException if this fails even after the specified number of retries. + * + * Note: this is a blocking action which may cost a lot of time, so don't call it in an message + * loop of [[RpcEndpoint]]. + * + * @param message the message to send + * @param timeout the timeout duration + * @tparam T type of the reply message + * @return the reply message from the corresponding [[RpcEndpoint]] + */ + def askWithReply[T](message: Any, timeout: FiniteDuration): T + + /** + * Sends a one-way asynchronous message. Fire-and-forget semantics. + * + * If invoked from within an [[RpcEndpoint]] then `self` is implicitly passed on as the implicit + * 'sender' argument. If not then no sender is available. + * + * This `sender` reference is then available in the receiving [[RpcEndpoint]] as the `sender` + * parameter of [[RpcEndpoint.receive]] + */ + def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit +} + +/** + * Represent a host with a port + */ +case class RpcAddress(host: String, port: Int) { + // TODO do we need to add the type of RpcEnv in the address? + + val hostPort: String = host + ":" + port + + override val toString: String = hostPort +} + +object RpcAddress { + + /** + * Return the [[RpcAddress]] represented by `uri`. + */ + def fromURIString(uri: String): RpcAddress = { + val u = new java.net.URI(uri) + RpcAddress(u.getHost, u.getPort) + } + + def fromSparkURL(sparkUrl: String): RpcAddress = { + val (host, port) = Utils.extractHostPortFromSparkUrl(sparkUrl) + RpcAddress(host, port) + } +} diff --git a/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala new file mode 100644 index 000000000000..5a9c40e613b8 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala @@ -0,0 +1,235 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import java.util.concurrent.CountDownLatch + +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.Future +import scala.language.postfixOps +import scala.reflect.ClassTag +import scala.util.control.NonFatal + +import _root_.akka.actor._ +import akka.pattern.{ask => akkaAsk} +import akka.remote._ +import com.google.common.annotations.VisibleForTesting + +import org.apache.spark.{Logging, SparkConf} +import org.apache.spark.rpc._ +import org.apache.spark.util.{ActorLogReceive, AkkaUtils} + +/** + * A RpcEnv implementation based on Akka. + * + * TODO Once we remove all usages of Akka in other place, we can move this file to a new project and + * remove Akka from the dependencies. + * + * @param actorSystem + * @param conf + * @param boundPort + */ +private[spark] class AkkaRpcEnv private ( + val actorSystem: ActorSystem, conf: SparkConf, val boundPort: Int) extends RpcEnv { + + private val defaultAddress: RpcAddress = { + val address = actorSystem.asInstanceOf[ExtendedActorSystem].provider.getDefaultAddress + AkkaUtils.akkaAddressToRpcAddress(address) + } + + override def setupEndpoint(name: String, endpointCreator: => RpcEndpoint): RpcEndpointRef = { + val latch = new CountDownLatch(1) + try { + @volatile var endpointRef: AkkaRpcEndpointRef = null + val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging { + + val endpoint = endpointCreator + // Wait until `endpointRef` is set. TODO better solution? + latch.await() + require(endpointRef != null) + registerEndpoint(endpoint, endpointRef) + + var isNetworkRpcEndpoint = false + + override def preStart(): Unit = { + if (endpoint.isInstanceOf[NetworkRpcEndpoint]) { + isNetworkRpcEndpoint = true + // Listen for remote client network events only when it's `NetworkRpcEndpoint` + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + } + endpoint.onStart() + } + + override def receiveWithLogging: Receive = if (isNetworkRpcEndpoint) { + case AssociatedEvent(_, remoteAddress, _) => + try { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onConnected(akkaAddressToRpcAddress(remoteAddress)) + } catch { + case NonFatal(e) => endpoint.onError(e) + } + + case DisassociatedEvent(_, remoteAddress, _) => + try { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onDisconnected(akkaAddressToRpcAddress(remoteAddress)) + } catch { + case NonFatal(e) => endpoint.onError(e) + } + + case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) => + try { + endpoint.asInstanceOf[NetworkRpcEndpoint]. + onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress)) + } catch { + case NonFatal(e) => endpoint.onError(e) + } + case e: RemotingLifecycleEvent => + // TODO ignore? + + case message: Any => + logDebug("Received RPC message: " + message) + try { + val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => endpoint.onError(e) + } + } else { + case message: Any => + logDebug("Received RPC message: " + message) + try { + val pf = endpoint.receive(new AkkaRpcEndpointRef(defaultAddress, sender(), conf)) + if (pf.isDefinedAt(message)) { + pf.apply(message) + } + } catch { + case NonFatal(e) => endpoint.onError(e) + } + } + + override def postStop(): Unit = { + endpoint.onStop() + } + + }), name = name) + endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf) + endpointRef + } finally { + latch.countDown() + } + } + + private def akkaAddressToRpcAddress(address: Address): RpcAddress = { + RpcAddress(address.host.getOrElse(defaultAddress.host), + address.port.getOrElse(defaultAddress.port)) + } + + override def setupDriverEndpointRef(name: String): RpcEndpointRef = { + new AkkaRpcEndpointRef(defaultAddress, AkkaUtils.makeDriverRef(name, conf, actorSystem), conf) + } + + override def setupEndpointRefByUrl(url: String): RpcEndpointRef = { + val timeout = Duration.create(conf.getLong("spark.akka.lookupTimeout", 30), "seconds") + val ref = Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) + new AkkaRpcEndpointRef(defaultAddress, ref, conf) + } + + override def setupEndpointRef( + systemName: String, address: RpcAddress, endpointName: String): RpcEndpointRef = { + setupEndpointRefByUrl( + "akka.tcp://%s@%s:%s/user/%s".format(systemName, address.host, address.port, endpointName)) + } + + override def stopAll(): Unit = { + actorSystem.shutdown() + } + + override def stop(endpoint: RpcEndpointRef): Unit = { + require(endpoint.isInstanceOf[AkkaRpcEndpointRef]) + unregisterEndpoint(endpoint) + actorSystem.stop(endpoint.asInstanceOf[AkkaRpcEndpointRef].actorRef) + } + + override def awaitTermination(): Unit = { + actorSystem.awaitTermination() + } + + override def toString = s"${getClass.getSimpleName}($actorSystem)" +} + +private[rpc] object AkkaRpcEnv { + + def apply(config: RpcEnvConfig): RpcEnv = { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + config.name, config.host, config.port, config.conf, config.securityManager) + new AkkaRpcEnv(actorSystem, config.conf, boundPort) + } + + // TODO Remove it + @VisibleForTesting + def apply(name: String, conf: SparkConf): AkkaRpcEnv = { + new AkkaRpcEnv(ActorSystem(name), conf, -1) + } +} + +private[akka] class AkkaRpcEndpointRef( + @transient defaultAddress: RpcAddress, + val actorRef: ActorRef, + @transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging { + // `defaultAddress` and `conf` won't be used after initialization. So it's safe to be transient. + + private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3) + private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000) + private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds + + override val address: RpcAddress = { + val akkaAddress = actorRef.path.address + RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host), + akkaAddress.port.getOrElse(defaultAddress.port)) + } + + override def ask[T: ClassTag](message: Any): Future[T] = ask(message, defaultTimeout) + + override def ask[T: ClassTag](message: Any, timeout: FiniteDuration): Future[T] = { + actorRef.ask(message)(timeout).mapTo[T] + } + + override def askWithReply[T](message: Any): T = askWithReply(message, defaultTimeout) + + override def askWithReply[T](message: Any, timeout: FiniteDuration): T = { + // TODO: Consider removing multiple attempts + AkkaUtils.askWithReply(message, actorRef, maxRetries, retryWaitMs, timeout) + } + + override def send(message: Any)(implicit sender: RpcEndpointRef = RpcEndpoint.noSender): Unit = { + implicit val actorSender: ActorRef = + if (sender == null) { + Actor.noSender + } else { + require(sender.isInstanceOf[AkkaRpcEndpointRef]) + sender.asInstanceOf[AkkaRpcEndpointRef].actorRef + } + actorRef ! message + } + + override def toString: String = s"${getClass.getSimpleName}($actorRef)" +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 8cb15918baa8..9e6f11686a4b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -171,11 +171,8 @@ class DAGScheduler( taskMetrics: Array[(Long, Int, Int, TaskMetrics)], // (taskId, stageId, stateAttempt, metrics) blockManagerId: BlockManagerId): Boolean = { listenerBus.post(SparkListenerExecutorMetricsUpdate(execId, taskMetrics)) - implicit val timeout = Timeout(600 seconds) - - Await.result( - blockManagerMaster.driverActor ? BlockManagerHeartbeat(blockManagerId), - timeout.duration).asInstanceOf[Boolean] + blockManagerMaster.driverActor.askWithReply[Boolean]( + BlockManagerHeartbeat(blockManagerId), 600 seconds) } // Called by TaskScheduler when an executor fails. diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala index a1dfb0106259..b814fcecf540 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala @@ -18,14 +18,13 @@ package org.apache.spark.scheduler import java.nio.ByteBuffer +import java.util.concurrent.{TimeUnit, Executors} import java.util.{TimerTask, Timer} import java.util.concurrent.atomic.AtomicLong -import scala.concurrent.duration._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet -import scala.language.postfixOps import scala.util.Random import org.apache.spark._ @@ -142,11 +141,11 @@ private[spark] class TaskSchedulerImpl( if (!isLocal && conf.getBoolean("spark.speculation", false)) { logInfo("Starting speculative execution thread") - import sc.env.actorSystem.dispatcher - sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, - SPECULATION_INTERVAL milliseconds) { - Utils.tryOrExit { checkSpeculatableTasks() } - } + val scheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("task-scheduler-speculation")) + scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = Utils.tryOrExit { checkSpeculatableTasks() } + }, SPECULATION_INTERVAL, SPECULATION_INTERVAL, TimeUnit.MILLISECONDS) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala index 1da6fe976da5..f2cc88b5a384 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedClusterMessage.scala @@ -20,6 +20,7 @@ package org.apache.spark.scheduler.cluster import java.nio.ByteBuffer import org.apache.spark.TaskState.TaskState +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.{SerializableBuffer, Utils} private[spark] sealed trait CoarseGrainedClusterMessage extends Serializable @@ -39,8 +40,8 @@ private[spark] object CoarseGrainedClusterMessages { case class RegisterExecutorFailed(message: String) extends CoarseGrainedClusterMessage // Executors to driver - case class RegisterExecutor(executorId: String, hostPort: String, cores: Int) - extends CoarseGrainedClusterMessage { + case class RegisterExecutor(executorId: String, hostPort: String, cores: Int, + executorRef: RpcEndpointRef) extends CoarseGrainedClusterMessage { Utils.checkHostPort(hostPort, "Expected host port") } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 5786d367464f..d107af1168a2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -18,19 +18,15 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.AtomicInteger +import java.util.concurrent.{TimeUnit, Executors} import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} -import scala.concurrent.Await -import scala.concurrent.duration._ - -import akka.actor._ -import akka.pattern.ask -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState} +import org.apache.spark.rpc.{RpcAddress, NetworkRpcEndpoint, RpcEnv, RpcEndpointRef} import org.apache.spark.scheduler._ import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ -import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils} +import org.apache.spark.util.{SerializableBuffer, AkkaUtils, Utils} /** * A scheduler backend that waits for coarse grained executors to connect to it through Akka. @@ -41,7 +37,7 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut * (spark.deploy.*). */ private[spark] -class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem) +class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: RpcEnv) extends ExecutorAllocationClient with SchedulerBackend with Logging { // Use an atomic variable to track total number of cores in the cluster for simplicity and speed @@ -63,7 +59,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste private val executorDataMap = new HashMap[String, ExecutorData] - // Number of executors requested from the cluster manager that have not registered yet + // Number of executors requested from the cluster manager thaSimpt have not registered yet private var numPendingExecutors = 0 private val listenerBus = scheduler.sc.listenerBus @@ -71,34 +67,37 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Executors we have requested the cluster manager to kill that have not died yet private val executorsPendingToRemove = new HashSet[String] - class DriverActor(sparkProperties: Seq[(String, String)]) extends Actor with ActorLogReceive { + class DriverActor(override val rpcEnv: RpcEnv, sparkProperties: Seq[(String, String)]) + extends NetworkRpcEndpoint with Logging { + override protected def log = CoarseGrainedSchedulerBackend.this.log - private val addressToExecutorId = new HashMap[Address, String] + private val addressToExecutorId = new HashMap[RpcAddress, String] - override def preStart() { - // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) + private val reviveScheduler = + Executors.newScheduledThreadPool(1, Utils.namedThreadFactory("driver-revive-scheduler")) + override def onStart() { // Periodically revive offers to allow delay scheduling to work val reviveInterval = conf.getLong("spark.scheduler.revive.interval", 1000) - import context.dispatcher - context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) + reviveScheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ReviveOffers) + }, 0, reviveInterval, TimeUnit.MILLISECONDS) } - def receiveWithLogging = { - case RegisterExecutor(executorId, hostPort, cores) => + def receive(sender: RpcEndpointRef) = { + case RegisterExecutor(executorId, hostPort, cores, executorRef) => Utils.checkHostPort(hostPort, "Host port expected " + hostPort) if (executorDataMap.contains(executorId)) { - sender ! RegisterExecutorFailed("Duplicate executor ID: " + executorId) + sender.send(RegisterExecutorFailed("Duplicate executor ID: " + executorId)) } else { logInfo("Registered executor: " + sender + " with ID " + executorId) - sender ! RegisteredExecutor + sender.send(RegisteredExecutor) - addressToExecutorId(sender.path.address) = executorId + addressToExecutorId(sender.address) = executorId totalCoreCount.addAndGet(cores) totalRegisteredExecutors.addAndGet(1) val (host, _) = Utils.parseHostPort(hostPort) - val data = new ExecutorData(sender, sender.path.address, host, cores, cores) + val data = new ExecutorData(sender, sender.address, host, cores, cores) // This must be synchronized because variables mutated // in this block are read when requesting executors CoarseGrainedSchedulerBackend.this.synchronized { @@ -132,33 +131,29 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste case KillTask(taskId, executorId, interruptThread) => executorDataMap.get(executorId) match { case Some(executorInfo) => - executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread) + executorInfo.executorActor.send(KillTask(taskId, executorId, interruptThread)) case None => // Ignoring the task kill since the executor is not registered. logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.") } case StopDriver => - sender ! true - context.stop(self) + sender.send(true) + stop() case StopExecutors => logInfo("Asking each executor to shut down") for ((_, executorData) <- executorDataMap) { - executorData.executorActor ! StopExecutor + executorData.executorActor .send(StopExecutor) } - sender ! true + sender.send(true) case RemoveExecutor(executorId, reason) => removeExecutor(executorId, reason) - sender ! true - - case DisassociatedEvent(_, address, _) => - addressToExecutorId.get(address).foreach(removeExecutor(_, - "remote Akka client disassociated")) + sender.send(true) case RetrieveSparkProps => - sender ! sparkProperties + sender.send(sparkProperties) } // Make fake resource offers on all executors @@ -198,11 +193,16 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste else { val executorData = executorDataMap(task.executorId) executorData.freeCores -= scheduler.CPUS_PER_TASK - executorData.executorActor ! LaunchTask(new SerializableBuffer(serializedTask)) + executorData.executorActor.send(LaunchTask(new SerializableBuffer(serializedTask))) } } } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + addressToExecutorId.get(remoteAddress).foreach(removeExecutor(_, + "remote Akka client disassociated")) + } + // Remove a disconnected slave from the cluster def removeExecutor(executorId: String, reason: String): Unit = { executorDataMap.get(executorId) match { @@ -222,7 +222,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } - var driverActor: ActorRef = null + var driverActor: RpcEndpointRef = null val taskIdsOnSlave = new HashMap[String, HashSet[String]] override def start() { @@ -233,16 +233,15 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } } // TODO (prashant) send conf instead of properties - driverActor = actorSystem.actorOf( - Props(new DriverActor(properties)), name = CoarseGrainedSchedulerBackend.ACTOR_NAME) + driverActor = rpcEnv.setupEndpoint(CoarseGrainedSchedulerBackend.ACTOR_NAME, + new DriverActor(rpcEnv, properties)) } def stopExecutors() { try { if (driverActor != null) { logInfo("Shutting down all executors") - val future = driverActor.ask(StopExecutors)(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(StopExecutors) } } catch { case e: Exception => @@ -254,8 +253,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste stopExecutors() try { if (driverActor != null) { - val future = driverActor.ask(StopDriver)(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(StopDriver) } } catch { case e: Exception => @@ -264,11 +262,11 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste } override def reviveOffers() { - driverActor ! ReviveOffers + driverActor.send(ReviveOffers) } override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - driverActor ! KillTask(taskId, executorId, interruptThread) + driverActor.send(KillTask(taskId, executorId, interruptThread)) } override def defaultParallelism(): Int = { @@ -278,8 +276,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste // Called by subclasses when notified of a lost worker def removeExecutor(executorId: String, reason: String) { try { - val future = driverActor.ask(RemoveExecutor(executorId, reason))(timeout) - Await.ready(future, timeout) + driverActor.askWithReply(RemoveExecutor(executorId, reason)) } catch { case e: Exception => throw new SparkException("Error notifying standalone scheduler's driver actor", e) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala index eb52ddfb1eab..9b245b23fe07 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ExecutorData.scala @@ -17,20 +17,20 @@ package org.apache.spark.scheduler.cluster -import akka.actor.{Address, ActorRef} +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef} /** * Grouping of data for an executor used by CoarseGrainedSchedulerBackend. * - * @param executorActor The ActorRef representing this executor + * @param executorActor The RpcEndpointRef representing this executor * @param executorAddress The network address of this executor * @param executorHost The hostname that this executor is running on * @param freeCores The current number of cores available for work on the executor * @param totalCores The total number of cores available to the executor */ private[cluster] class ExecutorData( - val executorActor: ActorRef, - val executorAddress: Address, + val executorActor: RpcEndpointRef, + val executorAddress: RpcAddress, override val executorHost: String, var freeCores: Int, override val totalCores: Int diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index ee10aa061f4e..a851a01c07ba 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -27,7 +27,7 @@ private[spark] class SimrSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, driverFilePath: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with Logging { val tmpPath = new Path(driverFilePath + "_tmp") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index 7eb87a564d6f..bd97f94e92f5 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -17,6 +17,7 @@ package org.apache.spark.scheduler.cluster +import org.apache.spark.rpc.RpcAddress import org.apache.spark.{Logging, SparkConf, SparkContext, SparkEnv} import org.apache.spark.deploy.{ApplicationDescription, Command} import org.apache.spark.deploy.client.{AppClient, AppClientListener} @@ -27,7 +28,7 @@ private[spark] class SparkDeploySchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, masters: Array[String]) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with AppClientListener with Logging { @@ -79,7 +80,8 @@ private[spark] class SparkDeploySchedulerBackend( val appDesc = new ApplicationDescription(sc.appName, maxCores, sc.executorMemory, command, appUIAddress, sc.eventLogDir) - client = new AppClient(sc.env.actorSystem, masters, appDesc, this, conf) + val masterAddresses = masters.map(RpcAddress.fromSparkURL).toSet + client = new AppClient(sc.env.rpcEnv, masterAddresses, appDesc, this, conf) client.start() waitForRegistration() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala index f14aaeea0a25..37ac0e882b6b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala @@ -18,18 +18,15 @@ package org.apache.spark.scheduler.cluster import scala.concurrent.{Future, ExecutionContext} - -import akka.actor.{Actor, ActorRef, Props} -import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} +import scala.util.control.NonFatal import org.apache.spark.SparkContext +import org.apache.spark.rpc.{RpcAddress, RpcEndpointRef, RpcEnv, NetworkRpcEndpoint} import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.scheduler.TaskSchedulerImpl import org.apache.spark.ui.JettyUtils import org.apache.spark.util.{AkkaUtils, Utils} -import scala.util.control.NonFatal - /** * Abstract Yarn scheduler backend that contains common logic * between the client and cluster Yarn scheduler backends. @@ -37,7 +34,7 @@ import scala.util.control.NonFatal private[spark] abstract class YarnSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) { + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) { if (conf.getOption("spark.scheduler.minRegisteredResourcesRatio").isEmpty) { minRegisteredRatio = 0.8 @@ -45,10 +42,8 @@ private[spark] abstract class YarnSchedulerBackend( protected var totalExpectedExecutors = 0 - private val yarnSchedulerActor: ActorRef = - actorSystem.actorOf( - Props(new YarnSchedulerActor), - name = YarnSchedulerBackend.ACTOR_NAME) + private val yarnSchedulerActor: RpcEndpointRef = + rpcEnv.setupEndpoint(YarnSchedulerBackend.ACTOR_NAME, new YarnSchedulerActor(rpcEnv)) private implicit val askTimeout = AkkaUtils.askTimeout(sc.conf) @@ -57,16 +52,14 @@ private[spark] abstract class YarnSchedulerBackend( * This includes executors already pending or running. */ override def doRequestTotalExecutors(requestedTotal: Int): Boolean = { - AkkaUtils.askWithReply[Boolean]( - RequestExecutors(requestedTotal), yarnSchedulerActor, askTimeout) + yarnSchedulerActor.askWithReply(RequestExecutors(requestedTotal)) } /** * Request that the ApplicationMaster kill the specified executors. */ override def doKillExecutors(executorIds: Seq[String]): Boolean = { - AkkaUtils.askWithReply[Boolean]( - KillExecutors(executorIds), yarnSchedulerActor, askTimeout) + yarnSchedulerActor.askWithReply(KillExecutors(executorIds)) } override def sufficientResourcesRegistered(): Boolean = { @@ -98,18 +91,13 @@ private[spark] abstract class YarnSchedulerBackend( /** * An actor that communicates with the ApplicationMaster. */ - private class YarnSchedulerActor extends Actor { - private var amActor: Option[ActorRef] = None + private class YarnSchedulerActor(override val rpcEnv: RpcEnv) extends NetworkRpcEndpoint { + private var amActor: Option[RpcEndpointRef] = None implicit val askAmActorExecutor = ExecutionContext.fromExecutor( Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor")) - override def preStart(): Unit = { - // Listen for disassociation events - context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) - } - - override def receive = { + override def receive(sender: RpcEndpointRef) = { case RegisterClusterManager => logInfo(s"ApplicationMaster registered as $sender") amActor = Some(sender) @@ -117,39 +105,38 @@ private[spark] abstract class YarnSchedulerBackend( case r: RequestExecutors => amActor match { case Some(actor) => - val driverActor = sender Future { - driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout) + sender.send(actor.askWithReply[Boolean](r)) } onFailure { case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e) } case None => logWarning("Attempted to request executors before the AM has registered!") - sender ! false + sender.send(false) } case k: KillExecutors => amActor match { case Some(actor) => - val driverActor = sender Future { - driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout) + sender.send(actor.askWithReply[Boolean](k)) } onFailure { case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e) } case None => logWarning("Attempted to kill executors before the AM has registered!") - sender ! false + sender.send(false) } case AddWebUIFilter(filterName, filterParams, proxyBase) => addWebUIFilter(filterName, filterParams, proxyBase) - sender ! true + sender.send(true) + } - case d: DisassociatedEvent => - if (amActor.isDefined && sender == amActor.get) { - logWarning(s"ApplicationMaster has disassociated: $d") - } + override def onDisconnected(remoteAddress: RpcAddress): Unit = { + if (amActor.isDefined && remoteAddress == amActor.get.address) { + logWarning(s"ApplicationMaster has disassociated: $remoteAddress") + } } } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 5289661eb896..371c2debaa21 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -47,7 +47,7 @@ private[spark] class CoarseMesosSchedulerBackend( scheduler: TaskSchedulerImpl, sc: SparkContext, master: String) - extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.rpcEnv) with MScheduler with Logging { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala index 05b6fa54564b..bd2e0980e362 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala @@ -19,13 +19,11 @@ package org.apache.spark.scheduler.local import java.nio.ByteBuffer -import akka.actor.{Actor, ActorRef, Props} - import org.apache.spark.{Logging, SparkContext, SparkEnv, TaskState} import org.apache.spark.TaskState.TaskState import org.apache.spark.executor.{Executor, ExecutorBackend} +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler.{SchedulerBackend, TaskSchedulerImpl, WorkerOffer} -import org.apache.spark.util.ActorLogReceive private case class ReviveOffers() @@ -41,10 +39,11 @@ private case class StopExecutor() * and the TaskSchedulerImpl. */ private[spark] class LocalActor( + override val rpcEnv: RpcEnv, scheduler: TaskSchedulerImpl, executorBackend: LocalBackend, private val totalCores: Int) - extends Actor with ActorLogReceive with Logging { + extends RpcEndpoint with Logging { private var freeCores = totalCores @@ -54,7 +53,7 @@ private[spark] class LocalActor( private val executor = new Executor( localExecutorId, localExecutorHostname, SparkEnv.get, isLocal = true) - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case ReviveOffers => reviveOffers() @@ -91,31 +90,30 @@ private[spark] class LocalBackend(scheduler: TaskSchedulerImpl, val totalCores: extends SchedulerBackend with ExecutorBackend { private val appId = "local-" + System.currentTimeMillis - var localActor: ActorRef = null + var localActor: RpcEndpointRef = null override def start() { - localActor = SparkEnv.get.actorSystem.actorOf( - Props(new LocalActor(scheduler, this, totalCores)), - "LocalBackendActor") + localActor = SparkEnv.get.rpcEnv.setupEndpoint("LocalBackendActor", + new LocalActor(SparkEnv.get.rpcEnv, scheduler, this, totalCores)) } override def stop() { - localActor ! StopExecutor + localActor.send(StopExecutor) } override def reviveOffers() { - localActor ! ReviveOffers + localActor.send(ReviveOffers) } override def defaultParallelism() = scheduler.conf.getInt("spark.default.parallelism", totalCores) override def killTask(taskId: Long, executorId: String, interruptThread: Boolean) { - localActor ! KillTask(taskId, interruptThread) + localActor.send(KillTask(taskId, interruptThread)) } override def statusUpdate(taskId: Long, state: TaskState, serializedData: ByteBuffer) { - localActor ! StatusUpdate(taskId, state, serializedData) + localActor.send(StatusUpdate(taskId, state, serializedData)) } override def applicationId(): String = appId diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 8bc5a1cd18b6..1f25338d5b74 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -26,7 +26,6 @@ import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.duration._ import scala.util.Random -import akka.actor.{ActorSystem, Props} import sun.nio.ch.DirectBuffer import org.apache.spark._ @@ -37,6 +36,7 @@ import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer} import org.apache.spark.network.netty.SparkTransportConf import org.apache.spark.network.shuffle.ExternalShuffleClient import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo +import org.apache.spark.rpc.RpcEnv import org.apache.spark.serializer.Serializer import org.apache.spark.shuffle.ShuffleManager import org.apache.spark.shuffle.hash.HashShuffleManager @@ -64,7 +64,7 @@ private[spark] class BlockResult( */ private[spark] class BlockManager( executorId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, val master: BlockManagerMaster, defaultSerializer: Serializer, maxMemory: Long, @@ -136,9 +136,9 @@ private[spark] class BlockManager( // Whether to compress shuffle output temporarily spilled to disk private val compressShuffleSpill = conf.getBoolean("spark.shuffle.spill.compress", true) - private val slaveActor = actorSystem.actorOf( - Props(new BlockManagerSlaveActor(this, mapOutputTracker)), - name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) + private val slaveActor = rpcEnv.setupEndpoint( + name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next, + new BlockManagerSlaveActor(rpcEnv, this, mapOutputTracker)) // Pending re-registration action being executed asynchronously or null if none is pending. // Accesses should synchronize on asyncReregisterLock. @@ -167,7 +167,7 @@ private[spark] class BlockManager( */ def this( execId: String, - actorSystem: ActorSystem, + rpcEnv: RpcEnv, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf, @@ -176,7 +176,7 @@ private[spark] class BlockManager( blockTransferService: BlockTransferService, securityManager: SecurityManager, numUsableCores: Int) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), + this(execId, rpcEnv, master, serializer, BlockManager.getMaxMemory(conf), conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores) } @@ -1206,7 +1206,7 @@ private[spark] class BlockManager( shuffleClient.close() } diskBlockManager.stop() - actorSystem.stop(slaveActor) + rpcEnv.stop(slaveActor) blockInfo.clear() memoryStore.clear() diskStore.clear() diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index b63c7f191155..1906de736244 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -20,20 +20,17 @@ package org.apache.spark.storage import scala.concurrent.{Await, Future} import scala.concurrent.ExecutionContext.Implicits.global -import akka.actor._ - import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.AkkaUtils private[spark] class BlockManagerMaster( - var driverActor: ActorRef, + var driverActor: RpcEndpointRef, conf: SparkConf, isDriver: Boolean) extends Logging { - private val AKKA_RETRY_ATTEMPTS: Int = AkkaUtils.numRetries(conf) - private val AKKA_RETRY_INTERVAL_MS: Int = AkkaUtils.retryWaitMs(conf) val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" @@ -46,7 +43,8 @@ class BlockManagerMaster( } /** Register the BlockManager's id with the driver. */ - def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + def registerBlockManager(blockManagerId: BlockManagerId, maxMemSize: Long, + slaveActor: RpcEndpointRef) { logInfo("Trying to register BlockManager") tell(RegisterBlockManager(blockManagerId, maxMemSize, slaveActor)) logInfo("Registered BlockManager") @@ -218,8 +216,7 @@ class BlockManagerMaster( * throw a SparkException if this fails. */ private def askDriverWithReply[T](message: Any): T = { - AkkaUtils.askWithReply(message, driverActor, AKKA_RETRY_ATTEMPTS, AKKA_RETRY_INTERVAL_MS, - timeout) + driverActor.askWithReply(message) } } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 64133464d8da..5627c7a48d71 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -17,29 +17,34 @@ package org.apache.spark.storage +import java.util.concurrent.{ScheduledFuture, TimeUnit, Executors} import java.util.{HashMap => JHashMap} import scala.collection.mutable import scala.collection.JavaConversions._ -import scala.concurrent.Future -import scala.concurrent.duration._ +import scala.concurrent.{ExecutionContext, Future} -import akka.actor.{Actor, ActorRef, Cancellable} -import akka.pattern.ask - -import org.apache.spark.{Logging, SparkConf, SparkException} +import org.apache.spark.{Logging, SparkConf} import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.rpc.{RpcEnv, RpcEndpointRef, RpcEndpoint} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.{ActorLogReceive, AkkaUtils, Utils} +import org.apache.spark.util.Utils /** * BlockManagerMasterActor is an actor on the master node to track statuses of * all slaves' block managers. */ private[spark] -class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus: LiveListenerBus) - extends Actor with ActorLogReceive with Logging { +class BlockManagerMasterActor(override val rpcEnv: RpcEnv, val isLocal: Boolean, conf: SparkConf, + listenerBus: LiveListenerBus) extends RpcEndpoint with Logging { + + val scheduler = Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("block-manager-master-actor-heartbeat-scheduler")) + + implicit val executor = ExecutionContext.fromExecutor( + Executors.newFixedThreadPool(Runtime.getRuntime.availableProcessors(), + Utils.namedThreadFactory("block-manager-master-actor-ask-timeout-executor"))) // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] @@ -50,84 +55,82 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Mapping from block id to the set of block managers that have the block. private val blockLocations = new JHashMap[BlockId, mutable.HashSet[BlockManagerId]] - private val akkaTimeout = AkkaUtils.askTimeout(conf) - val slaveTimeout = conf.getLong("spark.storage.blockManagerSlaveTimeoutMs", 120 * 1000) val checkTimeoutInterval = conf.getLong("spark.storage.blockManagerTimeoutIntervalMs", 60000) - var timeoutCheckingTask: Cancellable = null + var timeoutCheckingTask: ScheduledFuture[_] = null - override def preStart() { - import context.dispatcher - timeoutCheckingTask = context.system.scheduler.schedule(0.seconds, - checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) - super.preStart() + override def onStart() { + super.onStart() + timeoutCheckingTask = scheduler.scheduleAtFixedRate(new Runnable { + override def run(): Unit = self.send(ExpireDeadHosts) + }, 0, checkTimeoutInterval, TimeUnit.MILLISECONDS) } - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RegisterBlockManager(blockManagerId, maxMemSize, slaveActor) => register(blockManagerId, maxMemSize, slaveActor) - sender ! true + sender.send(true) case UpdateBlockInfo( blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) => - sender ! updateBlockInfo( - blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize) + sender.send(updateBlockInfo( + blockManagerId, blockId, storageLevel, deserializedSize, size, tachyonSize)) case GetLocations(blockId) => - sender ! getLocations(blockId) + sender.send(getLocations(blockId)) case GetLocationsMultipleBlockIds(blockIds) => - sender ! getLocationsMultipleBlockIds(blockIds) + sender.send(getLocationsMultipleBlockIds(blockIds)) case GetPeers(blockManagerId) => - sender ! getPeers(blockManagerId) + sender.send(getPeers(blockManagerId)) case GetActorSystemHostPortForExecutor(executorId) => - sender ! getActorSystemHostPortForExecutor(executorId) + sender.send(getActorSystemHostPortForExecutor(executorId)) case GetMemoryStatus => - sender ! memoryStatus + sender.send(memoryStatus) case GetStorageStatus => - sender ! storageStatus + sender.send(storageStatus) case GetBlockStatus(blockId, askSlaves) => - sender ! blockStatus(blockId, askSlaves) + sender.send(blockStatus(blockId, askSlaves)) case GetMatchingBlockIds(filter, askSlaves) => - sender ! getMatchingBlockIds(filter, askSlaves) + sender.send(getMatchingBlockIds(filter, askSlaves)) case RemoveRdd(rddId) => - sender ! removeRdd(rddId) + sender.send(removeRdd(rddId)) case RemoveShuffle(shuffleId) => - sender ! removeShuffle(shuffleId) + sender.send(removeShuffle(shuffleId)) case RemoveBroadcast(broadcastId, removeFromDriver) => - sender ! removeBroadcast(broadcastId, removeFromDriver) + sender.send(removeBroadcast(broadcastId, removeFromDriver)) case RemoveBlock(blockId) => removeBlockFromWorkers(blockId) - sender ! true + sender.send(true) case RemoveExecutor(execId) => removeExecutor(execId) - sender ! true + sender.send(true) case StopBlockManagerMaster => - sender ! true + sender.send(true) if (timeoutCheckingTask != null) { - timeoutCheckingTask.cancel() + timeoutCheckingTask.cancel(true) } - context.stop(self) + stop() case ExpireDeadHosts => expireDeadHosts() case BlockManagerHeartbeat(blockManagerId) => - sender ! heartbeatReceived(blockManagerId) + sender.send(heartbeatReceived(blockManagerId)) case other => logWarning("Got unknown message: " + other) @@ -148,22 +151,20 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Ask the slaves to remove the RDD, and put the result in a sequence of Futures. // The dispatcher is used as an implicit argument into the Future sequence construction. - import context.dispatcher val removeMsg = RemoveRdd(rddId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveActor.ask[Int](removeMsg) }.toSeq ) } private def removeShuffle(shuffleId: Int): Future[Seq[Boolean]] = { // Nothing to do in the BlockManagerMasterActor data structures - import context.dispatcher val removeMsg = RemoveShuffle(shuffleId) Future.sequence( blockManagerInfo.values.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Boolean] + bm.slaveActor.ask[Boolean](removeMsg) }.toSeq ) } @@ -174,14 +175,13 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus * from the executors, but not from the driver. */ private def removeBroadcast(broadcastId: Long, removeFromDriver: Boolean): Future[Seq[Int]] = { - import context.dispatcher val removeMsg = RemoveBroadcast(broadcastId, removeFromDriver) val requiredBlockManagers = blockManagerInfo.values.filter { info => removeFromDriver || !info.blockManagerId.isDriver } Future.sequence( requiredBlockManagers.map { bm => - bm.slaveActor.ask(removeMsg)(akkaTimeout).mapTo[Int] + bm.slaveActor.ask[Int](removeMsg) }.toSeq ) } @@ -251,7 +251,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus // Remove the block from the slave's BlockManager. // Doesn't actually wait for a confirmation and the message might get lost. // If message loss becomes frequent, we should add retry logic here. - blockManager.get.slaveActor.ask(RemoveBlock(blockId))(akkaTimeout) + blockManager.get.slaveActor.send(RemoveBlock(blockId)) } } } @@ -281,7 +281,6 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def blockStatus( blockId: BlockId, askSlaves: Boolean): Map[BlockManagerId, Future[Option[BlockStatus]]] = { - import context.dispatcher val getBlockStatus = GetBlockStatus(blockId) /* * Rather than blocking on the block status query, master actor should simply return @@ -291,7 +290,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus blockManagerInfo.values.map { info => val blockStatusFuture = if (askSlaves) { - info.slaveActor.ask(getBlockStatus)(akkaTimeout).mapTo[Option[BlockStatus]] + info.slaveActor.ask[Option[BlockStatus]](getBlockStatus) } else { Future { info.getStatus(blockId) } } @@ -310,13 +309,12 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getMatchingBlockIds( filter: BlockId => Boolean, askSlaves: Boolean): Future[Seq[BlockId]] = { - import context.dispatcher val getMatchingBlockIds = GetMatchingBlockIds(filter) Future.sequence( blockManagerInfo.values.map { info => val future = if (askSlaves) { - info.slaveActor.ask(getMatchingBlockIds)(akkaTimeout).mapTo[Seq[BlockId]] + info.slaveActor.ask[Seq[BlockId]](getMatchingBlockIds) } else { Future { info.blocks.keys.filter(filter).toSeq } } @@ -325,7 +323,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus ).map(_.flatten.toSeq) } - private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: ActorRef) { + private def register(id: BlockManagerId, maxMemSize: Long, slaveActor: RpcEndpointRef) { val time = System.currentTimeMillis() if (!blockManagerInfo.contains(id)) { blockManagerIdByExecutor.get(id.executorId) match { @@ -419,11 +417,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = { for ( blockManagerId <- blockManagerIdByExecutor.get(executorId); - info <- blockManagerInfo.get(blockManagerId); - host <- info.slaveActor.path.address.host; - port <- info.slaveActor.path.address.port + info <- blockManagerInfo.get(blockManagerId) ) yield { - (host, port) + (info.slaveActor.address.host, info.slaveActor.address.port) } } } @@ -446,7 +442,7 @@ private[spark] class BlockManagerInfo( val blockManagerId: BlockManagerId, timeMs: Long, val maxMem: Long, - val slaveActor: ActorRef) + val slaveActor: RpcEndpointRef) extends Logging { private var _lastSeenMs: Long = timeMs diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala index 3f32099d08cc..26e1b311c032 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala @@ -19,8 +19,7 @@ package org.apache.spark.storage import java.io.{Externalizable, ObjectInput, ObjectOutput} -import akka.actor.ActorRef - +import org.apache.spark.rpc.RpcEndpointRef import org.apache.spark.util.Utils private[spark] object BlockManagerMessages { @@ -52,7 +51,7 @@ private[spark] object BlockManagerMessages { case class RegisterBlockManager( blockManagerId: BlockManagerId, maxMemSize: Long, - sender: ActorRef) + sender: RpcEndpointRef) extends ToBlockManagerMaster case class UpdateBlockInfo( diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala index 8462871e798a..b2462e1a3291 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveActor.scala @@ -17,13 +17,15 @@ package org.apache.spark.storage -import scala.concurrent.Future +import java.util.concurrent.Executors -import akka.actor.{ActorRef, Actor} +import scala.concurrent.ExecutionContext +import scala.concurrent.Future import org.apache.spark.{Logging, MapOutputTracker, SparkEnv} +import org.apache.spark.rpc.{RpcEnv, RpcEndpoint, RpcEndpointRef} import org.apache.spark.storage.BlockManagerMessages._ -import org.apache.spark.util.ActorLogReceive +import org.apache.spark.util.Utils /** * An actor to take commands from the master to execute options. For example, @@ -31,14 +33,15 @@ import org.apache.spark.util.ActorLogReceive */ private[storage] class BlockManagerSlaveActor( + override val rpcEnv: RpcEnv, blockManager: BlockManager, - mapOutputTracker: MapOutputTracker) - extends Actor with ActorLogReceive with Logging { + mapOutputTracker: MapOutputTracker) extends RpcEndpoint with Logging { - import context.dispatcher + implicit val executor = ExecutionContext.fromExecutorService(Executors.newScheduledThreadPool(1, + Utils.namedThreadFactory("block-manager-slave-actor-executor"))) // Operations that involve removing blocks may be slow and should be done asynchronously - override def receiveWithLogging = { + override def receive(sender: RpcEndpointRef) = { case RemoveBlock(blockId) => doAsync[Boolean]("removing block " + blockId, sender) { blockManager.removeBlock(blockId) @@ -64,25 +67,25 @@ class BlockManagerSlaveActor( } case GetBlockStatus(blockId, _) => - sender ! blockManager.getStatus(blockId) + sender.send(blockManager.getStatus(blockId)) case GetMatchingBlockIds(filter, _) => - sender ! blockManager.getMatchingBlockIds(filter) + sender.send(blockManager.getMatchingBlockIds(filter)) } - private def doAsync[T](actionMessage: String, responseActor: ActorRef)(body: => T) { + private def doAsync[T](actionMessage: String, responseActor: RpcEndpointRef)(body: => T) { val future = Future { logDebug(actionMessage) body } future.onSuccess { case response => logDebug("Done " + actionMessage + ", response is " + response) - responseActor ! response + responseActor.send(response) logDebug("Sent response: " + response + " to " + responseActor) } future.onFailure { case t: Throwable => logError("Error in " + actionMessage, t) - responseActor ! null.asInstanceOf[T] + responseActor.send(null.asInstanceOf[T]) } } } diff --git a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala index 332d0cbb2dc0..142a32eb1dce 100644 --- a/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala +++ b/core/src/main/scala/org/apache/spark/util/ActorLogReceive.scala @@ -43,7 +43,7 @@ private[spark] trait ActorLogReceive { private val _receiveWithLogging = receiveWithLogging - override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) + final override def isDefinedAt(o: Any): Boolean = _receiveWithLogging.isDefinedAt(o) override def apply(o: Any): Unit = { if (log.isDebugEnabled) { diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 4c9b1e3c46f0..0f9055162b9f 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,11 +17,13 @@ package org.apache.spark.util +import org.apache.spark.rpc.RpcAddress + import scala.collection.JavaConversions.mapAsJavaMap import scala.concurrent.Await import scala.concurrent.duration.{Duration, FiniteDuration} -import akka.actor.{ActorRef, ActorSystem, ExtendedActorSystem} +import akka.actor.{Address, ActorRef, ActorSystem, ExtendedActorSystem} import akka.pattern.ask import com.typesafe.config.ConfigFactory @@ -178,7 +180,7 @@ private[spark] object AkkaUtils extends Logging { message: Any, actor: ActorRef, maxAttempts: Int, - retryInterval: Int, + retryInterval: Long, timeout: FiniteDuration): T = { // TODO: Consider removing multiple attempts if (actor == null) { @@ -233,4 +235,9 @@ private[spark] object AkkaUtils extends Logging { logInfo(s"Connecting to $name: $url") Await.result(actorSystem.actorSelection(url).resolveOne(timeout), timeout) } + + def akkaAddressToRpcAddress(akkaAddress: Address): RpcAddress = { + // TODO How to handle that a remoteAddress doesn't have host & port + RpcAddress(akkaAddress.host.getOrElse("localhost"), akkaAddress.port.getOrElse(-1)) + } } diff --git a/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala new file mode 100644 index 000000000000..0a8c1c5ea129 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/HeartbeatReceiverSuite.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.apache.spark.executor.TaskMetrics +import org.apache.spark.storage.BlockManagerId +import org.scalatest.FunSuite +import org.mockito.Mockito._ +import org.mockito.Matchers +import org.mockito.Matchers._ + +import org.apache.spark.scheduler.TaskScheduler + +class HeartbeatReceiverSuite extends FunSuite with LocalSparkContext { + + test("HeartbeatReceiver") { + sc = new SparkContext("local[2]", "test") + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(true) + + sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc.env.rpcEnv, scheduler)) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("heartbeat") + + val metrics = new TaskMetrics + metrics.jvmGCTime = 100 + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(false === response.reregisterBlockManager) + } + + test("HeartbeatReceiver re-register") { + sc = new SparkContext("local[2]", "test") + val scheduler = mock(classOf[TaskScheduler]) + when(scheduler.executorHeartbeatReceived(any(), any(), any())).thenReturn(false) + + sc.env.rpcEnv.setupEndpoint("heartbeat", new HeartbeatReceiver(sc.env.rpcEnv, scheduler)) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("heartbeat") + + val metrics = new TaskMetrics + metrics.jvmGCTime = 100 + val blockManagerId = BlockManagerId("executor-1", "localhost", 12345) + val response = receiverRef.askWithReply[HeartbeatResponse]( + Heartbeat("executor-1", Array(1L -> metrics), blockManagerId)) + + verify(scheduler).executorHeartbeatReceived( + Matchers.eq("executor-1"), Matchers.eq(Array(1L -> metrics)), Matchers.eq(blockManagerId)) + assert(true === response.reregisterBlockManager) + } +} diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index d27880f4bc32..a06738321a39 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -17,34 +17,31 @@ package org.apache.spark -import scala.concurrent.Await - -import akka.actor._ -import akka.testkit.TestActorRef +import org.mockito.Mockito._ import org.scalatest.FunSuite +import org.apache.spark.rpc.{RpcAddress, RpcEnv, RpcEndpoint, RpcEndpointRef} import org.apache.spark.scheduler.{CompressedMapStatus, MapStatus} import org.apache.spark.shuffle.FetchFailedException import org.apache.spark.storage.BlockManagerId -import org.apache.spark.util.AkkaUtils class MapOutputTrackerSuite extends FunSuite { private val conf = new SparkConf test("master start and stop") { - val actorSystem = ActorSystem("test") + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.stop() - actorSystem.shutdown() + rpcEnv.stopAll() } test("master register shuffle and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) assert(tracker.containsShuffle(10)) val size1000 = MapStatus.decompressSize(MapStatus.compressSize(1000L)) @@ -57,13 +54,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(statuses.toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000), (BlockManagerId("b", "hostB", 1000), size10000))) tracker.stop() - actorSystem.shutdown() + rpcEnv.stopAll() } test("master register and unregister shuffle") { - val actorSystem = ActorSystem("test") + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -78,14 +76,14 @@ class MapOutputTrackerSuite extends FunSuite { assert(tracker.getServerStatuses(10, 0).isEmpty) tracker.stop() - actorSystem.shutdown() + rpcEnv.stopAll() } test("master register shuffle and unregister map output and fetch") { - val actorSystem = ActorSystem("test") + val rpcEnv = RpcEnv.create("test", conf) val tracker = new MapOutputTrackerMaster(conf) - tracker.trackerActor = - actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker, conf))) + tracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, tracker, conf)) tracker.registerShuffle(10, 2) val compressedSize1000 = MapStatus.compressSize(1000L) val compressedSize10000 = MapStatus.compressSize(10000L) @@ -104,25 +102,23 @@ class MapOutputTrackerSuite extends FunSuite { intercept[FetchFailedException] { tracker.getServerStatuses(10, 1) } tracker.stop() - actorSystem.shutdown() + rpcEnv.stopAll() } test("remote fetch") { val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, conf = conf, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, conf = conf, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = conf, securityManager = new SecurityManager(conf)) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -147,8 +143,8 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.stop() slaveTracker.stop() - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.stopAll() + slaveRpcEnv.stopAll() } test("remote fetch below akka frame size") { @@ -157,19 +153,19 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = RpcEnv.create("test", conf) + val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) + rpcEnv.setupEndpoint("MapOutputTracker", masterActor) // Frame size should be ~123B, and no exception should be thrown masterTracker.registerShuffle(10, 1) masterTracker.registerMapOutput(10, 0, MapStatus( BlockManagerId("88", "mph", 1000), Array.fill[Long](10)(0))) - masterActor.receive(GetMapOutputStatuses(10)) + val sender = mock(classOf[RpcEndpointRef]) + masterActor.receive(sender).apply(GetMapOutputStatuses(20)) // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.stopAll() } test("remote fetch exceeds akka frame size") { @@ -178,10 +174,8 @@ class MapOutputTrackerSuite extends FunSuite { newConf.set("spark.akka.askTimeout", "1") // Fail fast val masterTracker = new MapOutputTrackerMaster(conf) - val actorSystem = ActorSystem("test") - val actorRef = TestActorRef[MapOutputTrackerMasterActor]( - Props(new MapOutputTrackerMasterActor(masterTracker, newConf)))(actorSystem) - val masterActor = actorRef.underlyingActor + val rpcEnv = RpcEnv.create("test", conf) + val masterActor = new MapOutputTrackerMasterActor(rpcEnv, masterTracker, newConf) // Frame size should be ~1.1MB, and MapOutputTrackerMasterActor should throw exception. // Note that the size is hand-selected here because map output statuses are compressed before @@ -191,9 +185,11 @@ class MapOutputTrackerSuite extends FunSuite { masterTracker.registerMapOutput(20, i, new CompressedMapStatus( BlockManagerId("999", "mps", 1000), Array.fill[Long](4000000)(0))) } - intercept[SparkException] { masterActor.receive(GetMapOutputStatuses(20)) } + intercept[SparkException] { + masterActor.receive(RpcEndpoint.noSender).apply(GetMapOutputStatuses(20)) + } // masterTracker.stop() // this throws an exception - actorSystem.shutdown() + rpcEnv.stopAll() } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala deleted file mode 100644 index 3d2335f9b363..000000000000 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.deploy.master - -import akka.actor.Address -import org.scalatest.FunSuite - -import org.apache.spark.SparkException - -class MasterSuite extends FunSuite { - - test("toAkkaUrl") { - val akkaUrl = Master.toAkkaUrl("spark://1.2.3.4:1234") - assert("akka.tcp://sparkMaster@1.2.3.4:1234/user/Master" === akkaUrl) - } - - test("toAkkaUrl: a typo url") { - val e = intercept[SparkException] { - Master.toAkkaUrl("spark://1.2. 3.4:1234") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } - - test("toAkkaAddress") { - val address = Master.toAkkaAddress("spark://1.2.3.4:1234") - assert(Address("akka.tcp", "sparkMaster", "1.2.3.4", 1234) === address) - } - - test("toAkkaAddress: a typo url") { - val e = intercept[SparkException] { - Master.toAkkaAddress("spark://1.2. 3.4:1234") - } - assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) - } -} diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala index 5e538d6fab2a..8d5cb38ef729 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/WorkerWatcherSuite.scala @@ -17,32 +17,35 @@ package org.apache.spark.deploy.worker -import akka.actor.{ActorSystem, AddressFromURIString, Props} -import akka.testkit.TestActorRef -import akka.remote.DisassociatedEvent +import akka.actor.AddressFromURIString +import org.apache.spark.SparkConf +import org.apache.spark.rpc.RpcEnv +import org.apache.spark.util.AkkaUtils import org.scalatest.FunSuite class WorkerWatcherSuite extends FunSuite { test("WorkerWatcher shuts down on valid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" + val rpcEnv = RpcEnv.create("test", new SparkConf()) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" val targetWorkerAddress = AddressFromURIString(targetWorkerUrl) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, targetWorkerAddress, false)) - assert(actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(targetWorkerAddress)) + assert(workerWatcher.isShutDown) + rpcEnv.stopAll() } test("WorkerWatcher stays alive on invalid disassociation") { - val actorSystem = ActorSystem("test") - val targetWorkerUrl = "akka://1.2.3.4/user/Worker" - val otherAkkaURL = "akka://4.3.2.1/user/OtherActor" + val rpcEnv = RpcEnv.create("test", new SparkConf()) + val targetWorkerUrl = "akka://test@1.2.3.4:1234/user/Worker" + val otherAkkaURL = "akka://test@4.3.2.1:1234/user/OtherActor" val otherAkkaAddress = AddressFromURIString(otherAkkaURL) - val actorRef = TestActorRef[WorkerWatcher](Props(classOf[WorkerWatcher], targetWorkerUrl))(actorSystem) - val workerWatcher = actorRef.underlyingActor + val workerWatcher = new WorkerWatcher(rpcEnv, targetWorkerUrl) workerWatcher.setTesting(testing = true) - actorRef.underlyingActor.receive(new DisassociatedEvent(null, otherAkkaAddress, false)) - assert(!actorRef.underlyingActor.isShutDown) + rpcEnv.setupEndpoint("worker-watcher", workerWatcher) + workerWatcher.onDisconnected(AkkaUtils.akkaAddressToRpcAddress(otherAkkaAddress)) + assert(!workerWatcher.isShutDown) + rpcEnv.stopAll() } } diff --git a/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala b/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala new file mode 100644 index 000000000000..96f7f7a2f214 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/executor/ExecutorActorSuite.scala @@ -0,0 +1,33 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.executor + +import org.apache.spark.util.ThreadStackTrace +import org.apache.spark.{LocalSparkContext, SparkContext} +import org.scalatest.FunSuite + +class ExecutorActorSuite extends FunSuite with LocalSparkContext { + + test("ExecutorActor") { + sc = new SparkContext("local[2]", "test") + sc.env.rpcEnv.setupEndpoint("executor-actor", new ExecutorActor(sc.env.rpcEnv, "executor-1")) + val receiverRef = sc.env.rpcEnv.setupDriverEndpointRef("executor-actor") + val response = receiverRef.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump) + assert(response.size > 0) + } + } diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala new file mode 100644 index 000000000000..7a7c4e06b35f --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc + +import scala.concurrent.duration._ +import scala.language.postfixOps + +import org.scalatest.{BeforeAndAfterAll, FunSuite} +import org.scalatest.concurrent.Eventually._ + +/** + * Common tests for an RpcEnv implementation. + */ +abstract class RpcEnvSuite extends FunSuite with BeforeAndAfterAll { + + var env: RpcEnv = _ + + override def beforeAll(): Unit = { + env = createRpcEnv + } + + override def afterAll(): Unit = { + if(env != null) { + env.stopAll() + } + } + + def createRpcEnv: RpcEnv + + test("send a message locally") { + @volatile var message: String = null + val rpcEndpointRef = env.setupEndpoint("send_test", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => message = msg + } + }) + rpcEndpointRef.send("hello") + Thread.sleep(2000) + assert("hello" === message) + } + + test("ask a message locally") { + val rpcEndpointRef = env.setupEndpoint("ask_test", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => sender.send(msg) + } + }) + val reply = rpcEndpointRef.askWithReply[String]("hello") + assert("hello" === reply) + } + + test("ping pong") { + case object Start + + case class Ping(id: Int) + + case class Pong(id: Int) + + val pongRef = env.setupEndpoint("pong", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case Ping(id) => sender.send(Pong(id)) + } + }) + + val pingRef = env.setupEndpoint("ping", new RpcEndpoint { + override val rpcEnv = env + + var requester: RpcEndpointRef = _ + + override def receive(sender: RpcEndpointRef) = { + case Start => { + requester = sender + pongRef.send(Ping(1)) + } + case p @ Pong(id) => { + if (id < 10) { + sender.send(Ping(id + 1)) + } else { + requester.send(p) + } + } + } + }) + + val reply = pingRef.askWithReply[Pong](Start) + assert(Pong(10) === reply) + } + + test("register and unregister") { + val endpoint = new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case msg: String => sender.send(msg) + } + } + val rpcEndpointRef = env.setupEndpoint("register_test", endpoint) + + eventually(timeout(5 seconds), interval(200 milliseconds)) { + assert(rpcEndpointRef eq env.endpointRef(endpoint)) + } + endpoint.stop() + + val e = intercept[IllegalArgumentException] { + env.endpointRef(endpoint) + } + assert(e.getMessage.contains("Cannot find RpcEndpointRef")) + } + + test("fault tolerance") { + case class SetState(state: Int) + + case object Crash + + case object GetState + + val rpcEndpointRef = env.setupEndpoint("fault_tolerance", new RpcEndpoint { + override val rpcEnv = env + + var state: Int = 0 + + override def receive(sender: RpcEndpointRef) = { + case SetState(state) => this.state = state + case Crash => throw new RuntimeException("Oops") + case GetState => sender.send(state) + } + }) + assert(0 === rpcEndpointRef.askWithReply[Int](GetState)) + + rpcEndpointRef.send(SetState(10)) + assert(10 === rpcEndpointRef.askWithReply[Int](GetState)) + + rpcEndpointRef.send(Crash) + // RpcEndpoint is crashed. Should reset its state. + assert(0 === rpcEndpointRef.askWithReply[Int](GetState)) + } +} diff --git a/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala new file mode 100644 index 000000000000..8f6d2aaaa82d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala @@ -0,0 +1,50 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.akka + +import org.apache.spark.rpc._ +import org.apache.spark.{SecurityManager, SparkConf} + +class AkkaRpcEnvSuite extends RpcEnvSuite { + + override def createRpcEnv: RpcEnv = { + val conf = new SparkConf() + AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12345, new SecurityManager(conf))) + } + + test("setupEndpointRef: systemName, address, endpointName") { + val ref = env.setupEndpoint("test_endpoint", new RpcEndpoint { + override val rpcEnv = env + + override def receive(sender: RpcEndpointRef) = { + case _ => + } + }) + val conf = new SparkConf() + val newRpcEnv = + AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf))) + try { + val newRef = newRpcEnv.setupEndpointRef("test", ref.address, "test_endpoint") + assert("akka.tcp://test@localhost:12345/user/test_endpoint" === + newRef.asInstanceOf[AkkaRpcEndpointRef].actorRef.path.toString) + } finally { + newRpcEnv.stopAll() + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 24f41bf8cccd..9475e4150baf 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -26,6 +26,7 @@ import org.scalatest.Matchers import org.apache.spark.{LocalSparkContext, SparkContext} import org.apache.spark.executor.TaskMetrics +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.util.ResetSystemProperties class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers with BeforeAndAfter @@ -271,7 +272,8 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with Matchers // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1) .map { x => 1.to(akkaFrameSize).toArray } .reduce { case (x, y) => x } diff --git a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala index e3a3803e6483..6578d2f39092 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/TaskResultGetterSuite.scala @@ -27,6 +27,7 @@ import org.scalatest.{BeforeAndAfter, FunSuite} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv} +import org.apache.spark.rpc.akka.AkkaRpcEnv import org.apache.spark.storage.TaskResultBlockId /** @@ -86,7 +87,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark test("handling results larger than Akka frame size") { sc = new SparkContext("local", "test", conf) val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) @@ -111,7 +113,8 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with LocalSpark val resultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) scheduler.taskResultGetter = resultGetter val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt + sc.env.rpcEnv.asInstanceOf[AkkaRpcEnv].actorSystem.settings. + config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(resultGetter.removeBlockSuccessfully) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala index c2903c859799..efc796f4ab16 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerReplicationSuite.scala @@ -22,25 +22,24 @@ import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import org.mockito.Mockito.{mock, when} -import org.scalatest.{BeforeAndAfter, FunSuite, Matchers, PrivateMethodTester} +import org.scalatest.{BeforeAndAfter, FunSuite, Matchers} import org.scalatest.concurrent.Eventually._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.network.BlockTransferService import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager import org.apache.spark.storage.StorageLevel._ -import org.apache.spark.util.{AkkaUtils, SizeEstimator} /** Testsuite that tests block replication in BlockManager */ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAndAfter { private val conf = new SparkConf(false) - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null val securityMgr = new SecurityManager(conf) val mapOutputTracker = new MapOutputTrackerMaster(conf) @@ -61,7 +60,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val store = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val store = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) store.initialize("app-id") allStores += store @@ -69,12 +68,11 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd } before { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + this.rpcEnv = RpcEnv.create( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem conf.set("spark.authenticate", "false") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") @@ -84,7 +82,8 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd conf.set("spark.storage.cachedPeersTtl", "10") master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + rpcEnv.setupEndpoint("block-manager-master", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) allStores.clear() } @@ -92,9 +91,9 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd after { allStores.foreach { _.stop() } allStores.clear() - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.stopAll() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -262,7 +261,7 @@ class BlockManagerReplicationSuite extends FunSuite with Matchers with BeforeAnd val failableTransfer = mock(classOf[BlockTransferService]) // this wont actually work when(failableTransfer.hostName).thenReturn("some-hostname") when(failableTransfer.port).thenReturn(1000) - val failableStore = new BlockManager("failable-store", actorSystem, master, serializer, + val failableStore = new BlockManager("failable-store", rpcEnv, master, serializer, 10000, conf, mapOutputTracker, shuffleManager, failableTransfer, securityMgr, 0) failableStore.initialize("app-id") allStores += failableStore // so that this gets stopped after test diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index ffe6f039145e..d736f2eb29d0 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -19,18 +19,12 @@ package org.apache.spark.storage import java.nio.{ByteBuffer, MappedByteBuffer} import java.util.Arrays -import java.util.concurrent.TimeUnit import scala.collection.mutable.ArrayBuffer -import scala.concurrent.Await import scala.concurrent.duration._ import scala.language.implicitConversions import scala.language.postfixOps -import akka.actor._ -import akka.pattern.ask -import akka.util.Timeout - import org.mockito.Mockito.{mock, when} import org.scalatest._ @@ -40,6 +34,7 @@ import org.scalatest.concurrent.Timeouts._ import org.apache.spark.{MapOutputTrackerMaster, SparkConf, SparkContext, SecurityManager} import org.apache.spark.executor.DataReadMethod import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.shuffle.hash.HashShuffleManager @@ -53,7 +48,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var master: BlockManagerMaster = null conf.set("spark.authenticate", "false") val securityMgr = new SecurityManager(conf) @@ -72,27 +67,27 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach maxMem: Long, name: String = SparkContext.DRIVER_IDENTIFIER): BlockManager = { val transfer = new NioBlockTransferService(conf, securityMgr) - val manager = new BlockManager(name, actorSystem, master, serializer, maxMem, conf, + val manager = new BlockManager(name, rpcEnv, master, serializer, maxMem, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) manager.initialize("app-id") manager } override def beforeEach(): Unit = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem( + this.rpcEnv = RpcEnv.create( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case System.setProperty("os.arch", "amd64") conf.set("os.arch", "amd64") conf.set("spark.test.useCompressedOops", "true") - conf.set("spark.driver.port", boundPort.toString) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) conf.set("spark.storage.unrollFraction", "0.4") conf.set("spark.storage.unrollMemoryThreshold", "512") master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + rpcEnv.setupEndpoint("block-manager-master", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) val initialize = PrivateMethod[Unit]('initialize) @@ -108,9 +103,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach store2.stop() store2 = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.stopAll() + rpcEnv.awaitTermination() + rpcEnv = null master = null } @@ -357,11 +352,9 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach master.removeExecutor(store.blockManagerId.executorId) assert(master.getLocations("a1").size == 0, "a1 was not removed from master") - implicit val timeout = Timeout(30, TimeUnit.SECONDS) - val reregister = !Await.result( - master.driverActor ? BlockManagerHeartbeat(store.blockManagerId), - timeout.duration).asInstanceOf[Boolean] - assert(reregister == true) + val reregister = ! master.driverActor.askWithReply[Boolean]( + BlockManagerHeartbeat(store.blockManagerId)) + assert(reregister === true) } test("reregistration on block update") { @@ -785,7 +778,7 @@ class BlockManagerSuite extends FunSuite with Matchers with BeforeAndAfterEach test("block store put failure") { // Use Java serializer so we can create an unserializable error. val transfer = new NioBlockTransferService(conf, securityMgr) - store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, actorSystem, master, + store = new BlockManager(SparkContext.DRIVER_IDENTIFIER, rpcEnv, master, new JavaSerializer(conf), 1200, conf, mapOutputTracker, shuffleManager, transfer, securityMgr, 0) diff --git a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala index 6bbf72e929dc..bbad58c11627 100644 --- a/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/AkkaUtilsSuite.scala @@ -17,10 +17,7 @@ package org.apache.spark.util -import scala.concurrent.Await - -import akka.actor._ - +import org.apache.spark.rpc.RpcEnv import org.scalatest.FunSuite import org.apache.spark._ @@ -40,14 +37,13 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) - val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "true") @@ -56,49 +52,45 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(securityManagerBad.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = conf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.stopAll() + slaveRpcEnv.stopAll() } test("remote fetch security off") { val conf = new SparkConf conf.set("spark.authenticate", "false") conf.set("spark.authenticate.secret", "bad") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === false) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "good") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) assert(securityManagerBad.isAuthenticationEnabled() === false) @@ -116,41 +108,39 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.stopAll() + slaveRpcEnv.stopAll() } test("remote fetch security pass") { val conf = new SparkConf conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val goodconf = new SparkConf goodconf.set("spark.authenticate", "true") goodconf.set("spark.authenticate.secret", "good") - val securityManagerGood = new SecurityManager(goodconf); + val securityManagerGood = new SecurityManager(goodconf) assert(securityManagerGood.isAuthenticationEnabled() === true) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = goodconf, securityManager = securityManagerGood) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() @@ -166,8 +156,8 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro assert(slaveTracker.getServerStatuses(10, 0).toSeq === Seq((BlockManagerId("a", "hostA", 1000), size1000))) - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.stopAll() + slaveRpcEnv.stopAll() } test("remote fetch security off client") { @@ -175,38 +165,36 @@ class AkkaUtilsSuite extends FunSuite with LocalSparkContext with ResetSystemPro conf.set("spark.authenticate", "true") conf.set("spark.authenticate.secret", "good") - val securityManager = new SecurityManager(conf); + val securityManager = new SecurityManager(conf) val hostname = "localhost" - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("spark", hostname, 0, + val rpcEnv = RpcEnv.create("spark", hostname, 0, conf = conf, securityManager = securityManager) - System.setProperty("spark.hostPort", hostname + ":" + boundPort) + System.setProperty("spark.hostPort", hostname + ":" + rpcEnv.boundPort) assert(securityManager.isAuthenticationEnabled() === true) val masterTracker = new MapOutputTrackerMaster(conf) - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker, conf)), "MapOutputTracker") + masterTracker.trackerActor = rpcEnv.setupEndpoint( + "MapOutputTracker", new MapOutputTrackerMasterActor(rpcEnv, masterTracker, conf)) val badconf = new SparkConf badconf.set("spark.authenticate", "false") badconf.set("spark.authenticate.secret", "bad") - val securityManagerBad = new SecurityManager(badconf); + val securityManagerBad = new SecurityManager(badconf) assert(securityManagerBad.isAuthenticationEnabled() === false) - val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0, + val slaveRpcEnv = RpcEnv.create("spark-slave", hostname, 0, conf = badconf, securityManager = securityManagerBad) val slaveTracker = new MapOutputTrackerWorker(conf) - val selection = slaveSystem.actorSelection( - s"akka.tcp://spark@localhost:$boundPort/user/MapOutputTracker") - val timeout = AkkaUtils.lookupTimeout(conf) + val selection = s"akka.tcp://spark@localhost:${rpcEnv.boundPort}/user/MapOutputTracker" intercept[akka.actor.ActorNotFound] { - slaveTracker.trackerActor = Await.result(selection.resolveOne(timeout), timeout) + slaveTracker.trackerActor = slaveRpcEnv.setupEndpointRefByUrl(selection) } - actorSystem.shutdown() - slaveSystem.shutdown() + rpcEnv.stopAll() + slaveRpcEnv.stopAll() } } diff --git a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala index 4544382094f9..c0aed252f2a1 100644 --- a/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/UtilsSuite.scala @@ -29,7 +29,7 @@ import com.google.common.base.Charsets.UTF_8 import com.google.common.io.Files import org.scalatest.FunSuite -import org.apache.spark.SparkConf +import org.apache.spark.{SparkException, SparkConf} class UtilsSuite extends FunSuite with ResetSystemProperties { @@ -381,4 +381,18 @@ class UtilsSuite extends FunSuite with ResetSystemProperties { require(cnt === 2, "prepare should be called twice") require(time < 500, "preparation time should not count") } + + test("extractHostPortFromSparkUrl") { + val (host, port) = Utils.extractHostPortFromSparkUrl("spark://1.2.3.4:1234") + assert("1.2.3.4" === host) + assert(1234 === port) + } + + test("extractHostPortFromSparkUrl: a typo url") { + val e = intercept[SparkException] { + Utils.extractHostPortFromSparkUrl("spark://1.2. 3.4:1234") + } + assert("Invalid master URL: spark://1.2. 3.4:1234" === e.getMessage) + } + } diff --git a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala index 132ff2443fc0..7864b2e712e4 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/ReceivedBlockHandlerSuite.scala @@ -24,7 +24,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.duration._ import scala.language.postfixOps -import akka.actor.{ActorSystem, Props} import com.google.common.io.Files import org.apache.commons.io.FileUtils import org.apache.hadoop.conf.Configuration @@ -33,6 +32,7 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark._ import org.apache.spark.network.nio.NioBlockTransferService +import org.apache.spark.rpc.RpcEnv import org.apache.spark.scheduler.LiveListenerBus import org.apache.spark.serializer.KryoSerializer import org.apache.spark.shuffle.hash.HashShuffleManager @@ -56,7 +56,7 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche val manualClock = new ManualClock val blockManagerSize = 10000000 - var actorSystem: ActorSystem = null + var rpcEnv: RpcEnv = null var blockManagerMaster: BlockManagerMaster = null var blockManager: BlockManager = null var tempDirectory: File = null @@ -64,14 +64,14 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche before { val (actorSystem, boundPort) = AkkaUtils.createActorSystem( "test", "localhost", 0, conf = conf, securityManager = securityMgr) - this.actorSystem = actorSystem - conf.set("spark.driver.port", boundPort.toString) + rpcEnv = RpcEnv.create("test", "localhost", 0, conf = conf, securityManager = securityMgr) + conf.set("spark.driver.port", rpcEnv.boundPort.toString) - blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf, new LiveListenerBus))), + blockManagerMaster = new BlockManagerMaster(rpcEnv.setupEndpoint("block-manager-master-actor", + new BlockManagerMasterActor(rpcEnv, true, conf, new LiveListenerBus)), conf, true) - blockManager = new BlockManager("bm", actorSystem, blockManagerMaster, serializer, + blockManager = new BlockManager("bm", rpcEnv, blockManagerMaster, serializer, blockManagerSize, conf, mapOutputTracker, shuffleManager, new NioBlockTransferService(conf, securityMgr), securityMgr, 0) blockManager.initialize("app-id") @@ -89,9 +89,9 @@ class ReceivedBlockHandlerSuite extends FunSuite with BeforeAndAfter with Matche blockManagerMaster.stop() blockManagerMaster = null } - actorSystem.shutdown() - actorSystem.awaitTermination() - actorSystem = null + rpcEnv.stopAll() + rpcEnv.awaitTermination() + rpcEnv = null if (tempDirectory != null && tempDirectory.exists()) { FileUtils.deleteDirectory(tempDirectory)