1818 */
1919package org .neo4j .driver .v1 .integration ;
2020
21+ import org .junit .After ;
2122import org .junit .Before ;
2223import org .junit .Test ;
2324
2425import java .io .IOException ;
26+ import java .net .ServerSocket ;
27+ import java .net .Socket ;
28+ import java .net .SocketException ;
2529import java .nio .ByteBuffer ;
2630import java .nio .channels .ByteChannel ;
27- import java .security .GeneralSecurityException ;
28- import java .security .KeyManagementException ;
2931import java .security .KeyStore ;
30- import java .security .KeyStoreException ;
31- import java .security .NoSuchAlgorithmException ;
32- import java .security .UnrecoverableKeyException ;
3332import java .security .cert .CertificateException ;
3433import java .security .cert .X509Certificate ;
34+ import java .util .concurrent .ExecutorService ;
35+ import java .util .concurrent .Future ;
3536import javax .net .ssl .KeyManagerFactory ;
3637import javax .net .ssl .SSLContext ;
38+ import javax .net .ssl .SSLServerSocketFactory ;
3739import javax .net .ssl .TrustManager ;
3840import javax .net .ssl .X509TrustManager ;
3941
42+ import static java .util .concurrent .Executors .newSingleThreadExecutor ;
43+ import static java .util .concurrent .TimeUnit .SECONDS ;
44+ import static org .junit .Assert .assertNull ;
45+ import static org .junit .Assert .assertTrue ;
46+ import static org .neo4j .driver .v1 .util .DaemonThreadFactory .daemon ;
47+
4048/**
4149 * This tests that the TLSSocketChannel handles every combination of network buffer sizes that we
4250 * can reasonably expect to see in the wild. It exhaustively tests power-of-two sizes up to 2^16
4351 * for the following variables:
44- *
52+ * <p>
4553 * - Network frame size
4654 * - Bolt message size
4755 * - Read buffer size
48- *
56+ * <p>
4957 * It tests every possible combination, and it does this currently only for the read path, expanding
5058 * to the write path as well would be useful. For each size, it sets up a TLS server and tests the
5159 * handshake, transferring the data, and verifying the data is correct after decryption.
5260 */
5361public abstract class TLSSocketChannelFragmentation
5462{
55- protected SSLContext sslCtx ;
63+ SSLContext sslCtx ;
64+ ServerSocket serverSocket ;
65+ volatile byte [] blobOfData ;
66+
67+ private ExecutorService serverExecutor ;
68+ private Future <?> serverTask ;
5669
5770 @ Before
58- public void setup () throws Throwable
71+ public void setUp () throws Throwable
72+ {
73+ sslCtx = createSSLContext ();
74+ serverSocket = createServerSocket ( sslCtx );
75+ serverExecutor = createServerExecutor ();
76+ serverTask = launchServer ( serverExecutor , createServerRunnable ( sslCtx ) );
77+ }
78+
79+ @ After
80+ public void tearDown () throws Exception
5981 {
60- createSSLContext ();
61- createServer ();
82+ serverSocket .close ();
83+ serverExecutor .shutdownNow ();
84+ assertTrue ( "Unable to terminate server socket" , serverExecutor .awaitTermination ( 30 , SECONDS ) );
85+
86+ assertNull ( serverTask .get ( 30 , SECONDS ) );
6287 }
6388
6489 @ Test
@@ -67,51 +92,104 @@ public void shouldHandleFuzziness() throws Throwable
6792 // Given
6893 int networkFrameSize , userBufferSize , blobOfDataSize ;
6994
70- for ( int dataBlobMagnitude = 1 ; dataBlobMagnitude < 16 ; dataBlobMagnitude += 2 )
95+ for ( int dataBlobMagnitude = 1 ; dataBlobMagnitude < 16 ; dataBlobMagnitude += 2 )
7196 {
7297 blobOfDataSize = (int ) Math .pow ( 2 , dataBlobMagnitude );
98+ blobOfData = blobOfData ( blobOfDataSize );
7399
74- for ( int frameSizeMagnitude = 1 ; frameSizeMagnitude < 16 ; frameSizeMagnitude += 2 )
100+ for ( int frameSizeMagnitude = 1 ; frameSizeMagnitude < 16 ; frameSizeMagnitude += 2 )
75101 {
76102 networkFrameSize = (int ) Math .pow ( 2 , frameSizeMagnitude );
77- for ( int userBufferMagnitude = 1 ; userBufferMagnitude < 16 ; userBufferMagnitude += 2 )
103+ for ( int userBufferMagnitude = 1 ; userBufferMagnitude < 16 ; userBufferMagnitude += 2 )
78104 {
79105 userBufferSize = (int ) Math .pow ( 2 , userBufferMagnitude );
80- testForBufferSizes ( blobOfDataSize , networkFrameSize , userBufferSize );
106+ testForBufferSizes ( blobOfData , networkFrameSize , userBufferSize );
81107 }
82108 }
83109 }
84110 }
85111
86- protected void createSSLContext ()
87- throws KeyStoreException , IOException , NoSuchAlgorithmException , CertificateException ,
88- UnrecoverableKeyException , KeyManagementException
112+ protected abstract void testForBufferSizes ( byte [] blobOfData , int networkFrameSize , int userBufferSize )
113+ throws Exception ;
114+
115+ protected abstract Runnable createServerRunnable ( SSLContext sslContext ) throws IOException ;
116+
117+ private static SSLContext createSSLContext () throws Exception
89118 {
90- KeyStore ks = KeyStore .getInstance ("JKS" );
119+ KeyStore ks = KeyStore .getInstance ( "JKS" );
91120 char [] password = "password" .toCharArray ();
92- ks .load ( getClass () .getResourceAsStream ( "/keystore.jks" ), password );
93- KeyManagerFactory kmf = KeyManagerFactory .getInstance ("SunX509" );
94- kmf .init (ks , password );
121+ ks .load ( TLSSocketChannelFragmentation . class .getResourceAsStream ( "/keystore.jks" ), password );
122+ KeyManagerFactory kmf = KeyManagerFactory .getInstance ( "SunX509" );
123+ kmf .init ( ks , password );
95124
96- sslCtx = SSLContext .getInstance ("TLS" );
97- sslCtx .init ( kmf .getKeyManagers (), new TrustManager []{new X509TrustManager () {
98- public void checkClientTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
125+ SSLContext sslCtx = SSLContext .getInstance ( "TLS" );
126+ sslCtx .init ( kmf .getKeyManagers (), new TrustManager []{new X509TrustManager ()
127+ {
128+ @ Override
129+ public void checkClientTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
99130 {
100131 }
101132
102- public void checkServerTrusted (X509Certificate [] chain , String authType ) throws CertificateException {
133+ @ Override
134+ public void checkServerTrusted ( X509Certificate [] chain , String authType ) throws CertificateException
135+ {
103136 }
104137
105- public X509Certificate [] getAcceptedIssuers () {
138+ @ Override
139+ public X509Certificate [] getAcceptedIssuers ()
140+ {
106141 return null ;
107142 }
108143 }}, null );
144+
145+ return sslCtx ;
109146 }
110147
111- protected abstract void testForBufferSizes ( int blobOfDataSize , int networkFrameSize , int userBufferSize ) throws IOException ,
112- GeneralSecurityException ;
148+ private static ServerSocket createServerSocket ( SSLContext sslContext ) throws IOException
149+ {
150+ SSLServerSocketFactory ssf = sslContext .getServerSocketFactory ();
151+ return ssf .createServerSocket ( 0 );
152+ }
153+
154+ private ExecutorService createServerExecutor ()
155+ {
156+ return newSingleThreadExecutor ( daemon ( getClass ().getSimpleName () + "-Server-" ) );
157+ }
113158
114- protected abstract void createServer () throws IOException ;
159+ private Future <?> launchServer ( ExecutorService executor , Runnable runnable )
160+ {
161+ return executor .submit ( runnable );
162+ }
163+
164+ static byte [] blobOfData ( int dataBlobSize )
165+ {
166+ byte [] blobOfData = new byte [dataBlobSize ];
167+ // If the blob is all zeros, we'd miss data corruption problems in assertions, so
168+ // fill the data blob with different values.
169+ for ( int i = 0 ; i < blobOfData .length ; i ++ )
170+ {
171+ blobOfData [i ] = (byte ) (i % 128 );
172+ }
173+
174+ return blobOfData ;
175+ }
176+
177+ static Socket accept ( ServerSocket serverSocket ) throws IOException
178+ {
179+ try
180+ {
181+ return serverSocket .accept ();
182+ }
183+ catch ( SocketException e )
184+ {
185+ String message = e .getMessage ();
186+ if ( "Socket closed" .equalsIgnoreCase ( message ) )
187+ {
188+ return null ;
189+ }
190+ throw e ;
191+ }
192+ }
115193
116194 /**
117195 * Delegates to underlying channel, but only reads up to the set amount at a time, used to emulate
@@ -122,7 +200,7 @@ protected static class LittleAtATimeChannel implements ByteChannel
122200 private final ByteChannel delegate ;
123201 private final int maxFrameSize ;
124202
125- public LittleAtATimeChannel ( ByteChannel delegate , int maxFrameSize )
203+ LittleAtATimeChannel ( ByteChannel delegate , int maxFrameSize )
126204 {
127205
128206 this .delegate = delegate ;
@@ -152,7 +230,7 @@ public int write( ByteBuffer src ) throws IOException
152230 }
153231 finally
154232 {
155- src .limit (originalLimit );
233+ src .limit ( originalLimit );
156234 }
157235 }
158236
@@ -167,7 +245,7 @@ public int read( ByteBuffer dst ) throws IOException
167245 }
168246 finally
169247 {
170- dst .limit (originalLimit );
248+ dst .limit ( originalLimit );
171249 }
172250 }
173251 }
0 commit comments