Skip to content

Commit 6a18dae

Browse files
committed
Return AbstractSubscribableChannel from @bean methods
Declare SubscribableChannel @beans in WebSocketMessageBrokerConfigurationSupport as AbstractSubscribableChannel to avoid the need for casting when registering interceptors. Issue: SPR-11065
1 parent 0340cc5 commit 6a18dae

File tree

3 files changed

+76
-90
lines changed

3 files changed

+76
-90
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupport.java

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,15 @@
1616

1717
package org.springframework.messaging.simp.config;
1818

19-
import java.util.ArrayList;
20-
import java.util.List;
21-
2219
import org.springframework.context.annotation.Bean;
2320
import org.springframework.messaging.Message;
24-
import org.springframework.messaging.SubscribableChannel;
2521
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
2622
import org.springframework.messaging.simp.SimpMessageSendingOperations;
2723
import org.springframework.messaging.simp.SimpMessagingTemplate;
2824
import org.springframework.messaging.simp.handler.*;
29-
import org.springframework.messaging.simp.handler.SimpAnnotationMethodMessageHandler;
25+
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
3026
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
31-
import org.springframework.messaging.support.converter.ByteArrayMessageConverter;
32-
import org.springframework.messaging.support.converter.CompositeMessageConverter;
33-
import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
34-
import org.springframework.messaging.support.converter.MappingJackson2MessageConverter;
35-
import org.springframework.messaging.support.converter.MessageConverter;
36-
import org.springframework.messaging.support.converter.StringMessageConverter;
27+
import org.springframework.messaging.support.converter.*;
3728
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
3829
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
3930
import org.springframework.util.ClassUtils;
@@ -44,6 +35,9 @@
4435
import org.springframework.web.socket.WebSocketHandler;
4536
import org.springframework.web.socket.server.config.SockJsServiceRegistration;
4637

38+
import java.util.ArrayList;
39+
import java.util.List;
40+
4741

