Skip to content

Commit 441ec34

Browse files
committed
[SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty.
It's been reported that when the number of disks is large and the number of nodes is small, Netty network throughput is low compared with NIO. We suspect the problem is that only a small number of disks are utilized to serve shuffle files at any given point, due to connection reuse. This patch adds a new config parameter to specify the number of concurrent connections between two peer nodes, default to 2. Author: Reynold Xin <[email protected]> Closes #3625 from rxin/SPARK-4740 and squashes the following commits: ad4241a [Reynold Xin] Updated javadoc. f33c72b [Reynold Xin] Code review feedback. 0fefabb [Reynold Xin] Use double check in synchronization. 41dfcb2 [Reynold Xin] Added test case. 9076b4a [Reynold Xin] Fixed two NPEs. 3e1306c [Reynold Xin] Minor style fix. 4f21673 [Reynold Xin] [SPARK-4740] Create multiple concurrent connections between two peer nodes in Netty. (cherry picked from commit 2b9b726) Signed-off-by: Reynold Xin <[email protected]>
1 parent b0d64e5 commit 441ec34

File tree

3 files changed

+180
-46
lines changed

3 files changed

+180
-46
lines changed

network/common/src/main/java/org/apache/spark/network/client/TransportClientFactory.java

Lines changed: 85 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import java.net.InetSocketAddress;
2323
import java.net.SocketAddress;
2424
import java.util.List;
25+
import java.util.Random;
2526
import java.util.concurrent.ConcurrentHashMap;
2627
import java.util.concurrent.atomic.AtomicReference;
2728

@@ -42,6 +43,7 @@
4243
import org.apache.spark.network.TransportContext;
4344
import org.apache.spark.network.server.TransportChannelHandler;
4445
import org.apache.spark.network.util.IOMode;
46+
import org.apache.spark.network.util.JavaUtils;
4547
import org.apache.spark.network.util.NettyUtils;
4648
import org.apache.spark.network.util.TransportConf;
4749

@@ -56,12 +58,31 @@
5658
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
5759
*/
5860
public class TransportClientFactory implements Closeable {
61+
62+
/** A simple data structure to track the pool of clients between two peer nodes. */
63+
private static class ClientPool {
64+
TransportClient[] clients;
65+
Object[] locks;
66+
67+
public ClientPool(int size) {
68+
clients = new TransportClient[size];
69+
locks = new Object[size];
70+
for (int i = 0; i < size; i++) {
71+
locks[i] = new Object();
72+
}
73+
}
74+
}
75+
5976
private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);
6077

6178
private final TransportContext context;
6279
private final TransportConf conf;
6380
private final List<TransportClientBootstrap> clientBootstraps;
64-
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
81+
private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;
82+
83+
/** Random number generator for picking connections between peers. */
84+
private final Random rand;
85+
private final int numConnectionsPerPeer;
6586

6687
private final Class<? extends Channel> socketChannelClass;
6788
private EventLoopGroup workerGroup;
@@ -73,7 +94,9 @@ public TransportClientFactory(
7394
this.context = Preconditions.checkNotNull(context);
7495
this.conf = context.getConf();
7596
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
76-
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
97+
this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
98+
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
99+
this.rand = new Random();
77100

78101
IOMode ioMode = IOMode.valueOf(conf.ioMode());
79102
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
@@ -84,10 +107,14 @@ public TransportClientFactory(
84107
}
85108

86109
/**
87-
* Create a new {@link TransportClient} connecting to the given remote host / port. This will
88-
* reuse TransportClients if they are still active and are for the same remote address. Prior
89-
* to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
90-
* that are registered with this factory.
110+
* Create a {@link TransportClient} connecting to the given remote host / port.
111+
*
112+
* We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
113+
* and randomly picks one to use. If no client was previously created in the randomly selected
114+
* spot, this function creates a new client and places it there.
115+
*
116+
* Prior to the creation of a new TransportClient, we will execute all
117+
* {@link TransportClientBootstrap}s that are registered with this factory.
91118
*
92119
* This blocks until a connection is successfully established and fully bootstrapped.
93120
*
@@ -97,23 +124,48 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
97124
// Get connection from the connection pool first.
98125
// If it is not found or not active, create a new one.
99126
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
100-
TransportClient cachedClient = connectionPool.get(address);
101-
if (cachedClient != null) {
102-
if (cachedClient.isActive()) {
103-
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
104-
return cachedClient;
105-
} else {
106-
logger.info("Found inactive connection to {}, closing it.", address);
107-
connectionPool.remove(address, cachedClient); // Remove inactive clients.
127+
128+
// Create the ClientPool if we don't have it yet.
129+
ClientPool clientPool = connectionPool.get(address);
130+
if (clientPool == null) {
131+
connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
132+
clientPool = connectionPool.get(address);
133+
}
134+
135+
int clientIndex = rand.nextInt(numConnectionsPerPeer);
136+
TransportClient cachedClient = clientPool.clients[clientIndex];
137+
138+
if (cachedClient != null && cachedClient.isActive()) {
139+
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
140+
return cachedClient;
141+
}
142+
143+
// If we reach here, we don't have an existing connection open. Let's create a new one.
144+
// Multiple threads might race here to create new connections. Keep only one of them active.
145+
synchronized (clientPool.locks[clientIndex]) {
146+
cachedClient = clientPool.clients[clientIndex];
147+
148+
if (cachedClient != null) {
149+
if (cachedClient.isActive()) {
150+
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
151+
return cachedClient;
152+
} else {
153+
logger.info("Found inactive connection to {}, creating a new one.", address);
154+
}
108155
}
156+
clientPool.clients[clientIndex] = createClient(address);
157+
return clientPool.clients[clientIndex];
109158
}
159+
}
110160

161+
/** Create a completely new {@link TransportClient} to the remote address. */
162+
private TransportClient createClient(InetSocketAddress address) throws IOException {
111163
logger.debug("Creating new connection to " + address);
112164

113165
Bootstrap bootstrap = new Bootstrap();
114166
bootstrap.group(workerGroup)
115167
.channel(socketChannelClass)
116-
// Disable Nagle's Algorithm since we don't want packets to wait
168+
// Disable Nagle's Algorithm since we don't want packets to wait
117169
.option(ChannelOption.TCP_NODELAY, true)
118170
.option(ChannelOption.SO_KEEPALIVE, true)
119171
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
@@ -130,7 +182,7 @@ public void initChannel(SocketChannel ch) {
130182
});
131183

132184
// Connect to the remote server
133-
long preConnect = System.currentTimeMillis();
185+
long preConnect = System.nanoTime();
134186
ChannelFuture cf = bootstrap.connect(address);
135187
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
136188
throw new IOException(
@@ -143,43 +195,37 @@ public void initChannel(SocketChannel ch) {
143195
assert client != null : "Channel future completed successfully with null client";
144196

145197
// Execute any client bootstraps synchronously before marking the Client as successful.
146-
long preBootstrap = System.currentTimeMillis();
198+
long preBootstrap = System.nanoTime();
147199
logger.debug("Connection to {} successful, running bootstraps...", address);
148200
try {
149201
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
150202
clientBootstrap.doBootstrap(client);
151203
}
152204
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
153-
long bootstrapTime = System.currentTimeMillis() - preBootstrap;
154-
logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
205+
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
206+
logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
155207
client.close();
156208
throw Throwables.propagate(e);
157209
}
158-
long postBootstrap = System.currentTimeMillis();
159-
160-
// Successful connection & bootstrap -- in the event that two threads raced to create a client,
161-
// use the first one that was put into the connectionPool and close the one we made here.
162-
TransportClient oldClient = connectionPool.putIfAbsent(address, client);
163-
if (oldClient == null) {
164-
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
165-
address, postBootstrap - preConnect, postBootstrap - preBootstrap);
166-
return client;
167-
} else {
168-
logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
169-
postBootstrap - preConnect);
170-
client.close();
171-
return oldClient;
172-
}
210+
long postBootstrap = System.nanoTime();
211+
212+
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
213+
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
214+
215+
return client;
173216
}
174217

