Skip to content

Commit

Permalink
Reintroduce InvalidMessage.
Browse files Browse the repository at this point in the history
This improves compatibility with the legacy implementation and clarifies
error reporting.

Fix #1548.
  • Loading branch information
aaugustin committed Nov 17, 2024
1 parent d8891a1 commit 59d4dcf
Show file tree
Hide file tree
Showing 18 changed files with 136 additions and 40 deletions.
8 changes: 8 additions & 0 deletions docs/project/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ notice.

*In development*

Bug fixes
.........

* Wrapped errors when reading the opening handshake request or response in
:exc:`~exceptions.InvalidMessage` so that :func:`~asyncio.client.connect`
raises :exc:`~exceptions.InvalidHandshake` or a subclass when the opening
handshake fails.

.. _14.1:

14.1
Expand Down
4 changes: 2 additions & 2 deletions docs/reference/exceptions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ also reported by :func:`~websockets.asyncio.server.serve` in logs.

.. autoexception:: InvalidHandshake

.. autoexception:: InvalidMessage

.. autoexception:: SecurityError

.. autoexception:: InvalidStatus
Expand Down Expand Up @@ -74,8 +76,6 @@ Legacy exceptions

These exceptions are only used by the legacy :mod:`asyncio` implementation.

.. autoexception:: InvalidMessage

.. autoexception:: InvalidStatusCode

.. autoexception:: AbortHandshake
Expand Down
4 changes: 3 additions & 1 deletion src/websockets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidMessage",
"InvalidOrigin",
"InvalidParameterName",
"InvalidParameterValue",
Expand Down Expand Up @@ -71,6 +72,7 @@
InvalidHeader,
InvalidHeaderFormat,
InvalidHeaderValue,
InvalidMessage,
InvalidOrigin,
InvalidParameterName,
InvalidParameterValue,
Expand Down Expand Up @@ -122,6 +124,7 @@
"InvalidHeader": ".exceptions",
"InvalidHeaderFormat": ".exceptions",
"InvalidHeaderValue": ".exceptions",
"InvalidMessage": ".exceptions",
"InvalidOrigin": ".exceptions",
"InvalidParameterName": ".exceptions",
"InvalidParameterValue": ".exceptions",
Expand Down Expand Up @@ -159,7 +162,6 @@
"WebSocketClientProtocol": ".legacy.client",
# .legacy.exceptions
"AbortHandshake": ".legacy.exceptions",
"InvalidMessage": ".legacy.exceptions",
"InvalidStatusCode": ".legacy.exceptions",
"RedirectHandshake": ".legacy.exceptions",
"WebSocketProtocolError": ".legacy.exceptions",
Expand Down
6 changes: 4 additions & 2 deletions src/websockets/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ..client import ClientProtocol, backoff
from ..datastructures import HeadersLike
from ..exceptions import InvalidStatus, SecurityError
from ..exceptions import InvalidMessage, InvalidStatus, SecurityError
from ..extensions.base import ClientExtensionFactory
from ..extensions.permessage_deflate import enable_client_permessage_deflate
from ..headers import validate_subprotocols
Expand Down Expand Up @@ -147,7 +147,9 @@ def process_exception(exc: Exception) -> Exception | None:
That exception will be raised, breaking out of the retry loop.
"""
if isinstance(exc, (EOFError, OSError, asyncio.TimeoutError)):
if isinstance(exc, (OSError, asyncio.TimeoutError)):
return None
if isinstance(exc, InvalidMessage) and isinstance(exc.__cause__, EOFError):
return None
if isinstance(exc, InvalidStatus) and exc.response.status_code in [
500, # Internal Server Error
Expand Down
6 changes: 5 additions & 1 deletion src/websockets/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
InvalidHandshake,
InvalidHeader,
InvalidHeaderValue,
InvalidMessage,
InvalidStatus,
InvalidUpgrade,
NegotiationError,
Expand Down Expand Up @@ -318,7 +319,10 @@ def parse(self) -> Generator[None]:
self.reader.read_to_eof,
)
except Exception as exc:
self.handshake_exc = exc
self.handshake_exc = InvalidMessage(
"did not receive a valid HTTP response"
)
self.handshake_exc.__cause__ = exc
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
Expand Down
11 changes: 9 additions & 2 deletions src/websockets/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
* :exc:`InvalidURI`
* :exc:`InvalidHandshake`
* :exc:`SecurityError`
* :exc:`InvalidMessage` (legacy)
* :exc:`InvalidMessage`
* :exc:`InvalidStatus`
* :exc:`InvalidStatusCode` (legacy)
* :exc:`InvalidHeader`
Expand Down Expand Up @@ -48,6 +48,7 @@
"InvalidHeader",
"InvalidHeaderFormat",
"InvalidHeaderValue",
"InvalidMessage",
"InvalidOrigin",
"InvalidUpgrade",
"NegotiationError",
Expand Down Expand Up @@ -185,6 +186,13 @@ class SecurityError(InvalidHandshake):
"""