4842
/**
4943
* Configuration support for broker-backed messaging over WebSocket using a higher-level
@@ -118,12 +112,12 @@ protected void registerStompEndpoints(StompEndpointRegistry registry) {
118112
}
119113

120114
@Bean
121-
public SubscribableChannel webSocketRequestChannel() {
115+
public AbstractSubscribableChannel webSocketRequestChannel() {
122116
return new ExecutorSubscribableChannel(webSocketChannelExecutor());
123117
}
124118

125119
@Bean
126-
public SubscribableChannel webSocketResponseChannel() {
120+
public AbstractSubscribableChannel webSocketResponseChannel() {
127121
return new ExecutorSubscribableChannel(webSocketChannelExecutor());
128122
}
129123

@@ -209,7 +203,7 @@ public SimpMessageSendingOperations brokerMessagingTemplate() {
209203
}
210204

211205
@Bean
212-
public SubscribableChannel brokerChannel() {
206+
public AbstractSubscribableChannel brokerChannel() {
213207
return new ExecutorSubscribableChannel(); // synchronous
214208
}
215209

spring-messaging/src/test/java/org/springframework/messaging/simp/config/WebSocketMessageBrokerConfigurationSupportTests.java

Lines changed: 65 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,13 @@
1616

1717
package org.springframework.messaging.simp.config;
1818

19-
import java.util.List;
20-
import java.util.Map;
21-
2219
import org.junit.Before;
2320
import org.junit.Test;
24-
import org.mockito.ArgumentCaptor;
25-
import org.mockito.Mockito;
2621
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
2722
import org.springframework.context.annotation.Bean;
2823
import org.springframework.context.annotation.Configuration;
2924
import org.springframework.messaging.Message;
3025
import org.springframework.messaging.MessageHandler;
31-
import org.springframework.messaging.SubscribableChannel;
3226
import org.springframework.messaging.handler.annotation.MessageMapping;
3327
import org.springframework.messaging.handler.annotation.SendTo;
3428
import org.springframework.messaging.handler.websocket.SubProtocolWebSocketHandler;
@@ -43,6 +37,8 @@
4337
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
4438
import org.springframework.messaging.simp.stomp.StompTextMessageBuilder;
4539
import org.springframework.messaging.support.MessageBuilder;
40+
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
41+
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
4642
import org.springframework.messaging.support.converter.CompositeMessageConverter;
4743
import org.springframework.messaging.support.converter.DefaultContentTypeResolver;
4844
import org.springframework.stereotype.Controller;
@@ -52,9 +48,11 @@
5248
import org.springframework.web.socket.TextMessage;
5349
import org.springframework.web.socket.support.TestWebSocketSession;
5450

51+
import java.util.ArrayList;
52+
import java.util.List;
53+
import java.util.Map;
54+
5555
import static org.junit.Assert.*;
56-
import static org.mockito.Matchers.*;
57-
import static org.mockito.Mockito.*;
5856

5957

6058
/**
@@ -95,27 +93,20 @@ public void handlerMapping() {
9593
@Test
9694
public void webSocketRequestChannel() {
9795

98-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
99-
100-
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
101-
verify(channel, times(3)).subscribe(captor.capture());
96+
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
97+
List<MessageHandler> handlers = channel.handlers;
10298

103-
List<MessageHandler> values = captor.getAllValues();
104-
assertEquals(3, values.size());
105-
106-
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
107-
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
108-
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
99+
assertEquals(3, handlers.size());
100+
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
101+
assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
102+
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
109103
}
110104

111105
@Test
112106
public void webSocketRequestChannelWithStompBroker() {
113-
SubscribableChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
107+
TestChannel channel = this.cxtStompBroker.getBean("webSocketRequestChannel", TestChannel.class);
108+
List<MessageHandler> values = channel.handlers;
114109

115-
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
116-
verify(channel, times(3)).subscribe(captor.capture());
117-
118-
List<MessageHandler> values = captor.getAllValues();
119110
assertEquals(3, values.size());
120111
assertTrue(values.contains(cxtStompBroker.getBean(SimpAnnotationMethodMessageHandler.class)));
121112
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
@@ -125,16 +116,13 @@ public void webSocketRequestChannelWithStompBroker() {
125116
@Test
126117
public void webSocketRequestChannelSendMessage() throws Exception {
127118

128-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", SubscribableChannel.class);
119+
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketRequestChannel", TestChannel.class);
129120
SubProtocolWebSocketHandler webSocketHandler = this.cxtSimpleBroker.getBean(SubProtocolWebSocketHandler.class);
130121

131122
TextMessage textMessage = StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build();
132123
webSocketHandler.handleMessage(new TestWebSocketSession(), textMessage);
133124

134-
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
135-
verify(channel).send(captor.capture());
136-
137-
Message message = captor.getValue();
125+
Message<?> message = channel.messages.get(0);
138126
StompHeaderAccessor headers = StompHeaderAccessor.wrap(message);
139127

140128
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@@ -143,15 +131,17 @@ public void webSocketRequestChannelSendMessage() throws Exception {
143131

144132
@Test
145133
public void webSocketResponseChannel() {
146-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
147-
verify(channel).subscribe(any(SubProtocolWebSocketHandler.class));
148-
verifyNoMoreInteractions(channel);
134+
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
135+
List<MessageHandler> values = channel.handlers;
136+
137+
assertEquals(1, values.size());
138+
assertTrue(values.get(0) instanceof SubProtocolWebSocketHandler);
149139
}
150140

151141
@Test
152142
public void webSocketResponseChannelUsedByAnnotatedMethod() {
153143

154-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
144+
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
155145
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
156146

157147
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@@ -160,12 +150,9 @@ public void webSocketResponseChannelUsedByAnnotatedMethod() {
160150
headers.setDestination("/foo");
161151
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
162152

163-
when(channel.send(any(Message.class))).thenReturn(true);
164153
messageHandler.handleMessage(message);
165154

166-
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
167-
verify(channel).send(captor.capture());
168-
message = captor.getValue();
155+
message = channel.messages.get(0);
169156
headers = StompHeaderAccessor.wrap(message);
170157

171158
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@@ -175,7 +162,7 @@ public void webSocketResponseChannelUsedByAnnotatedMethod() {
175162

176163
@Test
177164
public void webSocketResponseChannelUsedBySimpleBroker() {
178-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", SubscribableChannel.class);
165+
TestChannel channel = this.cxtSimpleBroker.getBean("webSocketResponseChannel", TestChannel.class);
179166
SimpleBrokerMessageHandler broker = this.cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class);
180167

181168
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SUBSCRIBE);
@@ -193,12 +180,9 @@ public void webSocketResponseChannelUsedBySimpleBroker() {
193180
message = MessageBuilder.withPayload("bar".getBytes()).setHeaders(headers).build();
194181

195182
// message
196-
when(channel.send(any(Message.class))).thenReturn(true);
197183
broker.handleMessage(message);
198184

199-
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
200-
verify(channel).send(captor.capture());
201-
message = captor.getValue();
185+
message = channel.messages.get(0);
202186
headers = StompHeaderAccessor.wrap(message);
203187

204188
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@@ -208,45 +192,36 @@ public void webSocketResponseChannelUsedBySimpleBroker() {
208192

209193
@Test
210194
public void brokerChannel() {
211-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
212-
213-
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
214-
verify(channel, times(2)).subscribe(captor.capture());
195+
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
196+
List<MessageHandler> handlers = channel.handlers;
215197

216-
List<MessageHandler> values = captor.getAllValues();
217-
assertEquals(2, values.size());
218-
assertTrue(values.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
219-
assertTrue(values.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
198+
assertEquals(2, handlers.size());
199+
assertTrue(handlers.contains(cxtSimpleBroker.getBean(UserDestinationMessageHandler.class)));
200+
assertTrue(handlers.contains(cxtSimpleBroker.getBean(SimpleBrokerMessageHandler.class)));
220201
}
221202

222203
@Test
223204
public void brokerChannelWithStompBroker() {
224-
SubscribableChannel channel = this.cxtStompBroker.getBean("brokerChannel", SubscribableChannel.class);
225-
226-
ArgumentCaptor<MessageHandler> captor = ArgumentCaptor.forClass(MessageHandler.class);
227-
verify(channel, times(2)).subscribe(captor.capture());
205+
TestChannel channel = this.cxtStompBroker.getBean("brokerChannel", TestChannel.class);
206+
List<MessageHandler> handlers = channel.handlers;
228207

229-
List<MessageHandler> values = captor.getAllValues();
230-
assertEquals(2, values.size());
231-
assertTrue(values.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
232-
assertTrue(values.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
208+
assertEquals(2, handlers.size());
209+
assertTrue(handlers.contains(cxtStompBroker.getBean(UserDestinationMessageHandler.class)));
210+
assertTrue(handlers.contains(cxtStompBroker.getBean(StompBrokerRelayMessageHandler.class)));
233211
}
234212

235213
@Test
236214
public void brokerChannelUsedByAnnotatedMethod() {
237-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
215+
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
238216
SimpAnnotationMethodMessageHandler messageHandler = this.cxtSimpleBroker.getBean(SimpAnnotationMethodMessageHandler.class);
239217

240218
StompHeaderAccessor headers = StompHeaderAccessor.create(StompCommand.SEND);
241219
headers.setDestination("/foo");
242220
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
243221

244-
when(channel.send(any(Message.class))).thenReturn(true);
245222
messageHandler.handleMessage(message);
246223

247-
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
248-
verify(channel).send(captor.capture());
249-
message = captor.getValue();
224+
message = channel.messages.get(0);
250225
headers = StompHeaderAccessor.wrap(message);
251226

252227
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@@ -256,7 +231,7 @@ public void brokerChannelUsedByAnnotatedMethod() {
256231

257232
@Test
258233
public void brokerChannelUsedByUserDestinationMessageHandler() {
259-
SubscribableChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", SubscribableChannel.class);
234+
TestChannel channel = this.cxtSimpleBroker.getBean("brokerChannel", TestChannel.class);
260235
UserDestinationMessageHandler messageHandler = this.cxtSimpleBroker.getBean(UserDestinationMessageHandler.class);
261236

262237
this.cxtSimpleBroker.getBean(UserSessionRegistry.class).registerSessionId("joe", "s1");
@@ -265,12 +240,9 @@ public void brokerChannelUsedByUserDestinationMessageHandler() {
265240
headers.setDestination("/user/joe/foo");
266241
Message<?> message = MessageBuilder.withPayload(new byte[0]).setHeaders(headers).build();
267242

268-
when(channel.send(any(Message.class))).thenReturn(true);
269243
messageHandler.handleMessage(message);
270244

271-
ArgumentCaptor<Message> captor = ArgumentCaptor.forClass(Message.class);
272-
verify(channel).send(captor.capture());
273-
message = captor.getValue();
245+
message = channel.messages.get(0);
274246
headers = StompHeaderAccessor.wrap(message);
275247

276248
assertEquals(SimpMessageType.MESSAGE, headers.getMessageType());
@@ -340,19 +312,39 @@ static class TestWebSocketMessageBrokerConfiguration extends DelegatingWebSocket
340312

341313
@Override
342314
@Bean
343-
public SubscribableChannel webSocketRequestChannel() {
344-
return Mockito.mock(SubscribableChannel.class);
315+
public AbstractSubscribableChannel webSocketRequestChannel() {
316+
return new TestChannel();
345317
}
346318

347319
@Override
348320
@Bean
349-
public SubscribableChannel webSocketResponseChannel() {
350-
return Mockito.mock(SubscribableChannel.class);
321+
public AbstractSubscribableChannel webSocketResponseChannel() {
322+
return new TestChannel();
323+
}
324+
325+
@Override
326+
public AbstractSubscribableChannel brokerChannel() {
327+
return new TestChannel();
328+
}
329+
}
330+
331+
private static class TestChannel extends ExecutorSubscribableChannel {
332+
333+
private final List<MessageHandler> handlers = new ArrayList<>();
334+
335+
private final List<Message<?>> messages = new ArrayList<>();
336+
337+
338+
@Override
339+
public boolean subscribeInternal(MessageHandler handler) {
340+
this.handlers.add(handler);
341+
return super.subscribeInternal(handler);
351342
}
352343

353344
@Override
354-
public SubscribableChannel brokerChannel() {
355-
return Mockito.mock(SubscribableChannel.class);
345+
public boolean sendInternal(Message<?> message, long timeout) {
346+
this.messages.add(message);
347+
return true;
356348
}
357349
}
358350

spring-messaging/src/test/java/org/springframework/messaging/simp/handler/SimpAnnotationMethodIntegrationTests.java

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,14 @@
3434
import org.springframework.context.annotation.Bean;
3535
import org.springframework.context.annotation.ComponentScan;
3636
import org.springframework.context.annotation.Configuration;
37-
import org.springframework.messaging.SubscribableChannel;
3837
import org.springframework.messaging.handler.annotation.MessageExceptionHandler;
3938
import org.springframework.messaging.handler.annotation.MessageMapping;
4039
import org.springframework.messaging.simp.config.DelegatingWebSocketMessageBrokerConfiguration;
4140
import org.springframework.messaging.simp.config.MessageBrokerConfigurer;
4241
import org.springframework.messaging.simp.config.StompEndpointRegistry;
4342
import org.springframework.messaging.simp.config.WebSocketMessageBrokerConfigurer;
4443
import org.springframework.messaging.simp.stomp.StompCommand;
44+
import org.springframework.messaging.support.channel.AbstractSubscribableChannel;
4545
import org.springframework.messaging.support.channel.ExecutorSubscribableChannel;
4646
import org.springframework.stereotype.Controller;
4747
import org.springframework.web.socket.AbstractWebSocketIntegrationTests;
@@ -227,13 +227,13 @@ static class TestMessageBrokerConfiguration extends DelegatingWebSocketMessageBr
227227

228228
@Override
229229
@Bean
230-
public SubscribableChannel webSocketRequestChannel() {
230+
public AbstractSubscribableChannel webSocketRequestChannel() {
231231
return new ExecutorSubscribableChannel(); // synchronous
232232
}
233233

234234
@Override
235235
@Bean
236-
public SubscribableChannel webSocketResponseChannel() {
236+
public AbstractSubscribableChannel webSocketResponseChannel() {
237237
return new ExecutorSubscribableChannel(); // synchronous
238238
}
239239
}

0 commit comments

Comments
 (0)