Skip to content

Commit 3e56123

Browse files
committed
Use lazy to eliminate CountDownLatch
1 parent 07f128f commit 3e56123

File tree

1 file changed

+90
-78
lines changed

1 file changed

+90
-78
lines changed

core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala

Lines changed: 90 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.rpc.akka
1919

2020
import java.net.URI
21-
import java.util.concurrent.{ConcurrentHashMap, CountDownLatch}
21+
import java.util.concurrent.ConcurrentHashMap
2222

2323
import scala.concurrent.{Await, Future}
2424
import scala.concurrent.duration._
@@ -92,97 +92,94 @@ private[spark] class AkkaRpcEnv private (
9292
}
9393

9494
override def setupThreadSafeEndpoint(name: String, endpoint: RpcEndpoint): RpcEndpointRef = {
95-
val latch = new CountDownLatch(1)
96-
try {
97-
@volatile var endpointRef: AkkaRpcEndpointRef = null
98-
val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging {
95+
@volatile var endpointRef: AkkaRpcEndpointRef = null
96+
// Use lazy because the Actor needs to use `endpointRef`.
97+
// So `actorRef` should be created after assigning `endpointRef`.
98+
lazy val actorRef = actorSystem.actorOf(Props(new Actor with ActorLogReceive with Logging {
99+
100+
require(endpointRef != null)
101+
registerEndpoint(endpoint, endpointRef)
102+
103+
override def preStart(): Unit = {
104+
// Listen for remote client network events
105+
context.system.eventStream.subscribe(self, classOf[AssociationEvent])
106+
safelyCall(endpoint) {
107+
endpoint.onStart()
108+
}
109+
}
99110

100-
// Wait until `endpointRef` is set. TODO better solution?
101-
latch.await()
102-
require(endpointRef != null)
103-
registerEndpoint(endpoint, endpointRef)
111+
override def receiveWithLogging: Receive = {
112+
case AssociatedEvent(_, remoteAddress, _) =>
113+
safelyCall(endpoint) {
114+
endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress))
115+
}
104116

105-
override def preStart(): Unit = {
106-
// Listen for remote client network events
107-
context.system.eventStream.subscribe(self, classOf[AssociationEvent])
117+
case DisassociatedEvent(_, remoteAddress, _) =>
108118
safelyCall(endpoint) {
109-
endpoint.onStart()
119+
endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress))
110120
}
111-
}
112121

113-
override def receiveWithLogging: Receive = {
114-
case AssociatedEvent(_, remoteAddress, _) =>
115-
safelyCall(endpoint) {
116-
endpoint.onConnected(akkaAddressToRpcAddress(remoteAddress))
117-
}
122+
case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) =>
123+
safelyCall(endpoint) {
124+
endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress))
125+
}
118126

119-
case DisassociatedEvent(_, remoteAddress, _) =>
120-
safelyCall(endpoint) {
121-
endpoint.onDisconnected(akkaAddressToRpcAddress(remoteAddress))
122-
}
127+
case e: AssociationEvent =>
128+
// TODO ignore?
123129

124-
case AssociationErrorEvent(cause, localAddress, remoteAddress, inbound, _) =>
125-
safelyCall(endpoint) {
126-
endpoint.onNetworkError(cause, akkaAddressToRpcAddress(remoteAddress))
127-
}
130+
case AkkaMessage(message: Any, reply: Boolean)=>
131+
logDebug("Received RPC message: " + AkkaMessage(message, reply))
132+
safelyCall(endpoint) {
133+
val s = sender()
134+
val pf =
135+
if (reply) {
136+
endpoint.receiveAndReply(new RpcCallContext {
137+
override def sendFailure(e: Throwable): Unit = {
138+
s ! AkkaFailure(e)
139+
}
128140

129-
case e: AssociationEvent =>
130-
// TODO ignore?
141+
override def reply(response: Any): Unit = {
142+
s ! AkkaMessage(response, false)
143+
}
131144

132-
case AkkaMessage(message: Any, reply: Boolean)=>
133-
logDebug("Received RPC message: " + AkkaMessage(message, reply))
134-
safelyCall(endpoint) {
135-
val s = sender()
136-
val pf =
145+
// Some RpcEndpoints need to know the sender's address
146+
override val sender: RpcEndpointRef =
147+
new AkkaRpcEndpointRef(defaultAddress, s, conf)
148+
})
149+
} else {
150+
endpoint.receive
151+
}
152+
try {
153+
if (pf.isDefinedAt(message)) {
154+
pf.apply(message)
155+
}
156+
} catch {
157+
case NonFatal(e) =>
137158
if (reply) {
138-
endpoint.receiveAndReply(new RpcCallContext {
139-
override def sendFailure(e: Throwable): Unit = {
140-
s ! AkkaFailure(e)
141-
}
142-
143-
override def reply(response: Any): Unit = {
144-
s ! AkkaMessage(response, false)
145-
}
146-
147-
// Some RpcEndpoints need to know the sender's address
148-
override val sender: RpcEndpointRef =
149-
new AkkaRpcEndpointRef(defaultAddress, s, conf)
150-
})
159+
// If the sender asks a reply, we should send the error back to the sender
160+
s ! AkkaFailure(e)
151161
} else {
152-
endpoint.receive
153-
}
154-
try {
155-
if (pf.isDefinedAt(message)) {
156-
pf.apply(message)
162+
throw e
157163
}
158-
} catch {
159-
case NonFatal(e) =>
160-
if (reply) {
161-
// If the sender asks a reply, we should send the error back to the sender
162-
s ! AkkaFailure(e)
163-
} else {
164-
throw e
165-
}
166-
}
167164
}
168-
case message: Any => {
169-
logWarning(s"Unknown message: $message")
170165
}
166+
case message: Any => {
167+
logWarning(s"Unknown message: $message")
171168
}
169+
}
172170

173-
override def postStop(): Unit = {
174-
unregisterEndpoint(endpoint.self)
175-
safelyCall(endpoint) {
176-
endpoint.onStop()
177-
}
171+
override def postStop(): Unit = {
172+
unregisterEndpoint(endpoint.self)
173+
safelyCall(endpoint) {
174+
endpoint.onStop()
178175
}
176+
}
179177

180-
}), name = name)
181-
endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf)
182-
endpointRef
183-
} finally {
184-
latch.countDown()
185-
}
178+
}), name = name)
179+
endpointRef = new AkkaRpcEndpointRef(defaultAddress, actorRef, conf, initInConstructor = false)
180+
// Now actorRef can be created safely
181+
endpointRef.init()
182+
endpointRef
186183
}
187184

