|
19 | 19 | import java.io.IOException; |
20 | 20 | import java.io.InputStream; |
21 | 21 | import java.io.OutputStream; |
| 22 | +import java.net.InetAddress; |
| 23 | +import java.net.InetSocketAddress; |
22 | 24 | import java.net.URI; |
| 25 | +import java.net.UnknownHostException; |
23 | 26 | import java.time.Duration; |
24 | 27 | import java.util.ArrayList; |
| 28 | +import java.util.LinkedHashMap; |
25 | 29 | import java.util.List; |
| 30 | +import java.util.Map; |
26 | 31 | import java.util.Objects; |
| 32 | +import java.util.concurrent.Callable; |
27 | 33 | import java.util.concurrent.CountDownLatch; |
28 | 34 | import java.util.concurrent.TimeUnit; |
| 35 | +import java.util.function.Function; |
| 36 | +import java.util.stream.Collectors; |
| 37 | + |
| 38 | +import javax.websocket.ClientEndpointConfig; |
| 39 | +import javax.websocket.ClientEndpointConfig.Configurator; |
| 40 | +import javax.websocket.Endpoint; |
| 41 | +import javax.websocket.HandshakeResponse; |
| 42 | +import javax.websocket.WebSocketContainer; |
29 | 43 |
|
30 | 44 | import org.apache.tomcat.websocket.WsWebSocketContainer; |
31 | 45 | import org.awaitility.Awaitility; |
|
34 | 48 | import org.junit.jupiter.api.Disabled; |
35 | 49 | import org.junit.jupiter.api.Test; |
36 | 50 |
|
| 51 | +import org.springframework.http.HttpHeaders; |
| 52 | +import org.springframework.util.concurrent.ListenableFuture; |
37 | 53 | import org.springframework.web.client.RestTemplate; |
38 | 54 | import org.springframework.web.socket.CloseStatus; |
39 | 55 | import org.springframework.web.socket.PingMessage; |
40 | 56 | import org.springframework.web.socket.PongMessage; |
41 | 57 | import org.springframework.web.socket.TextMessage; |
| 58 | +import org.springframework.web.socket.WebSocketExtension; |
| 59 | +import org.springframework.web.socket.WebSocketHandler; |
42 | 60 | import org.springframework.web.socket.WebSocketMessage; |
43 | 61 | import org.springframework.web.socket.WebSocketSession; |
| 62 | +import org.springframework.web.socket.adapter.standard.StandardWebSocketHandlerAdapter; |
| 63 | +import org.springframework.web.socket.adapter.standard.StandardWebSocketSession; |
| 64 | +import org.springframework.web.socket.adapter.standard.WebSocketToStandardExtensionAdapter; |
44 | 65 | import org.springframework.web.socket.client.WebSocketClient; |
45 | 66 | import org.springframework.web.socket.client.standard.StandardWebSocketClient; |
46 | 67 | import org.springframework.web.socket.handler.TextWebSocketHandler; |
@@ -94,7 +115,16 @@ void triggerReload() throws Exception { |
94 | 115 | (msgs) -> msgs.size() == 2); |
95 | 116 | assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7"); |
96 | 117 | assertThat(messages.get(1)).contains("command\":\"reload\""); |
| 118 | + } |
97 | 119 |
|
| 120 | + @Test // gh-26813 |
| 121 | + void triggerReloadWithUppercaseHeaders() throws Exception { |
| 122 | + LiveReloadWebSocketHandler handler = connect(UppercaseWebSocketClient::new); |
| 123 | + this.server.triggerReload(); |
| 124 | + List<String> messages = await().atMost(Duration.ofSeconds(10)).until(handler::getMessages, |
| 125 | + (msgs) -> msgs.size() == 2); |
| 126 | + assertThat(messages.get(0)).contains("http://livereload.com/protocols/official-7"); |
| 127 | + assertThat(messages.get(1)).contains("command\":\"reload\""); |
98 | 128 | } |
99 | 129 |
|
100 | 130 | @Test |
@@ -126,7 +156,13 @@ void serverClose() throws Exception { |
126 | 156 | } |
127 | 157 |
|
128 | 158 | private LiveReloadWebSocketHandler connect() throws Exception { |
129 | | - WebSocketClient client = new StandardWebSocketClient(new WsWebSocketContainer()); |
| 159 | + return connect(StandardWebSocketClient::new); |
| 160 | + } |
| 161 | + |
| 162 | + private LiveReloadWebSocketHandler connect(Function<WebSocketContainer, WebSocketClient> clientFactory) |
| 163 | + throws Exception { |
| 164 | + WsWebSocketContainer webSocketContainer = new WsWebSocketContainer(); |
| 165 | + WebSocketClient client = clientFactory.apply(webSocketContainer); |
130 | 166 | LiveReloadWebSocketHandler handler = new LiveReloadWebSocketHandler(); |
131 | 167 | client.doHandshake(handler, "ws://localhost:" + this.port + "/livereload"); |
132 | 168 | handler.awaitHello(); |
@@ -246,4 +282,69 @@ CloseStatus getCloseStatus() { |
246 | 282 |
|
247 | 283 | } |
248 | 284 |
|
| 285 | + static class UppercaseWebSocketClient extends StandardWebSocketClient { |
| 286 | + |
| 287 | + private final WebSocketContainer webSocketContainer; |
| 288 | + |
| 289 | + UppercaseWebSocketClient(WebSocketContainer webSocketContainer) { |
| 290 | + super(webSocketContainer); |
| 291 | + this.webSocketContainer = webSocketContainer; |
| 292 | + } |
| 293 | + |
| 294 | + @Override |
| 295 | + protected ListenableFuture<WebSocketSession> doHandshakeInternal(WebSocketHandler webSocketHandler, |
| 296 | + HttpHeaders headers, URI uri, List<String> protocols, List<WebSocketExtension> extensions, |
| 297 | + Map<String, Object> attributes) { |
| 298 | + InetSocketAddress localAddress = new InetSocketAddress(getLocalHost(), uri.getPort()); |
| 299 | + InetSocketAddress remoteAddress = new InetSocketAddress(uri.getHost(), uri.getPort()); |
| 300 | + StandardWebSocketSession session = new StandardWebSocketSession(headers, attributes, localAddress, |
| 301 | + remoteAddress); |
| 302 | + ClientEndpointConfig endpointConfig = ClientEndpointConfig.Builder.create() |
| 303 | + .configurator(new UppercaseWebSocketClientConfigurator(headers)).preferredSubprotocols(protocols) |
| 304 | + .extensions(extensions.stream().map(WebSocketToStandardExtensionAdapter::new) |
| 305 | + .collect(Collectors.toList())) |
| 306 | + .build(); |
| 307 | + endpointConfig.getUserProperties().putAll(getUserProperties()); |
| 308 | + Endpoint endpoint = new StandardWebSocketHandlerAdapter(webSocketHandler, session); |
| 309 | + Callable<WebSocketSession> connectTask = () -> { |
| 310 | + this.webSocketContainer.connectToServer(endpoint, endpointConfig, uri); |
| 311 | + return session; |
| 312 | + }; |
| 313 | + return getTaskExecutor().submitListenable(connectTask); |
| 314 | + } |
| 315 | + |
| 316 | + private InetAddress getLocalHost() { |
| 317 | + try { |
| 318 | + return InetAddress.getLocalHost(); |
| 319 | + } |
| 320 | + catch (UnknownHostException ex) { |
| 321 | + return InetAddress.getLoopbackAddress(); |
| 322 | + } |
| 323 | + } |
| 324 | + |
| 325 | + } |
| 326 | + |
| 327 | + private static class UppercaseWebSocketClientConfigurator extends Configurator { |
| 328 | + |
| 329 | + private final HttpHeaders headers; |
| 330 | + |
| 331 | + UppercaseWebSocketClientConfigurator(HttpHeaders headers) { |
| 332 | + this.headers = headers; |
| 333 | + } |
| 334 | + |
| 335 | + @Override |
| 336 | + public void beforeRequest(Map<String, List<String>> requestHeaders) { |
| 337 | + Map<String, List<String>> uppercaseRequestHeaders = new LinkedHashMap<>(); |
| 338 | + requestHeaders.forEach((key, value) -> uppercaseRequestHeaders.put(key.toUpperCase(), value)); |
| 339 | + requestHeaders.clear(); |
| 340 | + requestHeaders.putAll(uppercaseRequestHeaders); |
| 341 | + requestHeaders.putAll(this.headers); |
| 342 | + } |
| 343 | + |
| 344 | + @Override |
| 345 | + public void afterResponse(HandshakeResponse response) { |
| 346 | + } |
| 347 | + |
| 348 | + } |
| 349 | + |
249 | 350 | } |
0 commit comments