Skip to content

Commit

Permalink
fix(webSocketRoute): allow no trailing slash in route matching
Browse files Browse the repository at this point in the history
  • Loading branch information
mxschmitt committed Dec 12, 2024
1 parent 4f2cdde commit 98e4bca
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 1 deletion.
6 changes: 5 additions & 1 deletion playwright/_impl/_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
Union,
cast,
)
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse

from playwright._impl._api_structures import NameValue
from playwright._impl._errors import (
Expand Down Expand Up @@ -157,6 +157,10 @@ def url_matches(
base_url = re.sub(r"^http", "ws", base_url)
if base_url:
match = urljoin(base_url, match)
parsed = urlparse(match)
if parsed.path == "":
parsed = parsed._replace(path="/")
match = parsed.geturl()
if isinstance(match, str):
match = glob_to_regex(match)
if isinstance(match, Pattern):
Expand Down
40 changes: 40 additions & 0 deletions tests/async/test_route_web_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,3 +346,43 @@ async def _handle_ws(ws: WebSocketRoute) -> None:
f"message: data=echo origin=ws://localhost:{server.PORT} lastEventId=",
],
)


async def test_should_work_with_no_trailing_slash(page: Page, server: Server) -> None:
log: list[str] = []

async def handle_ws(ws: WebSocketRoute) -> None:
def on_message(message: Union[str, bytes]) -> None:
if isinstance(message, bytes):
message = message.decode()
log.append(message)
ws.send("response")

ws.on_message(on_message)

# No trailing slash in the route pattern
await page.route_web_socket(f"ws://localhost:{server.PORT}", handle_ws)

await page.goto("about:blank")
await page.evaluate(
"""({ port }) => {
window.log = [];
// No trailing slash in WebSocket URL
window.ws = new WebSocket('ws://localhost:' + port);
window.ws.addEventListener('message', event => window.log.push(event.data));
}""",
{"port": server.PORT},
)

# Wait for WebSocket to be ready
await assert_equal(
lambda: page.evaluate("window.ws.readyState"), 1 # WebSocket.OPEN
)

await page.evaluate("window.ws.send('query')")

# Verify server received message
await assert_equal(lambda: log, ["query"])

# Verify client received response
await assert_equal(lambda: page.evaluate("window.log"), ["response"])
38 changes: 38 additions & 0 deletions tests/sync/test_route_web_socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,3 +340,41 @@ def _handle_ws(ws: WebSocketRoute) -> None:
f"message: data=echo origin=ws://localhost:{server.PORT} lastEventId=",
],
)


def test_should_work_with_no_trailing_slash(page: Page, server: Server) -> None:
log: list[str] = []

async def handle_ws(ws: WebSocketRoute) -> None:
def on_message(message: Union[str, bytes]) -> None:
if isinstance(message, bytes):
message = message.decode()
log.append(message)
ws.send("response")

ws.on_message(on_message)

# No trailing slash in the route pattern
page.route_web_socket(f"ws://localhost:{server.PORT}", handle_ws)

page.goto("about:blank")
page.evaluate(
"""({ port }) => {
window.log = [];
// No trailing slash in WebSocket URL
window.ws = new WebSocket('ws://localhost:' + port);
window.ws.addEventListener('message', event => window.log.push(event.data));
}""",
{"port": server.PORT},
)

# Wait for WebSocket to be ready
assert_equal(lambda: page.evaluate("window.ws.readyState"), 1) # WebSocket.OPEN

page.evaluate("window.ws.send('query')")

# Verify server received message
assert_equal(lambda: log, ["query"])

# Verify client received response
assert_equal(lambda: page.evaluate("window.log"), ["response"])

0 comments on commit 98e4bca

Please sign in to comment.