188185
/**
@@ -258,21 +255,36 @@ private[spark] object AkkaRpcEnv {
258255

259256
private[akka] class AkkaRpcEndpointRef(
260257
@transient defaultAddress: RpcAddress,
261-
val actorRef: ActorRef,
262-
@transient conf: SparkConf) extends RpcEndpointRef with Serializable with Logging {
258+
@transient _actorRef: => ActorRef,
259+
@transient conf: SparkConf,
260+
@transient initInConstructor: Boolean = true)
261+
extends RpcEndpointRef with Serializable with Logging {
263262
// `defaultAddress` and `conf` won't be used after initialization. So it's safe to be transient.
264263

265264
private[this] val maxRetries = conf.getInt("spark.akka.num.retries", 3)
266265
private[this] val retryWaitMs = conf.getLong("spark.akka.retry.wait", 3000)
267266
private[this] val defaultTimeout = conf.getLong("spark.akka.lookupTimeout", 30) seconds
268267

269-
override val address: RpcAddress = {
268+
lazy val actorRef = _actorRef
269+
270+
override lazy val address: RpcAddress = {
270271
val akkaAddress = actorRef.path.address
271272
RpcAddress(akkaAddress.host.getOrElse(defaultAddress.host),
272273
akkaAddress.port.getOrElse(defaultAddress.port))
273274
}
274275

275-
override val name: String = actorRef.path.name
276+
override lazy val name: String = actorRef.path.name
277+
278+
private[akka] def init(): Unit = {
279+
// Initialize the lazy vals
280+
actorRef
281+
address
282+
name
283+
}
284+
285+
if (initInConstructor) {
286+
init()
287+
}
276288

277289
override def askWithReply[T: ClassTag](message: Any): T = askWithReply(message, defaultTimeout)
278290

0 commit comments

Comments
 (0)