Skip to content

Commit 151b3c5

Browse files
committed
Add docs, timeout config, better failure handling
1 parent f6177d7 commit 151b3c5

File tree

8 files changed

+61
-24
lines changed

8 files changed

+61
-24
lines changed

core/src/main/scala/org/apache/spark/SparkConf.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,8 +218,8 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
218218
getAll.filter { case (k, _) => isAkkaConf(k) }
219219

220220
/**
221-
* Returns the Spark application id, valid in the Driver after TaskScheduler registration in the
222-
* driver and from the start in the Executor.
221+
* Returns the Spark application id, valid in the Driver after TaskScheduler registration and
222+
* from the start in the Executor.
223223
*/
224224
def getAppId: String = get("spark.app.id")
225225

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717

1818
package org.apache.spark.network.netty
1919

20+
import org.apache.spark.network.util.TransportConf
21+
2022
import scala.collection.JavaConversions._
2123
import scala.concurrent.{Future, Promise}
2224

@@ -41,6 +43,7 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
4143
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
4244
private val serializer = new JavaSerializer(conf)
4345
private val authEnabled = securityManager.isAuthenticationEnabled()
46+
private val transportConf = SparkTransportConf.fromSparkConf(conf)
4447

4548
private[this] var transportContext: TransportContext = _
4649
private[this] var server: TransportServer = _
@@ -53,10 +56,10 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
5356
(nettyRpcHandler, None)
5457
} else {
5558
(new SaslRpcHandler(nettyRpcHandler, securityManager),
56-
Some(new SaslBootstrap(conf.getAppId, securityManager)))
59+
Some(new SaslBootstrap(transportConf, conf.getAppId, securityManager)))
5760
}
5861
}
59-
transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
62+
transportContext = new TransportContext(transportConf, rpcHandler)
6063
clientFactory = transportContext.createClientFactory(bootstrap.toList)
6164
server = transportContext.createServer()
6265
logInfo("Server created on " + server.getPort)

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,8 @@ private[spark] class BlockManager(
176176
* the appId may not be known at BlockManager instantiation time (in particular for the driver,
177177
* where it is only learned after registration with the TaskScheduler).
178178
*
179-
* This method initializes the BlockTransferService and ShuffleClient registers with the
180-
* BlockManagerMaster, starts theBlockManagerWorker actor, and registers with a local shuffle
179+
* This method initializes the BlockTransferService and ShuffleClient, registers with the
180+
* BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
181181
* service if configured.
182182
*/
183183
def initialize(appId: String): Unit = {

network/common/src/main/java/org/apache/spark/network/client/TransportClientBootstrap.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,16 @@
1717

1818
package org.apache.spark.network.client;
1919

20+
/**
21+
* A bootstrap which is executed on a TransportClient before it is returned to the user.
22+
* This enables an initial exchange of information (e.g., SASL authentication tokens) on a once-per-
23+
* connection basis.
24+
*
25+
* Since connections (and TransportClients) are reused as much as possible, it is generally
26+
* reasonable to perform an expensive bootstrapping operation, as they often share a lifespan with
27+
* the JVM itself.
28+
*/
2029
public interface TransportClientBootstrap {
21-
public void doBootstrap(TransportClient client);
30+
/** Performs the bootstrapping operation, throwing an exception on failure. */
31+
public void doBootstrap(TransportClient client) throws RuntimeException;
2232
}

network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
import java.util.concurrent.TimeoutException;
2727
import java.util.concurrent.atomic.AtomicReference;
2828

29+
import com.google.common.base.Preconditions;
30+
import com.google.common.base.Throwables;
2931
import com.google.common.collect.Lists;
3032
import io.netty.bootstrap.Bootstrap;
3133
import io.netty.buffer.PooledByteBufAllocator;
@@ -42,6 +44,7 @@
4244
import org.apache.spark.network.TransportContext;
4345
import org.apache.spark.network.server.TransportChannelHandler;
4446
import org.apache.spark.network.util.IOMode;
47+
import org.apache.spark.network.util.JavaUtils;
4548
import org.apache.spark.network.util.NettyUtils;
4649
import org.apache.spark.network.util.TransportConf;
4750

@@ -60,19 +63,19 @@ public class TransportClientFactory implements Closeable {
6063

6164
private final TransportContext context;
6265
private final TransportConf conf;
63-
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
6466
private final List<TransportClientBootstrap> clientBootstraps;
67+
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
6568

6669
private final Class<? extends Channel> socketChannelClass;
6770
private EventLoopGroup workerGroup;
6871

6972
public TransportClientFactory(
7073
TransportContext context,
7174
List<TransportClientBootstrap> clientBootstraps) {
72-
this.context = context;
75+
this.context = Preconditions.checkNotNull(context);
7376
this.conf = context.getConf();
77+
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
7478
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
75-
this.clientBootstraps = Lists.newArrayList(clientBootstraps);
7679

7780
IOMode ioMode = IOMode.valueOf(conf.ioMode());
7881
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -138,10 +141,16 @@ public void initChannel(SocketChannel ch) {
138141
TransportClient client = clientRef.get();
139142
assert client != null : "Channel future completed successfully with null client";
140143

141-
logger.debug("Connection to {} successful, running bootstraps...", address);
142144
// Execute any client bootstraps synchronously before marking the Client as successful.
143-
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
144-
clientBootstrap.doBootstrap(client);
145+
logger.debug("Connection to {} successful, running bootstraps...", address);
146+
try {
147+
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
148+
clientBootstrap.doBootstrap(client);
149+
}
150+
} catch (Exception e) { // catch Exception as the bootstrap may be written in Scala
151+
logger.error("Exception while bootstrapping client", e);
152+
client.close();
153+
throw Throwables.propagate(e);
145154
}
146155

147156
logger.debug("Successfully executed {} bootstraps for {}", clientBootstraps.size(), address);

network/common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,4 +55,7 @@ public int connectionTimeoutMs() {
5555

5656
/** Send buffer size (SO_SNDBUF). */
5757
public int sendBuf() { return conf.getInt("spark.shuffle.io.sendBuffer", -1); }
58+
59+
/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
60+
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
5861
}

network/shuffle/src/main/java/org/apache/spark/network/sasl/SaslBootstrap.java

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,31 +24,43 @@
2424

2525
import org.apache.spark.network.client.TransportClient;
2626
import org.apache.spark.network.client.TransportClientBootstrap;
27+
import org.apache.spark.network.util.TransportConf;
2728

29+
/**
30+
* Bootstraps a {@link TransportClient} by performing SASL authentication on the connection. The
31+
* server should be setup with a {@link SaslRpcHandler} with matching keys for the given appId.
32+
*/
2833
public class SaslBootstrap implements TransportClientBootstrap {
2934
private final Logger logger = LoggerFactory.getLogger(SaslBootstrap.class);
3035

31-
private final String secretKeyId;
36+
private final TransportConf conf;
37+
private final String appId;
3238
private final SecretKeyHolder secretKeyHolder;
3339

34-
public SaslBootstrap(String secretKeyId, SecretKeyHolder secretKeyHolder) {
35-
this.secretKeyId = secretKeyId;
40+
public SaslBootstrap(TransportConf conf, String appId, SecretKeyHolder secretKeyHolder) {
41+
this.conf = conf;
42+
this.appId = appId;
3643
this.secretKeyHolder = secretKeyHolder;
3744
}
3845

46+
/**
47+
* Performs SASL authentication by sending a token, and then proceeding with the SASL
48+
* challenge-response tokens until we either successfully authenticate or throw an exception
49+
* due to mismatch.
50+
*/
3951
public void doBootstrap(TransportClient client) {
40-
SparkSaslClient saslClient = new SparkSaslClient(secretKeyId, secretKeyHolder);
52+
SparkSaslClient saslClient = new SparkSaslClient(appId, secretKeyHolder);
4153
try {
4254
byte[] payload = saslClient.firstToken();
4355

4456
while (!saslClient.isComplete()) {
45-
SaslMessage msg = new SaslMessage(secretKeyId, payload);
46-
logger.info("Sending msg {} {}", secretKeyId, payload.length);
57+
SaslMessage msg = new SaslMessage(appId, payload);
58+
logger.info("Sending msg {} {}", appId, payload.length);
4759
ByteBuf buf = Unpooled.buffer(msg.encodedLength());
4860
msg.encode(buf);
4961

50-
byte[] response = client.sendRpcSync(buf.array(), 300000);
51-
logger.info("Got response {} {}", secretKeyId, response.length);
62+
byte[] response = client.sendRpcSync(buf.array(), conf.saslRTTimeout());
63+
logger.info("Got response {} {}", appId, response.length);
5264
payload = saslClient.response(response);
5365
}
5466
} finally {

network/shuffle/src/test/java/org/apache/spark/network/sasl/SaslIntegrationSuite.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ public void afterEach() {
9696
public void testGoodClient() {
9797
clientFactory = context.createClientFactory(
9898
Lists.<TransportClientBootstrap>newArrayList(
99-
new SaslBootstrap("app-id", new TestSecretKeyHolder("good-key"))));
99+
new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("good-key"))));
100100

101101
TransportClient client = clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());
102102
String msg = "Hello, World!";
@@ -108,7 +108,7 @@ public void testGoodClient() {
108108
public void testBadClient() {
109109
clientFactory = context.createClientFactory(
110110
Lists.<TransportClientBootstrap>newArrayList(
111-
new SaslBootstrap("app-id", new TestSecretKeyHolder("bad-key"))));
111+
new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("bad-key"))));
112112

113113
try {
114114
// Bootstrap should fail on startup.
@@ -146,7 +146,7 @@ public void testNoSaslServer() {
146146
TransportContext context = new TransportContext(conf, handler);
147147
clientFactory = context.createClientFactory(
148148
Lists.<TransportClientBootstrap>newArrayList(
149-
new SaslBootstrap("app-id", new TestSecretKeyHolder("key"))));
149+
new SaslBootstrap(conf, "app-id", new TestSecretKeyHolder("key"))));
150150
TransportServer server = context.createServer();
151151
try {
152152
clientFactory.createClient(TestUtils.getLocalHost(), server.getPort());

0 commit comments

Comments
 (0)