diff --git a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala index 27c943da88105..41d6d146a86d7 100644 --- a/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala +++ b/core/src/main/scala/org/apache/spark/rpc/netty/Dispatcher.scala @@ -22,6 +22,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.JavaConverters._ import scala.concurrent.Promise +import scala.util.control.NonFatal import org.apache.spark.SparkException import org.apache.spark.internal.Logging @@ -44,16 +45,6 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte private val shutdownLatch = new CountDownLatch(1) private lazy val sharedLoop = new SharedMessageLoop(nettyEnv.conf, this, numUsableCores) - private def getMessageLoop(name: String, endpoint: RpcEndpoint): MessageLoop = { - endpoint match { - case e: IsolatedRpcEndpoint => - new DedicatedMessageLoop(name, e, this) - case _ => - sharedLoop.register(name, endpoint) - sharedLoop - } - } - /** * True if the dispatcher has been stopped. Once stopped, all messages posted will be bounced * immediately. @@ -68,11 +59,31 @@ private[netty] class Dispatcher(nettyEnv: NettyRpcEnv, numUsableCores: Int) exte if (stopped) { throw new IllegalStateException("RpcEnv has been stopped") } - if (endpoints.putIfAbsent(name, getMessageLoop(name, endpoint)) != null) { + if (endpoints.containsKey(name)) { throw new IllegalArgumentException(s"There is already an RpcEndpoint called $name") } + + // This must be done before assigning RpcEndpoint to MessageLoop, as MessageLoop sets Inbox be + // active when registering, and endpointRef must be put into endpointRefs before onStart is + // called. + endpointRefs.put(endpoint, endpointRef) + + var messageLoop: MessageLoop = null + try { + messageLoop = endpoint match { + case e: IsolatedRpcEndpoint => + new DedicatedMessageLoop(name, e, this) + case _ => + sharedLoop.register(name, endpoint) + sharedLoop + } + endpoints.put(name, messageLoop) + } catch { + case NonFatal(e) => + endpointRefs.remove(endpoint) + throw e + } } - endpointRefs.put(endpoint, endpointRef) endpointRef }