@@ -101,6 +101,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
101101
102102 private MessageHeaderInitializer headerInitializer ;
103103
104+ private final Map <String , Principal > stompAuthentications = new ConcurrentHashMap <String , Principal >();
105+
104106 private Boolean immutableMessageInterceptorPresent ;
105107
106108 private ApplicationEventPublisher eventPublisher ;
@@ -247,11 +249,10 @@ else if (webSocketMessage instanceof BinaryMessage) {
247249 try {
248250 StompHeaderAccessor headerAccessor =
249251 MessageHeaderAccessor .getAccessor (message , StompHeaderAccessor .class );
250- Principal user = session .getPrincipal ();
251252
252253 headerAccessor .setSessionId (session .getId ());
253254 headerAccessor .setSessionAttributes (session .getAttributes ());
254- headerAccessor .setUser (user );
255+ headerAccessor .setUser (getUser ( session ) );
255256 headerAccessor .setHeader (SimpMessageHeaderAccessor .HEART_BEAT_HEADER , headerAccessor .getHeartbeat ());
256257 if (!detectImmutableMessageInterceptor (outputChannel )) {
257258 headerAccessor .setImmutable ();
@@ -261,7 +262,8 @@ else if (webSocketMessage instanceof BinaryMessage) {
261262 logger .trace ("From client: " + headerAccessor .getShortLogMessage (message .getPayload ()));
262263 }
263264
264- if (StompCommand .CONNECT .equals (headerAccessor .getCommand ())) {
265+ boolean isConnect = StompCommand .CONNECT .equals (headerAccessor .getCommand ());
266+ if (isConnect ) {
265267 this .stats .incrementConnectCount ();
266268 }
267269 else if (StompCommand .DISCONNECT .equals (headerAccessor .getCommand ())) {
@@ -272,15 +274,23 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
272274 SimpAttributesContextHolder .setAttributesFromMessage (message );
273275 boolean sent = outputChannel .send (message );
274276
275- if (sent && this . eventPublisher != null ) {
276- if (StompCommand . CONNECT . equals ( headerAccessor . getCommand ()) ) {
277- publishEvent ( new SessionConnectEvent ( this , message , user ) );
278- }
279- else if ( StompCommand . SUBSCRIBE . equals ( headerAccessor . getCommand ())) {
280- publishEvent ( new SessionSubscribeEvent ( this , message , user ));
277+ if (sent ) {
278+ if (isConnect ) {
279+ Principal user = headerAccessor . getUser ( );
280+ if ( user != null && user != session . getPrincipal ()) {
281+ this . stompAuthentications . put ( session . getId (), user );
282+ }
281283 }
282- else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
283- publishEvent (new SessionUnsubscribeEvent (this , message , user ));
284+ if (this .eventPublisher != null ) {
285+ if (isConnect ) {
286+ publishEvent (new SessionConnectEvent (this , message , getUser (session )));
287+ }
288+ else if (StompCommand .SUBSCRIBE .equals (headerAccessor .getCommand ())) {
289+ publishEvent (new SessionSubscribeEvent (this , message , getUser (session )));
290+ }
291+ else if (StompCommand .UNSUBSCRIBE .equals (headerAccessor .getCommand ())) {
292+ publishEvent (new SessionUnsubscribeEvent (this , message , getUser (session )));
293+ }
284294 }
285295 }
286296 }
@@ -298,6 +308,11 @@ else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
298308 }
299309 }
300310
311+ private Principal getUser (WebSocketSession session ) {
312+ Principal user = this .stompAuthentications .get (session .getId ());
313+ return user != null ? user : session .getPrincipal ();
314+ }
315+
301316 private void handleError (WebSocketSession session , Throwable ex , Message <byte []> clientMessage ) {
302317 if (getErrorHandler () == null ) {
303318 sendErrorMessage (session , ex );
@@ -395,7 +410,7 @@ else if (StompCommand.CONNECTED.equals(command)) {
395410 try {
396411 SimpAttributes simpAttributes = new SimpAttributes (session .getId (), session .getAttributes ());
397412 SimpAttributesContextHolder .setAttributes (simpAttributes );
398- Principal user = session . getPrincipal ( );
413+ Principal user = getUser ( session );
399414 publishEvent (new SessionConnectedEvent (this , (Message <byte []>) message , user ));
400415 }
401416 finally {
@@ -535,7 +550,7 @@ protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccess
535550 private StompHeaderAccessor afterStompSessionConnected (Message <?> message , StompHeaderAccessor accessor ,
536551 WebSocketSession session ) {
537552
538- Principal principal = session . getPrincipal ( );
553+ Principal principal = getUser ( session );
539554 if (principal != null ) {
540555 accessor = toMutableAccessor (accessor , message );
541556 accessor .setNativeHeader (CONNECTED_USER_HEADER , principal .getName ());
@@ -574,12 +589,13 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
574589 try {
575590 SimpAttributesContextHolder .setAttributes (simpAttributes );
576591 if (this .eventPublisher != null ) {
577- Principal user = session . getPrincipal ( );
592+ Principal user = getUser ( session );
578593 publishEvent (new SessionDisconnectEvent (this , message , session .getId (), closeStatus , user ));
579594 }
580595 outputChannel .send (message );
581596 }
582597 finally {
598+ this .stompAuthentications .remove (session .getId ());
583599 SimpAttributesContextHolder .resetAttributes ();
584600 simpAttributes .sessionCompleted ();
585601 }
@@ -592,7 +608,7 @@ private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
592608 }
593609 headerAccessor .setSessionId (session .getId ());
594610 headerAccessor .setSessionAttributes (session .getAttributes ());
595- headerAccessor .setUser (session . getPrincipal ( ));
611+ headerAccessor .setUser (getUser ( session ));
596612 return MessageBuilder .createMessage (EMPTY_PAYLOAD , headerAccessor .getMessageHeaders ());
597613 }
598614
0 commit comments