Skip to content

Commit 2191d80

Browse files
committed
Allow athentication at the STOMP level
This commit makes it possible for a ChannelInterceptor to override the user header in a Spring Message that contains a STOMP CONNECT frame. After the message is sent, the updated user header is observed and saved to be associated with session thereafter. Issue: SPR-14690
1 parent d4411f4 commit 2191d80

File tree

3 files changed

+226
-49
lines changed

3 files changed

+226
-49
lines changed

spring-websocket/src/main/java/org/springframework/web/socket/messaging/StompSubProtocolHandler.java

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,8 @@ public class StompSubProtocolHandler implements SubProtocolHandler, ApplicationE
101101

102102
private MessageHeaderInitializer headerInitializer;
103103

104+
private final Map<String, Principal> stompAuthentications = new ConcurrentHashMap<String, Principal>();
105+
104106
private Boolean immutableMessageInterceptorPresent;
105107

106108
private ApplicationEventPublisher eventPublisher;
@@ -247,11 +249,10 @@ else if (webSocketMessage instanceof BinaryMessage) {
247249
try {
248250
StompHeaderAccessor headerAccessor =
249251
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class);
250-
Principal user = session.getPrincipal();
251252

252253
headerAccessor.setSessionId(session.getId());
253254
headerAccessor.setSessionAttributes(session.getAttributes());
254-
headerAccessor.setUser(user);
255+
headerAccessor.setUser(getUser(session));
255256
headerAccessor.setHeader(SimpMessageHeaderAccessor.HEART_BEAT_HEADER, headerAccessor.getHeartbeat());
256257
if (!detectImmutableMessageInterceptor(outputChannel)) {
257258
headerAccessor.setImmutable();
@@ -261,7 +262,8 @@ else if (webSocketMessage instanceof BinaryMessage) {
261262
logger.trace("From client: " + headerAccessor.getShortLogMessage(message.getPayload()));
262263
}
263264

264-
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
265+
boolean isConnect = StompCommand.CONNECT.equals(headerAccessor.getCommand());
266+
if (isConnect) {
265267
this.stats.incrementConnectCount();
266268
}
267269
else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
@@ -272,15 +274,23 @@ else if (StompCommand.DISCONNECT.equals(headerAccessor.getCommand())) {
272274
SimpAttributesContextHolder.setAttributesFromMessage(message);
273275
boolean sent = outputChannel.send(message);
274276

275-
if (sent && this.eventPublisher != null) {
276-
if (StompCommand.CONNECT.equals(headerAccessor.getCommand())) {
277-
publishEvent(new SessionConnectEvent(this, message, user));
278-
}
279-
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
280-
publishEvent(new SessionSubscribeEvent(this, message, user));
277+
if (sent) {
278+
if (isConnect) {
279+
Principal user = headerAccessor.getUser();
280+
if (user != null && user != session.getPrincipal()) {
281+
this.stompAuthentications.put(session.getId(), user);
282+
}
281283
}
282-
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
283-
publishEvent(new SessionUnsubscribeEvent(this, message, user));
284+
if (this.eventPublisher != null) {
285+
if (isConnect) {
286+
publishEvent(new SessionConnectEvent(this, message, getUser(session)));
287+
}
288+
else if (StompCommand.SUBSCRIBE.equals(headerAccessor.getCommand())) {
289+
publishEvent(new SessionSubscribeEvent(this, message, getUser(session)));
290+
}
291+
else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
292+
publishEvent(new SessionUnsubscribeEvent(this, message, getUser(session)));
293+
}
284294
}
285295
}
286296
}
@@ -298,6 +308,11 @@ else if (StompCommand.UNSUBSCRIBE.equals(headerAccessor.getCommand())) {
298308
}
299309
}
300310

311+
private Principal getUser(WebSocketSession session) {
312+
Principal user = this.stompAuthentications.get(session.getId());
313+
return user != null ? user : session.getPrincipal();
314+
}
315+
301316
private void handleError(WebSocketSession session, Throwable ex, Message<byte[]> clientMessage) {
302317
if (getErrorHandler() == null) {
303318
sendErrorMessage(session, ex);
@@ -395,7 +410,7 @@ else if (StompCommand.CONNECTED.equals(command)) {
395410
try {
396411
SimpAttributes simpAttributes = new SimpAttributes(session.getId(), session.getAttributes());
397412
SimpAttributesContextHolder.setAttributes(simpAttributes);
398-
Principal user = session.getPrincipal();
413+
Principal user = getUser(session);
399414
publishEvent(new SessionConnectedEvent(this, (Message<byte[]>) message, user));
400415
}
401416
finally {
@@ -535,7 +550,7 @@ protected StompHeaderAccessor toMutableAccessor(StompHeaderAccessor headerAccess
535550
private StompHeaderAccessor afterStompSessionConnected(Message<?> message, StompHeaderAccessor accessor,
536551
WebSocketSession session) {
537552

538-
Principal principal = session.getPrincipal();
553+
Principal principal = getUser(session);
539554
if (principal != null) {
540555
accessor = toMutableAccessor(accessor, message);
541556
accessor.setNativeHeader(CONNECTED_USER_HEADER, principal.getName());
@@ -574,12 +589,13 @@ public void afterSessionEnded(WebSocketSession session, CloseStatus closeStatus,
574589
try {
575590
SimpAttributesContextHolder.setAttributes(simpAttributes);
576591
if (this.eventPublisher != null) {
577-
Principal user = session.getPrincipal();
592+
Principal user = getUser(session);
578593
publishEvent(new SessionDisconnectEvent(this, message, session.getId(), closeStatus, user));
579594
}
580595
outputChannel.send(message);
581596
}
582597
finally {
598+
this.stompAuthentications.remove(session.getId());
583599
SimpAttributesContextHolder.resetAttributes();
584600
simpAttributes.sessionCompleted();
585601
}
@@ -592,7 +608,7 @@ private Message<byte[]> createDisconnectMessage(WebSocketSession session) {
592608
}
593609
headerAccessor.setSessionId(session.getId());
594610
headerAccessor.setSessionAttributes(session.getAttributes());
595-
headerAccessor.setUser(session.getPrincipal());
611+
headerAccessor.setUser(getUser(session));
596612
return MessageBuilder.createMessage(EMPTY_PAYLOAD, headerAccessor.getMessageHeaders());
597613
}
598614

spring-websocket/src/test/java/org/springframework/web/socket/messaging/StompSubProtocolHandlerTests.java

Lines changed: 74 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package org.springframework.web.socket.messaging;
1818

1919
import java.io.IOException;
20+
import java.security.Principal;
2021
import java.util.ArrayList;
2122
import java.util.Arrays;
2223
import java.util.Collections;
@@ -34,6 +35,8 @@
3435
import org.springframework.context.PayloadApplicationEvent;
3536
import org.springframework.messaging.Message;
3637
import org.springframework.messaging.MessageChannel;
38+
import org.springframework.messaging.MessageHandler;
39+
import org.springframework.messaging.MessagingException;
3740
import org.springframework.messaging.simp.SimpAttributes;
3841
import org.springframework.messaging.simp.SimpAttributesContextHolder;
3942
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
@@ -56,19 +59,29 @@
5659
import org.springframework.web.socket.handler.TestWebSocketSession;
5760
import org.springframework.web.socket.sockjs.transport.SockJsSession;
5861

59-
import static org.hamcrest.Matchers.*;
60-
import static org.junit.Assert.*;
62+
import static org.hamcrest.Matchers.is;
63+
import static org.junit.Assert.assertArrayEquals;
64+
import static org.junit.Assert.assertEquals;
65+
import static org.junit.Assert.assertFalse;
66+
import static org.junit.Assert.assertNotNull;
67+
import static org.junit.Assert.assertThat;
68+
import static org.junit.Assert.assertTrue;
6169
import static org.mockito.Mockito.any;
62-
import static org.mockito.Mockito.*;
70+
import static org.mockito.Mockito.mock;
71+
import static org.mockito.Mockito.reset;
72+
import static org.mockito.Mockito.times;
73+
import static org.mockito.Mockito.verify;
74+
import static org.mockito.Mockito.verifyNoMoreInteractions;
75+
import static org.mockito.Mockito.verifyZeroInteractions;
76+
import static org.mockito.Mockito.when;
6377

6478
/**
6579
* Test fixture for {@link StompSubProtocolHandler} tests.
66-
*
6780
* @author Rossen Stoyanchev
6881
*/
6982
public class StompSubProtocolHandlerTests {
7083

71-
public static final byte[] EMPTY_PAYLOAD = new byte[0];
84+
private static final byte[] EMPTY_PAYLOAD = new byte[0];
7285

7386
private StompSubProtocolHandler protocolHandler;
7487

@@ -210,22 +223,26 @@ public void handleMessageToClientWithSimpHeartbeat() {
210223
public void handleMessageToClientWithHeartbeatSuppressingSockJsHeartbeat() throws IOException {
211224

212225
SockJsSession sockJsSession = Mockito.mock(SockJsSession.class);
226+
when(sockJsSession.getId()).thenReturn("s1");
213227
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
214228
accessor.setHeartbeat(0, 10);
215229
Message<byte[]> message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
216230
this.protocolHandler.handleMessageToClient(sockJsSession, message);
217231

232+
verify(sockJsSession).getId();
218233
verify(sockJsSession).getPrincipal();
219234
verify(sockJsSession).disableHeartbeat();
220235
verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
221236
verifyNoMoreInteractions(sockJsSession);
222237

223238
sockJsSession = Mockito.mock(SockJsSession.class);
239+
when(sockJsSession.getId()).thenReturn("s1");
224240
accessor = StompHeaderAccessor.create(StompCommand.CONNECTED);
225241
accessor.setHeartbeat(0, 0);
226242
message = MessageBuilder.createMessage(EMPTY_PAYLOAD, accessor.getMessageHeaders());
227243
this.protocolHandler.handleMessageToClient(sockJsSession, message);
228244

245+
verify(sockJsSession).getId();
229246
verify(sockJsSession).getPrincipal();
230247
verify(sockJsSession).sendMessage(any(WebSocketMessage.class));
231248
verifyNoMoreInteractions(sockJsSession);
@@ -352,6 +369,28 @@ public Message<?> preSend(Message<?> message, MessageChannel channel) {
352369
assertFalse(mutable.get());
353370
}
354371

372+
@Test // SPR-14690
373+
public void handleMessageFromClientWithTokenAuthentication() {
374+
ExecutorSubscribableChannel channel = new ExecutorSubscribableChannel();
375+
channel.addInterceptor(new AuthenticationInterceptor("[email protected]"));
376+
channel.addInterceptor(new ImmutableMessageChannelInterceptor());
377+
378+
TestMessageHandler messageHandler = new TestMessageHandler();
379+
channel.subscribe(messageHandler);
380+
381+
StompSubProtocolHandler handler = new StompSubProtocolHandler();
382+
handler.afterSessionStarted(this.session, channel);
383+
384+
TextMessage wsMessage = StompTextMessageBuilder.create(StompCommand.CONNECT).build();
385+
handler.handleMessageFromClient(this.session, wsMessage, channel);
386+
387+
assertEquals(1, messageHandler.getMessages().size());
388+
Message<?> message = messageHandler.getMessages().get(0);
389+
Principal user = SimpMessageHeaderAccessor.getUser(message.getHeaders());
390+
assertNotNull(user);
391+
assertEquals("[email protected]", user.getName());
392+
}
393+
355394
@Test
356395
public void handleMessageFromClientWithInvalidStompCommand() {
357396

@@ -504,4 +543,34 @@ public void publishEvent(Object event) {
504543
}
505544
}
506545

546+
private static class TestMessageHandler implements MessageHandler {
547+
548+
private final List<Message> messages = new ArrayList<>();
549+
550+
public List<Message> getMessages() {
551+
return this.messages;
552+
}
553+
554+
@Override
555+
public void handleMessage(Message<?> message) throws MessagingException {
556+
this.messages.add(message);
557+
}
558+
}
559+
560+
private static class AuthenticationInterceptor extends ChannelInterceptorAdapter {
561+
562+
private final String name;
563+
564+
565+
public AuthenticationInterceptor(String name) {
566+
this.name = name;
567+
}
568+
569+
@Override
570+
public Message<?> preSend(Message<?> message, MessageChannel channel) {
571+
TestPrincipal user = new TestPrincipal(name);
572+
MessageHeaderAccessor.getAccessor(message, StompHeaderAccessor.class).setUser(user);
573+
return message;
574+
}
575+
}
507576
}

0 commit comments

Comments
 (0)