Skip to content

Commit 66e5a24

Browse files
committed
[SPARK-4238] [Core] Perform network-level retry of shuffle file fetches
This adds a RetryingBlockFetcher to the NettyBlockTransferService which is wrapped around our typical OneForOneBlockFetcher, adding retry logic in the event of an IOException. This sort of retry allows us to avoid marking an entire executor as failed due to garbage collection or high network load. TODO: - [ ] unit tests - [ ] put in ExternalShuffleClient too
1 parent 4c42986 commit 66e5a24

File tree

14 files changed

+320
-31
lines changed

14 files changed

+320
-31
lines changed

core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCal
2727
import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
2828
import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
2929
import org.apache.spark.network.server._
30-
import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
30+
import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
3131
import org.apache.spark.serializer.JavaSerializer
3232
import org.apache.spark.storage.{BlockId, StorageLevel}
3333
import org.apache.spark.util.Utils
@@ -71,9 +71,22 @@ class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManage
7171
listener: BlockFetchingListener): Unit = {
7272
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
7373
try {
74-
val client = clientFactory.createClient(host, port)
75-
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
76-
.start(OpenBlocks(blockIds.map(BlockId.apply)))
74+
val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
75+
override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
76+
val client = clientFactory.createClient(host, port)
77+
new OneForOneBlockFetcher(client, blockIds.toArray, listener)
78+
.start(OpenBlocks(blockIds.map(BlockId.apply)))
79+
}
80+
}
81+
82+
val maxRetries = transportConf.maxIORetries()
83+
if (maxRetries > 0) {
84+
// Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
85+
// a bug in this code.
86+
new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
87+
} else {
88+
blockFetchStarter.createAndStart(blockIds, listener)
89+
}
7790
} catch {
7891
case e: Exception =>
7992
logError("Exception while beginning fetchBlocks", e)

network/common/src/main/java/org/apache/spark/network/client/TransportClient.java

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
package org.apache.spark.network.client;
1919

2020
import java.io.Closeable;
21+
import java.io.IOException;
2122
import java.util.UUID;
23+
import java.util.concurrent.ExecutionException;
2224
import java.util.concurrent.TimeUnit;
2325

2426
import com.google.common.base.Objects;
@@ -116,8 +118,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
116118
serverAddr, future.cause());
117119
logger.error(errorMsg, future.cause());
118120
handler.removeFetchRequest(streamChunkId);
119-
callback.onFailure(chunkIndex, new RuntimeException(errorMsg, future.cause()));
120121
channel.close();
122+
try {
123+
callback.onFailure(chunkIndex, new IOException(errorMsg, future.cause()));
124+
} catch (Exception e) {
125+
logger.error("Uncaught exception in RPC response callback handler!", e);
126+
}
121127
}
122128
}
123129
});
@@ -147,8 +153,12 @@ public void operationComplete(ChannelFuture future) throws Exception {
147153
serverAddr, future.cause());
148154
logger.error(errorMsg, future.cause());
149155
handler.removeRpcRequest(requestId);
150-
callback.onFailure(new RuntimeException(errorMsg, future.cause()));
151156
channel.close();
157+
try {
158+
callback.onFailure(new IOException(errorMsg, future.cause()));
159+
} catch (Exception e) {
160+
logger.error("Uncaught exception in RPC response callback handler!", e);
161+
}
152162
}
153163
}
154164
});
@@ -175,6 +185,8 @@ public void onFailure(Throwable e) {
175185

176186
try {
177187
return result.get(timeoutMs, TimeUnit.MILLISECONDS);
188+
} catch (ExecutionException e) {
189+
throw Throwables.propagate(e.getCause());
178190
} catch (Exception e) {
179191
throw Throwables.propagate(e);
180192
}

network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,12 @@
1818
package org.apache.spark.network.client;
1919

2020
import java.io.Closeable;
21+
import java.io.IOException;
2122
import java.lang.reflect.Field;
2223
import java.net.InetSocketAddress;
2324
import java.net.SocketAddress;
2425
import java.util.List;
2526
import java.util.concurrent.ConcurrentHashMap;
26-
import java.util.concurrent.TimeoutException;
2727
import java.util.concurrent.atomic.AtomicReference;
2828

2929
import com.google.common.base.Preconditions;
@@ -44,7 +44,6 @@
4444
import org.apache.spark.network.TransportContext;
4545
import org.apache.spark.network.server.TransportChannelHandler;
4646
import org.apache.spark.network.util.IOMode;
47-
import org.apache.spark.network.util.JavaUtils;
4847
import org.apache.spark.network.util.NettyUtils;
4948
import org.apache.spark.network.util.TransportConf;
5049

@@ -93,15 +92,17 @@ public TransportClientFactory(
9392
*
9493
* Concurrency: This method is safe to call from multiple threads.
9594
*/
96-
public TransportClient createClient(String remoteHost, int remotePort) {
95+
public TransportClient createClient(String remoteHost, int remotePort) throws IOException {
9796
// Get connection from the connection pool first.
9897
// If it is not found or not active, create a new one.
9998
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
10099
TransportClient cachedClient = connectionPool.get(address);
101100
if (cachedClient != null) {
102101
if (cachedClient.isActive()) {
102+
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
103103
return cachedClient;
104104
} else {
105+
logger.info("Found inactive connection to {}, closing it.", address);
105106
connectionPool.remove(address, cachedClient); // Remove inactive clients.
106107
}
107108
}
@@ -133,10 +134,10 @@ public void initChannel(SocketChannel ch) {
133134
long preConnect = System.currentTimeMillis();
134135
ChannelFuture cf = bootstrap.connect(address);
135136
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
136-
throw new RuntimeException(
137+
throw new IOException(
137138
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
138139
} else if (cf.cause() != null) {
139-
throw new RuntimeException(String.format("Failed to connect to %s", address), cf.cause());
140+
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
140141
}
141142

142143
TransportClient client = clientRef.get();
@@ -198,7 +199,7 @@ public void close() {
198199
*/
199200
private PooledByteBufAllocator createPooledByteBufAllocator() {
200201
return new PooledByteBufAllocator(
201-
PlatformDependent.directBufferPreferred(),
202+
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred(),
202203
getPrivateStaticField("DEFAULT_NUM_HEAP_ARENA"),
203204
getPrivateStaticField("DEFAULT_NUM_DIRECT_ARENA"),
204205
getPrivateStaticField("DEFAULT_PAGE_SIZE"),

network/common/src/main/java/org/apache/spark/network/client/TransportResponseHandler.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.client;
1919

20+
import java.io.IOException;
2021
import java.util.Map;
2122
import java.util.concurrent.ConcurrentHashMap;
2223

@@ -94,7 +95,7 @@ public void channelUnregistered() {
9495
String remoteAddress = NettyUtils.getRemoteAddress(channel);
9596
logger.error("Still have {} requests outstanding when connection from {} is closed",
9697
numOutstandingRequests(), remoteAddress);
97-
failOutstandingRequests(new RuntimeException("Connection from " + remoteAddress + " closed"));
98+
failOutstandingRequests(new IOException("Connection from " + remoteAddress + " closed"));
9899
}
99100
}
100101

network/common/src/main/java/org/apache/spark/network/protocol/MessageEncoder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ public void encode(ChannelHandlerContext ctx, Message in, List<Object> out) {
6666
// All messages have the frame length, message type, and message itself.
6767
int headerLength = 8 + msgType.encodedLength() + in.encodedLength();
6868
long frameLength = headerLength + bodyLength;
69-
ByteBuf header = ctx.alloc().buffer(headerLength);
69+
ByteBuf header = ctx.alloc().heapBuffer(headerLength);
7070
header.writeLong(frameLength);
7171
msgType.encode(header);
7272
in.encode(header);

network/common/src/main/java/org/apache/spark/network/server/TransportServer.java

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
import io.netty.channel.ChannelOption;
2929
import io.netty.channel.EventLoopGroup;
3030
import io.netty.channel.socket.SocketChannel;
31+
import io.netty.util.internal.PlatformDependent;
3132
import org.slf4j.Logger;
3233
import org.slf4j.LoggerFactory;
3334

@@ -71,11 +72,14 @@ private void init(int portToBind) {
7172
NettyUtils.createEventLoop(ioMode, conf.serverThreads(), "shuffle-server");
7273
EventLoopGroup workerGroup = bossGroup;
7374

75+
PooledByteBufAllocator allocator = new PooledByteBufAllocator(
76+
conf.preferDirectBufs() && PlatformDependent.directBufferPreferred());
77+
7478
bootstrap = new ServerBootstrap()
7579
.group(bossGroup, workerGroup)
7680
.channel(NettyUtils.getServerChannelClass(ioMode))
77-
.option(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT)
78-
.childOption(ChannelOption.ALLOCATOR, PooledByteBufAllocator.DEFAULT);
81+
.option(ChannelOption.ALLOCATOR, allocator)
82+
.childOption(ChannelOption.ALLOCATOR, allocator);
7983

8084
if (conf.backLog() > 0) {
8185
bootstrap.option(ChannelOption.SO_BACKLOG, conf.backLog());

network/common/src/main/java/org/apache/spark/network/util/NettyUtils.java

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,17 @@
3737
* Utilities for creating various Netty constructs based on whether we're using EPOLL or NIO.
3838
*/
3939
public class NettyUtils {
40-
/** Creates a Netty EventLoopGroup based on the IOMode. */
41-
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
42-
43-
ThreadFactory threadFactory = new ThreadFactoryBuilder()
40+
/** Creates a new ThreadFactory which prefixes each thread with the given name. */
41+
public static ThreadFactory createThreadFactory(String threadPoolPrefix) {
42+
return new ThreadFactoryBuilder()
4443
.setDaemon(true)
45-
.setNameFormat(threadPrefix + "-%d")
44+
.setNameFormat(threadPoolPrefix + "-%d")
4645
.build();
46+
}
47+
48+
/** Creates a Netty EventLoopGroup based on the IOMode. */
49+
public static EventLoopGroup createEventLoop(IOMode mode, int numThreads, String threadPrefix) {
50+
ThreadFactory threadFactory = createThreadFactory(threadPrefix);
4751

4852
switch (mode) {
4953
case NIO:

network/common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,11 @@ public TransportConf(ConfigProvider conf) {
3030
/** IO mode: nio or epoll */
3131
public String ioMode() { return conf.get("spark.shuffle.io.mode", "NIO").toUpperCase(); }
3232

33+
/** If true, we will prefer allocating off-heap byte buffers within Netty. */
34+
public boolean preferDirectBufs() {
35+
return conf.getBoolean("spark.shuffle.io.preferDirectBufs", true);
36+
}
37+
3338
/** Connect timeout in secs. Default 120 secs. */
3439
public int connectionTimeoutMs() {
3540
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
@@ -58,4 +63,16 @@ public int connectionTimeoutMs() {
5863

5964
/** Timeout for a single round trip of SASL token exchange, in milliseconds. */
6065
public int saslRTTimeout() { return conf.getInt("spark.shuffle.sasl.timeout", 30000); }
66+
67+
/**
68+
* Max number of times we will try IO exceptions (such as connection timeouts) per request.
69+
* If set to 0, we will not do any retries.
70+
*/
71+
public int maxIORetries() { return conf.getInt("spark.shuffle.io.maxIORetries", 3); }
72+
73+
/**
74+
* Time (in milliseconds) that we will wait in order to perform a retry after an IOException.
75+
* Only relevant if maxIORetries > 0.
76+
*/
77+
public int ioRetryWaitTime() { return conf.getInt("spark.shuffle.io.ioRetryWaitTime", 5000); }
6178
}

network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public void tearDown() {
5757
}
5858

5959
@Test
60-
public void createAndReuseBlockClients() throws TimeoutException {
60+
public void createAndReuseBlockClients() throws Exception {
6161
TransportClientFactory factory = context.createClientFactory();
6262
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
6363
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
@@ -88,7 +88,7 @@ public void neverReturnInactiveClients() throws Exception {
8888
}
8989

9090
@Test
91-
public void closeBlockClientsWithFactory() throws TimeoutException {
91+
public void closeBlockClientsWithFactory() throws Exception {
9292
TransportClientFactory factory = context.createClientFactory();
9393
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
9494
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());

network/shuffle/src/main/java/org/apache/spark/network/shuffle/ExternalShuffleClient.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.network.shuffle;
1919

20+
import java.io.IOException;
2021
import java.util.List;
2122

2223
import com.google.common.collect.Lists;
@@ -108,7 +109,7 @@ public void registerWithShuffleServer(
108109
String host,
109110
int port,
110111
String execId,
111-
ExecutorShuffleInfo executorInfo) {
112+
ExecutorShuffleInfo executorInfo) throws IOException {
112113
assert appId != null : "Called before init()";
113114
TransportClient client = clientFactory.createClient(host, port);
114115
byte[] registerExecutorMessage =

0 commit comments

Comments
 (0)