Skip to content

Commit 93fd7c6

Browse files
committed
Merge pull request #26813 from francislavoie
* pr/26813: Polish 'Make livereload websocket headers case insensitive' Make livereload websocket headers case insensitive Closes gh-26813
2 parents a1e279f + 5ca687c commit 93fd7c6

File tree

2 files changed

+113
-10
lines changed

2 files changed

+113
-10
lines changed

spring-boot-project/spring-boot-devtools/src/main/java/org/springframework/boot/devtools/livereload/Connection.java

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2012-2019 the original author or authors.
2+
* Copyright 2012-2021 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -23,25 +23,28 @@
2323
import java.net.SocketTimeoutException;
2424
import java.security.MessageDigest;
2525
import java.security.NoSuchAlgorithmException;
26+
import java.util.Locale;
2627
import java.util.regex.Matcher;
2728
import java.util.regex.Pattern;
2829

2930
import org.apache.commons.logging.Log;
3031
import org.apache.commons.logging.LogFactory;
3132

3233
import org.springframework.core.log.LogMessage;
34+
import org.springframework.util.Assert;
3335
import org.springframework.util.Base64Utils;
3436

3537
/**
3638
* A {@link LiveReloadServer} connection.
3739
*
3840
* @author Phillip Webb
41+
* @author Francis Lavoie
3942
*/
4043
class Connection {
4144

4245
private static final Log logger = LogFactory.getLog(Connection.class);
4346

44-
private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^Sec-WebSocket-Key:(.*)$", Pattern.MULTILINE);
47+
private static final Pattern WEBSOCKET_KEY_PATTERN = Pattern.compile("^sec-websocket-key:(.*)$", Pattern.MULTILINE);
4548

4649
public static final String WEBSOCKET_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
4750

@@ -68,19 +71,20 @@ class Connection {
6871
this.socket = socket;
6972
this.inputStream = new ConnectionInputStream(inputStream);
7073
this.outputStream = new ConnectionOutputStream(outputStream);
71-
this.header = this.inputStream.readHeader();
72-
logger.debug(LogMessage.format("Established livereload connection [%s]", this.header));
74+
String header = this.inputStream.readHeader();
75+
logger.debug(LogMessage.format("Established livereload connection [%s]", header));
76+
this.header = header.toLowerCase(Locale.ENGLISH);
7377
}
7478

7579
/**
7680
* Run the connection.
7781
* @throws Exception in case of errors
7882
*/
7983
void run() throws Exception {
80-
if (this.header.contains("Upgrade: websocket") && this.header.contains("Sec-WebSocket-Version: 13")) {
84+
if (this.header.contains("upgrade: websocket") && this.header.contains("sec-websocket-version: 13")) {
8185
runWebSocket();
8286
}
83-
if (this.header.contains("GET /livereload.js")) {
87+
if (this.header.contains("get /livereload.js")) {
8488
this.outputStream.writeHttp(getClass().getResourceAsStream("livereload.js"), "text/javascript");
8589
}
8690
}
@@ -140,9 +144,7 @@ private void writeWebSocketFrame(Frame frame) throws IOException {
140144

141145
private String getWebsocketAcceptResponse() throws NoSuchAlgorithmException {
142146
Matcher matcher = WEBSOCKET_KEY_PATTERN.matcher(this.header);
143-
if (!matcher.find()) {
144-
throw new IllegalStateException("No Sec-WebSocket-Key");
145-
}
147+
Assert.state(matcher.find(), "No Sec-WebSocket-Key");
146148
String response = matcher.group(1).trim() + WEBSOCKET_GUID;
147149
MessageDigest messageDigest = MessageDigest.getInstance("SHA-1");
148150
messageDigest.update(response.getBytes(), 0, response.length());

spring-boot-project/spring-boot-devtools/src/test/java/org/springframework/boot/devtools/livereload/LiveReloadServerTests.java

Lines changed: 102 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,27 @@
1919
import java.io.IOException;
2020
import java.io.InputStream;
2121
import java.io.OutputStream;
22+
import java.net.InetAddress;
23+
import java.net.InetSocketAddress;
2224
import java.net.URI;
25+
import java.net.UnknownHostException;
2326
import java.time.Duration;
2427
import java.util.ArrayList;
28+
import java.util.LinkedHashMap;
2529
import java.util.List;
30+
import java.util.Map;
2631
import java.util.Objects;
32+
import java.util.concurrent.Callable;
2733
import java.util.concurrent.CountDownLatch;
2834
import java.util.concurrent.TimeUnit;
35+
import java.util.function.Function;
36+
import java.util.stream.Collectors;
37+
38+
import javax.websocket.ClientEndpointConfig;
39+
import javax.websocket.ClientEndpointConfig.Configurator;
40+
import javax.websocket.Endpoint;
41+
import javax.websocket.HandshakeResponse;
42+
import javax.websocket.WebSocketContainer;
2943

3044
import org.apache.tomcat.websocket.WsWebSocketContainer;
3145
import org.awaitility.Awaitility;
@@ -34,13 +48,20 @@
3448
import org.junit.jupiter.api.Disabled;
3549
import org.junit.jupiter.api.Test;
3650

51+
import org.springframework.http.HttpHeaders;
52+
import org.springframework.util.concurrent.ListenableFuture;
3753
import org.springframework.web.client.RestTemplate;
3854
import org.springframework.web.socket.CloseStatus;
3955
import org.springframework.web.socket.PingMessage;
4056
import org.springframework.web.socket.PongMessage;
4157
import org.springframework.web.socket.TextMessage;
58+
import org.springframework.web.socket.WebSocketExtension;
59+
import org.springframework.web.socket.WebSocketHandler;
4260
import org.springframework.web.socket.WebSocketMessage;
4361
import org.springframework.web.socket.WebSocketSession;
62+
import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter;
63+
import org.springframework.web.socket.adapter.standard.StandardWebSocketSession;
64+
import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter;
4465
import org.springframework.web.socket.client.WebSocketClient;
4566
import org.springframework.web.socket.client.standard.StandardWebSocketClient;
4667
import org.springframework.web.socket.handler.TextWebSocketHandler;
@@ -94,7 +115,16 @@ void triggerReload() throws Exception {
94115
(msgs) -> msgs.size() == 2);
95116
assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7");
96117
assertThat(messages.get(1)).contains("command\":\"reload\"");
118+
}
97119

120+
@Test // gh-26813
121+
void triggerReloadWithUppercaseHeaders() throws Exception {
122+
LiveReloadWebSocketHandler handler = connect(UppercaseWebSocketClient::new);
123+
this.server.triggerReload();
124+
List<String> messages = await().atMost(Duration.ofSeconds(10)).until(handler::getMessages,
125+
(msgs) -> msgs.size() == 2);
126+
assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7");
127+
assertThat(messages.get(1)).contains("command\":\"reload\"");
98128
}
99129

100130
@Test
@@ -126,7 +156,13 @@ void serverClose() throws Exception {
126156
}
127157

128158
private LiveReloadWebSocketHandler connect() throws Exception {
129-
WebSocketClient client = new StandardWebSocketClient(new WsWebSocketContainer());
159+
return connect(StandardWebSocketClient::new);
160+
}
161+
162+
private LiveReloadWebSocketHandler connect(Function<WebSocketContainer, WebSocketClient> clientFactory)
163+
throws Exception {
164+
WsWebSocketContainer webSocketContainer = new WsWebSocketContainer();
165+
WebSocketClient client = clientFactory.apply(webSocketContainer);
130166
LiveReloadWebSocketHandler handler = new LiveReloadWebSocketHandler();
131167
client.doHandshake(handler, "ws://localhost:" + this.port + "/livereload");
132168
handler.awaitHello();
@@ -246,4 +282,69 @@ CloseStatus getCloseStatus() {
246282

247283
}
248284

285+
static class UppercaseWebSocketClient extends StandardWebSocketClient {
286+
287+
private final WebSocketContainer webSocketContainer;
288+
289+
UppercaseWebSocketClient(WebSocketContainer webSocketContainer) {
290+
super(webSocketContainer);
291+
this.webSocketContainer = webSocketContainer;
292+
}
293+
294+
@Override
295+
protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler,
296+
HttpHeaders headers, URI uri, List<String> protocols, List<WebSocketExtension> extensions,
297+
Map<String, Object> attributes) {
298+
InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), uri.getPort());
299+
InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), uri.getPort());
300+
StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress,
301+
remoteAddress);
302+
ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create()
303+
.configurator(new UppercaseWebSocketClientConfigurator(headers)).preferredSubprotocols(protocols)
304+
.extensions(extensions.stream().map(WebSocketToStandardExtensionAdapter::new)
305+
.collect(Collectors.toList()))
306+
.build();
307+
endpointConfig.getUserProperties().putAll(getUserProperties());
308+
Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session);
309+
Callable<WebSocketSession> connectTask = () -> {
310+
this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri);
311+
return session;
312+
};
313+
return getTaskExecutor().submitListenable(connectTask);
314+
}
315+
316+
private InetAddress getLocalHost() {
317+
try {
318+
return InetAddress.getLocalHost();
319+
}
320+
catch (UnknownHostException ex) {
321+
return InetAddress.getLoopbackAddress();
322+
}
323+
}
324+
325+
}
326+
327+
private static class UppercaseWebSocketClientConfigurator extends Configurator {
328+
329+
private final HttpHeaders headers;
330+
331+
UppercaseWebSocketClientConfigurator(HttpHeaders headers) {
332+
this.headers = headers;
333+
}
334+
335+
@Override
336+
public void beforeRequest(Map<String, List<String>> requestHeaders) {
337+
Map<String, List<String>> uppercaseRequestHeaders = new LinkedHashMap<>();
338+
requestHeaders.forEach((key, value) -> uppercaseRequestHeaders.put(key.toUpperCase(), value));
339+
requestHeaders.clear();
340+
requestHeaders.putAll(uppercaseRequestHeaders);
341+
requestHeaders.putAll(this.headers);
342+
}
343+
344+
@Override
345+
public void afterResponse(HandshakeResponse response) {
346+
}
347+
348+
}
349+
249350
}

0 commit comments

Comments
 (0)