class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""


class InvalidStatus(InvalidHandshake):
"""
Raised when a handshake response rejects the WebSocket upgrade.
Expand Down Expand Up @@ -410,7 +418,6 @@ class ConcurrencyError(WebSocketException, RuntimeError):
deprecated_aliases={
# deprecated in 14.0 - 2024-11-09
"AbortHandshake": ".legacy.exceptions",
"InvalidMessage": ".legacy.exceptions",
"InvalidStatusCode": ".legacy.exceptions",
"RedirectHandshake": ".legacy.exceptions",
"WebSocketProtocolError": ".legacy.exceptions",
Expand Down
3 changes: 2 additions & 1 deletion src/websockets/legacy/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..exceptions import (
InvalidHeader,
InvalidHeaderValue,
InvalidMessage,
NegotiationError,
SecurityError,
)
Expand All @@ -34,7 +35,7 @@
from ..http11 import USER_AGENT
from ..typing import ExtensionHeader, LoggerLike, Origin, Subprotocol
from ..uri import WebSocketURI, parse_uri
from .exceptions import InvalidMessage, InvalidStatusCode, RedirectHandshake
from .exceptions import InvalidStatusCode, RedirectHandshake
from .handshake import build_request, check_response
from .http import read_response
from .protocol import WebSocketCommonProtocol
Expand Down
9 changes: 2 additions & 7 deletions src/websockets/legacy/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,13 @@
from .. import datastructures
from ..exceptions import (
InvalidHandshake,
# InvalidMessage was incorrectly moved here in versions 14.0 and 14.1.
InvalidMessage, # noqa: F401
ProtocolError as WebSocketProtocolError, # noqa: F401
)
from ..typing import StatusLike


class InvalidMessage(InvalidHandshake):
"""
Raised when a handshake request or response is malformed.
"""


class InvalidStatusCode(InvalidHandshake):
"""
Raised when a handshake response status code is invalid.
Expand Down
3 changes: 2 additions & 1 deletion src/websockets/legacy/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from ..exceptions import (
InvalidHandshake,
InvalidHeader,
InvalidMessage,
InvalidOrigin,
InvalidUpgrade,
NegotiationError,
Expand All @@ -32,7 +33,7 @@
from ..http11 import SERVER
from ..protocol import State
from ..typing import ExtensionHeader, LoggerLike, Origin, StatusLike, Subprotocol
from .exceptions import AbortHandshake, InvalidMessage
from .exceptions import AbortHandshake
from .handshake import build_response, check_request
from .http import read_request
from .protocol import WebSocketCommonProtocol, broadcast
Expand Down
6 changes: 5 additions & 1 deletion src/websockets/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
InvalidHandshake,
InvalidHeader,
InvalidHeaderValue,
InvalidMessage,
InvalidOrigin,
InvalidUpgrade,
NegotiationError,
Expand Down Expand Up @@ -552,7 +553,10 @@ def parse(self) -> Generator[None]:
self.reader.read_line,
)
except Exception as exc:
self.handshake_exc = exc
self.handshake_exc = InvalidMessage(
"did not receive a valid HTTP request"
)
self.handshake_exc.__cause__ = exc
self.send_eof()
self.parser = self.discard()
next(self.parser) # start coroutine
Expand Down
27 changes: 20 additions & 7 deletions tests/asyncio/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from websockets.client import backoff
from websockets.exceptions import (
InvalidHandshake,
InvalidMessage,
InvalidStatus,
InvalidURI,
SecurityError,
Expand Down Expand Up @@ -151,30 +152,32 @@ async def test_reconnect(self):
iterations = 0
successful = 0

def process_request(connection, request):
async def process_request(connection, request):
nonlocal iterations
iterations += 1
# Retriable errors
if iterations == 1:
connection.transport.close()
await asyncio.sleep(3 * MS)
elif iterations == 2:
connection.transport.close()
elif iterations == 3:
return connection.respond(http.HTTPStatus.SERVICE_UNAVAILABLE, "🚒")
# Fatal error
elif iterations == 5:
elif iterations == 6:
return connection.respond(http.HTTPStatus.PAYMENT_REQUIRED, "💸")

async with serve(*args, process_request=process_request) as server:
with self.assertRaises(InvalidStatus) as raised:
async with short_backoff_delay():
async for client in connect(get_uri(server)):
async for client in connect(get_uri(server), open_timeout=3 * MS):
self.assertEqual(client.protocol.state.name, "OPEN")
successful += 1

self.assertEqual(
str(raised.exception),
"server rejected WebSocket connection: HTTP 402",
)
self.assertEqual(iterations, 5)
self.assertEqual(iterations, 6)
self.assertEqual(successful, 2)

async def test_reconnect_with_custom_process_exception(self):
Expand Down Expand Up @@ -393,11 +396,16 @@ def close_connection(self, request):
self.close_transport()

async with serve(*args, process_request=close_connection) as server:
with self.assertRaises(EOFError) as raised:
with self.assertRaises(InvalidMessage) as raised:
async with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"did not receive a valid HTTP response",
)
self.assertIsInstance(raised.exception.__cause__, EOFError)
self.assertEqual(
str(raised.exception.__cause__),
"connection closed while reading HTTP status line",
)

