Skip to content

Commit

Permalink
Improve tests for sync implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Feb 11, 2024
1 parent 9b5273c commit de768cf
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 50 deletions.
8 changes: 4 additions & 4 deletions tests/sync/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ def assertEval(self, client, expr, value):


@contextlib.contextmanager
def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs):
with serve(ws_handler, host, port, **kwargs) as server:
def run_server(handler=eval_shell, host="localhost", port=0, **kwargs):
with serve(handler, host, port, **kwargs) as server:
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
Expand All @@ -37,8 +37,8 @@ def run_server(ws_handler=eval_shell, host="localhost", port=0, **kwargs):


@contextlib.contextmanager
def run_unix_server(path, ws_handler=eval_shell, **kwargs):
with unix_serve(ws_handler, path, **kwargs) as server:
def run_unix_server(path, handler=eval_shell, **kwargs):
with unix_serve(handler, path, **kwargs) as server:
thread = threading.Thread(target=server.serve_forever)
thread.start()
try:
Expand Down
59 changes: 31 additions & 28 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import threading
import unittest

from websockets.exceptions import InvalidHandshake
from websockets.exceptions import InvalidHandshake, InvalidURI
from websockets.extensions.permessage_deflate import PerMessageDeflate
from websockets.sync.client import *

Expand All @@ -25,29 +25,6 @@ def test_connection(self):
with run_client(server) as client:
self.assertEqual(client.protocol.state.name, "OPEN")

def test_connection_fails(self):
"""Client connects to server but the handshake fails."""

def remove_accept_header(self, request, response):
del response.headers["Sec-WebSocket-Accept"]

# The connection will be open for the server but failed for the client.
# Use a connection handler that exits immediately to avoid an exception.
with run_server(do_nothing, process_response=remove_accept_header) as server:
with self.assertRaises(InvalidHandshake) as raised:
with run_client(server, close_timeout=MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"missing Sec-WebSocket-Accept header",
)

def test_tcp_connection_fails(self):
"""Client fails to connect to server."""
with self.assertRaises(OSError):
with run_client("ws://localhost:54321"): # invalid port
self.fail("did not raise")

def test_existing_socket(self):
"""Client connects using a pre-existing socket."""
with run_server() as server:
Expand Down Expand Up @@ -103,6 +80,35 @@ def create_connection(*args, **kwargs):
with run_client(server, create_connection=create_connection) as client:
self.assertTrue(client.create_connection_ran)

def test_invalid_uri(self):
"""Client receives an invalid URI."""
with self.assertRaises(InvalidURI):
with run_client("http://localhost"): # invalid scheme
self.fail("did not raise")

def test_tcp_connection_fails(self):
"""Client fails to connect to server."""
with self.assertRaises(OSError):
with run_client("ws://localhost:54321"): # invalid port
self.fail("did not raise")

def test_handshake_fails(self):
"""Client connects to server but the handshake fails."""

def remove_accept_header(self, request, response):
del response.headers["Sec-WebSocket-Accept"]

# The connection will be open for the server but failed for the client.
# Use a connection handler that exits immediately to avoid an exception.
with run_server(do_nothing, process_response=remove_accept_header) as server:
with self.assertRaises(InvalidHandshake) as raised:
with run_client(server, close_timeout=MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"missing Sec-WebSocket-Accept header",
)

def test_timeout_during_handshake(self):
"""Client times out before receiving handshake response from server."""
gate = threading.Event()
Expand All @@ -115,10 +121,7 @@ def stall_connection(self, request):
with run_server(do_nothing, process_request=stall_connection) as server:
try:
with self.assertRaises(TimeoutError) as raised:
# While it shouldn't take 50ms to open a connection, this
# test becomes flaky in CI when setting a smaller timeout,
# even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR.
with run_client(server, open_timeout=5 * MS):
with run_client(server, open_timeout=2 * MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
Expand Down
36 changes: 18 additions & 18 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,6 @@ def test_connection(self):
with run_client(server) as client:
self.assertEval(client, "ws.protocol.state.name", "OPEN")

def test_connection_fails(self):
"""Server receives connection from client but the handshake fails."""

def remove_key_header(self, request):
del request.headers["Sec-WebSocket-Key"]

with run_server(process_request=remove_key_header) as server:
with self.assertRaises(InvalidStatus) as raised:
with run_client(server):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 400",
)

def test_connection_handler_returns(self):
"""Connection handler returns."""
with run_server(do_nothing) as server:
Expand Down Expand Up @@ -81,8 +66,8 @@ def test_existing_socket(self):
"""Server receives connection using a pre-existing socket."""
with socket.create_server(("localhost", 0)) as sock:
with run_server(sock=sock):
# Build WebSocket URI to ensure we connect to the right socket.
with run_client("ws://{}:{}/".format(*sock.getsockname())) as client:
uri = "ws://{}:{}/".format(*sock.getsockname())
with run_client(uri) as client:
self.assertEval(client, "ws.protocol.state.name", "OPEN")

def test_select_subprotocol(self):
Expand Down Expand Up @@ -185,7 +170,7 @@ def process_response(ws, request, response):
self.assertEval(client, "ws.process_response_ran", "True")

def test_process_response_override_response(self):
"""Server runs process_response after processing the handshake."""
"""Server runs process_response and overrides the handshake response."""

def process_response(ws, request, response):
headers = response.headers.copy()
Expand Down Expand Up @@ -253,6 +238,21 @@ def create_connection(*args, **kwargs):
with run_client(server) as client:
self.assertEval(client, "ws.create_connection_ran", "True")

def test_handshake_fails(self):
"""Server receives connection from client but the handshake fails."""

def remove_key_header(self, request):
del request.headers["Sec-WebSocket-Key"]

with run_server(process_request=remove_key_header) as server:
with self.assertRaises(InvalidStatus) as raised:
with run_client(server):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 400",
)

def test_timeout_during_handshake(self):
"""Server times out before receiving handshake request from client."""
with run_server(open_timeout=MS) as server:
Expand Down

0 comments on commit de768cf

Please sign in to comment.