Skip to content

Commit

Permalink
Standardize style for testing exceptions.
Browse files Browse the repository at this point in the history
  • Loading branch information
aaugustin committed Jan 21, 2024
1 parent e21811e commit 45d8de7
Show file tree
Hide file tree
Showing 6 changed files with 265 additions and 157 deletions.
19 changes: 12 additions & 7 deletions tests/legacy/test_client_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1331,20 +1331,24 @@ def test_checking_origin_succeeds(self):

@with_server(origins=["http://localhost"])
def test_checking_origin_fails(self):
with self.assertRaisesRegex(
InvalidHandshake, "server rejected WebSocket connection: HTTP 403"
):
with self.assertRaises(InvalidHandshake) as raised:
self.start_client(origin="http://otherhost")
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 403",
)

@with_server(origins=["http://localhost"])
def test_checking_origins_fails_with_multiple_headers(self):
with self.assertRaisesRegex(
InvalidHandshake, "server rejected WebSocket connection: HTTP 400"
):
with self.assertRaises(InvalidHandshake) as raised:
self.start_client(
origin="http://localhost",
extra_headers=[("Origin", "http://otherhost")],
)
self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 400",
)

@with_server(origins=[None])
@with_client()
Expand Down Expand Up @@ -1574,8 +1578,9 @@ async def run_client():
pass # work around bug in coverage

with self.assertLogs("websockets", logging.INFO) as logs:
with self.assertRaisesRegex(Exception, "BOOM"):
with self.assertRaises(Exception) as raised:
self.loop.run_until_complete(run_client())
self.assertEqual(str(raised.exception), "BOOM")

# Iteration 1
self.assertEqual(
Expand Down
76 changes: 60 additions & 16 deletions tests/legacy/test_http.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,30 +31,48 @@ async def test_read_request(self):

async def test_read_request_empty(self):
self.stream.feed_eof()
with self.assertRaisesRegex(
EOFError, "connection closed while reading HTTP request line"
):
with self.assertRaises(EOFError) as raised:
await read_request(self.stream)
self.assertEqual(
str(raised.exception),
"connection closed while reading HTTP request line",
)

async def test_read_request_invalid_request_line(self):
self.stream.feed_data(b"GET /\r\n\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP request line: GET /"):
with self.assertRaises(ValueError) as raised:
await read_request(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP request line: GET /",
)

async def test_read_request_unsupported_method(self):
self.stream.feed_data(b"OPTIONS * HTTP/1.1\r\n\r\n")
with self.assertRaisesRegex(ValueError, "unsupported HTTP method: OPTIONS"):
with self.assertRaises(ValueError) as raised:
await read_request(self.stream)
self.assertEqual(
str(raised.exception),
"unsupported HTTP method: OPTIONS",
)

async def test_read_request_unsupported_version(self):
self.stream.feed_data(b"GET /chat HTTP/1.0\r\n\r\n")
with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"):
with self.assertRaises(ValueError) as raised:
await read_request(self.stream)
self.assertEqual(
str(raised.exception),
"unsupported HTTP version: HTTP/1.0",
)

async def test_read_request_invalid_header(self):
self.stream.feed_data(b"GET /chat HTTP/1.1\r\nOops\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"):
with self.assertRaises(ValueError) as raised:
await read_request(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP header line: Oops",
)

async def test_read_response(self):
# Example from the protocol overview in RFC 6455
Expand All @@ -73,40 +91,66 @@ async def test_read_response(self):

async def test_read_response_empty(self):
self.stream.feed_eof()
with self.assertRaisesRegex(
EOFError, "connection closed while reading HTTP status line"
):
with self.assertRaises(EOFError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"connection closed while reading HTTP status line",
)

async def test_read_request_invalid_status_line(self):
self.stream.feed_data(b"Hello!\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP status line: Hello!"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP status line: Hello!",
)

async def test_read_response_unsupported_version(self):
self.stream.feed_data(b"HTTP/1.0 400 Bad Request\r\n\r\n")
with self.assertRaisesRegex(ValueError, "unsupported HTTP version: HTTP/1.0"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"unsupported HTTP version: HTTP/1.0",
)

async def test_read_response_invalid_status(self):
self.stream.feed_data(b"HTTP/1.1 OMG WTF\r\n\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP status code: OMG"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP status code: OMG",
)

async def test_read_response_unsupported_status(self):
self.stream.feed_data(b"HTTP/1.1 007 My name is Bond\r\n\r\n")
with self.assertRaisesRegex(ValueError, "unsupported HTTP status code: 007"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"unsupported HTTP status code: 007",
)

async def test_read_response_invalid_reason(self):
self.stream.feed_data(b"HTTP/1.1 200 \x7f\r\n\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP reason phrase: \\x7f"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP reason phrase: \x7f",
)

async def test_read_response_invalid_header(self):
self.stream.feed_data(b"HTTP/1.1 500 Internal Server Error\r\nOops\r\n")
with self.assertRaisesRegex(ValueError, "invalid HTTP header line: Oops"):
with self.assertRaises(ValueError) as raised:
await read_response(self.stream)
self.assertEqual(
str(raised.exception),
"invalid HTTP header line: Oops",
)

async def test_header_name(self):
self.stream.feed_data(b"foo bar: baz qux\r\n\r\n")
Expand Down
90 changes: 50 additions & 40 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,13 @@ def remove_accept_header(self, request, response):
# The connection will be open for the server but failed for the client.
# Use a connection handler that exits immediately to avoid an exception.
with run_server(do_nothing, process_response=remove_accept_header) as server:
with self.assertRaisesRegex(
InvalidHandshake,
"missing Sec-WebSocket-Accept header",
):
with self.assertRaises(InvalidHandshake) as raised:
with run_client(server, close_timeout=MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"missing Sec-WebSocket-Accept header",
)

def test_tcp_connection_fails(self):
"""Client fails to connect to server."""
Expand Down Expand Up @@ -107,15 +108,16 @@ def stall_connection(self, request):
# Use a connection handler that exits immediately to avoid an exception.
with run_server(do_nothing, process_request=stall_connection) as server:
try:
with self.assertRaisesRegex(
TimeoutError,
"timed out during handshake",
):
with self.assertRaises(TimeoutError) as raised:
# While it shouldn't take 50ms to open a connection, this
# test becomes flaky in CI when setting a smaller timeout,
# even after increasing WEBSOCKETS_TESTS_TIMEOUT_FACTOR.
with run_client(server, open_timeout=5 * MS):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"timed out during handshake",
)
finally:
gate.set()

