diff --git a/common/repositories.bzl b/common/repositories.bzl index d0d8bc7901567..e7f65e87e49bd 100644 --- a/common/repositories.bzl +++ b/common/repositories.bzl @@ -50,8 +50,8 @@ js_library( http_archive( name = "linux_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/linux-x86_64/en-US/firefox-149.0b1.tar.xz", - sha256 = "f36c7db981d3098145c55b6cac5f9665066b9bb37b68531cc9dce59b72726c49", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/linux-x86_64/en-US/firefox-148.0b15.tar.xz", + sha256 = "23621cf9537fd8d52d3a4e83bba48984f705facd4c10819aa07ae2531e11e2e5", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -72,8 +72,8 @@ js_library( dmg_archive( name = "mac_beta_firefox", - url = "https://ftp.mozilla.org/pub/firefox/releases/149.0b1/mac/en-US/Firefox%20149.0b1.dmg", - sha256 = "3d1f4abb063b47b392af528cf7501132fc405e2a9f1eccab71913b0fdf3e538c", + url = "https://ftp.mozilla.org/pub/firefox/releases/148.0b15/mac/en-US/Firefox%20148.0b15.dmg", + sha256 = "0da5fee250eb13165dda25f9e29d238e51f9e0d56b9355b9c8746f6bc8d5c1fe", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -277,8 +277,8 @@ js_library( http_archive( name = "linux_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chrome-linux64.zip", - sha256 = "5b1961b081f0156a1923a9d9d1bfffdf00f82e8722152c35eb5eb742d63ceeb8", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chrome-linux64.zip", + sha256 = "6c3241cf5eab6b5eaed9b0b741bae799377dea26985aed08cda51fb75433218e", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") package(default_visibility = ["//visibility:public"]) @@ -298,8 +298,8 @@ js_library( ) http_archive( name = "mac_beta_chrome", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chrome-mac-arm64.zip", - sha256 = "207867110edc624316b18684065df4eb06b938a3fd9141790a726ab280e2640f", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chrome-mac-arm64.zip", + sha256 = "b39fe2de33190da209845e5d21ea44c75d66d0f4c33c5a293d8b6a259d3c4029", strip_prefix = "chrome-mac-arm64", patch_cmds = [ "mv 'Google Chrome for Testing.app' Chrome.app", @@ -319,8 +319,8 @@ js_library( ) http_archive( name = "linux_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/linux64/chromedriver-linux64.zip", - sha256 = "a8c7be8669829ed697759390c8c42b4bca3f884fd20980e078129f5282dabe1a", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/linux64/chromedriver-linux64.zip", + sha256 = "c6927758a816f0a2f5f10609b34f74080a8c0f08feaf177a68943d8d4aae3a72", strip_prefix = "chromedriver-linux64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") @@ -337,8 +337,8 @@ js_library( http_archive( name = "mac_beta_chromedriver", - url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.31/mac-arm64/chromedriver-mac-arm64.zip", - sha256 = "84c3717c0eeba663d0b8890a0fc06faa6fe158227876fc6954461730ccc81634", + url = "https://storage.googleapis.com/chrome-for-testing-public/146.0.7680.16/mac-arm64/chromedriver-mac-arm64.zip", + sha256 = "29c44a53be87fccea4a7887a7ed2b45b5812839e357e091c6a784ee17bb8da78", strip_prefix = "chromedriver-mac-arm64", build_file_content = """ load("@aspect_rules_js//js:defs.bzl", "js_library") diff --git a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java index 8924095ff0aad..0c8e82b4b692a 100644 --- a/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java +++ b/java/src/org/openqa/selenium/grid/TemplateGridServerCommand.java @@ -18,6 +18,7 @@ package org.openqa.selenium.grid; import java.io.Closeable; +import java.net.URI; import java.util.Collections; import java.util.List; import java.util.Optional; @@ -47,7 +48,10 @@ public Server asServer(Config initialConfig) { Handlers handler = createHandlers(config); return new NettyServer( - new BaseServerOptions(config), handler.httpHandler, handler.websocketHandler) { + new BaseServerOptions(config), + handler.httpHandler, + handler.websocketHandler, + handler.tcpTunnelResolver) { @Override public void stop() { @@ -92,12 +96,23 @@ public abstract static class Handlers implements Closeable { public final BiFunction, Optional>> websocketHandler; + /** Optional resolver for direct TCP tunnel of WebSocket connections. May be null. */ + public final Function> tcpTunnelResolver; + public Handlers( HttpHandler http, BiFunction, Optional>> websocketHandler) { + this(http, websocketHandler, null); + } + + public Handlers( + HttpHandler http, + BiFunction, Optional>> websocketHandler, + Function> tcpTunnelResolver) { this.httpHandler = Require.nonNull("HTTP handler", http); this.websocketHandler = websocketHandler == null ? (str, sink) -> Optional.empty() : websocketHandler; + this.tcpTunnelResolver = tcpTunnelResolver; } @Override diff --git a/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java b/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java index 998811a3202b4..a2569e67e1ce2 100644 --- a/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java +++ b/java/src/org/openqa/selenium/grid/node/ProxyNodeWebsockets.java @@ -251,6 +251,9 @@ private Consumer createWsEndPoint( LOG.info("Establishing connection to " + uri); AtomicBoolean connectionReleased = new AtomicBoolean(false); + // Set to true as soon as the browser signals it is closing so the send lambda can stop + // forwarding data frames without racing against the JDK WebSocket output stream being closed. + AtomicBoolean upstreamClosing = new AtomicBoolean(false); HttpClient client = clientFactory.createClient(ClientConfig.defaultConfig().baseUri(uri)); try { @@ -258,22 +261,66 @@ private Consumer createWsEndPoint( client.openSocket( new HttpRequest(GET, uri.toString()), new ForwardingListener( - node, downstream, sessionConsumer, sessionId, connectionReleased)); + node, + downstream, + sessionConsumer, + sessionId, + connectionReleased, + client, + upstreamClosing)); return (msg) -> { - try { - upstream.send(msg); - } finally { + // Fast path: once the browser has signalled close, there is no point sending further + // data frames — the JDK WebSocket output is already closing and the send would either + // be dropped or throw "Output closed". For the CloseMessage echo we skip the actual + // network write (the JDK stack handles the protocol-level echo internally when it fires + // onClose) and go straight to resource cleanup. + if (upstreamClosing.get()) { if (msg instanceof CloseMessage) { if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client after upstream close for " + uri, e); + } } + } else { + LOG.log(Level.FINE, "Dropping in-flight data frame for closing session " + sessionId); + } + return; + } + + // Slow path: upstream is (was) open — attempt the send and catch the narrow race where + // the browser closes between the upstreamClosing check above and the actual write. + try { + upstream.send(msg); + } catch (Exception e) { + LOG.log( + Level.FINE, + "Could not forward message to browser WebSocket for session " + + sessionId + + " (connection likely closed concurrently)", + e); + if (connectionReleased.compareAndSet(false, true)) { + node.releaseConnection(sessionId); try { client.close(); - } catch (Exception e) { - LOG.log(Level.WARNING, "Failed to shutdown the client of " + uri, e); + } catch (Exception ce) { + LOG.log(Level.FINE, "Failed to close client after send error for " + uri, ce); } } + return; + } + if (msg instanceof CloseMessage) { + if (connectionReleased.compareAndSet(false, true)) { + node.releaseConnection(sessionId); + } + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.WARNING, "Failed to shutdown the client of " + uri, e); + } } }; } catch (Exception e) { @@ -289,18 +336,24 @@ private static class ForwardingListener implements WebSocket.Listener { private final Consumer sessionConsumer; private final SessionId sessionId; private final AtomicBoolean connectionReleased; + private final HttpClient client; + private final AtomicBoolean upstreamClosing; public ForwardingListener( Node node, Consumer downstream, Consumer sessionConsumer, SessionId sessionId, - AtomicBoolean connectionReleased) { + AtomicBoolean connectionReleased, + HttpClient client, + AtomicBoolean upstreamClosing) { this.node = node; this.downstream = Objects.requireNonNull(downstream); this.sessionConsumer = Objects.requireNonNull(sessionConsumer); this.sessionId = Objects.requireNonNull(sessionId); this.connectionReleased = Objects.requireNonNull(connectionReleased); + this.client = Objects.requireNonNull(client); + this.upstreamClosing = Objects.requireNonNull(upstreamClosing); } @Override @@ -311,9 +364,19 @@ public void onBinary(byte[] data) { @Override public void onClose(int code, String reason) { + // Signal the send lambda before forwarding the close downstream so that any data frames + // still queued in the Netty pipeline are discarded rather than attempted on a closing stream. + upstreamClosing.set(true); downstream.accept(new CloseMessage(code, reason)); if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + // Close the HttpClient eagerly so the connection slot is freed even if the client-side + // Close echo never arrives (e.g. the client dropped the TCP connection). + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client on upstream WebSocket close", e); + } } } @@ -325,9 +388,15 @@ public void onText(CharSequence data) { @Override public void onError(Throwable cause) { + upstreamClosing.set(true); LOG.log(Level.WARNING, "Error proxying websocket command", cause); if (connectionReleased.compareAndSet(false, true)) { node.releaseConnection(sessionId); + try { + client.close(); + } catch (Exception e) { + LOG.log(Level.FINE, "Failed to close client after WebSocket error", e); + } } } } diff --git a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java index 6f78c69340a62..ba27c6aa76402 100644 --- a/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java +++ b/java/src/org/openqa/selenium/grid/router/httpd/RouterServer.java @@ -32,15 +32,19 @@ import java.io.Closeable; import java.io.IOException; import java.io.UncheckedIOException; +import java.net.URI; import java.net.URL; import java.time.Duration; import java.util.Collections; import java.util.Map; +import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Stream; import org.openqa.selenium.BuildInfo; +import org.openqa.selenium.NoSuchSessionException; import org.openqa.selenium.UsernameAndPassword; import org.openqa.selenium.cli.CliCommand; import org.openqa.selenium.grid.TemplateGridServerCommand; @@ -67,6 +71,8 @@ import org.openqa.selenium.grid.sessionqueue.remote.RemoteNewSessionQueue; import org.openqa.selenium.grid.web.GridUiRoute; import org.openqa.selenium.internal.Require; +import org.openqa.selenium.remote.HttpSessionId; +import org.openqa.selenium.remote.SessionId; import org.openqa.selenium.remote.http.ClientConfig; import org.openqa.selenium.remote.http.Contents; import org.openqa.selenium.remote.http.HttpClient; @@ -183,7 +189,25 @@ protected Handlers createHandlers(Config config) { // access to it. Routable routeWithLiveness = Route.combine(route, get("/readyz").to(() -> readinessCheck)); - return new Handlers(routeWithLiveness, new ProxyWebsocketsIntoGrid(clientFactory, sessions)) { + // Resolve a request URI to the Node URI for direct TCP tunnelling of WebSocket connections. + // Falls back to ProxyWebsocketsIntoGrid (the websocketHandler) when the session is not found. + Function> tcpTunnelResolver = + uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(sessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + + return new Handlers( + routeWithLiveness, + new ProxyWebsocketsIntoGrid(clientFactory, sessions), + tcpTunnelResolver) { @Override public void close() { router.close(); diff --git a/java/src/org/openqa/selenium/netty/server/NettyServer.java b/java/src/org/openqa/selenium/netty/server/NettyServer.java index 4423129736d12..2331eecfb1883 100644 --- a/java/src/org/openqa/selenium/netty/server/NettyServer.java +++ b/java/src/org/openqa/selenium/netty/server/NettyServer.java @@ -36,11 +36,13 @@ import java.net.BindException; import java.net.InetSocketAddress; import java.net.MalformedURLException; +import java.net.URI; import java.net.URL; import java.security.cert.CertificateException; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import javax.net.ssl.SSLException; import org.openqa.selenium.grid.server.BaseServerOptions; import org.openqa.selenium.grid.server.Server; @@ -62,6 +64,7 @@ public class NettyServer implements Server { private final BiFunction, Optional>> websocketHandler; private final SslContext sslCtx; private final boolean allowCors; + private final Function> tcpTunnelResolver; private Channel channel; @@ -73,9 +76,28 @@ public NettyServer( BaseServerOptions options, HttpHandler handler, BiFunction, Optional>> websocketHandler) { + this(options, handler, websocketHandler, null); + } + + /** + * Creates a {@link NettyServer} with an optional TCP-level tunnel resolver for WebSocket + * connections. When {@code tcpTunnelResolver} is non-null, WebSocket upgrade requests that + * contain a Selenium session ID are intercepted before the normal WebSocket handler: the Router + * opens a raw TCP connection to the resolved Node URI and bridges the two sockets directly, + * removing itself from the WebSocket data path entirely. + * + * @param tcpTunnelResolver maps a request URI to the target Node URI. Return {@link + * Optional#empty()} to fall through to the normal WebSocket handler. + */ + public NettyServer( + BaseServerOptions options, + HttpHandler handler, + BiFunction, Optional>> websocketHandler, + Function> tcpTunnelResolver) { Require.nonNull("Server options", options); Require.nonNull("Handler", handler); this.websocketHandler = Require.nonNull("Factory for websocket connections", websocketHandler); + this.tcpTunnelResolver = tcpTunnelResolver; InternalLoggerFactory.setDefaultFactory(JdkLoggerFactory.INSTANCE); @@ -155,7 +177,9 @@ public NettyServer start() { b.group(bossGroup, workerGroup) .channel(NioServerSocketChannel.class) .handler(new LoggingHandler(LogLevel.DEBUG)) - .childHandler(new SeleniumHttpInitializer(sslCtx, handler, websocketHandler, allowCors)); + .childHandler( + new SeleniumHttpInitializer( + sslCtx, handler, websocketHandler, allowCors, tcpTunnelResolver)); try { // Using a flag to avoid binding to the host, useful in environments like Docker, diff --git a/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java b/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java index 532b87e9f950e..19326a28ff104 100644 --- a/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java +++ b/java/src/org/openqa/selenium/netty/server/SeleniumHttpInitializer.java @@ -25,9 +25,11 @@ import io.netty.handler.ssl.SslContext; import io.netty.handler.stream.ChunkedWriteHandler; import io.netty.util.AttributeKey; +import java.net.URI; import java.util.Optional; import java.util.function.BiFunction; import java.util.function.Consumer; +import java.util.function.Function; import org.openqa.selenium.internal.Require; import org.openqa.selenium.remote.http.HttpHandler; import org.openqa.selenium.remote.http.Message; @@ -40,16 +42,27 @@ class SeleniumHttpInitializer extends ChannelInitializer { private final BiFunction, Optional>> webSocketHandler; private final SslContext sslCtx; private final boolean allowCors; + private final Function> tcpTunnelResolver; SeleniumHttpInitializer( SslContext sslCtx, HttpHandler seleniumHandler, BiFunction, Optional>> webSocketHandler, boolean allowCors) { + this(sslCtx, seleniumHandler, webSocketHandler, allowCors, null); + } + + SeleniumHttpInitializer( + SslContext sslCtx, + HttpHandler seleniumHandler, + BiFunction, Optional>> webSocketHandler, + boolean allowCors, + Function> tcpTunnelResolver) { this.sslCtx = sslCtx; this.seleniumHandler = Require.nonNull("HTTP handler", seleniumHandler); this.webSocketHandler = Require.nonNull("WebSocket handler", webSocketHandler); this.allowCors = allowCors; + this.tcpTunnelResolver = tcpTunnelResolver; } @Override @@ -63,6 +76,10 @@ protected void initChannel(SocketChannel ch) { // Websocket magic ch.pipeline().addLast("ws-compression", new WebSocketServerCompressionHandler()); + // TCP tunnel intercepts WS upgrades before the normal WS handler when configured. + if (tcpTunnelResolver != null) { + ch.pipeline().addLast("tcp-tunnel", new TcpUpgradeTunnelHandler(tcpTunnelResolver)); + } ch.pipeline().addLast("ws-protocol", new WebSocketUpgradeHandler(KEY, webSocketHandler)); ch.pipeline().addLast("netty-to-se-messages", new MessageInboundConverter()); ch.pipeline().addLast("se-to-netty-messages", new MessageOutboundConverter()); diff --git a/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java b/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java new file mode 100644 index 0000000000000..b87b94fd9e95f --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/TcpTunnelHandler.java @@ -0,0 +1,72 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * Forwards every inbound {@link io.netty.buffer.ByteBuf} to a target {@link Channel}. Used on both + * ends of a transparent TCP tunnel once the WebSocket upgrade handshake has been proxied. + */ +class TcpTunnelHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = Logger.getLogger(TcpTunnelHandler.class.getName()); + + private final Channel target; + + TcpTunnelHandler(Channel target) { + this.target = target; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + target + .writeAndFlush(msg) + .addListener( + future -> { + if (!future.isSuccess()) { + LOG.log( + Level.WARNING, + "TCP tunnel write failed on " + + ctx.channel() + + " -> " + + target + + ", closing both channels", + future.cause()); + ctx.close(); + target.close(); + } + }); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + target.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.log(Level.WARNING, "TCP tunnel error, closing both channels", cause); + ctx.close(); + target.close(); + } +} diff --git a/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java new file mode 100644 index 0000000000000..985cdf247e6fe --- /dev/null +++ b/java/src/org/openqa/selenium/netty/server/TcpUpgradeTunnelHandler.java @@ -0,0 +1,330 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.netty.server; + +import io.netty.bootstrap.Bootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import java.net.URI; +import java.util.Optional; +import java.util.function.Function; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLException; + +/** + * Netty handler placed in the server pipeline before {@link WebSocketUpgradeHandler}. When it sees + * an HTTP WebSocket upgrade request that carries a Selenium session ID, it resolves the Node URI + * and establishes a transparent TCP tunnel, removing the Router from the data path entirely. + * + *

If no Node URI is found for the session (or the request is not a WS upgrade), the request is + * passed to the next handler in the pipeline (falling through to {@link WebSocketUpgradeHandler}). + * + *

If the Node URI uses {@code https}, an SSL handler is added to the node-side channel so that + * the Router transparently terminates TLS with the client and re-establishes it with the Node. + * + *

If the TCP connect to the Node fails (e.g. the Node is unreachable in a Kubernetes + * port-forward topology), the original upgrade request is fired back through the pipeline so the + * normal {@link WebSocketUpgradeHandler} / {@code ProxyWebsocketsIntoGrid} path can handle it. + */ +class TcpUpgradeTunnelHandler extends ChannelInboundHandlerAdapter { + + private static final Logger LOG = Logger.getLogger(TcpUpgradeTunnelHandler.class.getName()); + + /** + * Lazily-initialised, process-wide SSL context used when connecting to HTTPS nodes. All node + * certificates are trusted because Grid nodes commonly use self-signed certificates for internal + * cluster communication. The external client↔Router TLS boundary is separate and unaffected. + */ + private static volatile SslContext clientSslContext; + + private final Function> nodeUriResolver; + + /** + * @param nodeUriResolver maps an HTTP request URI (e.g. {@code /session//bidi}) to the Node + * URI. Return {@link Optional#empty()} to fall through to the normal WS handler. + */ + TcpUpgradeTunnelHandler(Function> nodeUriResolver) { + this.nodeUriResolver = nodeUriResolver; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.fireChannelRead(msg); + return; + } + + HttpRequest req = (HttpRequest) msg; + + if (!isWebSocketUpgrade(req)) { + ctx.fireChannelRead(req); + return; + } + + String uri = req.uri(); + Optional maybeNodeUri = nodeUriResolver.apply(uri); + + if (maybeNodeUri.isEmpty()) { + ctx.fireChannelRead(req); + return; + } + + URI nodeUri = maybeNodeUri.get(); + Channel clientChannel = ctx.channel(); + + // Pause client reads while connecting so we don't lose or mis-process data. + clientChannel.config().setAutoRead(false); + + boolean useTls = "https".equalsIgnoreCase(nodeUri.getScheme()); + int port = nodeUri.getPort() != -1 ? nodeUri.getPort() : (useTls ? 443 : 80); + String host = nodeUri.getHost(); + + SslContext nodeSslCtx = null; + if (useTls) { + try { + nodeSslCtx = buildClientSslContext(); + } catch (SSLException e) { + LOG.log( + Level.WARNING, + "Failed to build SSL context for HTTPS node at " + + host + + ":" + + port + + ", falling back to WebSocket handler", + e); + clientChannel.config().setAutoRead(true); + ctx.fireChannelRead(req); + return; + } + } + final SslContext finalNodeSslCtx = nodeSslCtx; + + Bootstrap bootstrap = + new Bootstrap() + .group(clientChannel.eventLoop()) + .channel(NioSocketChannel.class) + .handler( + new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + if (finalNodeSslCtx != null) { + // SSL handler must be first so the codec operates on plaintext. + ch.pipeline() + .addLast("ssl", finalNodeSslCtx.newHandler(ch.alloc(), host, port)); + } + ch.pipeline().addLast("http-codec", new HttpClientCodec()); + ch.pipeline() + .addLast( + "upgrade-handler", new NodeUpgradeResponseHandler(clientChannel, req)); + } + }); + + ChannelFuture connectFuture = bootstrap.connect(host, port); + connectFuture.addListener( + future -> { + if (!future.isSuccess()) { + // The Node is unreachable (wrong network, K8s port-forward topology, etc.). + // Re-enable reads and pass the request to the next handler so that + // ProxyWebsocketsIntoGrid can try to handle it via its own HTTP client. + LOG.log( + Level.WARNING, + "TCP tunnel connect failed for " + + host + + ":" + + port + + ", falling back to WebSocket handler", + future.cause()); + clientChannel.config().setAutoRead(true); + ctx.fireChannelRead(req); + } + // On success, NodeUpgradeResponseHandler.channelActive sends the request. + }); + } + + private static boolean isWebSocketUpgrade(HttpRequest req) { + return req.headers().containsValue(HttpHeaderNames.CONNECTION, "Upgrade", true) + && req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + } + + private static SslContext buildClientSslContext() throws SSLException { + if (clientSslContext == null) { + synchronized (TcpUpgradeTunnelHandler.class) { + if (clientSslContext == null) { + // InsecureTrustManagerFactory is appropriate here: Grid nodes commonly use self-signed + // certificates for intra-cluster communication, and the trust boundary that matters to + // end users is the client↔Router TLS connection, not this Router↔Node hop. + clientSslContext = + SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + } + } + } + return clientSslContext; + } + + // --------------------------------------------------------------------------- + // Inner handler attached to the node-side channel + // --------------------------------------------------------------------------- + + private static final class NodeUpgradeResponseHandler extends ChannelInboundHandlerAdapter { + + private final Channel clientChannel; + private final HttpRequest upgradeRequest; + private boolean tunnelEstablished = false; + + NodeUpgradeResponseHandler(Channel clientChannel, HttpRequest upgradeRequest) { + this.clientChannel = clientChannel; + this.upgradeRequest = upgradeRequest; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + // Forward the original upgrade request to the Node. + DefaultHttpRequest nodeReq = + new DefaultHttpRequest( + upgradeRequest.protocolVersion(), upgradeRequest.method(), upgradeRequest.uri()); + nodeReq.headers().set(upgradeRequest.headers()); + ctx.writeAndFlush(nodeReq); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + try { + if (tunnelEstablished || !(msg instanceof HttpObject)) { + // Tunnel is live or not HTTP; any stale buffered data is discarded. + return; + } + + if (!(msg instanceof HttpResponse)) { + // LastHttpContent or other codec artefact before the 101 — skip. + return; + } + + HttpResponse resp = (HttpResponse) msg; + + if (resp.status().code() != 101) { + LOG.warning("Node rejected WebSocket upgrade: " + resp.status()); + ctx.close(); + clientChannel.close(); + return; + } + + tunnelEstablished = true; + Channel nodeChannel = ctx.channel(); + + // Build a proper Netty HTTP 101 response, copying all headers from the Node's response. + // Writing a DefaultFullHttpResponse goes through HttpResponseEncoder, which correctly + // encodes it, and the HttpServerKeepAliveHandler does not close the channel for 101. + DefaultFullHttpResponse clientResponse = + new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.SWITCHING_PROTOCOLS, + Unpooled.EMPTY_BUFFER); + clientResponse.headers().set(resp.headers()); + + clientChannel + .writeAndFlush(clientResponse) + .addListener( + writeFuture -> { + if (!writeFuture.isSuccess()) { + LOG.log( + Level.WARNING, + "Failed to write 101 response to client", + writeFuture.cause()); + clientChannel.close(); + nodeChannel.close(); + return; + } + + // Rewire node channel: remove HTTP codec and this handler, add byte tunnel. + // The "ssl" handler (if present) is intentionally left in place — it + // transparently handles TLS framing for the raw byte stream. + nodeChannel.pipeline().remove("upgrade-handler"); + nodeChannel.pipeline().remove("http-codec"); + nodeChannel.pipeline().addLast("tunnel", new TcpTunnelHandler(clientChannel)); + + // Rewire client channel: replace the tcp-tunnel intercept handler with a raw + // byte tunnel, then strip remaining HTTP/WS handlers that are no longer needed. + ChannelPipeline cp = clientChannel.pipeline(); + cp.replace("tcp-tunnel", "tunnel", new TcpTunnelHandler(nodeChannel)); + for (String name : + new String[] { + "codec", + "keep-alive", + "chunked-write", + "ws-compression", + "ws-protocol", + "netty-to-se-messages", + "se-to-netty-messages", + "se-websocket-handler", + "se-request", + "se-response", + "se-handler" + }) { + if (cp.get(name) != null) { + cp.remove(name); + } + } + + // Re-enable reads on the client now that the tunnel is live. + clientChannel.config().setAutoRead(true); + }); + + } finally { + ReferenceCountUtil.release(msg); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + if (!tunnelEstablished) { + LOG.warning("Node channel closed before tunnel was established"); + clientChannel.close(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + LOG.log(Level.WARNING, "Error during node upgrade handshake", cause); + ctx.close(); + clientChannel.close(); + } + } +} diff --git a/java/test/org/openqa/selenium/grid/router/BUILD.bazel b/java/test/org/openqa/selenium/grid/router/BUILD.bazel index 4b8096bb9bd24..d8fea7a2f56c4 100644 --- a/java/test/org/openqa/selenium/grid/router/BUILD.bazel +++ b/java/test/org/openqa/selenium/grid/router/BUILD.bazel @@ -139,6 +139,10 @@ java_test_suite( "//java/src/org/openqa/selenium/firefox", "//java/src/org/openqa/selenium/grid", "//java/src/org/openqa/selenium/grid/config", + "//java/src/org/openqa/selenium/grid/distributor", + "//java/src/org/openqa/selenium/grid/distributor/local", + "//java/src/org/openqa/selenium/grid/distributor/selector", + "//java/src/org/openqa/selenium/grid/node/local", "//java/src/org/openqa/selenium/json", "//java/src/org/openqa/selenium/remote", "//java/src/org/openqa/selenium/support", diff --git a/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java new file mode 100644 index 0000000000000..cc09ef5377027 --- /dev/null +++ b/java/test/org/openqa/selenium/grid/router/TunnelWebsocketTest.java @@ -0,0 +1,635 @@ +// Licensed to the Software Freedom Conservancy (SFC) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The SFC licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.openqa.selenium.grid.router; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.openqa.selenium.remote.Dialect.W3C; +import static org.openqa.selenium.remote.http.HttpMethod.GET; + +import java.net.ServerSocket; +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URL; +import java.time.Duration; +import java.time.Instant; +import java.util.Collections; +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.openqa.selenium.ImmutableCapabilities; +import org.openqa.selenium.MutableCapabilities; +import org.openqa.selenium.NoSuchSessionException; +import org.openqa.selenium.SessionNotCreatedException; +import org.openqa.selenium.events.EventBus; +import org.openqa.selenium.events.local.GuavaEventBus; +import org.openqa.selenium.grid.config.MapConfig; +import org.openqa.selenium.grid.data.CreateSessionResponse; +import org.openqa.selenium.grid.data.DefaultSlotMatcher; +import org.openqa.selenium.grid.data.RequestId; +import org.openqa.selenium.grid.data.Session; +import org.openqa.selenium.grid.data.SessionRequest; +import org.openqa.selenium.grid.distributor.local.LocalDistributor; +import org.openqa.selenium.grid.distributor.selector.DefaultSlotSelector; +import org.openqa.selenium.grid.node.local.LocalNode; +import org.openqa.selenium.grid.security.Secret; +import org.openqa.selenium.grid.server.BaseServerOptions; +import org.openqa.selenium.grid.server.Server; +import org.openqa.selenium.grid.sessionmap.SessionMap; +import org.openqa.selenium.grid.sessionmap.local.LocalSessionMap; +import org.openqa.selenium.grid.sessionqueue.local.LocalNewSessionQueue; +import org.openqa.selenium.grid.testing.PassthroughHttpClient; +import org.openqa.selenium.grid.testing.TestSessionFactory; +import org.openqa.selenium.internal.Either; +import org.openqa.selenium.netty.server.NettyServer; +import org.openqa.selenium.remote.HttpSessionId; +import org.openqa.selenium.remote.SessionId; +import org.openqa.selenium.remote.http.BinaryMessage; +import org.openqa.selenium.remote.http.HttpClient; +import org.openqa.selenium.remote.http.HttpHandler; +import org.openqa.selenium.remote.http.HttpRequest; +import org.openqa.selenium.remote.http.HttpResponse; +import org.openqa.selenium.remote.http.TextMessage; +import org.openqa.selenium.remote.http.WebSocket; +import org.openqa.selenium.remote.tracing.DefaultTestTracer; +import org.openqa.selenium.remote.tracing.Tracer; +import org.openqa.selenium.support.ui.FluentWait; + +class TunnelWebsocketTest { + + private final HttpHandler nullHandler = req -> new HttpResponse(); + private final MapConfig emptyConfig = new MapConfig(Collections.emptyMap()); + + private Server tunnelServer; + private Server backendServer; + private SessionMap sessions; + + @BeforeEach + void setUp() { + Tracer tracer = DefaultTestTracer.createTracer(); + EventBus events = new GuavaEventBus(); + sessions = new LocalSessionMap(tracer, events); + } + + @AfterEach + void tearDown() { + if (tunnelServer != null) { + tunnelServer.stop(); + } + if (backendServer != null) { + backendServer.stop(); + } + } + + private Function> createResolver() { + return uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(sessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + } + + private Server createEchoBackend( + String response, CountDownLatch receivedLatch, AtomicReference received) { + return new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + receivedLatch.countDown(); + if (!response.isEmpty()) { + sink.accept(new TextMessage(response)); + } + } + })) + .start(); + } + + @Test + void shouldForwardTextMessageToBackend() throws URISyntaxException, InterruptedException { + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = createEchoBackend("", latch, received); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendText("Hello tunnel"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("Hello tunnel"); + } + } + + @Test + void shouldForwardTextMessageFromBackendToClient() + throws URISyntaxException, InterruptedException { + backendServer = createEchoBackend("pong", new CountDownLatch(1), new AtomicReference<>()); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + CountDownLatch latch = new CountDownLatch(1); + AtomicReference reply = new AtomicReference<>(); + + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + reply.set(data.toString()); + latch.countDown(); + } + })) { + + socket.sendText("ping"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(reply.get()).isEqualTo("pong"); + } + } + + @Test + void shouldForwardBinaryMessages() throws URISyntaxException, InterruptedException { + byte[] payload = new byte[] {1, 2, 3, 4}; + + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof BinaryMessage) { + received.set(((BinaryMessage) msg).data()); + latch.countDown(); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendBinary(payload); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo(payload); + } + } + + @Test + void shouldFallBackToWebSocketHandlerWhenSessionNotFound() { + // No session in the map — tunnel resolver returns empty, falling through to the WS handler + // which also returns empty. WebSocketUpgradeHandler responds with 400 Bad Request. + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + SessionId unknownId = new SessionId(UUID.randomUUID()); + + boolean exceptionThrown = false; + try { + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + unknownId + "/bidi"), new WebSocket.Listener() {}); + } catch (Exception e) { + // Expected: connection is rejected (400) because the session is not in the map. + exceptionThrown = true; + } + assertThat(exceptionThrown).as("Expected openSocket to fail for unknown session").isTrue(); + } + + @Test + void shouldFallBackToWebSocketHandlerWhenNodeIsUnreachable() throws Exception { + // Allocate a port then immediately close the socket so nothing is listening on it. + int closedPort; + try (ServerSocket ss = new ServerSocket(0)) { + closedPort = ss.getLocalPort(); + } + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + new URI("http://127.0.0.1:" + closedPort), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + boolean exceptionThrown = false; + try { + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {}); + } catch (Exception e) { + // Expected: TCP connect fails, falls back to the WS handler (which returns empty) → 400. + // The important thing is a graceful rejection, not an abrupt channel close. + exceptionThrown = true; + } + assertThat(exceptionThrown) + .as("Expected openSocket to fail gracefully when node is unreachable") + .isTrue(); + } + + @Test + void shouldTunnelWebSocketThroughHttpsNode() throws URISyntaxException, InterruptedException { + // Start the backend with a self-signed certificate so its URL is https://. + // The tunnel handler detects the https scheme and adds a TLS handler on the node-side channel. + MapConfig httpsConfig = + new MapConfig(Map.of("server", Map.of("https-self-signed", true, "hostname", "localhost"))); + AtomicReference received = new AtomicReference<>(); + CountDownLatch latch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(httpsConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + latch.countDown(); + } + })) + .start(); + + // backendServer.getUrl() is now https://localhost: + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + socket.sendText("secure-hello"); + + assertThat(latch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("secure-hello"); + } + } + + @Test + void shouldSupportMultipleMessagesOnSameConnection() + throws URISyntaxException, InterruptedException { + int messageCount = 5; + CountDownLatch latch = new CountDownLatch(messageCount); + AtomicReference count = new AtomicReference<>(0); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + count.updateAndGet(c -> c + 1); + latch.countDown(); + } + })) + .start(); + + SessionId id = new SessionId(UUID.randomUUID()); + sessions.add( + new Session( + id, + backendServer.getUrl().toURI(), + new ImmutableCapabilities(), + new ImmutableCapabilities(), + Instant.now())); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> Optional.empty(), + createResolver()) + .start(); + + HttpClient.Factory factory = HttpClient.Factory.createDefault(); + try (WebSocket socket = + factory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, "/session/" + id + "/bidi"), new WebSocket.Listener() {})) { + + for (int i = 0; i < messageCount; i++) { + socket.sendText("msg-" + i); + } + + assertThat(latch.await(10, SECONDS)).isTrue(); + assertThat(count.get()).isEqualTo(messageCount); + } + } + + /** + * Integration test that exercises the full Grid session lifecycle with BiDi enabled. + * + *

Flow: client requests {@code webSocketUrl: true} → LocalDistributor → LocalNode + * (TestSessionFactory returns {@code webSocketUrl} capability pointing to the Router) → + * LocalSessionMap registration → client reads {@code webSocketUrl} from capabilities → connects + * to Router BiDi WebSocket → TCP tunnel → stub backend server. + * + *

This mirrors what a real WebDriver BiDi client does: request {@code webSocketUrl: true}, + * receive a {@code webSocketUrl} in the response capabilities, and connect to it. + */ + @Test + void shouldTunnelBiDiThroughFullGridSessionLifecycle() + throws URISyntaxException, InterruptedException { + // Stub backend — simulates a Node's BiDi WebSocket endpoint. Echoes a fixed reply. + AtomicReference received = new AtomicReference<>(); + CountDownLatch receivedLatch = new CountDownLatch(1); + AtomicReference reply = new AtomicReference<>(); + CountDownLatch replyLatch = new CountDownLatch(1); + + backendServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + (uri, sink) -> + Optional.of( + msg -> { + if (msg instanceof TextMessage) { + received.set(((TextMessage) msg).text()); + receivedLatch.countDown(); + sink.accept(new TextMessage("bidi-ack")); + } + })) + .start(); + + URI backendUri = backendServer.getUrl().toURI(); + + // Wire up in-process Grid components — mirrors how Standalone sets up the session path. + Tracer tracer = DefaultTestTracer.createTracer(); + GuavaEventBus bus = new GuavaEventBus(); + Secret secret = new Secret("test"); + ImmutableCapabilities stereotype = new ImmutableCapabilities("browserName", "chrome"); + + LocalSessionMap gridSessions = new LocalSessionMap(tracer, bus); + LocalNewSessionQueue queue = + new LocalNewSessionQueue( + tracer, + new DefaultSlotMatcher(), + Duration.ofSeconds(2), + Duration.ofSeconds(5), + Duration.ofSeconds(1), + secret, + 5); + + // routerUrl is set after tunnelServer starts so the TestSessionFactory can embed the Router's + // WebSocket URL in the returned webSocketUrl capability (the real Grid does the same). + AtomicReference routerUrl = new AtomicReference<>(); + + // TestSessionFactory: session URI → backendServer (so the TCP tunnel connects there). + // The returned capabilities include webSocketUrl pointing to the Router's BiDi endpoint, + // which is what a real Node would return after the Router rewrites the capability. + LocalNode node = + LocalNode.builder(tracer, bus, backendUri, backendUri, secret) + .add( + stereotype, + new TestSessionFactory( + stereotype, + (id, caps) -> { + URL rUrl = routerUrl.get(); + MutableCapabilities returnedCaps = new MutableCapabilities(caps); + returnedCaps.setCapability( + "webSocketUrl", + "ws://" + + rUrl.getHost() + + ":" + + rUrl.getPort() + + "/session/" + + id + + "/bidi"); + return new Session(id, backendUri, stereotype, returnedCaps, Instant.now()); + })) + .build(); + + LocalDistributor distributor = + new LocalDistributor( + tracer, + bus, + new PassthroughHttpClient.Factory(node), + gridSessions, + queue, + new DefaultSlotSelector(), + secret, + Duration.ofMinutes(5), + false, + Duration.ofSeconds(5), + Runtime.getRuntime().availableProcessors(), + new DefaultSlotMatcher(), + Duration.ofSeconds(30)); + distributor.add(node); + + // Wait for node capacity, then start the Router so routerUrl is known before newSession(). + new FluentWait<>(distributor) + .withTimeout(Duration.ofSeconds(5)) + .pollingEvery(Duration.ofMillis(100)) + .until(d -> d.getStatus().hasCapacity()); + + HttpClient.Factory clientFactory = HttpClient.Factory.createDefault(); + Function> tcpTunnelResolver = + uri -> + HttpSessionId.getSessionId(uri) + .map(SessionId::new) + .flatMap( + id -> { + try { + return Optional.of(gridSessions.getUri(id)); + } catch (NoSuchSessionException e) { + return Optional.empty(); + } + }); + + tunnelServer = + new NettyServer( + new BaseServerOptions(emptyConfig), + nullHandler, + new ProxyWebsocketsIntoGrid(clientFactory, gridSessions), + tcpTunnelResolver) + .start(); + routerUrl.set(tunnelServer.getUrl()); + + // Create a session with webSocketUrl: true — BiDi explicitly enabled by the client. + // LocalDistributor registers the session in gridSessions automatically. + SessionRequest sessionRequest = + new SessionRequest( + new RequestId(UUID.randomUUID()), + Instant.now(), + Set.of(W3C), + Set.of(new ImmutableCapabilities("browserName", "chrome", "webSocketUrl", true)), + Map.of(), + Map.of()); + Either result = + distributor.newSession(sessionRequest); + assertThat(result.isRight()).as("Session creation should succeed").isTrue(); + + // Read webSocketUrl from the returned capabilities — this is how a real client locates the + // BiDi endpoint, not by constructing the path manually. + // The capabilities system deserialises URL-like strings as URI objects, so avoid casting. + Object webSocketUrlCap = + result.right().getSession().getCapabilities().getCapability("webSocketUrl"); + assertThat(webSocketUrlCap).as("webSocketUrl capability must be present").isNotNull(); + + // Connect using the path from webSocketUrl (e.g. /session//bidi). + // The host/port points at the Router which uses the TCP tunnel to reach the backend. + String wsPath = new URI(webSocketUrlCap.toString()).getPath(); + + try (WebSocket socket = + clientFactory + .createClient(tunnelServer.getUrl()) + .openSocket( + new HttpRequest(GET, wsPath), + new WebSocket.Listener() { + @Override + public void onText(CharSequence data) { + reply.set(data.toString()); + replyLatch.countDown(); + } + })) { + + socket.sendText("{\"method\":\"session.new\"}"); + + // Verify client → backend direction. + assertThat(receivedLatch.await(5, SECONDS)).isTrue(); + assertThat(received.get()).isEqualTo("{\"method\":\"session.new\"}"); + + // Verify backend → client direction (the echo reply). + assertThat(replyLatch.await(5, SECONDS)).isTrue(); + assertThat(reply.get()).isEqualTo("bidi-ack"); + } + + distributor.close(); + bus.close(); + } +}