Skip to content

Commit c88bfc5

Browse files
committed
Refactor state management in StompSubProtocolHandler
Closes gh-35591
1 parent a96558c commit c88bfc5

File tree

3 files changed

+80
-34
lines changed

3 files changed

+80
-34
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 68 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.Set;
2626
import java.util.concurrent.ConcurrentHashMap;
2727
import java.util.concurrent.atomic.AtomicInteger;
28+
import java.util.function.Consumer;
2829

2930
import org.apache.commons.logging.Log;
3031
import org.apache.commons.logging.LogFactory;
@@ -108,10 +109,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
108109
@Nullable
109110
private MessageHeaderInitializer headerInitializer;
110111

111-
@Nullable
112-
private Map<String, MessageChannel> orderedHandlingMessageChannels;
112+
private final Map<String, SessionInfo> sessions = new ConcurrentHashMap<>();
113113

114-
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<>();
114+
private boolean preserveReceiveOrder;
115115

116116
@Nullable
117117
private Boolean immutableMessageInterceptorPresent;
@@ -208,7 +208,7 @@ public MessageHeaderInitializer getHeaderInitializer() {
208208
* @since 6.1
209209
*/
210210
public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
211-
this.orderedHandlingMessageChannels = (preserveReceiveOrder ? new ConcurrentHashMap<>() : null);
211+
this.preserveReceiveOrder = preserveReceiveOrder;
212212
}
213213

214214
/**
@@ -217,7 +217,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
217217
* @since 6.1
218218
*/
219219
public boolean isPreserveReceiveOrder() {
220-
return (this.orderedHandlingMessageChannels != null);
220+
return this.preserveReceiveOrder;
221221
}
222222

