Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

collect errors more reliably from websocket test client #2814

Merged
merged 17 commits into from
Dec 29, 2024
Merged
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 37 additions & 38 deletions starlette/testclient.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import contextlib
import enum
import inspect
import io
import json
Expand All @@ -9,7 +10,6 @@
import sys
import typing
from concurrent.futures import Future
from functools import cached_property
from types import GeneratorType
from urllib.parse import unquote, urljoin

Expand Down Expand Up @@ -85,6 +85,14 @@ class WebSocketDenialResponse( # type: ignore[misc]
"""


class _Eof(enum.Enum):
EOF = enum.auto()


EOF: typing.Final = _Eof.EOF
Eof = typing.Literal[_Eof.EOF]


class WebSocketTestSession:
def __init__(
self,
Expand All @@ -97,63 +105,53 @@ def __init__(
self.accepted_subprotocol = None
self.portal_factory = portal_factory
self._receive_queue: queue.Queue[Message] = queue.Queue()
self._send_queue: queue.Queue[Message | BaseException] = queue.Queue()
self._send_queue: queue.Queue[Message | Eof | BaseException] = queue.Queue()
self.extra_headers = None
self.should_close: anyio.Event

def __enter__(self) -> WebSocketTestSession:
self.exit_stack = contextlib.ExitStack()
self.portal = self.exit_stack.enter_context(self.portal_factory())
with contextlib.ExitStack() as stack:
self.portal = portal = stack.enter_context(self.portal_factory())

try:
_: Future[None] = self.portal.start_task_soon(self._run)
fut, cs = self.portal.start_task(self._run)
self.send({"type": "websocket.connect"})
message = self.receive()
self._raise_on_close(message)
except Exception:
self.exit_stack.close()
raise
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
return self

@cached_property
def should_close(self) -> anyio.Event:
return anyio.Event()

async def _notify_close(self) -> None:
self.should_close.set()
self.accepted_subprotocol = message.get("subprotocol", None)
self.extra_headers = message.get("headers", None)
stack.callback(fut.result)
stack.callback(portal.call, cs.cancel)
stack.callback(self.close, 1000)
self.exit_stack = stack.pop_all()
return self

def __exit__(self, *args: typing.Any) -> None:
try:
self.close(1000)
finally:
self.portal.start_task_soon(self._notify_close)
self.exit_stack.close()
while not self._send_queue.empty():
self.exit_stack.close()

while True:
message = self._send_queue.get()
if message is EOF:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there an analogous to EOF from the standard library on 3.13?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It raises an exception

break
if isinstance(message, BaseException):
raise message
raise message # pragma: no cover (defensive, should be impossible)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why it should be impossible?

The except BaseException as exc below doesn't have a pragma: no cover, so I assume it's being hit?

Copy link
Member Author

@graingert graingert Dec 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is impossible because the exit stack will raise the exception out of fut.result() and so the queue won't be consumed.

This is only possible to be hit if ws.receive() is interrupted (eg with a KI) while waiting for an exception or message to be placed on the queue.

I'm currently sketching out another slight refactor that uses MemoryObjectStreams here instead that should clean this up a bit


async def _run(self) -> None:
async def _run(self, *, task_status: anyio.abc.TaskStatus[anyio.CancelScope]) -> None:
"""
The sub-thread in which the websocket session runs.
"""

async def run_app(tg: anyio.abc.TaskGroup) -> None:
try:
try:
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except anyio.get_cancelled_exc_class():
...
self.should_close = anyio.Event()
graingert marked this conversation as resolved.
Show resolved Hide resolved
with anyio.CancelScope() as cs:
task_status.started(cs)
await self.app(self.scope, self._asgi_receive, self._asgi_send)
except BaseException as exc:
self._send_queue.put(exc)
raise
finally:
tg.cancel_scope.cancel()

async with anyio.create_task_group() as tg:
tg.start_soon(run_app, tg)
await self.should_close.wait()
tg.cancel_scope.cancel()
self.should_close.set()
finally:
self._send_queue.put(EOF) # TODO: use self._send_queue.shutdown() on 3.13+
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use the if sys.version_info here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's better to stick to the EOF approach until someone puts up a Queue.shutdown backport, or we can use a MemoryObjectStream with portal


async def _asgi_receive(self) -> Message:
while self._receive_queue.empty():
Expand Down Expand Up @@ -202,6 +200,7 @@ def close(self, code: int = 1000, reason: str | None = None) -> None:

def receive(self) -> Message:
message = self._send_queue.get()
assert message is not EOF
if isinstance(message, BaseException):
raise message
return message
Expand Down
Loading