Expand Down Expand Up @@ -443,11 +451,16 @@ async def junk(reader, writer):
server = await asyncio.start_server(junk, "localhost", 0)
host, port = get_host_port(server)
async with server:
with self.assertRaises(ValueError) as raised:
with self.assertRaises(InvalidMessage) as raised:
async with connect(f"ws://{host}:{port}"):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"did not receive a valid HTTP response",
)
self.assertIsInstance(raised.exception.__cause__, ValueError)
self.assertEqual(
str(raised.exception.__cause__),
"unsupported protocol; expected HTTP/1.1: "
"220 smtp.invalid ESMTP Postfix",
)
Expand Down
4 changes: 4 additions & 0 deletions tests/asyncio/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,10 @@ async def test_junk_handshake(self):
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["did not receive a valid HTTP request"],
)
self.assertEqual(
[str(record.exc_info[1].__cause__) for record in logs.records],
["invalid HTTP request line: HELO relay.invalid"],
)

Expand Down
4 changes: 0 additions & 4 deletions tests/legacy/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,6 @@
class ExceptionsTests(unittest.TestCase):
def test_str(self):
for exception, exception_str in [
(
InvalidMessage("malformed HTTP message"),
"malformed HTTP message",
),
(
InvalidStatusCode(403, Headers()),
"server rejected WebSocket connection: HTTP 403",
Expand Down
21 changes: 18 additions & 3 deletions tests/sync/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
import time
import unittest

from websockets.exceptions import InvalidHandshake, InvalidStatus, InvalidURI
from websockets.exceptions import (
InvalidHandshake,
InvalidMessage,
InvalidStatus,
InvalidURI,
)
from websockets.extensions.permessage_deflate import PerMessageDeflate
from websockets.sync.client import *

Expand Down Expand Up @@ -149,11 +154,16 @@ def close_connection(self, request):
self.close_socket()

with run_server(process_request=close_connection) as server:
with self.assertRaises(EOFError) as raised:
with self.assertRaises(InvalidMessage) as raised:
with connect(get_uri(server)):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"did not receive a valid HTTP response",
)
self.assertIsInstance(raised.exception.__cause__, EOFError)
self.assertEqual(
str(raised.exception.__cause__),
"connection closed while reading HTTP status line",
)

Expand Down Expand Up @@ -203,11 +213,16 @@ def handle(self):
thread = threading.Thread(target=server.serve_forever, args=(MS,))
thread.start()
try:
with self.assertRaises(ValueError) as raised:
with self.assertRaises(InvalidMessage) as raised:
with connect(f"ws://{host}:{port}"):
self.fail("did not raise")
self.assertEqual(
str(raised.exception),
"did not receive a valid HTTP response",
)
self.assertIsInstance(raised.exception.__cause__, ValueError)
self.assertEqual(
str(raised.exception.__cause__),
"unsupported protocol; expected HTTP/1.1: "
"220 smtp.invalid ESMTP Postfix",
)
Expand Down
4 changes: 4 additions & 0 deletions tests/sync/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,10 @@ def test_junk_handshake(self):
)
self.assertEqual(
[str(record.exc_info[1]) for record in logs.records],
["did not receive a valid HTTP request"],
)
self.assertEqual(
[str(record.exc_info[1].__cause__) for record in logs.records],
["invalid HTTP request line: HELO relay.invalid"],
)

Expand Down
Loading

0 comments on commit 59d4dcf

Please sign in to comment.