Skip to content

Commit

Permalink
Support ws:// to wss:// redirects.
Browse files Browse the repository at this point in the history
Fix #1454.
  • Loading branch information
aaugustin committed Jul 21, 2024
1 parent 650d08c commit e05f6dc
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 5 deletions.
14 changes: 10 additions & 4 deletions src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -558,21 +558,27 @@ def handle_redirect(self, uri: str) -> None:
raise SecurityError("redirect from WSS to WS")

same_origin = (
old_wsuri.host == new_wsuri.host and old_wsuri.port == new_wsuri.port
old_wsuri.secure == new_wsuri.secure
and old_wsuri.host == new_wsuri.host
and old_wsuri.port == new_wsuri.port
)

# Rewrite the host and port arguments for cross-origin redirects.
# Rewrite secure, host, and port for cross-origin redirects.
# This preserves connection overrides with the host and port
# arguments if the redirect points to the same host and port.
if not same_origin:
# Replace the host and port argument passed to the protocol factory.
factory = self._create_connection.args[0]
# Support TLS upgrade.
if not old_wsuri.secure and new_wsuri.secure:
factory.keywords["secure"] = True
self._create_connection.keywords.setdefault("ssl", True)
# Replace secure, host, and port arguments of the protocol factory.
factory = functools.partial(
factory.func,
*factory.args,
**dict(factory.keywords, host=new_wsuri.host, port=new_wsuri.port),
)
# Replace the host and port argument passed to create_connection.
# Replace secure, host, and port arguments of create_connection.
self._create_connection = functools.partial(
self._create_connection.func,
*(factory, new_wsuri.host, new_wsuri.port),
Expand Down
13 changes: 12 additions & 1 deletion tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ async def redirect_request(path, headers, test, status):
location = "/"
elif path == "/infinite":
location = get_server_uri(test.server, test.secure, "/infinite")
elif path == "/force_secure":
location = get_server_uri(test.server, True, "/")
elif path == "/force_insecure":
location = get_server_uri(test.server, False, "/")
elif path == "/missing_location":
Expand Down Expand Up @@ -1290,7 +1292,16 @@ def test_connection_error_during_closing_handshake(self, close):
class ClientServerTests(
CommonClientServerTests, ClientServerTestsMixin, AsyncioTestCase
):
pass

def test_redirect_secure(self):
with temp_test_redirecting_server(self):
# websockets doesn't support serving non-TLS and TLS connections
# from the same server and this test suite makes it difficult to
# run two servers. Therefore, we expect the redirect to create a
# TLS client connection to a non-TLS server, which will fail.
with self.assertRaises(ssl.SSLError):
with self.temp_client("/force_secure"):
self.fail("did not raise")


class SecureClientServerTests(
Expand Down

0 comments on commit e05f6dc

Please sign in to comment.