Skip to content

Commit 4f21673

Browse files
committed
[SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty.
1 parent 6f61e1f commit 4f21673

File tree

2 files changed

+78
-35
lines changed

2 files changed

+78
-35
lines changed

network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 73 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.net.InetSocketAddress;
2323
import java.net.SocketAddress;
2424
import java.util.List;
25+
import java.util.Random;
2526
import java.util.concurrent.ConcurrentHashMap;
2627
import java.util.concurrent.atomic.AtomicReference;
2728

@@ -56,12 +57,27 @@
5657
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
5758
*/
5859
public class TransportClientFactory implements Closeable {
60+
61+
private class ClientPool {
62+
TransportClient[] clients;
63+
Object[] locks;
64+
65+
public ClientPool() {
66+
clients = new TransportClient[numConnectionsPerPeer];
67+
locks = new Object[numConnectionsPerPeer];
68+
}
69+
}
70+
5971
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
6072

6173
private final TransportContext context;
6274
private final TransportConf conf;
6375
private final List<TransportClientBootstrap> clientBootstraps;
64-
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
76+
private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
77+
78+
/** Random number generator for picking connections between peers. */
79+
private final Random rand;
80+
private final int numConnectionsPerPeer;
6581

6682
private final Class<? extends Channel> socketChannelClass;
6783
private EventLoopGroup workerGroup;
@@ -73,7 +89,9 @@ public TransportClientFactory(
7389
this.context = Preconditions.checkNotNull(context);
7490
this.conf = context.getConf();
7591
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
76-
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
92+
this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
93+
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
94+
this.rand = new Random();
7795

7896
IOMode ioMode = IOMode.valueOf(conf.ioMode());
7997
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -97,27 +115,49 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
97115
// Get connection from the connection pool first.
98116
// If it is not found or not active, create a new one.
99117
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
100-
TransportClient cachedClient = connectionPool.get(address);
118+
119+
// Create the ClientPool if we don't have it yet.
120+
ClientPool clientPool = connectionPool.get(address);
121+
if (clientPool == null) {
122+
clientPool = connectionPool.putIfAbsent(address, new ClientPool());
123+
}
124+
125+
int clientIndex = rand.nextInt(numConnectionsPerPeer);
126+
TransportClient cachedClient = clientPool.clients[clientIndex];
101127
if (cachedClient != null) {
102128
if (cachedClient.isActive()) {
103129
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
104130
return cachedClient;
105131
} else {
106132
logger.info("Found inactive connection to {}, closing it.", address);
107-
connectionPool.remove(address, cachedClient); // Remove inactive clients.
133+
clientPool.clients[clientIndex] = null; // Remove inactive clients.
108134
}
109135
}
110136

137+
// If we reach here, we don't have an existing connection open. Let's create a new one.
138+
// Multiple threads might race here to create new connections. Let's keep only one of them
139+
// active at anytime.
140+
synchronized (clientPool.locks[clientIndex]) {
141+
if (clientPool.clients[clientIndex] == null || !clientPool.clients[clientIndex].isActive()) {
142+
clientPool.clients[clientIndex] = createClient(address);
143+
}
144+
}
145+
146+
return clientPool.clients[clientIndex];
147+
}
148+
149+
/** Create a completely new {@link TransportClient} to the remote address. */
150+
private TransportClient createClient(InetSocketAddress address) throws IOException {
111151
logger.debug("Creating new connection to " + address);
112152

113153
Bootstrap bootstrap = new Bootstrap();
114154
bootstrap.group(workerGroup)
115-
.channel(socketChannelClass)
116-
// Disable Nagle's Algorithm since we don't want packets to wait
117-
.option(ChannelOption.TCP_NODELAY, true)
118-
.option(ChannelOption.SO_KEEPALIVE, true)
119-
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
120-
.option(ChannelOption.ALLOCATOR, pooledAllocator);
155+
.channel(socketChannelClass)
156+
// Disable Nagle's Algorithm since we don't want packets to wait
157+
.option(ChannelOption.TCP_NODELAY, true)
158+
.option(ChannelOption.SO_KEEPALIVE, true)
159+
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
160+
.option(ChannelOption.ALLOCATOR, pooledAllocator);
121161

122162
final AtomicReference<TransportClient> clientRef = new AtomicReference<TransportClient>();
123163

@@ -130,11 +170,11 @@ public void initChannel(SocketChannel ch) {
130170
});
131171

132172
// Connect to the remote server
133-
long preConnect = System.currentTimeMillis();
173+
long preConnect = System.nanoTime();
134174
ChannelFuture cf = bootstrap.connect(address);
135175
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
136176
throw new IOException(
137-
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
177+
String.format("Connecting to %s timed out (%s ms)", address, conf.connectionTimeoutMs()));
138178
} else if (cf.cause() != null) {
139179
throw new IOException(String.format("Failed to connect to %s", address), cf.cause());
140180
}
@@ -143,43 +183,41 @@ public void initChannel(SocketChannel ch) {
143183
assert client != null : "Channel future completed successfully with null client";
144184

145185
// Execute any client bootstraps synchronously before marking the Client as successful.
146-
long preBootstrap = System.currentTimeMillis();
186+
long preBootstrap = System.nanoTime();
147187
logger.debug("Connection to {} successful, running bootstraps...", address);
148188
try {
149189
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
150190
clientBootstrap.doBootstrap(client);
151191
}
152192
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
153-
long bootstrapTime = System.currentTimeMillis() - preBootstrap;
154-
logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
193+
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
194+
logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
155195
client.close();
156196
throw Throwables.propagate(e);
157197
}
158-
long postBootstrap = System.currentTimeMillis();
159-
160-
// Successful connection & bootstrap -- in the event that two threads raced to create a client,
161-
// use the first one that was put into the connectionPool and close the one we made here.
162-
TransportClient oldClient = connectionPool.putIfAbsent(address, client);
163-
if (oldClient == null) {
164-
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
165-
address, postBootstrap - preConnect, postBootstrap - preBootstrap);
166-
return client;
167-
} else {
168-
logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
169-
postBootstrap - preConnect);
170-
client.close();
171-
return oldClient;
172-
}
198+
long postBootstrap = System.nanoTime();
199+
200+
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
201+
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
202+
203+
return client;
173204
}
174205

175206
/** Close all connections in the connection pool, and shutdown the worker thread pool. */
176207
@Override
177208
public void close() {
178-
for (TransportClient client : connectionPool.values()) {
179-
try {
180-
client.close();
181-
} catch (RuntimeException e) {
182-
logger.warn("Ignoring exception during close", e);
209+
// Go through all clients and close them if they are active.
210+
for (ClientPool clientPool : connectionPool.values()) {
211+
for (int i = 0; i < clientPool.clients.length; i++) {
212+
TransportClient client = clientPool.clients[i];
213+
if (client != null) {
214+
clientPool.clients[i] = null;
215+
try {
216+
client.close();
217+
} catch (RuntimeException e) {
218+
logger.warn("Ignoring exception during close", e);
219+
}
220+
}
183221
}
184222
}
185223
connectionPool.clear();

network/common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ public int connectionTimeoutMs() {
4040
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
4141
}
4242

43+
/** Number of concurrent connections between two nodes for fetching data. **/
44+
public int numConnectionsPerPeer() {
45+
return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
46+
}
47+
4348
/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
4449
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
4550

0 commit comments

Comments
 (0)