2525import java .util .Set ;
2626import java .util .concurrent .ConcurrentHashMap ;
2727import java .util .concurrent .atomic .AtomicInteger ;
28+ import java .util .function .Consumer ;
2829
2930import org .apache .commons .logging .Log ;
3031import org .apache .commons .logging .LogFactory ;
@@ -108,10 +109,9 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
108109 @ Nullable
109110 private MessageHeaderInitializer headerInitializer ;
110111
111- @ Nullable
112- private Map <String , MessageChannel > orderedHandlingMessageChannels ;
112+ private final Map <String , SessionInfo > sessions = new ConcurrentHashMap <>();
113113
114- private final Map < String , Principal > stompAuthentications = new ConcurrentHashMap <>() ;
114+ private boolean preserveReceiveOrder ;
115115
116116 @ Nullable
117117 private Boolean immutableMessageInterceptorPresent ;
@@ -208,7 +208,7 @@ public MessageHeaderInitializer getHeaderInitializer() {
208208 * @since 6.1
209209 */
210210 public void setPreserveReceiveOrder (boolean preserveReceiveOrder ) {
211- this .orderedHandlingMessageChannels = ( preserveReceiveOrder ? new ConcurrentHashMap <>() : null ) ;
211+ this .preserveReceiveOrder = preserveReceiveOrder ;
212212 }
213213
214214 /**
@@ -217,7 +217,7 @@ public void setPreserveReceiveOrder(boolean preserveReceiveOrder) {
217217 * @since 6.1
218218 */
219219 public boolean isPreserveReceiveOrder () {
220- return ( this .orderedHandlingMessageChannels != null ) ;
220+ return this .preserveReceiveOrder ;
221221 }
222222
223223 @ Override
@@ -252,7 +252,7 @@ public Stats getStats() {
252252 */
253253 @ Override
254254 public void handleMessageFromClient (WebSocketSession session ,
255- WebSocketMessage <?> webSocketMessage , MessageChannel targetChannel ) {
255+ WebSocketMessage <?> webSocketMessage , MessageChannel channel ) {
256256
257257 List <Message <byte []>> messages ;
258258 try {
@@ -295,35 +295,36 @@ else if (webSocketMessage instanceof BinaryMessage binaryMessage) {
295295 return ;
296296 }
297297
298- MessageChannel channelToUse = targetChannel ;
299- if (this .orderedHandlingMessageChannels != null ) {
300- channelToUse = this .orderedHandlingMessageChannels .computeIfAbsent (
301- session .getId (), id -> new OrderedMessageChannelDecorator (targetChannel , logger ));
302- }
298+ SessionInfo info = this .sessions .get (session .getId ());
299+ MessageChannel channelToUse = (info != null ? info .getMessageChannelToUse () : null );
303300
304301 for (Message <byte []> message : messages ) {
305- StompHeaderAccessor headerAccessor =
306- MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
302+ StompHeaderAccessor headerAccessor = MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
307303 Assert .state (headerAccessor != null , "No StompHeaderAccessor" );
308304
309305 StompCommand command = headerAccessor .getCommand ();
310- boolean isConnect = StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command );
311-
306+ boolean isConnect = ( StompCommand .CONNECT .equals (command ) || StompCommand .STOMP .equals (command ) );
307+ String sessionId = session . getId ();
312308 boolean sent = false ;
309+
313310 try {
311+ if (isConnect ) {
312+ channelToUse = (this .preserveReceiveOrder ? new OrderedMessageChannelDecorator (channel , logger ) : channel );
313+ info = new SessionInfo (channelToUse , session .getPrincipal ());
314+ SessionInfo prevInfo = this .sessions .putIfAbsent (sessionId , info );
315+ Assert .state (prevInfo == null , "Session already exists" );
316+ headerAccessor .setUserChangeCallback (info );
317+ }
318+ else {
319+ Assert .state (channelToUse != null , "Unknown session: " + sessionId );
320+ }
314321
315- headerAccessor .setSessionId (session . getId () );
322+ headerAccessor .setSessionId (sessionId );
316323 headerAccessor .setSessionAttributes (session .getAttributes ());
317324 headerAccessor .setUser (getUser (session ));
318- if (isConnect ) {
319- headerAccessor .setUserChangeCallback (user -> {
320- if (user != null && user != session .getPrincipal ()) {
321- this .stompAuthentications .put (session .getId (), user );
322- }
323- });
324- }
325325 headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
326- if (!detectImmutableMessageInterceptor (targetChannel )) {
326+
327+ if (!detectImmutableMessageInterceptor (channel )) {
327328 headerAccessor .setImmutable ();
328329 }
329330
@@ -363,24 +364,29 @@ else if (StompCommand.UNSUBSCRIBE.equals(command)) {
363364 }
364365 catch (Throwable ex ) {
365366 if (logger .isDebugEnabled ()) {
366- logger .debug ("Failed to send message to MessageChannel in session " + session . getId () , ex );
367+ logger .debug ("Failed to send message to MessageChannel in session " + sessionId , ex );
367368 }
368369 else if (logger .isErrorEnabled ()) {
369370 // Skip for unsent CONNECT or SUBSCRIBE (likely authentication/authorization issues)
370371 if (sent || !(isConnect || StompCommand .SUBSCRIBE .equals (command ))) {
371372 logger .error ("Failed to send message to MessageChannel in session " +
372- session . getId () + ":" + ex .getMessage ());
373+ sessionId + ":" + ex .getMessage ());
373374 }
374375 }
375376 handleError (session , ex , message );
376377 }
378+
379+ if (!sent && isConnect ) {
380+ this .sessions .remove (sessionId );
381+ break ;
382+ }
377383 }
378384 }
379385
380386 @ Nullable
381387 private Principal getUser (WebSocketSession session ) {
382- Principal user = this .stompAuthentications .get (session .getId ());
383- return (user != null ? user : session .getPrincipal ());
388+ SessionInfo info = this .sessions .get (session .getId ());
389+ return (info != null ? info . getUser () : session .getPrincipal ());
384390 }
385391
386392 private void handleError (WebSocketSession session , Throwable ex , @ Nullable Message <byte []> clientMessage ) {
@@ -685,10 +691,7 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
685691 outputChannel .send (message );
686692 }
687693 finally {
688- if (this .orderedHandlingMessageChannels != null ) {
689- this .orderedHandlingMessageChannels .remove (session .getId ());
690- }
691- this .stompAuthentications .remove (session .getId ());
694+ this .sessions .remove (session .getId ());
692695 SimpAttributesContextHolder .resetAttributes ();
693696 simpAttributes .sessionCompleted ();
694697 }
@@ -718,6 +721,39 @@ public String toString() {
718721 }
719722
720723
724+ private static class SessionInfo implements Consumer <Principal > {
725+
726+ private final MessageChannel channel ;
727+
728+ @ Nullable
729+ private final Principal webSocketUser ;
730+
731+ @ Nullable
732+ private volatile Principal stompUser ;
733+
734+ SessionInfo (MessageChannel channel , @ Nullable Principal user ) {
735+ this .channel = channel ;
736+ this .webSocketUser = user ;
737+ }
738+
739+ public MessageChannel getMessageChannelToUse () {
740+ return this .channel ;
741+ }
742+
743+ @ Nullable
744+ public Principal getUser () {
745+ return (this .stompUser != null ? this .stompUser : this .webSocketUser );
746+ }
747+
748+ @ Override
749+ public void accept (@ Nullable Principal stompUser ) {
750+ if (stompUser != null && stompUser != this .webSocketUser ) {
751+ this .stompUser = stompUser ;
752+ }
753+ }
754+ }
755+
756+
721757 /**
722758 * Contract for access to session counters.
723759 * @since 5.2
0 commit comments