175218
/** Close all connections in the connection pool, and shutdown the worker thread pool. */
176219
@Override
177220
public void close() {
178-
for (TransportClient client : connectionPool.values()) {
179-
try {
180-
client.close();
181-
} catch (RuntimeException e) {
182-
logger.warn("Ignoring exception during close", e);
221+
// Go through all clients and close them if they are active.
222+
for (ClientPool clientPool : connectionPool.values()) {
223+
for (int i = 0; i < clientPool.clients.length; i++) {
224+
TransportClient client = clientPool.clients[i];
225+
if (client != null) {
226+
clientPool.clients[i] = null;
227+
JavaUtils.closeQuietly(client);
228+
}
183229
}
184230
}
185231
connectionPool.clear();

network/common/src/main/java/org/apache/spark/network/util/TransportConf.java

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,11 @@ public int connectionTimeoutMs() {
4040
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
4141
}
4242

43+
/** Number of concurrent connections between two nodes for fetching data. **/
44+
public int numConnectionsPerPeer() {
45+
return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
46+
}
47+
4348
/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
4449
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }
4550

network/common/src/test/java/org/apache/spark/network/TransportClientFactorySuite.java

Lines changed: 90 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818
package org.apache.spark.network;
1919

2020
import java.io.IOException;
21-
import java.util.concurrent.TimeoutException;
21+
import java.util.Collections;
22+
import java.util.HashSet;
23+
import java.util.NoSuchElementException;
24+
import java.util.Set;
25+
import java.util.concurrent.atomic.AtomicInteger;
2226

2327
import org.junit.After;
2428
import org.junit.Before;
@@ -32,6 +36,7 @@
3236
import org.apache.spark.network.server.NoOpRpcHandler;
3337
import org.apache.spark.network.server.RpcHandler;
3438
import org.apache.spark.network.server.TransportServer;
39+
import org.apache.spark.network.util.ConfigProvider;
3540
import org.apache.spark.network.util.JavaUtils;
3641
import org.apache.spark.network.util.SystemPropertyConfigProvider;
3742
import org.apache.spark.network.util.TransportConf;
@@ -57,16 +62,94 @@ public void tearDown() {
5762
JavaUtils.closeQuietly(server2);
5863
}
5964

65+
/**
66+
* Request a bunch of clients to a single server to test
67+
* we create up to maxConnections of clients.
68+
*
69+
* If concurrent is true, create multiple threads to create clients in parallel.
70+
*/
71+
private void testClientReuse(final int maxConnections, boolean concurrent)
72+
throws IOException, InterruptedException {
73+
TransportConf conf = new TransportConf(new ConfigProvider() {
74+
@Override
75+
public String get(String name) {
76+
if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
77+
return Integer.toString(maxConnections);
78+
} else {
79+
throw new NoSuchElementException();
80+
}
81+
}
82+
});
83+
84+
RpcHandler rpcHandler = new NoOpRpcHandler();
85+
TransportContext context = new TransportContext(conf, rpcHandler);
86+
final TransportClientFactory factory = context.createClientFactory();
87+
final Set<TransportClient> clients = Collections.synchronizedSet(
88+
new HashSet<TransportClient>());
89+
90+
final AtomicInteger failed = new AtomicInteger();
91+
Thread[] attempts = new Thread[maxConnections * 10];
92+
93+
// Launch a bunch of threads to create new clients.
94+
for (int i = 0; i < attempts.length; i++) {
95+
attempts[i] = new Thread() {
96+
@Override
97+
public void run() {
98+
try {
99+
TransportClient client =
100+
factory.createClient(TestUtils.getLocalHost(), server1.getPort());
101+
assert (client.isActive());
102+
clients.add(client);
103+
} catch (IOException e) {
104+
failed.incrementAndGet();
105+
}
106+
}
107+
};
108+
109+
if (concurrent) {
110+
attempts[i].start();
111+
} else {
112+
attempts[i].run();
113+
}
114+
}
115+
116+
// Wait until all the threads complete.
117+
for (int i = 0; i < attempts.length; i++) {
118+
attempts[i].join();
119+
}
120+
121+
assert(failed.get() == 0);
122+
assert(clients.size() == maxConnections);
123+
124+
for (TransportClient client : clients) {
125+
client.close();
126+
}
127+
}
128+
129+
@Test
130+
public void reuseClientsUpToConfigVariable() throws Exception {
131+
testClientReuse(1, false);
132+
testClientReuse(2, false);
133+
testClientReuse(3, false);
134+
testClientReuse(4, false);
135+
}
136+
60137
@Test
61-
public void createAndReuseBlockClients() throws IOException {
138+
public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
139+
testClientReuse(1, true);
140+
testClientReuse(2, true);
141+
testClientReuse(3, true);
142+
testClientReuse(4, true);
143+
}
144+
145+
@Test
146+
public void returnDifferentClientsForDifferentServers() throws IOException {
62147
TransportClientFactory factory = context.createClientFactory();
63148
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
64-
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
65-
TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
149+
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
66150
assertTrue(c1.isActive());
67-
assertTrue(c3.isActive());
68-
assertTrue(c1 == c2);
69-
assertTrue(c1 != c3);
151+
assertTrue(c2.isActive());
152+
assertTrue(c1 != c2);
70153
factory.close();
71154
}
72155

0 commit comments

Comments
 (0)