2222import java .net .InetSocketAddress ;
2323import java .net .SocketAddress ;
2424import java .util .List ;
25+ import java .util .Random ;
2526import java .util .concurrent .ConcurrentHashMap ;
2627import java .util .concurrent .atomic .AtomicReference ;
2728
4243import org .apache .spark .network .TransportContext ;
4344import org .apache .spark .network .server .TransportChannelHandler ;
4445import org .apache .spark .network .util .IOMode ;
46+ import org .apache .spark .network .util .JavaUtils ;
4547import org .apache .spark .network .util .NettyUtils ;
4648import org .apache .spark .network .util .TransportConf ;
4749
5658 * TransportClient, all given {@link TransportClientBootstrap}s will be run.
5759 */
5860public class TransportClientFactory implements Closeable {
61+
62+ /** A simple data structure to track the pool of clients between two peer nodes. */
63+ private static class ClientPool {
64+ TransportClient [] clients ;
65+ Object [] locks ;
66+
67+ public ClientPool (int size ) {
68+ clients = new TransportClient [size ];
69+ locks = new Object [size ];
70+ for (int i = 0 ; i < size ; i ++) {
71+ locks [i ] = new Object ();
72+ }
73+ }
74+ }
75+
5976 private final Logger logger = LoggerFactory .getLogger (TransportClientFactory .class );
6077
6178 private final TransportContext context ;
6279 private final TransportConf conf ;
6380 private final List <TransportClientBootstrap > clientBootstraps ;
64- private final ConcurrentHashMap <SocketAddress , TransportClient > connectionPool ;
81+ private final ConcurrentHashMap <SocketAddress , ClientPool > connectionPool ;
82+
83+ /** Random number generator for picking connections between peers. */
84+ private final Random rand ;
85+ private final int numConnectionsPerPeer ;
6586
6687 private final Class <? extends Channel > socketChannelClass ;
6788 private EventLoopGroup workerGroup ;
@@ -73,7 +94,9 @@ public TransportClientFactory(
7394 this .context = Preconditions .checkNotNull (context );
7495 this .conf = context .getConf ();
7596 this .clientBootstraps = Lists .newArrayList (Preconditions .checkNotNull (clientBootstraps ));
76- this .connectionPool = new ConcurrentHashMap <SocketAddress , TransportClient >();
97+ this .connectionPool = new ConcurrentHashMap <SocketAddress , ClientPool >();
98+ this .numConnectionsPerPeer = conf .numConnectionsPerPeer ();
99+ this .rand = new Random ();
77100
78101 IOMode ioMode = IOMode .valueOf (conf .ioMode ());
79102 this .socketChannelClass = NettyUtils .getClientChannelClass (ioMode );
@@ -84,10 +107,14 @@ public TransportClientFactory(
84107 }
85108
86109 /**
87- * Create a new {@link TransportClient} connecting to the given remote host / port. This will
88- * reuse TransportClients if they are still active and are for the same remote address. Prior
89- * to the creation of a new TransportClient, we will execute all {@link TransportClientBootstrap}s
90- * that are registered with this factory.
110+ * Create a {@link TransportClient} connecting to the given remote host / port.
111+ *
112+ * We maintains an array of clients (size determined by spark.shuffle.io.numConnectionsPerPeer)
113+ * and randomly picks one to use. If no client was previously created in the randomly selected
114+ * spot, this function creates a new client and places it there.
115+ *
116+ * Prior to the creation of a new TransportClient, we will execute all
117+ * {@link TransportClientBootstrap}s that are registered with this factory.
91118 *
92119 * This blocks until a connection is successfully established and fully bootstrapped.
93120 *
@@ -97,23 +124,48 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
97124 // Get connection from the connection pool first.
98125 // If it is not found or not active, create a new one.
99126 final InetSocketAddress address = new InetSocketAddress (remoteHost , remotePort );
100- TransportClient cachedClient = connectionPool .get (address );
101- if (cachedClient != null ) {
102- if (cachedClient .isActive ()) {
103- logger .trace ("Returning cached connection to {}: {}" , address , cachedClient );
104- return cachedClient ;
105- } else {
106- logger .info ("Found inactive connection to {}, closing it." , address );
107- connectionPool .remove (address , cachedClient ); // Remove inactive clients.
127+
128+ // Create the ClientPool if we don't have it yet.
129+ ClientPool clientPool = connectionPool .get (address );
130+ if (clientPool == null ) {
131+ connectionPool .putIfAbsent (address , new ClientPool (numConnectionsPerPeer ));
132+ clientPool = connectionPool .get (address );
133+ }
134+
135+ int clientIndex = rand .nextInt (numConnectionsPerPeer );
136+ TransportClient cachedClient = clientPool .clients [clientIndex ];
137+
138+ if (cachedClient != null && cachedClient .isActive ()) {
139+ logger .trace ("Returning cached connection to {}: {}" , address , cachedClient );
140+ return cachedClient ;
141+ }
142+
143+ // If we reach here, we don't have an existing connection open. Let's create a new one.
144+ // Multiple threads might race here to create new connections. Keep only one of them active.
145+ synchronized (clientPool .locks [clientIndex ]) {
146+ cachedClient = clientPool .clients [clientIndex ];
147+
148+ if (cachedClient != null ) {
149+ if (cachedClient .isActive ()) {
150+ logger .trace ("Returning cached connection to {}: {}" , address , cachedClient );
151+ return cachedClient ;
152+ } else {
153+ logger .info ("Found inactive connection to {}, creating a new one." , address );
154+ }
108155 }
156+ clientPool .clients [clientIndex ] = createClient (address );
157+ return clientPool .clients [clientIndex ];
109158 }
159+ }
110160
161+ /** Create a completely new {@link TransportClient} to the remote address. */
162+ private TransportClient createClient (InetSocketAddress address ) throws IOException {
111163 logger .debug ("Creating new connection to " + address );
112164
113165 Bootstrap bootstrap = new Bootstrap ();
114166 bootstrap .group (workerGroup )
115167 .channel (socketChannelClass )
116- // Disable Nagle's Algorithm since we don't want packets to wait
168+ // Disable Nagle's Algorithm since we don't want packets to wait
117169 .option (ChannelOption .TCP_NODELAY , true )
118170 .option (ChannelOption .SO_KEEPALIVE , true )
119171 .option (ChannelOption .CONNECT_TIMEOUT_MILLIS , conf .connectionTimeoutMs ())
@@ -130,7 +182,7 @@ public void initChannel(SocketChannel ch) {
130182 });
131183
132184 // Connect to the remote server
133- long preConnect = System .currentTimeMillis ();
185+ long preConnect = System .nanoTime ();
134186 ChannelFuture cf = bootstrap .connect (address );
135187 if (!cf .awaitUninterruptibly (conf .connectionTimeoutMs ())) {
136188 throw new IOException (
@@ -143,43 +195,37 @@ public void initChannel(SocketChannel ch) {
143195 assert client != null : "Channel future completed successfully with null client" ;
144196
145197 // Execute any client bootstraps synchronously before marking the Client as successful.
146- long preBootstrap = System .currentTimeMillis ();
198+ long preBootstrap = System .nanoTime ();
147199 logger .debug ("Connection to {} successful, running bootstraps..." , address );
148200 try {
149201 for (TransportClientBootstrap clientBootstrap : clientBootstraps ) {
150202 clientBootstrap .doBootstrap (client );
151203 }
152204 } catch (Exception e ) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
153- long bootstrapTime = System .currentTimeMillis () - preBootstrap ;
154- logger .error ("Exception while bootstrapping client after " + bootstrapTime + " ms" , e );
205+ long bootstrapTimeMs = ( System .nanoTime () - preBootstrap ) / 1000000 ;
206+ logger .error ("Exception while bootstrapping client after " + bootstrapTimeMs + " ms" , e );
155207 client .close ();
156208 throw Throwables .propagate (e );
157209 }
158- long postBootstrap = System .currentTimeMillis ();
159-
160- // Successful connection & bootstrap -- in the event that two threads raced to create a client,
161- // use the first one that was put into the connectionPool and close the one we made here.
162- TransportClient oldClient = connectionPool .putIfAbsent (address , client );
163- if (oldClient == null ) {
164- logger .debug ("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)" ,
165- address , postBootstrap - preConnect , postBootstrap - preBootstrap );
166- return client ;
167- } else {
168- logger .debug ("Two clients were created concurrently after {} ms, second will be disposed." ,
169- postBootstrap - preConnect );
170- client .close ();
171- return oldClient ;
172- }
210+ long postBootstrap = System .nanoTime ();
211+
212+ logger .debug ("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)" ,
213+ address , (postBootstrap - preConnect ) / 1000000 , (postBootstrap - preBootstrap ) / 1000000 );
214+
215+ return client ;
173216 }
174217
175218 /** Close all connections in the connection pool, and shutdown the worker thread pool. */
176219 @ Override
177220 public void close () {
178- for (TransportClient client : connectionPool .values ()) {
179- try {
180- client .close ();
181- } catch (RuntimeException e ) {
182- logger .warn ("Ignoring exception during close" , e );
221+ // Go through all clients and close them if they are active.
222+ for (ClientPool clientPool : connectionPool .values ()) {
223+ for (int i = 0 ; i < clientPool .clients .length ; i ++) {
224+ TransportClient client = clientPool .clients [i ];
225+ if (client != null ) {
226+ clientPool .clients [i ] = null ;
227+ JavaUtils .closeQuietly (client );
228+ }
183229 }
184230 }
185231 connectionPool .clear ();
0 commit comments