Skip to content

Commit b4810b8

Browse files
cbo-indeedbclozel
authored andcommitted
Add SSL support to RSocketServer
See gh-19399
1 parent dd02404 commit b4810b8

File tree

7 files changed

+205
-15
lines changed

7 files changed

+205
-15
lines changed

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketProperties.java

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919
import java.net.InetAddress;
2020

2121
import org.springframework.boot.context.properties.ConfigurationProperties;
22+
import org.springframework.boot.context.properties.NestedConfigurationProperty;
2223
import org.springframework.boot.rsocket.server.RSocketServer;
24+
import org.springframework.boot.web.server.Ssl;
2325

2426
/**
2527
* {@link ConfigurationProperties properties} for RSocket support.
2628
*
2729
* @author Brian Clozel
30+
* @author Chris Bono
2831
* @since 2.2.0
2932
*/
3033
@ConfigurationProperties("spring.rsocket")
@@ -59,6 +62,9 @@ public static class Server {
5962
*/
6063
private String mappingPath;
6164

65+
@NestedConfigurationProperty
66+
private Ssl ssl;
67+
6268
public Integer getPort() {
6369
return this.port;
6470
}
@@ -91,6 +97,14 @@ public void setMappingPath(String mappingPath) {
9197
this.mappingPath = mappingPath;
9298
}
9399

100+
public Ssl getSsl() {
101+
return this.ssl;
102+
}
103+
104+
public void setSsl(Ssl ssl) {
105+
this.ssl = ssl;
106+
}
107+
94108
}
95109

96110
}

spring-boot-project/spring-boot-autoconfigure/src/main/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfiguration.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ RSocketServerFactory rSocketServerFactory(RSocketProperties properties, ReactorR
9797
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
9898
map.from(properties.getServer().getAddress()).to(factory::setAddress);
9999
map.from(properties.getServer().getPort()).to(factory::setPort);
100+
map.from(properties.getServer().getSsl()).to(factory::setSsl);
100101
factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList()));
101102
return factory;
102103
}

spring-boot-project/spring-boot-autoconfigure/src/test/java/org/springframework/boot/autoconfigure/rsocket/RSocketServerAutoConfigurationTests.java

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,18 @@ void shouldSetLocalServerPortWhenRSocketServerPortIsSet() {
9191
});
9292
}
9393

94+
@Test
95+
void shouldUseSslWhenRocketServerSslIsConfigured() {
96+
reactiveWebContextRunner()
97+
.withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks",
98+
"spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0")
99+
.run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class)
100+
.hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class)
101+
.getBean(RSocketServerFactory.class)
102+
.hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks")
103+
.hasFieldOrPropertyWithValue("ssl.keyPassword", "password"));
104+
}
105+
94106
@Test
95107
void shouldUseCustomServerBootstrap() {
96108
contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context)
Binary file not shown.

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactory.java

Lines changed: 64 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@
3737
import org.springframework.boot.rsocket.server.RSocketServer;
3838
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
3939
import org.springframework.boot.rsocket.server.RSocketServerFactory;
40+
import org.springframework.boot.web.embedded.netty.SslServerCustomizer;
41+
import org.springframework.boot.web.server.Ssl;
42+
import org.springframework.boot.web.server.SslStoreProvider;
4043
import org.springframework.http.client.reactive.ReactorResourceFactory;
4144
import org.springframework.util.Assert;
4245

