1616
1717package org .springframework .web .socket .messaging ;
1818
19+ import java .io .IOException ;
1920import java .util .ArrayList ;
2021import java .util .Arrays ;
2122import java .util .HashSet ;
2425import java .util .Set ;
2526import java .util .TreeMap ;
2627import java .util .concurrent .ConcurrentHashMap ;
28+ import java .util .concurrent .locks .ReentrantLock ;
2729
2830import org .apache .commons .logging .Log ;
2931import org .apache .commons .logging .LogFactory ;
6466public class SubProtocolWebSocketHandler implements WebSocketHandler ,
6567 SubProtocolCapable , MessageHandler , SmartLifecycle {
6668
69+ /**
70+ * Sessions connected to this handler use a sub-protocol. Hence we expect to
71+ * receive some client messages. If we don't receive any within a minute, the
72+ * connection isn't doing well (proxy issue, slow network?) and can be closed.
73+ * @see #checkSessions()
74+ */
75+ private final int TIME_TO_FIRST_MESSAGE = 60 * 1000 ;
76+
77+
6778 private final Log logger = LogFactory .getLog (SubProtocolWebSocketHandler .class );
6879
80+
6981 private final MessageChannel clientInboundChannel ;
7082
7183 private final SubscribableChannel clientOutboundChannel ;
@@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler,
7587
7688 private SubProtocolHandler defaultProtocolHandler ;
7789
78- private final Map <String , WebSocketSession > sessions = new ConcurrentHashMap <String , WebSocketSession >();
90+ private final Map <String , WebSocketSessionHolder > sessions = new ConcurrentHashMap <String , WebSocketSessionHolder >();
7991
8092 private int sendTimeLimit = 10 * 1000 ;
8193
8294 private int sendBufferSizeLimit = 512 * 1024 ;
8395
96+ private volatile long lastSessionCheckTime = System .currentTimeMillis ();
97+
98+ private final ReentrantLock sessionCheckLock = new ReentrantLock ();
99+
84100 private final Object lifecycleMonitor = new Object ();
85101
86102 private volatile boolean running = false ;
@@ -214,12 +230,12 @@ public final void stop() {
214230 this .clientOutboundChannel .unsubscribe (this );
215231
216232 // Notify sessions to stop flushing messages
217- for (WebSocketSession session : this .sessions .values ()) {
233+ for (WebSocketSessionHolder holder : this .sessions .values ()) {
218234 try {
219- session .close (CloseStatus .GOING_AWAY );
235+ holder . getSession () .close (CloseStatus .GOING_AWAY );
220236 }
221237 catch (Throwable t ) {
222- logger .error ("Failed to close session id '" + session . getId () + "': " + t .getMessage ());
238+ logger .error ("Failed to close '" + holder . getSession () + "': " + t .getMessage ());
223239 }
224240 }
225241 }
@@ -235,15 +251,11 @@ public final void stop(Runnable callback) {
235251
236252 @ Override
237253 public void afterConnectionEstablished (WebSocketSession session ) throws Exception {
238-
239254 session = new ConcurrentWebSocketSessionDecorator (session , getSendTimeLimit (), getSendBufferSizeLimit ());
240-
241- this .sessions .put (session .getId (), session );
255+ this .sessions .put (session .getId (), new WebSocketSessionHolder (session ));
242256 if (logger .isDebugEnabled ()) {
243- logger .debug ("Started WebSocket session=" + session .getId () +
244- ", number of sessions=" + this .sessions .size ());
257+ logger .debug ("Started session " + session .getId () + ", number of sessions=" + this .sessions .size ());
245258 }
246-
247259 findProtocolHandler (session ).afterSessionStarted (session , this .clientInboundChannel );
248260 }
249261
@@ -283,41 +295,49 @@ protected final SubProtocolHandler findProtocolHandler(WebSocketSession session)
283295
284296 @ Override
285297 public void handleMessage (WebSocketSession session , WebSocketMessage <?> message ) throws Exception {
286- findProtocolHandler (session ).handleMessageFromClient (session , message , this .clientInboundChannel );
298+ SubProtocolHandler protocolHandler = findProtocolHandler (session );
299+ protocolHandler .handleMessageFromClient (session , message , this .clientInboundChannel );
300+ WebSocketSessionHolder holder = this .sessions .get (session .getId ());
301+ if (holder != null ) {
302+ holder .setHasHandledMessages ();
303+ }
304+ else {
305+ // Should never happen
306+ throw new IllegalStateException ("Session not found: " + session );
307+ }
308+ checkSessions ();
287309 }
288310
289311 @ Override
290312 public void handleMessage (Message <?> message ) throws MessagingException {
291-
292313 String sessionId = resolveSessionId (message );
293314 if (sessionId == null ) {
294315 logger .error ("sessionId not found in message " + message );
295316 return ;
296317 }
297-
298- WebSocketSession session = this .sessions .get (sessionId );
299- if (session == null ) {
318+ WebSocketSessionHolder holder = this .sessions .get (sessionId );
319+ if (holder == null ) {
300320 logger .error ("Session not found for session with id '" + sessionId + "', ignoring message " + message );
301321 return ;
302322 }
303-
323+ WebSocketSession session = holder . getSession ();
304324 try {
305325 findProtocolHandler (session ).handleMessageToClient (session , message );
306326 }
307327 catch (SessionLimitExceededException ex ) {
308328 try {
309- logger .error ("Terminating session id '" + sessionId + "'" , ex );
329+ logger .error ("Terminating '" + session + "'" , ex );
310330
311331 // Session may be unresponsive so clear first
312332 clearSession (session , ex .getStatus ());
313333 session .close (ex .getStatus ());
314334 }
315335 catch (Exception secondException ) {
316- logger .error ("Exception terminating session id '" + sessionId + "'" , secondException );
336+ logger .error ("Exception terminating '" + sessionId + "'" , secondException );
317337 }
318338 }
319339 catch (Exception e ) {
320- logger .error ("Failed to send message to client " + message , e );
340+ logger .error ("Failed to send message to client " + message + " in " + session , e );
321341 }
322342 }
323343
@@ -337,6 +357,43 @@ private String resolveSessionId(Message<?> message) {
337357 return null ;
338358 }
339359
360+ /**
361+ * Periodically check sessions to ensure they have received at least one
362+ * message or otherwise close them.
363+ */
364+ private void checkSessions () throws IOException {
365+ long currentTime = System .currentTimeMillis ();
366+ if (!isRunning () && currentTime - this .lastSessionCheckTime < TIME_TO_FIRST_MESSAGE ) {
367+ return ;
368+ }
369+ try {
370+ if (this .sessionCheckLock .tryLock ()) {
371+ for (WebSocketSessionHolder holder : this .sessions .values ()) {
372+ if (holder .hasHandledMessages ()) {
373+ continue ;
374+ }
375+ long timeSinceCreated = currentTime - holder .getCreateTime ();
376+ if (holder .hasHandledMessages () || timeSinceCreated < TIME_TO_FIRST_MESSAGE ) {
377+ continue ;
378+ }
379+ WebSocketSession session = holder .getSession ();
380+ if (logger .isErrorEnabled ()) {
381+ logger .error ("No messages received after " + timeSinceCreated + " ms. Closing " + holder );
382+ }
383+ try {
384+ session .close (CloseStatus .PROTOCOL_ERROR );
385+ }
386+ catch (Throwable t ) {
387+ logger .error ("Failed to close " + session , t );
388+ }
389+ }
390+ }
391+ }
392+ finally {
393+ this .sessionCheckLock .unlock ();
394+ }
395+ }
396+
340397 @ Override
341398 public void handleTransportError (WebSocketSession session , Throwable exception ) throws Exception {
342399 }
@@ -356,4 +413,45 @@ public boolean supportsPartialMessages() {
356413 return false ;
357414 }
358415
416+
417+ private static class WebSocketSessionHolder {
418+
419+ private final WebSocketSession session ;
420+
421+ private final long createTime = System .currentTimeMillis ();
422+
423+ private volatile boolean handledMessages ;
424+
425+
426+ private WebSocketSessionHolder (WebSocketSession session ) {
427+ this .session = session ;
428+ }
429+
430+ public WebSocketSession getSession () {
431+ return this .session ;
432+ }
433+
434+ public long getCreateTime () {
435+ return this .createTime ;
436+ }
437+
438+ public void setHasHandledMessages () {
439+ this .handledMessages = true ;
440+ }
441+
442+ public boolean hasHandledMessages () {
443+ return this .handledMessages ;
444+ }
445+
446+ @ Override
447+ public String toString () {
448+ if (this .session instanceof ConcurrentWebSocketSessionDecorator ) {
449+ return ((ConcurrentWebSocketSessionDecorator ) this .session ).getLastSession ().toString ();
450+ }
451+ else {
452+ return this .session .toString ();
453+ }
454+ }
455+ }
456+
359457}
0 commit comments