223223
@Override
@@ -252,7 +252,7 @@ public Stats getStats() {
252252
*/
253253
@Override
254254
public void handleMessageFromClient(WebSocketSession session,
255-
WebSocketMessage<?> webSocketMessage, MessageChannel targetChannel) {
255+
WebSocketMessage<?> webSocketMessage, MessageChannel channel) {
256256

257257
List<Message<byte[]>> messages;
258258
try {
@@ -295,35 +295,36 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
295295
return;
296296
}
297297

298-
MessageChannel channelToUse = targetChannel;
299-
if (this.orderedHandlingMessageChannels != null) {
300-
channelToUse = this.orderedHandlingMessageChannels.computeIfAbsent(
301-
session.getId(), id -> new OrderedMessageChannelDecorator(targetChannel, logger));
302-
}
298+
SessionInfo info = this.sessions.get(session.getId());
299+
MessageChannel channelToUse = (info != null ? info.getMessageChannelToUse() : null);
303300

304301
for (Message<byte[]> message : messages) {
305-
StompHeaderAccessor headerAccessor =
306-
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
302+
StompHeaderAccessor headerAccessor = MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
307303
Assert.state(headerAccessor != null, "No StompHeaderAccessor");
308304

309305
StompCommand command = headerAccessor.getCommand();
310-
boolean isConnect = StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command);
311-
306+
boolean isConnect = (StompCommand.CONNECT.equals(command) || StompCommand.STOMP.equals(command));
307+
String sessionId = session.getId();
312308
boolean sent = false;
309+
313310
try {
311+
if (isConnect) {
312+
channelToUse = (this.preserveReceiveOrder ? new OrderedMessageChannelDecorator(channel, logger) : channel);
313+
info = new SessionInfo(channelToUse, session.getPrincipal());
314+
SessionInfo prevInfo = this.sessions.putIfAbsent(sessionId, info);
315+
Assert.state(prevInfo == null, "Session already exists");
316+
headerAccessor.setUserChangeCallback(info);
317+
}
318+
else {
319+
Assert.state(channelToUse != null, "Unknown session: " + sessionId);
320+
}
314321

315-
headerAccessor.setSessionId(session.getId());
322+
headerAccessor.setSessionId(sessionId);
316323
headerAccessor.setSessionAttributes(session.getAttributes());
317324
headerAccessor.setUser(getUser(session));
318-
if (isConnect) {
319-
headerAccessor.setUserChangeCallback(user -> {
320-
if (user != null && user != session.getPrincipal()) {
321-
this.stompAuthentications.put(session.getId(), user);
322-
}
323-
});
324-
}
325325
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
326-
if (!detectImmutableMessageInterceptor(targetChannel)) {
326+
327+
if (!detectImmutableMessageInterceptor(channel)) {
327328
headerAccessor.setImmutable();
328329
}
329330

@@ -363,24 +364,29 @@ else if (StompCommand.UNSUBSCRIBE.equals(command)) {
363364
}
364365
catch (Throwable ex) {
365366
if (logger.isDebugEnabled()) {
366-
logger.debug("Failed to send message to MessageChannel in session " + session.getId(), ex);
367+
logger.debug("Failed to send message to MessageChannel in session " + sessionId, ex);
367368
}
368369
else if (logger.isErrorEnabled()) {
369370
// Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
370371
if (sent || !(isConnect || StompCommand.SUBSCRIBE.equals(command))) {
371372
logger.error("Failed to send message to MessageChannel in session " +
372-
session.getId() + ":" + ex.getMessage());
373+
sessionId + ":" + ex.getMessage());
373374
}
374375
}
375376
handleError(session, ex, message);
376377
}
378+
379+
if (!sent && isConnect) {
380+
this.sessions.remove(sessionId);
381+
break;
382+
}
377383
}
378384
}
379385

380386
@Nullable
381387
private Principal getUser(WebSocketSession session) {
382-
Principal user = this.stompAuthentications.get(session.getId());
383-
return (user != null ? user : session.getPrincipal());
388+
SessionInfo info = this.sessions.get(session.getId());
389+
return (info != null ? info.getUser() : session.getPrincipal());
384390
}
385391

386392
private void handleError(WebSocketSession session, Throwable ex, @Nullable Message<byte[]> clientMessage) {
@@ -685,10 +691,7 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
685691
outputChannel.send(message);
686692
}
687693
finally {
688-
if (this.orderedHandlingMessageChannels != null) {
689-
this.orderedHandlingMessageChannels.remove(session.getId());
690-
}
691-
this.stompAuthentications.remove(session.getId());
694+
this.sessions.remove(session.getId());
692695
SimpAttributesContextHolder.resetAttributes();
693696
simpAttributes.sessionCompleted();
694697
}
@@ -718,6 +721,39 @@ public String toString() {
718721
}
719722

720723

724+
private static class SessionInfo implements Consumer<Principal> {
725+
726+
private final MessageChannel channel;
727+
728+
@Nullable
729+
private final Principal webSocketUser;
730+
731+
@Nullable
732+
private volatile Principal stompUser;
733+
734+
SessionInfo(MessageChannel channel, @Nullable Principal user) {
735+
this.channel = channel;
736+
this.webSocketUser = user;
737+
}
738+
739+
public MessageChannel getMessageChannelToUse() {
740+
return this.channel;
741+
}
742+
743+
@Nullable
744+
public Principal getUser() {
745+
return (this.stompUser != null ? this.stompUser : this.webSocketUser);
746+
}
747+
748+
@Override
749+
public void accept(@Nullable Principal stompUser) {
750+
if (stompUser != null && stompUser != this.webSocketUser) {
751+
this.stompUser = stompUser;
752+
}
753+
}
754+
}
755+
756+
721757
/**
722758
* Contract for access to session counters.
723759
* @since 5.2

spring-websocket/src/test/java/org/springframework/web/socket/config/annotation/WebSocketMessageBrokerConfigurationSupportTests.java

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,13 +101,22 @@ void clientInboundChannelSendMessage() throws Exception {
101101
session.setOpen(true);
102102
webSocketHandler.afterConnectionEstablished(session);
103103

104+
webSocketHandler.handleMessage(session,
105+
StompTextMessageBuilder.create(StompCommand.CONNECT).headers("destination:/foo").build());
106+
104107
webSocketHandler.handleMessage(session,
105108
StompTextMessageBuilder.create(StompCommand.SEND).headers("destination:/foo").build());
106109

107110
Message<?> message = channel.messages.get(0);
108111
StompHeaderAccessor accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
109112
assertThat(accessor).isNotNull();
110113
assertThat(accessor.isMutable()).isFalse();
114+
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.CONNECT);
115+
116+
message = channel.messages.get(1);
117+
accessor = StompHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
118+
assertThat(accessor).isNotNull();
119+
assertThat(accessor.isMutable()).isFalse();
111120
assertThat(accessor.getMessageType()).isEqualTo(SimpMessageType.MESSAGE);
112121
assertThat(accessor.getDestination()).isEqualTo("/foo");
113122
}

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompWebSocketIntegrationTests.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,10 @@ void sendMessageToController(
8989

9090
super.setup(server, webSocketClient, testInfo);
9191

92-
TextMessage message = create(StompCommand.SEND).headers("destination:/app/simple").build();
92+
TextMessage m1 = create(StompCommand.CONNECT).headers("accept-version:1.1").build();
93+
TextMessage m2 = create(StompCommand.SEND).headers("destination:/app/simple").build();
9394

94-
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, message), "/ws").get()) {
95+
try (WebSocketSession session = execute(new TestClientWebSocketHandler(0, m1, m2), "/ws").get()) {
9596
assertThat(session).isNotNull();
9697
SimpleController controller = this.wac.getBean(SimpleController.class);
9798
assertThat(controller.latch.await(TIMEOUT, TimeUnit.SECONDS)).isTrue();

0 commit comments

Comments
 (0)