@@ -45,6 +48,7 @@
4548
* by Netty.
4649
*
4750
* @author Brian Clozel
51+
* @author Chris Bono
4852
* @since 2.2.0
4953
*/
5054
public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory {
@@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
6165

6266
private List<RSocketServerCustomizer> rSocketServerCustomizers = new ArrayList<>();
6367

68+
private Ssl ssl;
69+
70+
private SslStoreProvider sslStoreProvider;
71+
6472
@Override
6573
public void setPort(int port) {
6674
this.port = port;
@@ -76,6 +84,16 @@ public void setTransport(RSocketServer.Transport transport) {
7684
this.transport = transport;
7785
}
7886

87+
@Override
88+
public void setSsl(Ssl ssl) {
89+
this.ssl = ssl;
90+
}
91+
92+
@Override
93+
public void setSslStoreProvider(SslStoreProvider sslStoreProvider) {
94+
this.sslStoreProvider = sslStoreProvider;
95+
}
96+
7997
/**
8098
* Set the {@link ReactorResourceFactory} to get the shared resources from.
8199
* @param resourceFactory the server resources
@@ -133,21 +151,41 @@ private ServerTransport<CloseableChannel> createTransport() {
133151
}
134152

135153
private ServerTransport<CloseableChannel> createWebSocketTransport() {
154+
HttpServer httpServer;
136155
if (this.resourceFactory != null) {
137-
HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
156+
httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
138157
.bindAddress(this::getListenAddress);
139-
return WebsocketServerTransport.create(httpServer);
140158
}
141-
return WebsocketServerTransport.create(getListenAddress());
159+
else {
160+
InetSocketAddress listenAddress = this.getListenAddress();
161+
httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
162+
}
163+
164+
if (this.ssl != null && this.ssl.isEnabled()) {
165+
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
166+
httpServer = sslServerCustomizer.apply(httpServer);
167+
}
168+
169+
return WebsocketServerTransport.create(httpServer);
142170
}
143171

144172
private ServerTransport<CloseableChannel> createTcpTransport() {
173+
TcpServer tcpServer;
145174
if (this.resourceFactory != null) {
146-
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
175+
tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
147176
.bindAddress(this::getListenAddress);
148-
return TcpServerTransport.create(tcpServer);
149177
}
150-
return TcpServerTransport.create(getListenAddress());
178+
else {
179+
InetSocketAddress listenAddress = this.getListenAddress();
180+
tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
181+
}
182+
183+
if (this.ssl != null && this.ssl.isEnabled()) {
184+
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
185+
tcpServer = sslServerCustomizer.apply(tcpServer);
186+
}
187+
188+
return TcpServerTransport.create(tcpServer);
151189
}
152190

153191
private InetSocketAddress getListenAddress() {
@@ -157,4 +195,24 @@ private InetSocketAddress getListenAddress() {
157195
return new InetSocketAddress(this.port);
158196
}
159197

198+
private static final class TcpSslServerCustomizer extends SslServerCustomizer {
199+
200+
private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
201+
super(ssl, null, sslStoreProvider);
202+
}
203+
204+
// This does not override the apply in parent - currently just leveraging the
205+
// parent for its "getContextBuilder()" method. This should be refactored when
206+
// we add the concept of http/tcp customizers for RSocket.
207+
private TcpServer apply(TcpServer server) {
208+
try {
209+
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));
210+
}
211+
catch (Exception ex) {
212+
throw new IllegalStateException(ex);
213+
}
214+
}
215+
216+
}
217+
160218
}

spring-boot-project/spring-boot/src/main/java/org/springframework/boot/rsocket/server/ConfigurableRSocketServerFactory.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@
1818

1919
import java.net.InetAddress;
2020

21+
import org.springframework.boot.web.server.Ssl;
22+
import org.springframework.boot.web.server.SslStoreProvider;
23+
2124
/**
2225
* A configurable {@link RSocketServerFactory}.
2326
*
@@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory {
4548
*/
4649
void setTransport(RSocketServer.Transport transport);
4750

51+
/**
52+
* Sets the SSL configuration that will be applied to the server's default connector.
53+
* @param ssl the SSL configuration
54+
*/
55+
void setSsl(Ssl ssl);
56+
57+
/**
58+
* Sets a provider that will be used to obtain SSL stores.
59+
* @param sslStoreProvider the SSL store provider
60+
*/
61+
void setSslStoreProvider(SslStoreProvider sslStoreProvider);
62+
4863
}

spring-boot-project/spring-boot/src/test/java/org/springframework/boot/rsocket/netty/NettyRSocketServerFactoryTests.java

Lines changed: 99 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,25 +17,34 @@
1717
package org.springframework.boot.rsocket.netty;
1818

1919
import java.net.InetSocketAddress;
20+
import java.nio.channels.ClosedChannelException;
2021
import java.time.Duration;
2122
import java.util.Arrays;
2223
import java.util.concurrent.Callable;
2324

2425
import io.netty.buffer.PooledByteBufAllocator;
26+
import io.netty.handler.ssl.SslContextBuilder;
27+
import io.netty.handler.ssl.SslProvider;
28+
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
2529
import io.rsocket.ConnectionSetupPayload;
2630
import io.rsocket.Payload;
2731
import io.rsocket.RSocket;
2832
import io.rsocket.SocketAcceptor;
33+
import io.rsocket.transport.netty.client.TcpClientTransport;
2934
import io.rsocket.transport.netty.client.WebsocketClientTransport;
3035
import io.rsocket.util.DefaultPayload;
3136
import org.assertj.core.api.Assertions;
3237
import org.junit.jupiter.api.AfterEach;
3338
import org.junit.jupiter.api.Test;
3439
import org.mockito.InOrder;
3540
import reactor.core.publisher.Mono;
41+
import reactor.netty.tcp.TcpClient;
42+
import reactor.test.StepVerifier;
3643

3744
import org.springframework.boot.rsocket.server.RSocketServer;
3845
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
46+
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
47+
import org.springframework.boot.web.server.Ssl;
3948
import org.springframework.core.codec.CharSequenceEncoder;
4049
import org.springframework.core.codec.StringDecoder;
4150
import org.springframework.core.io.buffer.NettyDataBufferFactory;
@@ -45,6 +54,8 @@
4554
import org.springframework.util.SocketUtils;
4655

4756
import static org.assertj.core.api.Assertions.assertThat;
57+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
58+
import static org.assertj.core.api.Assertions.assertThatThrownBy;
4859
import static org.mockito.ArgumentMatchers.any;
4960
import static org.mockito.BDDMockito.will;
5061
import static org.mockito.Mockito.inOrder;
@@ -55,6 +66,7 @@
5566
*
5667
* @author Brian Clozel
5768
* @author Leo Li
69+
* @author Chris Bono
5870
*/
5971
class NettyRSocketServerFactoryTests {
6072

@@ -93,7 +105,7 @@ void specificPort() {
93105
this.server.start();
94106
return port;
95107
});
96-
this.requester = createRSocketTcpClient();
108+
this.requester = createRSocketTcpClient(false);
97109
String payload = "test payload";
98110
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
99111
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
@@ -106,7 +118,7 @@ void websocketTransport() {
106118
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
107119
this.server = factory.create(new EchoRequestResponseAcceptor());
108120
this.server.start();
109-
this.requester = createRSocketWebSocketClient();
121+
this.requester = createRSocketWebSocketClient(false);
110122
String payload = "test payload";
111123
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
112124
assertThat(response).isEqualTo(payload);
@@ -121,7 +133,7 @@ void websocketTransportWithReactorResource() {
121133
factory.setResourceFactory(resourceFactory);
122134
this.server = factory.create(new EchoRequestResponseAcceptor());
123135
this.server.start();
124-
this.requester = createRSocketWebSocketClient();
136+
this.requester = createRSocketWebSocketClient(false);
125137
String payload = "test payload";
126138
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
127139
assertThat(response).isEqualTo(payload);
@@ -144,16 +156,94 @@ void serverCustomizers() {
144156
}
145157
}
146158

147-
private RSocketRequester createRSocketTcpClient() {
148-
Assertions.assertThat(this.server).isNotNull();
149-
InetSocketAddress address = this.server.address();
150-
return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort());
159+
@Test
160+
void tcpTransportBasicSslFromClassPath() {
161+
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP);
162+
}
163+
164+
@Test
165+
void tcpTransportBasicSslFromFileSystem() {
166+
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP);
167+
}
168+
169+
@Test
170+
void websocketTransportBasicSslFromClassPath() {
171+
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET);
172+
}
173+
174+
@Test
175+
void websocketTransportBasicSslFromFileSystem() {
176+
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
177+
}
178+
179+
private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
180+
NettyRSocketServerFactory factory = getFactory();
181+
factory.setTransport(transport);
182+
Ssl ssl = new Ssl();
183+
ssl.setKeyStore(keyStore);
184+
ssl.setKeyPassword(keyPassword);
185+
factory.setSsl(ssl);
186+
this.server = factory.create(new EchoRequestResponseAcceptor());
187+
this.server.start();
188+
this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true)
189+
: createRSocketWebSocketClient(true);
190+
String payload = "test payload";
191+
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
192+
StepVerifier.create(responseMono).expectNext(payload).verifyComplete();
151193
}
152194

