1717package org .springframework .boot .rsocket .netty ;
1818
1919import java .net .InetSocketAddress ;
20+ import java .nio .channels .ClosedChannelException ;
2021import java .time .Duration ;
2122import java .util .Arrays ;
2223import java .util .concurrent .Callable ;
2324
2425import io .netty .buffer .PooledByteBufAllocator ;
26+ import io .netty .handler .ssl .SslContextBuilder ;
27+ import io .netty .handler .ssl .SslProvider ;
28+ import io .netty .handler .ssl .util .InsecureTrustManagerFactory ;
2529import io .rsocket .ConnectionSetupPayload ;
2630import io .rsocket .Payload ;
2731import io .rsocket .RSocket ;
2832import io .rsocket .SocketAcceptor ;
33+ import io .rsocket .transport .netty .client .TcpClientTransport ;
2934import io .rsocket .transport .netty .client .WebsocketClientTransport ;
3035import io .rsocket .util .DefaultPayload ;
3136import org .assertj .core .api .Assertions ;
3237import org .junit .jupiter .api .AfterEach ;
3338import org .junit .jupiter .api .Test ;
3439import org .mockito .InOrder ;
3540import reactor .core .publisher .Mono ;
41+ import reactor .netty .tcp .TcpClient ;
42+ import reactor .test .StepVerifier ;
3643
3744import org .springframework .boot .rsocket .server .RSocketServer ;
3845import org .springframework .boot .rsocket .server .RSocketServerCustomizer ;
46+ import org .springframework .boot .rsocket .server .RSocketServer .Transport ;
47+ import org .springframework .boot .web .server .Ssl ;
3948import org .springframework .core .codec .CharSequenceEncoder ;
4049import org .springframework .core .codec .StringDecoder ;
4150import org .springframework .core .io .buffer .NettyDataBufferFactory ;
4554import org .springframework .util .SocketUtils ;
4655
4756import static org .assertj .core .api .Assertions .assertThat ;
57+ import static org .assertj .core .api .Assertions .assertThatExceptionOfType ;
58+ import static org .assertj .core .api .Assertions .assertThatThrownBy ;
4859import static org .mockito .ArgumentMatchers .any ;
4960import static org .mockito .BDDMockito .will ;
5061import static org .mockito .Mockito .inOrder ;
5566 *
5667 * @author Brian Clozel
5768 * @author Leo Li
69+ * @author Chris Bono
5870 */
5971class NettyRSocketServerFactoryTests {
6072
@@ -93,7 +105,7 @@ void specificPort() {
93105 this .server .start ();
94106 return port ;
95107 });
96- this .requester = createRSocketTcpClient ();
108+ this .requester = createRSocketTcpClient (false );
97109 String payload = "test payload" ;
98110 String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
99111 assertThat (this .server .address ().getPort ()).isEqualTo (specificPort );
@@ -106,7 +118,7 @@ void websocketTransport() {
106118 factory .setTransport (RSocketServer .Transport .WEBSOCKET );
107119 this .server = factory .create (new EchoRequestResponseAcceptor ());
108120 this .server .start ();
109- this .requester = createRSocketWebSocketClient ();
121+ this .requester = createRSocketWebSocketClient (false );
110122 String payload = "test payload" ;
111123 String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
112124 assertThat (response ).isEqualTo (payload );
@@ -121,7 +133,7 @@ void websocketTransportWithReactorResource() {
121133 factory .setResourceFactory (resourceFactory );
122134 this .server = factory .create (new EchoRequestResponseAcceptor ());
123135 this .server .start ();
124- this .requester = createRSocketWebSocketClient ();
136+ this .requester = createRSocketWebSocketClient (false );
125137 String payload = "test payload" ;
126138 String response = this .requester .route ("test" ).data (payload ).retrieveMono (String .class ).block (TIMEOUT );
127139 assertThat (response ).isEqualTo (payload );
@@ -144,16 +156,94 @@ void serverCustomizers() {
144156 }
145157 }
146158
147- private RSocketRequester createRSocketTcpClient () {
148- Assertions .assertThat (this .server ).isNotNull ();
149- InetSocketAddress address = this .server .address ();
150- return createRSocketRequesterBuilder ().tcp (address .getHostString (), address .getPort ());
159+ @ Test
160+ void tcpTransportBasicSslFromClassPath () {
161+ testBasicSslWithKeyStore ("classpath:test.jks" , "password" , Transport .TCP );
162+ }
163+
164+ @ Test
165+ void tcpTransportBasicSslFromFileSystem () {
166+ testBasicSslWithKeyStore ("src/test/resources/test.jks" , "password" , Transport .TCP );
167+ }
168+
169+ @ Test
170+ void websocketTransportBasicSslFromClassPath () {
171+ testBasicSslWithKeyStore ("classpath:test.jks" , "password" , Transport .WEBSOCKET );
172+ }
173+
174+ @ Test
175+ void websocketTransportBasicSslFromFileSystem () {
176+ testBasicSslWithKeyStore ("src/test/resources/test.jks" , "password" , Transport .WEBSOCKET );
177+ }
178+
179+ private void testBasicSslWithKeyStore (String keyStore , String keyPassword , Transport transport ) {
180+ NettyRSocketServerFactory factory = getFactory ();
181+ factory .setTransport (transport );
182+ Ssl ssl = new Ssl ();
183+ ssl .setKeyStore (keyStore );
184+ ssl .setKeyPassword (keyPassword );
185+ factory .setSsl (ssl );
186+ this .server = factory .create (new EchoRequestResponseAcceptor ());
187+ this .server .start ();
188+ this .requester = (transport == Transport .TCP ) ? createRSocketTcpClient (true )
189+ : createRSocketWebSocketClient (true );
190+ String payload = "test payload" ;
191+ Mono <String > responseMono = this .requester .route ("test" ).data (payload ).retrieveMono (String .class );
192+ StepVerifier .create (responseMono ).expectNext (payload ).verifyComplete ();
151193 }
152194
153- private RSocketRequester createRSocketWebSocketClient () {
195+ @ Test
196+ void tcpTransportSslRejectsInsecureClient () {
197+ NettyRSocketServerFactory factory = getFactory ();
198+ factory .setTransport (Transport .TCP );
199+ Ssl ssl = new Ssl ();
200+ ssl .setKeyStore ("classpath:test.jks" );
201+ ssl .setKeyPassword ("password" );
202+ factory .setSsl (ssl );
203+ this .server = factory .create (new EchoRequestResponseAcceptor ());
204+ this .server .start ();
205+ this .requester = createRSocketTcpClient (false );
206+ String payload = "test payload" ;
207+ Mono <String > responseMono = this .requester .route ("test" ).data (payload ).retrieveMono (String .class );
208+ StepVerifier .create (responseMono )
209+ .verifyErrorSatisfies ((ex ) -> assertThatExceptionOfType (ClosedChannelException .class ));
210+ }
211+
212+ @ Test
213+ void websocketTransportSslRejectsInsecureClient () {
214+ NettyRSocketServerFactory factory = getFactory ();
215+ factory .setTransport (Transport .WEBSOCKET );
216+ Ssl ssl = new Ssl ();
217+ ssl .setKeyStore ("classpath:test.jks" );
218+ ssl .setKeyPassword ("password" );
219+ factory .setSsl (ssl );
220+ this .server = factory .create (new EchoRequestResponseAcceptor ());
221+ this .server .start ();
222+ // For WebSocket, the SSL failure results in a hang on the initial connect call
223+ assertThatThrownBy (() -> createRSocketWebSocketClient (false )).isInstanceOf (IllegalStateException .class )
224+ .hasStackTraceContaining ("Timeout on blocking read" );
225+ }
226+
227+ private RSocketRequester createRSocketTcpClient (boolean ssl ) {
228+ TcpClient tcpClient = createTcpClient (ssl );
229+ return createRSocketRequesterBuilder ().connect (TcpClientTransport .create (tcpClient )).block (TIMEOUT );
230+ }
231+
232+ private RSocketRequester createRSocketWebSocketClient (boolean ssl ) {
233+ TcpClient tcpClient = createTcpClient (ssl );
234+ return createRSocketRequesterBuilder ().connect (WebsocketClientTransport .create (tcpClient )).block (TIMEOUT );
235+ }
236+
237+ private TcpClient createTcpClient (boolean ssl ) {
154238 Assertions .assertThat (this .server ).isNotNull ();
155239 InetSocketAddress address = this .server .address ();
156- return createRSocketRequesterBuilder ().transport (WebsocketClientTransport .create (address ));
240+ TcpClient tcpClient = TcpClient .create ().host (address .getHostName ()).port (address .getPort ());
241+ if (ssl ) {
242+ SslContextBuilder builder = SslContextBuilder .forClient ().sslProvider (SslProvider .JDK )
243+ .trustManager (InsecureTrustManagerFactory .INSTANCE );
244+ tcpClient = tcpClient .secure ((spec ) -> spec .sslContext (builder ));
245+ }
246+ return tcpClient ;
157247 }
158248
159249 private RSocketRequester .Builder createRSocketRequesterBuilder () {
0 commit comments