Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

Expand Down Expand Up @@ -56,12 +57,28 @@
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {

/** A simple data structure to track the pool of clients between two peer nodes. */
private class ClientPool {
TransportClient[] clients;
Object[] locks;

public ClientPool() {
clients = new TransportClient[numConnectionsPerPeer];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can make this a private static class if we make this a constructor parameter.

locks = new Object[numConnectionsPerPeer];
}
}

private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);

private final TransportContext context;
private final TransportConf conf;
private final List<TransportClientBootstrap> clientBootstraps;
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;

/** Random number generator for picking connections between peers. */
private final Random rand;
private final int numConnectionsPerPeer;

private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
Expand All @@ -73,7 +90,9 @@ public TransportClientFactory(
this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.rand = new Random();

IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
Expand All @@ -97,23 +116,45 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);

// Create the ClientPool if we don't have it yet.
ClientPool clientPool = connectionPool.get(address);
if (clientPool == null) {
clientPool = connectionPool.putIfAbsent(address, new ClientPool());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

putIfAbsent returns the previous value, so this would be null

}

int clientIndex = rand.nextInt(numConnectionsPerPeer);
TransportClient cachedClient = clientPool.clients[clientIndex];
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.
clientPool.clients[clientIndex] = null; // Remove inactive clients.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't this be behind a lock?

}
}

// If we reach here, we don't have an existing connection open. Let's create a new one.
// Multiple threads might race here to create new connections. Let's keep only one of them
// active at anytime.
synchronized (clientPool.locks[clientIndex]) {
if (clientPool.clients[clientIndex] == null || !clientPool.clients[clientIndex].isActive()) {
clientPool.clients[clientIndex] = createClient(address);
}
}

return clientPool.clients[clientIndex];
}

/** Create a completely new {@link TransportClient} to the remote address. */
private TransportClient createClient(InetSocketAddress address) throws IOException {
logger.debug("Creating new connection to " + address);

Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
// Disable Nagle's Algorithm since we don't want packets to wait
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
Expand All @@ -130,7 +171,7 @@ public void initChannel(SocketChannel ch) {
});

// Connect to the remote server
long preConnect = System.currentTimeMillis();
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(
Expand All @@ -143,43 +184,41 @@ public void initChannel(SocketChannel ch) {
assert client != null : "Channel future completed successfully with null client";

// Execute any client bootstraps synchronously before marking the Client as successful.
long preBootstrap = System.currentTimeMillis();
long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTime = System.currentTimeMillis() - preBootstrap;
logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
client.close();
throw Throwables.propagate(e);
}
long postBootstrap = System.currentTimeMillis();

// Successful connection & bootstrap -- in the event that two threads raced to create a client,
// use the first one that was put into the connectionPool and close the one we made here.
TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, postBootstrap - preConnect, postBootstrap - preBootstrap);
return client;
} else {
logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
postBootstrap - preConnect);
client.close();
return oldClient;
}
long postBootstrap = System.nanoTime();

logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 spaces for indent


return client;
}

/** Close all connections in the connection pool, and shutdown the worker thread pool. */
@Override
public void close() {
for (TransportClient client : connectionPool.values()) {
try {
client.close();
} catch (RuntimeException e) {
logger.warn("Ignoring exception during close", e);
// Go through all clients and close them if they are active.
for (ClientPool clientPool : connectionPool.values()) {
for (int i = 0; i < clientPool.clients.length; i++) {
TransportClient client = clientPool.clients[i];
if (client != null) {
clientPool.clients[i] = null;
try {
client.close();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Q: Can we use JavaUtils.closeQuietly(client) here?

} catch (RuntimeException e) {
logger.warn("Ignoring exception during close", e);
}
}
}
}
connectionPool.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
}

/** Number of concurrent connections between two nodes for fetching data. **/
public int numConnectionsPerPeer() {
return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
}

/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }

Expand Down