Expand All @@ -126,12 +128,13 @@ def close_connection(self, request):
self.close_socket()

with run_server(process_request=close_connection) as server:
with self.assertRaisesRegex(
ConnectionError,
"connection closed during handshake",
):
with self.assertRaises(ConnectionError) as raised:
with run_client(server):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"connection closed during handshake",
)


class SecureClientTests(unittest.TestCase):
Expand Down Expand Up @@ -167,24 +170,26 @@ def test_set_server_hostname_explicitly(self):
def test_reject_invalid_server_certificate(self):
"""Client rejects certificate where server certificate isn't trusted."""
with run_server(ssl=SERVER_CONTEXT) as server:
with self.assertRaisesRegex(
ssl.SSLCertVerificationError,
r"certificate verify failed: self[ -]signed certificate",
):
with self.assertRaises(ssl.SSLCertVerificationError) as raised:
# The test certificate isn't trusted system-wide.
with run_client(server, secure=True):
self.fail("did not raise")
self.assertIn(
"certificate verify failed: self signed certificate",
str(raised.exception).replace("-", " "),
)

def test_reject_invalid_server_hostname(self):
"""Client rejects certificate where server hostname doesn't match."""
with run_server(ssl=SERVER_CONTEXT) as server:
with self.assertRaisesRegex(
ssl.SSLCertVerificationError,
r"certificate verify failed: Hostname mismatch",
):
with self.assertRaises(ssl.SSLCertVerificationError) as raised:
# This hostname isn't included in the test certificate.
with run_client(server, ssl=CLIENT_CONTEXT, server_hostname="invalid"):
self.fail("did not raise")
self.assertIn(
"certificate verify failed: Hostname mismatch",
str(raised.exception),
)


@unittest.skipUnless(hasattr(socket, "AF_UNIX"), "this test requires Unix sockets")
Expand Down Expand Up @@ -231,45 +236,50 @@ def test_set_server_hostname(self):
class ClientUsageErrorsTests(unittest.TestCase):
def test_ssl_without_secure_uri(self):
"""Client rejects ssl when URI isn't secure."""
with self.assertRaisesRegex(
TypeError,
"ssl argument is incompatible with a ws:// URI",
):
with self.assertRaises(TypeError) as raised:
connect("ws://localhost/", ssl=CLIENT_CONTEXT)
self.assertEqual(
str(raised.exception),
"ssl argument is incompatible with a ws:// URI",
)

def test_unix_without_path_or_sock(self):
"""Unix client requires path when sock isn't provided."""
with self.assertRaisesRegex(
TypeError,
"missing path argument",
):
with self.assertRaises(TypeError) as raised:
unix_connect()
self.assertEqual(
str(raised.exception),
"missing path argument",
)

def test_unix_with_path_and_sock(self):
"""Unix client rejects path when sock is provided."""
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
self.addCleanup(sock.close)
with self.assertRaisesRegex(
TypeError,
"path and sock arguments are incompatible",
):
with self.assertRaises(TypeError) as raised:
unix_connect(path="/", sock=sock)
self.assertEqual(
str(raised.exception),
"path and sock arguments are incompatible",
)

def test_invalid_subprotocol(self):
"""Client rejects single value of subprotocols."""
with self.assertRaisesRegex(
TypeError,
"subprotocols must be a list",
):
with self.assertRaises(TypeError) as raised:
connect("ws://localhost/", subprotocols="chat")
self.assertEqual(
str(raised.exception),
"subprotocols must be a list, not a str",
)

def test_unsupported_compression(self):
"""Client rejects incorrect value of compression."""
with self.assertRaisesRegex(
ValueError,
"unsupported compression: False",
):
with self.assertRaises(ValueError) as raised:
connect("ws://localhost/", compression=False)
self.assertEqual(
str(raised.exception),
"unsupported compression: False",
)


class BackwardsCompatibilityTests(DeprecationTestCase):
Expand Down
Loading

0 comments on commit 45d8de7

Please sign in to comment.