Skip to content

Commit 08564ae

Browse files
committed
Add RpcEnvFactory to create RpcEnv
1 parent e5df4ca commit 08564ae

File tree

3 files changed

+25
-27
lines changed

3 files changed

+25
-27
lines changed

core/src/main/scala/org/apache/spark/rpc/RpcEnv.scala

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -111,27 +111,18 @@ private[spark] case class RpcEnvConfig(
111111
securityManager: SecurityManager)
112112

113113
/**
114-
* A RpcEnv implementation must have a companion object with an
115-
* `apply(config: RpcEnvConfig): RpcEnv` method so that it can be created via Reflection.
116-
*
117-
* {{{
118-
* object MyCustomRpcEnv {
119-
* def apply(config: RpcEnvConfig): RpcEnv = {
120-
* ...
121-
* }
122-
* }
123-
* }}}
114+
* A RpcEnv implementation must have a [[RpcEnvFactory]] implementation with an empty constructor
115+
* so that it can be created via Reflection.
124116
*/
125117
private[spark] object RpcEnv {
126118

127-
private def getRpcEnvCompanion(conf: SparkConf): AnyRef = {
119+
private def getRpcEnvFactory(conf: SparkConf): RpcEnvFactory = {
128120
// Add more RpcEnv implementations here
129-
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnv")
121+
val rpcEnvNames = Map("akka" -> "org.apache.spark.rpc.akka.AkkaRpcEnvFactory")
130122
val rpcEnvName = conf.get("spark.rpc", "akka")
131-
val rpcEnvClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
132-
val companion = Class.forName(
133-
rpcEnvClassName + "$", true, Utils.getContextOrSparkClassLoader).getField("MODULE$").get(null)
134-
companion
123+
val rpcEnvFactoryClassName = rpcEnvNames.getOrElse(rpcEnvName.toLowerCase, rpcEnvName)
124+
Class.forName(rpcEnvFactoryClassName, true, Utils.getContextOrSparkClassLoader).
125+
newInstance().asInstanceOf[RpcEnvFactory]
135126
}
136127

137128
def create(
@@ -142,13 +133,20 @@ private[spark] object RpcEnv {
142133
securityManager: SecurityManager): RpcEnv = {
143134
// Using Reflection to create the RpcEnv to avoid to depend on Akka directly
144135
val config = RpcEnvConfig(conf, name, host, port, securityManager)
145-
val companion = getRpcEnvCompanion(conf)
146-
companion.getClass.getMethod("apply", classOf[RpcEnvConfig]).
147-
invoke(companion, config).asInstanceOf[RpcEnv]
136+
getRpcEnvFactory(conf).create(config)
148137
}
149138

150139
}
151140

141+
/**
142+
* A factory class to create the [[RpcEnv]]. It must have an empty constructor so that it can be
143+
* created using Reflection.
144+
*/
145+
private[spark] trait RpcEnvFactory {
146+
147+
def create(config: RpcEnvConfig): RpcEnv
148+
}
149+
152150
/**
153151
* An end point for the RPC that defines what functions to trigger given a message.
154152
*

core/src/main/scala/org/apache/spark/rpc/akka/AkkaRpcEnv.scala

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ import org.apache.spark.util.{ActorLogReceive, AkkaUtils}
4343
* @param conf
4444
* @param boundPort
4545
*/
46-
private[spark] class AkkaRpcEnv private (
46+
private[spark] class AkkaRpcEnv private[akka] (
4747
val actorSystem: ActorSystem, conf: SparkConf, boundPort: Int) extends RpcEnv with Logging {
4848

4949
private val defaultAddress: RpcAddress = {
@@ -250,14 +250,13 @@ private[spark] class AkkaRpcEnv private (
250250
override def toString = s"${getClass.getSimpleName}($actorSystem)"
251251
}
252252

253-
private[spark] object AkkaRpcEnv {
253+
private[spark] class AkkaRpcEnvFactory extends RpcEnvFactory {
254254

255-
def apply(config: RpcEnvConfig): RpcEnv = {
255+
def create(config: RpcEnvConfig): RpcEnv = {
256256
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
257-
config.name, config.host, config.port, config.conf, config.securityManager)
257+
config.name, config.host, config.port, config.conf, config.securityManager)
258258
new AkkaRpcEnv(actorSystem, config.conf, boundPort)
259259
}
260-
261260
}
262261

263262
private[akka] class AkkaRpcEndpointRef(

core/src/test/scala/org/apache/spark/rpc/akka/AkkaRpcEnvSuite.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ import org.apache.spark.{SecurityManager, SparkConf}
2323
class AkkaRpcEnvSuite extends RpcEnvSuite {
2424

2525
override def createRpcEnv(conf: SparkConf, name: String, port: Int): RpcEnv = {
26-
AkkaRpcEnv(RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
26+
new AkkaRpcEnvFactory().create(
27+
RpcEnvConfig(conf, name, "localhost", port, new SecurityManager(conf)))
2728
}
2829

2930
test("setupEndpointRef: systemName, address, endpointName") {
@@ -35,8 +36,8 @@ class AkkaRpcEnvSuite extends RpcEnvSuite {
3536
}
3637
})
3738
val conf = new SparkConf()
38-
val newRpcEnv =
39-
AkkaRpcEnv(RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
39+
val newRpcEnv = new AkkaRpcEnvFactory().create(
40+
RpcEnvConfig(conf, "test", "localhost", 12346, new SecurityManager(conf)))
4041
try {
4142
val newRef = newRpcEnv.setupEndpointRef("local", ref.address, "test_endpoint")
4243
assert("akka.tcp://local@localhost:12345/user/test_endpoint" ===

0 commit comments

Comments
 (0)