153-
private RSocketRequester createRSocketWebSocketClient() {
195+
@Test
196+
void tcpTransportSslRejectsInsecureClient() {
197+
NettyRSocketServerFactory factory = getFactory();
198+
factory.setTransport(Transport.TCP);
199+
Ssl ssl = new Ssl();
200+
ssl.setKeyStore("classpath:test.jks");
201+
ssl.setKeyPassword("password");
202+
factory.setSsl(ssl);
203+
this.server = factory.create(new EchoRequestResponseAcceptor());
204+
this.server.start();
205+
this.requester = createRSocketTcpClient(false);
206+
String payload = "test payload";
207+
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
208+
StepVerifier.create(responseMono)
209+
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
210+
}
211+
212+
@Test
213+
void websocketTransportSslRejectsInsecureClient() {
214+
NettyRSocketServerFactory factory = getFactory();
215+
factory.setTransport(Transport.WEBSOCKET);
216+
Ssl ssl = new Ssl();
217+
ssl.setKeyStore("classpath:test.jks");
218+
ssl.setKeyPassword("password");
219+
factory.setSsl(ssl);
220+
this.server = factory.create(new EchoRequestResponseAcceptor());
221+
this.server.start();
222+
// For WebSocket, the SSL failure results in a hang on the initial connect call
223+
assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class)
224+
.hasStackTraceContaining("Timeout on blocking read");
225+
}
226+
227+
private RSocketRequester createRSocketTcpClient(boolean ssl) {
228+
TcpClient tcpClient = createTcpClient(ssl);
229+
return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT);
230+
}
231+
232+
private RSocketRequester createRSocketWebSocketClient(boolean ssl) {
233+
TcpClient tcpClient = createTcpClient(ssl);
234+
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT);
235+
}
236+
237+
private TcpClient createTcpClient(boolean ssl) {
154238
Assertions.assertThat(this.server).isNotNull();
155239
InetSocketAddress address = this.server.address();
156-
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address));
240+
TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort());
241+
if (ssl) {
242+
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
243+
.trustManager(InsecureTrustManagerFactory.INSTANCE);
244+
tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder));
245+
}
246+
return tcpClient;
157247
}
158248

159249
private RSocketRequester.Builder createRSocketRequesterBuilder() {

0 commit comments

Comments
 (0)