Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.List;
import java.util.Random;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicReference;

Expand All @@ -42,6 +43,7 @@
import org.apache.spark.network.TransportContext;
import org.apache.spark.network.server.TransportChannelHandler;
import org.apache.spark.network.util.IOMode;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.NettyUtils;
import org.apache.spark.network.util.TransportConf;

Expand All @@ -56,12 +58,31 @@
* TransportClient, all given {@link TransportClientBootstrap}s will be run.
*/
public class TransportClientFactory implements Closeable {

/** A simple data structure to track the pool of clients between two peer nodes. */
private static class ClientPool {
TransportClient[] clients;
Object[] locks;

public ClientPool(int size) {
clients = new TransportClient[size];
locks = new Object[size];
for (int i = 0; i < size; i++) {
locks[i] = new Object();
}
}
}

private final Logger logger = LoggerFactory.getLogger(TransportClientFactory.class);

private final TransportContext context;
private final TransportConf conf;
private final List<TransportClientBootstrap> clientBootstraps;
private final ConcurrentHashMap<SocketAddress, TransportClient> connectionPool;
private final ConcurrentHashMap<SocketAddress, ClientPool> connectionPool;

/** Random number generator for picking connections between peers. */
private final Random rand;
private final int numConnectionsPerPeer;

private final Class<? extends Channel> socketChannelClass;
private EventLoopGroup workerGroup;
Expand All @@ -73,7 +94,9 @@ public TransportClientFactory(
this.context = Preconditions.checkNotNull(context);
this.conf = context.getConf();
this.clientBootstraps = Lists.newArrayList(Preconditions.checkNotNull(clientBootstraps));
this.connectionPool = new ConcurrentHashMap<SocketAddress, TransportClient>();
this.connectionPool = new ConcurrentHashMap<SocketAddress, ClientPool>();
this.numConnectionsPerPeer = conf.numConnectionsPerPeer();
this.rand = new Random();

IOMode ioMode = IOMode.valueOf(conf.ioMode());
this.socketChannelClass = NettyUtils.getClientChannelClass(ioMode);
Expand All @@ -84,10 +107,14 @@ public TransportClientFactory(
}

/**
* Create a new {@link TransportClient} connecting to the given remote host / port. This will
* reuse TransportClients if they are still active and are for the same remote address. Prior
* to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
* that are registered with this factory.
* Create a {@link TransportClient} connecting to the given remote host / port.
*
* We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
* and randomly picks one to use. If no client was previously created in the randomly selected
* spot, this function creates a new client and places it there.
*
* Prior to the creation of a new TransportClient, we will execute all
* {@link TransportClientBootstrap}s that are registered with this factory.
*
* This blocks until a connection is successfully established and fully bootstrapped.
*
Expand All @@ -97,23 +124,48 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
// Get connection from the connection pool first.
// If it is not found or not active, create a new one.
final InetSocketAddress address = new InetSocketAddress(remoteHost, remotePort);
TransportClient cachedClient = connectionPool.get(address);
if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, closing it.", address);
connectionPool.remove(address, cachedClient); // Remove inactive clients.

// Create the ClientPool if we don't have it yet.
ClientPool clientPool = connectionPool.get(address);
if (clientPool == null) {
connectionPool.putIfAbsent(address, new ClientPool(numConnectionsPerPeer));
clientPool = connectionPool.get(address);
}

int clientIndex = rand.nextInt(numConnectionsPerPeer);
TransportClient cachedClient = clientPool.clients[clientIndex];

if (cachedClient != null && cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
}

// If we reach here, we don't have an existing connection open. Let's create a new one.
// Multiple threads might race here to create new connections. Keep only one of them active.
synchronized (clientPool.locks[clientIndex]) {
cachedClient = clientPool.clients[clientIndex];

if (cachedClient != null) {
if (cachedClient.isActive()) {
logger.trace("Returning cached connection to {}: {}", address, cachedClient);
return cachedClient;
} else {
logger.info("Found inactive connection to {}, creating a new one.", address);
}
}
clientPool.clients[clientIndex] = createClient(address);
return clientPool.clients[clientIndex];
}
}

/** Create a completely new {@link TransportClient} to the remote address. */
private TransportClient createClient(InetSocketAddress address) throws IOException {
logger.debug("Creating new connection to " + address);

Bootstrap bootstrap = new Bootstrap();
bootstrap.group(workerGroup)
.channel(socketChannelClass)
// Disable Nagle's Algorithm since we don't want packets to wait
// Disable Nagle's Algorithm since we don't want packets to wait
.option(ChannelOption.TCP_NODELAY, true)
.option(ChannelOption.SO_KEEPALIVE, true)
.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, conf.connectionTimeoutMs())
Expand All @@ -130,7 +182,7 @@ public void initChannel(SocketChannel ch) {
});

// Connect to the remote server
long preConnect = System.currentTimeMillis();
long preConnect = System.nanoTime();
ChannelFuture cf = bootstrap.connect(address);
if (!cf.awaitUninterruptibly(conf.connectionTimeoutMs())) {
throw new IOException(
Expand All @@ -143,43 +195,37 @@ public void initChannel(SocketChannel ch) {
assert client != null : "Channel future completed successfully with null client";

// Execute any client bootstraps synchronously before marking the Client as successful.
long preBootstrap = System.currentTimeMillis();
long preBootstrap = System.nanoTime();
logger.debug("Connection to {} successful, running bootstraps...", address);
try {
for (TransportClientBootstrap clientBootstrap : clientBootstraps) {
clientBootstrap.doBootstrap(client);
}
} catch (Exception e) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
long bootstrapTime = System.currentTimeMillis() - preBootstrap;
logger.error("Exception while bootstrapping client after " + bootstrapTime + " ms", e);
long bootstrapTimeMs = (System.nanoTime() - preBootstrap) / 1000000;
logger.error("Exception while bootstrapping client after " + bootstrapTimeMs + " ms", e);
client.close();
throw Throwables.propagate(e);
}
long postBootstrap = System.currentTimeMillis();

// Successful connection & bootstrap -- in the event that two threads raced to create a client,
// use the first one that was put into the connectionPool and close the one we made here.
TransportClient oldClient = connectionPool.putIfAbsent(address, client);
if (oldClient == null) {
logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, postBootstrap - preConnect, postBootstrap - preBootstrap);
return client;
} else {
logger.debug("Two clients were created concurrently after {} ms, second will be disposed.",
postBootstrap - preConnect);
client.close();
return oldClient;
}
long postBootstrap = System.nanoTime();

logger.debug("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)",
address, (postBootstrap - preConnect) / 1000000, (postBootstrap - preBootstrap) / 1000000);

return client;
}

/** Close all connections in the connection pool, and shutdown the worker thread pool. */
@Override
public void close() {
for (TransportClient client : connectionPool.values()) {
try {
client.close();
} catch (RuntimeException e) {
logger.warn("Ignoring exception during close", e);
// Go through all clients and close them if they are active.
for (ClientPool clientPool : connectionPool.values()) {
for (int i = 0; i < clientPool.clients.length; i++) {
TransportClient client = clientPool.clients[i];
if (client != null) {
clientPool.clients[i] = null;
JavaUtils.closeQuietly(client);
}
}
}
connectionPool.clear();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ public int connectionTimeoutMs() {
return conf.getInt("spark.shuffle.io.connectionTimeout", 120) * 1000;
}

/** Number of concurrent connections between two nodes for fetching data. **/
public int numConnectionsPerPeer() {
return conf.getInt("spark.shuffle.io.numConnectionsPerPeer", 2);
}

/** Requested maximum length of the queue of incoming connections. Default -1 for no backlog. */
public int backLog() { return conf.getInt("spark.shuffle.io.backLog", -1); }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
package org.apache.spark.network;

import java.io.IOException;
import java.util.concurrent.TimeoutException;
import java.util.Collections;
import java.util.HashSet;
import java.util.NoSuchElementException;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;

import org.junit.After;
import org.junit.Before;
Expand All @@ -32,6 +36,7 @@
import org.apache.spark.network.server.NoOpRpcHandler;
import org.apache.spark.network.server.RpcHandler;
import org.apache.spark.network.server.TransportServer;
import org.apache.spark.network.util.ConfigProvider;
import org.apache.spark.network.util.JavaUtils;
import org.apache.spark.network.util.SystemPropertyConfigProvider;
import org.apache.spark.network.util.TransportConf;
Expand All @@ -57,16 +62,94 @@ public void tearDown() {
JavaUtils.closeQuietly(server2);
}

/**
* Request a bunch of clients to a single server to test
* we create up to maxConnections of clients.
*
* If concurrent is true, create multiple threads to create clients in parallel.
*/
private void testClientReuse(final int maxConnections, boolean concurrent)
throws IOException, InterruptedException {
TransportConf conf = new TransportConf(new ConfigProvider() {
@Override
public String get(String name) {
if (name.equals("spark.shuffle.io.numConnectionsPerPeer")) {
return Integer.toString(maxConnections);
} else {
throw new NoSuchElementException();
}
}
});

RpcHandler rpcHandler = new NoOpRpcHandler();
TransportContext context = new TransportContext(conf, rpcHandler);
final TransportClientFactory factory = context.createClientFactory();
final Set<TransportClient> clients = Collections.synchronizedSet(
new HashSet<TransportClient>());

final AtomicInteger failed = new AtomicInteger();
Thread[] attempts = new Thread[maxConnections * 10];

// Launch a bunch of threads to create new clients.
for (int i = 0; i < attempts.length; i++) {
attempts[i] = new Thread() {
@Override
public void run() {
try {
TransportClient client =
factory.createClient(TestUtils.getLocalHost(), server1.getPort());
assert (client.isActive());
clients.add(client);
} catch (IOException e) {
failed.incrementAndGet();
}
}
};

if (concurrent) {
attempts[i].start();
} else {
attempts[i].run();
}
}

// Wait until all the threads complete.
for (int i = 0; i < attempts.length; i++) {
attempts[i].join();
}

assert(failed.get() == 0);
assert(clients.size() == maxConnections);

for (TransportClient client : clients) {
client.close();
}
}

@Test
public void reuseClientsUpToConfigVariable() throws Exception {
testClientReuse(1, false);
testClientReuse(2, false);
testClientReuse(3, false);
testClientReuse(4, false);
}

@Test
public void createAndReuseBlockClients() throws IOException {
public void reuseClientsUpToConfigVariableConcurrent() throws Exception {
testClientReuse(1, true);
testClientReuse(2, true);
testClientReuse(3, true);
testClientReuse(4, true);
}

@Test
public void returnDifferentClientsForDifferentServers() throws IOException {
TransportClientFactory factory = context.createClientFactory();
TransportClient c1 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server1.getPort());
TransportClient c3 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
TransportClient c2 = factory.createClient(TestUtils.getLocalHost(), server2.getPort());
assertTrue(c1.isActive());
assertTrue(c3.isActive());
assertTrue(c1 == c2);
assertTrue(c1 != c3);
assertTrue(c2.isActive());
assertTrue(c1 != c2);
factory.close();
}

Expand Down