diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index fbf2dc73ea075..b4bca1e9401e2 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -51,7 +51,7 @@ private[spark] class CoarseGrainedExecutorBackend( userClassPath: Seq[URL], env: SparkEnv, resourcesFileOpt: Option[String]) - extends ThreadSafeRpcEndpoint with ExecutorBackend with Logging { + extends IsolatedRpcEndpoint with ExecutorBackend with Logging { private implicit val formats = DefaultFormats diff --git a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala index 97eed540b8f59..4728759e7fb0d 100644 --- a/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/rpc/RpcEndpoint.scala @@ -146,3 +146,19 @@ private[spark] trait RpcEndpoint { * [[ThreadSafeRpcEndpoint]] for different messages. */ private[spark] trait ThreadSafeRpcEndpoint extends RpcEndpoint + +/** + * An endpoint that uses a dedicated thread pool for delivering messages. + */ +private[spark] trait IsolatedRpcEndpoint extends RpcEndpoint { + + /** + * How many threads to use for delivering messages. By default, use a single thread. + * + * Note that requesting more than one thread means that the endpoint should be able to handle + * messages arriving from many threads at once, and all the things that entails (including + * messages being delivered to the endpoint out of order). + */ + def threadCount(): Int = 1 + +} diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 2f923d7902b05..27c943da88105 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -17,20 +17,16 @@ package org.apache.spark.rpc.netty -import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, LinkedBlockingQueue, ThreadPoolExecutor, TimeUnit} +import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap, CountDownLatch} import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise -import scala.util.control.NonFatal -import org.apache.spark.{SparkConf, SparkContext, SparkException} +import org.apache.spark.SparkException import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.EXECUTOR_ID -import org.apache.spark.internal.config.Network.RPC_NETTY_DISPATCHER_NUM_THREADS import org.apache.spark.network.client.RpcResponseCallback import org.apache.spark.rpc._ -import org.apache.spark.util.ThreadUtils /** * A message dispatcher, responsible for routing RPC messages to the appropriate endpoint(s). @@ -40,20 +36,23 @@ import org.apache.spark.util.ThreadUtils */ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) extends Logging { - private class EndpointData( - val name: String, - val endpoint: RpcEndpoint, - val ref: NettyRpcEndpointRef) { - val inbox = new Inbox(ref, endpoint) - } - - private val endpoints: ConcurrentMap[String, EndpointData] = - new ConcurrentHashMap[String, EndpointData] + private val endpoints: ConcurrentMap[String, MessageLoop] = + new ConcurrentHashMap[String, MessageLoop] private val endpointRefs: ConcurrentMap[RpcEndpoint, RpcEndpointRef] = new ConcurrentHashMap[RpcEndpoint, RpcEndpointRef] - // Track the receivers whose inboxes may contain messages. - private val receivers = new LinkedBlockingQueue[EndpointData] + private val shutdownLatch = new CountDownLatch(1) + private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores) + + private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = { + endpoint match { + case e: IsolatedRpcEndpoint => + new DedicatedMessageLoop(name, e, this) + case _ => + sharedLoop.register(name, endpoint) + sharedLoop + } + } /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced @@ -69,13 +68,11 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte if (stopped) { throw new IllegalStateException("RpcEnv has been stopped") } - if (endpoints.putIfAbsent(name, new EndpointData(name, endpoint, endpointRef)) != null) { + if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) { throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") } - val data = endpoints.get(name) - endpointRefs.put(data.endpoint, data.ref) - receivers.offer(data) // for the OnStart message } + endpointRefs.put(endpoint, endpointRef) endpointRef } @@ -85,10 +82,9 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte // Should be idempotent private def unregisterRpcEndpoint(name: String): Unit = { - val data = endpoints.remove(name) - if (data != null) { - data.inbox.stop() - receivers.offer(data) // for the OnStop message + val loop = endpoints.remove(name) + if (loop != null) { + loop.unregister(name) } // Don't clean `endpointRefs` here because it's possible that some messages are being processed // now and they can use `getRpcEndpointRef`. So `endpointRefs` will be cleaned in Inbox via @@ -155,14 +151,13 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte message: InboxMessage, callbackIfStopped: (Exception) => Unit): Unit = { val error = synchronized { - val data = endpoints.get(endpointName) + val loop = endpoints.get(endpointName) if (stopped) { Some(new RpcEnvStoppedException()) - } else if (data == null) { + } else if (loop == null) { Some(new SparkException(s"Could not find $endpointName.")) } else { - data.inbox.post(message) - receivers.offer(data) + loop.post(endpointName, message) None } } @@ -177,15 +172,23 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte } stopped = true } - // Stop all endpoints. This will queue all endpoints for processing by the message loops. - endpoints.keySet().asScala.foreach(unregisterRpcEndpoint) - // Enqueue a message that tells the message loops to stop. - receivers.offer(PoisonPill) - threadpool.shutdown() + var stopSharedLoop = false + endpoints.asScala.foreach { case (name, loop) => + unregisterRpcEndpoint(name) + if (!loop.isInstanceOf[SharedMessageLoop]) { + loop.stop() + } else { + stopSharedLoop = true + } + } + if (stopSharedLoop) { + sharedLoop.stop() + } + shutdownLatch.countDown() } def awaitTermination(): Unit = { - threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + shutdownLatch.await() } /** @@ -194,61 +197,4 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte def verify(name: String): Boolean = { endpoints.containsKey(name) } - - private def getNumOfThreads(conf: SparkConf): Int = { - val availableCores = - if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() - - val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) - .getOrElse(math.max(2, availableCores)) - - conf.get(EXECUTOR_ID).map { id => - val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" - conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) - }.getOrElse(modNumThreads) - } - - /** Thread pool used for dispatching messages. */ - private val threadpool: ThreadPoolExecutor = { - val numThreads = getNumOfThreads(nettyEnv.conf) - val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") - for (i <- 0 until numThreads) { - pool.execute(new MessageLoop) - } - pool - } - - /** Message loop used for dispatching messages. */ - private class MessageLoop extends Runnable { - override def run(): Unit = { - try { - while (true) { - try { - val data = receivers.take() - if (data == PoisonPill) { - // Put PoisonPill back so that other MessageLoops can see it. - receivers.offer(PoisonPill) - return - } - data.inbox.process(Dispatcher.this) - } catch { - case NonFatal(e) => logError(e.getMessage, e) - } - } - } catch { - case _: InterruptedException => // exit - case t: Throwable => - try { - // Re-submit a MessageLoop so that Dispatcher will still work if - // UncaughtExceptionHandler decides to not kill JVM. - threadpool.execute(new MessageLoop) - } finally { - throw t - } - } - } - } - - /** A poison endpoint that indicates MessageLoop should exit its message loop. */ - private val PoisonPill = new EndpointData(null, null, null) } diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala index 44d2622a42f58..2ed03f7430c32 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Inbox.scala @@ -54,9 +54,7 @@ private[netty] case class RemoteProcessConnectionError(cause: Throwable, remoteA /** * An inbox that stores messages for an [[RpcEndpoint]] and posts messages to it thread-safely. */ -private[netty] class Inbox( - val endpointRef: NettyRpcEndpointRef, - val endpoint: RpcEndpoint) +private[netty] class Inbox(val endpointName: String, val endpoint: RpcEndpoint) extends Logging { inbox => // Give this an alias so we can use it more clearly in closures. @@ -195,7 +193,7 @@ private[netty] class Inbox( * Exposed for testing. */ protected def onDrop(message: InboxMessage): Unit = { - logWarning(s"Drop $message because $endpointRef is stopped") + logWarning(s"Drop $message because endpoint $endpointName is stopped") } /** diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala new file mode 100644 index 0000000000000..c985c72f2adce --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rpc/netty/MessageLoop.scala @@ -0,0 +1,194 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rpc.netty + +import java.util.concurrent._ + +import scala.util.control.NonFatal + +import org.apache.spark.{SparkConf, SparkContext} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_ID +import org.apache.spark.internal.config.Network._ +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcEndpoint} +import org.apache.spark.util.ThreadUtils + +/** + * A message loop used by [[Dispatcher]] to deliver messages to endpoints. + */ +private sealed abstract class MessageLoop(dispatcher: Dispatcher) extends Logging { + + // List of inboxes with pending messages, to be processed by the message loop. + private val active = new LinkedBlockingQueue[Inbox]() + + // Message loop task; should be run in all threads of the message loop's pool. + protected val receiveLoopRunnable = new Runnable() { + override def run(): Unit = receiveLoop() + } + + protected val threadpool: ExecutorService + + private var stopped = false + + def post(endpointName: String, message: InboxMessage): Unit + + def unregister(name: String): Unit + + def stop(): Unit = { + synchronized { + if (!stopped) { + setActive(MessageLoop.PoisonPill) + threadpool.shutdown() + stopped = true + } + } + threadpool.awaitTermination(Long.MaxValue, TimeUnit.MILLISECONDS) + } + + protected final def setActive(inbox: Inbox): Unit = active.offer(inbox) + + private def receiveLoop(): Unit = { + try { + while (true) { + try { + val inbox = active.take() + if (inbox == MessageLoop.PoisonPill) { + // Put PoisonPill back so that other threads can see it. + setActive(MessageLoop.PoisonPill) + return + } + inbox.process(dispatcher) + } catch { + case NonFatal(e) => logError(e.getMessage, e) + } + } + } catch { + case _: InterruptedException => // exit + case t: Throwable => + try { + // Re-submit a receive task so that message delivery will still work if + // UncaughtExceptionHandler decides to not kill JVM. + threadpool.execute(receiveLoopRunnable) + } finally { + throw t + } + } + } +} + +private object MessageLoop { + /** A poison inbox that indicates the message loop should stop processing messages. */ + val PoisonPill = new Inbox(null, null) +} + +/** + * A message loop that serves multiple RPC endpoints, using a shared thread pool. + */ +private class SharedMessageLoop( + conf: SparkConf, + dispatcher: Dispatcher, + numUsableCores: Int) + extends MessageLoop(dispatcher) { + + private val endpoints = new ConcurrentHashMap[String, Inbox]() + + private def getNumOfThreads(conf: SparkConf): Int = { + val availableCores = + if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors() + + val modNumThreads = conf.get(RPC_NETTY_DISPATCHER_NUM_THREADS) + .getOrElse(math.max(2, availableCores)) + + conf.get(EXECUTOR_ID).map { id => + val role = if (id == SparkContext.DRIVER_IDENTIFIER) "driver" else "executor" + conf.getInt(s"spark.$role.rpc.netty.dispatcher.numThreads", modNumThreads) + }.getOrElse(modNumThreads) + } + + /** Thread pool used for dispatching messages. */ + override protected val threadpool: ThreadPoolExecutor = { + val numThreads = getNumOfThreads(conf) + val pool = ThreadUtils.newDaemonFixedThreadPool(numThreads, "dispatcher-event-loop") + for (i <- 0 until numThreads) { + pool.execute(receiveLoopRunnable) + } + pool + } + + override def post(endpointName: String, message: InboxMessage): Unit = { + val inbox = endpoints.get(endpointName) + inbox.post(message) + setActive(inbox) + } + + override def unregister(name: String): Unit = { + val inbox = endpoints.remove(name) + if (inbox != null) { + inbox.stop() + // Mark active to handle the OnStop message. + setActive(inbox) + } + } + + def register(name: String, endpoint: RpcEndpoint): Unit = { + val inbox = new Inbox(name, endpoint) + endpoints.put(name, inbox) + // Mark active to handle the OnStart message. + setActive(inbox) + } +} + +/** + * A message loop that is dedicated to a single RPC endpoint. + */ +private class DedicatedMessageLoop( + name: String, + endpoint: IsolatedRpcEndpoint, + dispatcher: Dispatcher) + extends MessageLoop(dispatcher) { + + private val inbox = new Inbox(name, endpoint) + + override protected val threadpool = if (endpoint.threadCount() > 1) { + ThreadUtils.newDaemonCachedThreadPool(s"dispatcher-$name", endpoint.threadCount()) + } else { + ThreadUtils.newDaemonSingleThreadExecutor(s"dispatcher-$name") + } + + (1 to endpoint.threadCount()).foreach { _ => + threadpool.submit(receiveLoopRunnable) + } + + // Mark active to handle the OnStart message. + setActive(inbox) + + override def post(endpointName: String, message: InboxMessage): Unit = { + require(endpointName == name) + inbox.post(message) + setActive(inbox) + } + + override def unregister(endpointName: String): Unit = synchronized { + require(endpointName == name) + inbox.stop() + // Mark active to handle the OnStop message. + setActive(inbox) + setActive(MessageLoop.PoisonPill) + threadpool.shutdown() + } +} diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 4958389ae4257..6e990d1335897 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -111,7 +111,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val rpcEnv: Rp private val reviveThread = ThreadUtils.newDaemonSingleThreadScheduledExecutor("driver-revive-thread") - class DriverEndpoint extends ThreadSafeRpcEndpoint with Logging { + class DriverEndpoint extends IsolatedRpcEndpoint with Logging { override val rpcEnv: RpcEnv = CoarseGrainedSchedulerBackend.this.rpcEnv diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala index faf6f713c838f..02d0e1a834909 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterEndpoint.scala @@ -30,7 +30,7 @@ import org.apache.spark.SparkConf import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.{config, Logging} import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{RpcCallContext, RpcEndpointRef, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEndpointRef, RpcEnv} import org.apache.spark.scheduler._ import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{RpcUtils, ThreadUtils, Utils} @@ -46,7 +46,7 @@ class BlockManagerMasterEndpoint( conf: SparkConf, listenerBus: LiveListenerBus, externalBlockStoreClient: Option[ExternalBlockStoreClient]) - extends ThreadSafeRpcEndpoint with Logging { + extends IsolatedRpcEndpoint with Logging { // Mapping from block manager id to the block manager's information. private val blockManagerInfo = new mutable.HashMap[BlockManagerId, BlockManagerInfo] diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala index f90595ab924b4..29e21142ce449 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerSlaveEndpoint.scala @@ -21,7 +21,7 @@ import scala.concurrent.{ExecutionContext, Future} import org.apache.spark.{MapOutputTracker, SparkEnv} import org.apache.spark.internal.Logging -import org.apache.spark.rpc.{RpcCallContext, RpcEnv, ThreadSafeRpcEndpoint} +import org.apache.spark.rpc.{IsolatedRpcEndpoint, RpcCallContext, RpcEnv} import org.apache.spark.storage.BlockManagerMessages._ import org.apache.spark.util.{ThreadUtils, Utils} @@ -34,7 +34,7 @@ class BlockManagerSlaveEndpoint( override val rpcEnv: RpcEnv, blockManager: BlockManager, mapOutputTracker: MapOutputTracker) - extends ThreadSafeRpcEndpoint with Logging { + extends IsolatedRpcEndpoint with Logging { private val asyncThreadPool = ThreadUtils.newDaemonCachedThreadPool("block-manager-slave-async-thread-pool", 100) diff --git a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala index 5929fbf85a1f4..c10f2c244e133 100644 --- a/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/RpcEnvSuite.scala @@ -36,7 +36,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.{SecurityManager, SparkConf, SparkEnv, SparkException, SparkFunSuite} import org.apache.spark.deploy.SparkHadoopUtil import org.apache.spark.internal.config._ -import org.apache.spark.internal.config.Network import org.apache.spark.util.{ThreadUtils, Utils} /** @@ -954,6 +953,40 @@ abstract class RpcEnvSuite extends SparkFunSuite with BeforeAndAfterAll { verify(endpoint, never()).onDisconnected(any()) verify(endpoint, never()).onNetworkError(any(), any()) } + + test("isolated endpoints") { + val latch = new CountDownLatch(1) + val singleThreadedEnv = createRpcEnv( + new SparkConf().set(Network.RPC_NETTY_DISPATCHER_NUM_THREADS, 1), "singleThread", 0) + try { + val blockingEndpoint = singleThreadedEnv.setupEndpoint("blocking", new IsolatedRpcEndpoint { + override val rpcEnv: RpcEnv = singleThreadedEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => + latch.await() + context.reply(m) + } + }) + + val nonBlockingEndpoint = singleThreadedEnv.setupEndpoint("non-blocking", new RpcEndpoint { + override val rpcEnv: RpcEnv = singleThreadedEnv + + override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = { + case m => context.reply(m) + } + }) + + val to = new RpcTimeout(5.seconds, "test-timeout") + val blockingFuture = blockingEndpoint.ask[String]("hi", to) + assert(nonBlockingEndpoint.askSync[String]("hello", to) === "hello") + latch.countDown() + assert(ThreadUtils.awaitResult(blockingFuture, 5.seconds) === "hi") + } finally { + latch.countDown() + singleThreadedEnv.shutdown() + } + } } class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala index e5539566e4b6f..c74c728b3e3f3 100644 --- a/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala +++ b/core/src/test/scala/org/apache/spark/rpc/netty/InboxSuite.scala @@ -29,12 +29,9 @@ class InboxSuite extends SparkFunSuite { test("post") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) val message = OneWayMessage(null, "hi") inbox.post(message) inbox.process(dispatcher) @@ -51,10 +48,9 @@ class InboxSuite extends SparkFunSuite { test("post: with reply") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) val message = RpcMessage(null, "hi", null) inbox.post(message) inbox.process(dispatcher) @@ -65,13 +61,10 @@ class InboxSuite extends SparkFunSuite { test("post: multiple threads") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) - when(endpointRef.name).thenReturn("hello") - val dispatcher = mock(classOf[Dispatcher]) val numDroppedMessages = new AtomicInteger(0) - val inbox = new Inbox(endpointRef, endpoint) { + val inbox = new Inbox("name", endpoint) { override def onDrop(message: InboxMessage): Unit = { numDroppedMessages.incrementAndGet() } @@ -107,12 +100,10 @@ class InboxSuite extends SparkFunSuite { test("post: Associated") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) - val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessConnected(remoteAddress)) inbox.process(dispatcher) @@ -121,12 +112,11 @@ class InboxSuite extends SparkFunSuite { test("post: Disassociated") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessDisconnected(remoteAddress)) inbox.process(dispatcher) @@ -135,13 +125,12 @@ class InboxSuite extends SparkFunSuite { test("post: AssociationError") { val endpoint = new TestRpcEndpoint - val endpointRef = mock(classOf[NettyRpcEndpointRef]) val dispatcher = mock(classOf[Dispatcher]) val remoteAddress = RpcAddress("localhost", 11111) val cause = new RuntimeException("Oops") - val inbox = new Inbox(endpointRef, endpoint) + val inbox = new Inbox("name", endpoint) inbox.post(RemoteProcessConnectionError(cause, remoteAddress)) inbox.process(dispatcher)