2222import java .net .InetSocketAddress ;
2323import java .net .SocketAddress ;
2424import java .util .List ;
25+ import java .util .Random ;
2526import java .util .concurrent .ConcurrentHashMap ;
2627import java .util .concurrent .atomic .AtomicReference ;
2728
5657 * TransportClient, all given {@link TransportClientBootstrap}s will be run.
5758 */
5859public class TransportClientFactory implements Closeable {
60+
61+ private class ClientPool {
62+ TransportClient [] clients ;
63+ Object [] locks ;
64+
65+ public ClientPool () {
66+ clients = new TransportClient [numConnectionsPerPeer ];
67+ locks = new Object [numConnectionsPerPeer ];
68+ }
69+ }
70+
5971 private final Logger logger = LoggerFactory .getLogger (TransportClientFactory .class );
6072
6173 private final TransportContext context ;
6274 private final TransportConf conf ;
6375 private final List <TransportClientBootstrap > clientBootstraps ;
64- private final ConcurrentHashMap <SocketAddress , TransportClient > connectionPool ;
76+ private final ConcurrentHashMap <SocketAddress , ClientPool > connectionPool ;
77+
78+ /** Random number generator for picking connections between peers. */
79+ private final Random rand ;
80+ private final int numConnectionsPerPeer ;
6581
6682 private final Class <? extends Channel > socketChannelClass ;
6783 private EventLoopGroup workerGroup ;
@@ -73,7 +89,9 @@ public TransportClientFactory(
7389 this .context = Preconditions .checkNotNull (context );
7490 this .conf = context .getConf ();
7591 this .clientBootstraps = Lists .newArrayList (Preconditions .checkNotNull (clientBootstraps ));
76- this .connectionPool = new ConcurrentHashMap <SocketAddress , TransportClient >();
92+ this .connectionPool = new ConcurrentHashMap <SocketAddress , ClientPool >();
93+ this .numConnectionsPerPeer = conf .numConnectionsPerPeer ();
94+ this .rand = new Random ();
7795
7896 IOMode ioMode = IOMode .valueOf (conf .ioMode ());
7997 this .socketChannelClass = NettyUtils .getClientChannelClass (ioMode );
@@ -97,27 +115,49 @@ public TransportClient createClient(String remoteHost, int remotePort) throws IO
97115 // Get connection from the connection pool first.
98116 // If it is not found or not active, create a new one.
99117 final InetSocketAddress address = new InetSocketAddress (remoteHost , remotePort );
100- TransportClient cachedClient = connectionPool .get (address );
118+
119+ // Create the ClientPool if we don't have it yet.
120+ ClientPool clientPool = connectionPool .get (address );
121+ if (clientPool == null ) {
122+ clientPool = connectionPool .putIfAbsent (address , new ClientPool ());
123+ }
124+
125+ int clientIndex = rand .nextInt (numConnectionsPerPeer );
126+ TransportClient cachedClient = clientPool .clients [clientIndex ];
101127 if (cachedClient != null ) {
102128 if (cachedClient .isActive ()) {
103129 logger .trace ("Returning cached connection to {}: {}" , address , cachedClient );
104130 return cachedClient ;
105131 } else {
106132 logger .info ("Found inactive connection to {}, closing it." , address );
107- connectionPool . remove ( address , cachedClient ); // Remove inactive clients.
133+ clientPool . clients [ clientIndex ] = null ; // Remove inactive clients.
108134 }
109135 }
110136
137+ // If we reach here, we don't have an existing connection open. Let's create a new one.
138+ // Multiple threads might race here to create new connections. Let's keep only one of them
139+ // active at anytime.
140+ synchronized (clientPool .locks [clientIndex ]) {
141+ if (clientPool .clients [clientIndex ] == null || !clientPool .clients [clientIndex ].isActive ()) {
142+ clientPool .clients [clientIndex ] = createClient (address );
143+ }
144+ }
145+
146+ return clientPool .clients [clientIndex ];
147+ }
148+
149+ /** Create a completely new {@link TransportClient} to the remote address. */
150+ private TransportClient createClient (InetSocketAddress address ) throws IOException {
111151 logger .debug ("Creating new connection to " + address );
112152
113153 Bootstrap bootstrap = new Bootstrap ();
114154 bootstrap .group (workerGroup )
115- .channel (socketChannelClass )
116- // Disable Nagle's Algorithm since we don't want packets to wait
117- .option (ChannelOption .TCP_NODELAY , true )
118- .option (ChannelOption .SO_KEEPALIVE , true )
119- .option (ChannelOption .CONNECT_TIMEOUT_MILLIS , conf .connectionTimeoutMs ())
120- .option (ChannelOption .ALLOCATOR , pooledAllocator );
155+ .channel (socketChannelClass )
156+ // Disable Nagle's Algorithm since we don't want packets to wait
157+ .option (ChannelOption .TCP_NODELAY , true )
158+ .option (ChannelOption .SO_KEEPALIVE , true )
159+ .option (ChannelOption .CONNECT_TIMEOUT_MILLIS , conf .connectionTimeoutMs ())
160+ .option (ChannelOption .ALLOCATOR , pooledAllocator );
121161
122162 final AtomicReference <TransportClient > clientRef = new AtomicReference <TransportClient >();
123163
@@ -130,11 +170,11 @@ public void initChannel(SocketChannel ch) {
130170 });
131171
132172 // Connect to the remote server
133- long preConnect = System .currentTimeMillis ();
173+ long preConnect = System .nanoTime ();
134174 ChannelFuture cf = bootstrap .connect (address );
135175 if (!cf .awaitUninterruptibly (conf .connectionTimeoutMs ())) {
136176 throw new IOException (
137- String .format ("Connecting to %s timed out (%s ms)" , address , conf .connectionTimeoutMs ()));
177+ String .format ("Connecting to %s timed out (%s ms)" , address , conf .connectionTimeoutMs ()));
138178 } else if (cf .cause () != null ) {
139179 throw new IOException (String .format ("Failed to connect to %s" , address ), cf .cause ());
140180 }
@@ -143,43 +183,41 @@ public void initChannel(SocketChannel ch) {
143183 assert client != null : "Channel future completed successfully with null client" ;
144184
145185 // Execute any client bootstraps synchronously before marking the Client as successful.
146- long preBootstrap = System .currentTimeMillis ();
186+ long preBootstrap = System .nanoTime ();
147187 logger .debug ("Connection to {} successful, running bootstraps..." , address );
148188 try {
149189 for (TransportClientBootstrap clientBootstrap : clientBootstraps ) {
150190 clientBootstrap .doBootstrap (client );
151191 }
152192 } catch (Exception e ) { // catch non-RuntimeExceptions too as bootstrap may be written in Scala
153- long bootstrapTime = System .currentTimeMillis () - preBootstrap ;
154- logger .error ("Exception while bootstrapping client after " + bootstrapTime + " ms" , e );
193+ long bootstrapTimeMs = ( System .nanoTime () - preBootstrap ) / 1000000 ;
194+ logger .error ("Exception while bootstrapping client after " + bootstrapTimeMs + " ms" , e );
155195 client .close ();
156196 throw Throwables .propagate (e );
157197 }
158- long postBootstrap = System .currentTimeMillis ();
159-
160- // Successful connection & bootstrap -- in the event that two threads raced to create a client,
161- // use the first one that was put into the connectionPool and close the one we made here.
162- TransportClient oldClient = connectionPool .putIfAbsent (address , client );
163- if (oldClient == null ) {
164- logger .debug ("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)" ,
165- address , postBootstrap - preConnect , postBootstrap - preBootstrap );
166- return client ;
167- } else {
168- logger .debug ("Two clients were created concurrently after {} ms, second will be disposed." ,
169- postBootstrap - preConnect );
170- client .close ();
171- return oldClient ;
172- }
198+ long postBootstrap = System .nanoTime ();
199+
200+ logger .debug ("Successfully created connection to {} after {} ms ({} ms spent in bootstraps)" ,
201+ address , (postBootstrap - preConnect ) / 1000000 , (postBootstrap - preBootstrap ) / 1000000 );
202+
203+ return client ;
173204 }
174205
175206 /** Close all connections in the connection pool, and shutdown the worker thread pool. */
176207 @ Override
177208 public void close () {
178- for (TransportClient client : connectionPool .values ()) {
179- try {
180- client .close ();
181- } catch (RuntimeException e ) {
182- logger .warn ("Ignoring exception during close" , e );
209+ // Go through all clients and close them if they are active.
210+ for (ClientPool clientPool : connectionPool .values ()) {
211+ for (int i = 0 ; i < clientPool .clients .length ; i ++) {
212+ TransportClient client = clientPool .clients [i ];
213+ if (client != null ) {
214+ clientPool .clients [i ] = null ;
215+ try {
216+ client .close ();
217+ } catch (RuntimeException e ) {
218+ logger .warn ("Ignoring exception during close" , e );
219+ }
220+ }
183221 }
184222 }
185223 connectionPool .clear ();
0 commit comments