1818
1919import java .net .InetSocketAddress ;
2020import java .nio .channels .ClosedChannelException ;
21- import java .time .Duration ;
2221import java .util .Arrays ;
2322import java .util .concurrent .Callable ;
2423
3837import org .junit .jupiter .api .Test ;
3938import org .mockito .InOrder ;
4039import reactor .core .publisher .Mono ;
40+ import reactor .netty .http .client .HttpClient ;
4141import reactor .netty .tcp .TcpClient ;
4242import reactor .test .StepVerifier ;
4343
4444import org .springframework .boot .rsocket .server .RSocketServer ;
45- import org .springframework .boot .rsocket .server .RSocketServerCustomizer ;
4645import org .springframework .boot .rsocket .server .RSocketServer .Transport ;
46+ import org .springframework .boot .rsocket .server .RSocketServerCustomizer ;
4747import org .springframework .boot .web .server .Ssl ;
4848import org .springframework .core .codec .CharSequenceEncoder ;
4949import org .springframework .core .codec .StringDecoder ;
5555
5656import static org .assertj .core .api .Assertions .assertThat ;
5757import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
58- import static org .assertj .core .api .Assertions .assertThatThrownBy ;
5958import static org .mockito .ArgumentMatchers .any ;
6059import static org .mockito .BDDMockito .will ;
6160import static org .mockito .Mockito .inOrder ;
@@ -74,10 +73,11 @@ class NettyRSocketServerFactoryTests {
7473
7574 private RSocketRequester requester ;
7675
77- private static final Duration TIMEOUT = Duration .ofSeconds (3 );
78-
7976 @ AfterEach
8077 void tearDown () {
78+ if (this .requester != null ) {
79+ this .requester .rsocketClient ().dispose ();
80+ }
8181 if (this .server != null ) {
8282 try {
8383 this .server .stop ();
@@ -86,9 +86,6 @@ void tearDown() {
8686 // Ignore
8787 }
8888 }
89- if (this .requester != null ) {
90- this .requester .rsocketClient ().dispose ();
91- }
9289 }
9390
9491 private NettyRSocketServerFactory getFactory () {
@@ -105,11 +102,9 @@ void specificPort() {
105102 this .server .start ();
106103 return port ;
107104 });
108- this .requester = createRSocketTcpClient (false );
109- String payload = "test payload" ;
110- String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
105+ this .requester = createRSocketTcpClient ();
111106 assertThat (this .server .address ().getPort ()).isEqualTo (specificPort );
112- assertThat ( response ). isEqualTo ( payload );
107+ checkEchoRequest ( );
113108 }
114109
115110 @ Test
@@ -118,10 +113,8 @@ void websocketTransport() {
118113 factory .setTransport (RSocketServer .Transport .WEBSOCKET );
119114 this .server = factory .create (new EchoRequestResponseAcceptor ());
120115 this .server .start ();
121- this .requester = createRSocketWebSocketClient (false );
122- String payload = "test payload" ;
123- String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
124- assertThat (response ).isEqualTo (payload );
116+ this .requester = createRSocketWebSocketClient ();
117+ checkEchoRequest ();
125118 }
126119
127120 @ Test
@@ -133,10 +126,8 @@ void websocketTransportWithReactorResource() {
133126 factory .setResourceFactory (resourceFactory );
134127 this .server = factory .create (new EchoRequestResponseAcceptor ());
135128 this .server .start ();
136- this .requester = createRSocketWebSocketClient (false );
137- String payload = "test payload" ;
138- String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
139- assertThat (response ).isEqualTo (payload );
129+ this .requester = createRSocketWebSocketClient ();
130+ checkEchoRequest ();
140131 }
141132
142133 @ Test
@@ -176,6 +167,12 @@ void websocketTransportBasicSslFromFileSystem() {
176167 testBasicSslWithKeyStore ("src/test/resources/test.jks" , "password" , Transport .WEBSOCKET );
177168 }
178169
170+ private void checkEchoRequest () {
171+ String payload = "test payload" ;
172+ Mono <String > response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class );
173+ StepVerifier .create (response ).expectNext (payload ).verifyComplete ();
174+ }
175+
179176 private void testBasicSslWithKeyStore (String keyStore , String keyPassword , Transport transport ) {
180177 NettyRSocketServerFactory factory = getFactory ();
181178 factory .setTransport (transport );
@@ -185,11 +182,9 @@ private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Trans
185182 factory .setSsl (ssl );
186183 this .server = factory .create (new EchoRequestResponseAcceptor ());
187184 this .server .start ();
188- this .requester = (transport == Transport .TCP ) ? createRSocketTcpClient (true )
189- : createRSocketWebSocketClient (true );
190- String payload = "test payload" ;
191- Mono <String > responseMono = this .requester .route ("test" ).data (payload ).retrieveMono (String .class );
192- StepVerifier .create (responseMono ).expectNext (payload ).verifyComplete ();
185+ this .requester = (transport == Transport .TCP ) ? createSecureRSocketTcpClient ()
186+ : createSecureRSocketWebSocketClient ();
187+ checkEchoRequest ();
193188 }
194189
195190 @ Test
@@ -202,48 +197,54 @@ void tcpTransportSslRejectsInsecureClient() {
202197 factory .setSsl (ssl );
203198 this .server = factory .create (new EchoRequestResponseAcceptor ());
204199 this .server .start ();
205- this .requester = createRSocketTcpClient (false );
200+ this .requester = createRSocketTcpClient ();
206201 String payload = "test payload" ;
207202 Mono <String > responseMono = this .requester .route ("test" ).data (payload ).retrieveMono (String .class );
208203 StepVerifier .create (responseMono )
209204 .verifyErrorSatisfies ((ex ) -> assertThatExceptionOfType (ClosedChannelException .class ));
210205 }
211206
212- @ Test
213- void websocketTransportSslRejectsInsecureClient () {
214- NettyRSocketServerFactory factory = getFactory ();
215- factory .setTransport (Transport .WEBSOCKET );
216- Ssl ssl = new Ssl ();
217- ssl .setKeyStore ("classpath:test.jks" );
218- ssl .setKeyPassword ("password" );
219- factory .setSsl (ssl );
220- this .server = factory .create (new EchoRequestResponseAcceptor ());
221- this .server .start ();
222- // For WebSocket, the SSL failure results in a hang on the initial connect call
223- assertThatThrownBy (() -> createRSocketWebSocketClient (false )).isInstanceOf (IllegalStateException .class )
224- .hasStackTraceContaining ("Timeout on blocking read" );
207+ private RSocketRequester createRSocketTcpClient () {
208+ return createRSocketRequesterBuilder ().transport (TcpClientTransport .create (createTcpClient ()));
209+ }
210+
211+ private RSocketRequester createRSocketWebSocketClient () {
212+ return createRSocketRequesterBuilder ().transport (WebsocketClientTransport .create (createHttpClient (), "/" ));
225213 }
226214
227- private RSocketRequester createRSocketTcpClient (boolean ssl ) {
228- TcpClient tcpClient = createTcpClient (ssl );
229- return createRSocketRequesterBuilder ().connect (TcpClientTransport .create (tcpClient )).block (TIMEOUT );
215+ private RSocketRequester createSecureRSocketTcpClient () {
216+ return createRSocketRequesterBuilder ().transport (TcpClientTransport .create (createSecureTcpClient ()));
230217 }
231218
232- private RSocketRequester createRSocketWebSocketClient ( boolean ssl ) {
233- TcpClient tcpClient = createTcpClient ( ssl );
234- return createRSocketRequesterBuilder (). connect (WebsocketClientTransport .create (tcpClient )). block ( TIMEOUT );
219+ private RSocketRequester createSecureRSocketWebSocketClient ( ) {
220+ return createRSocketRequesterBuilder ()
221+ . transport (WebsocketClientTransport .create (createSecureHttpClient (), "/" ) );
235222 }
236223
237- private TcpClient createTcpClient (boolean ssl ) {
224+ private HttpClient createSecureHttpClient () {
225+ HttpClient httpClient = createHttpClient ();
226+ SslContextBuilder builder = SslContextBuilder .forClient ().sslProvider (SslProvider .JDK )
227+ .trustManager (InsecureTrustManagerFactory .INSTANCE );
228+ return httpClient .secure ((spec ) -> spec .sslContext (builder ));
229+ }
230+
231+ private HttpClient createHttpClient () {
238232 Assertions .assertThat (this .server ).isNotNull ();
239233 InetSocketAddress address = this .server .address ();
240- TcpClient tcpClient = TcpClient .create ().host (address .getHostName ()).port (address .getPort ());
241- if (ssl ) {
242- SslContextBuilder builder = SslContextBuilder .forClient ().sslProvider (SslProvider .JDK )
243- .trustManager (InsecureTrustManagerFactory .INSTANCE );
244- tcpClient = tcpClient .secure ((spec ) -> spec .sslContext (builder ));
245- }
246- return tcpClient ;
234+ return HttpClient .create ().host (address .getHostName ()).port (address .getPort ());
235+ }
236+
237+ private TcpClient createSecureTcpClient () {
238+ TcpClient tcpClient = createTcpClient ();
239+ SslContextBuilder builder = SslContextBuilder .forClient ().sslProvider (SslProvider .JDK )
240+ .trustManager (InsecureTrustManagerFactory .INSTANCE );
241+ return tcpClient .secure ((spec ) -> spec .sslContext (builder ));
242+ }
243+
244+ private TcpClient createTcpClient () {
245+ Assertions .assertThat (this .server ).isNotNull ();
246+ InetSocketAddress address = this .server .address ();
247+ return TcpClient .create ().host (address .getHostName ()).port (address .getPort ());
247248 }
248249
249250 private RSocketRequester .Builder createRSocketRequesterBuilder () {
0 commit comments