diff --git a/requirements.txt b/requirements.txt index 65e240832..4794d120e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -12,8 +12,8 @@ types-contextvars==2.4.7.2 types-PyYAML==6.0.12.10 types-dataclasses==0.6.6 pytest==7.4.0 -trio==0.22.1 -anyio@git+https://github.com/agronholm/anyio.git +trio==0.21.0 +anyio==3.7.1 # Documentation mkdocs==1.4.3 diff --git a/starlette/_utils.py b/starlette/_utils.py index f06dd557c..26854f3d4 100644 --- a/starlette/_utils.py +++ b/starlette/_utils.py @@ -2,12 +2,20 @@ import functools import sys import typing +from contextlib import contextmanager if sys.version_info >= (3, 10): # pragma: no cover from typing import TypeGuard else: # pragma: no cover from typing_extensions import TypeGuard +has_exceptiongroups = True +if sys.version_info < (3, 11): # pragma: no cover + try: + from exceptiongroup import BaseExceptionGroup + except ImportError: + has_exceptiongroups = False + T = typing.TypeVar("T") AwaitableCallable = typing.Callable[..., typing.Awaitable[T]] @@ -66,3 +74,15 @@ async def __aenter__(self) -> SupportsAsyncCloseType: async def __aexit__(self, *args: typing.Any) -> typing.Union[None, bool]: await self.entered.close() return None + + +@contextmanager +def collapse_excgroups() -> typing.Generator[None, None, None]: + try: + yield + except BaseException as exc: + if has_exceptiongroups: + while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: + exc = exc.exceptions[0] # pragma: no cover + + raise exc diff --git a/starlette/middleware/base.py b/starlette/middleware/base.py index ee99ee6cb..ad3ffcfee 100644 --- a/starlette/middleware/base.py +++ b/starlette/middleware/base.py @@ -1,18 +1,14 @@ -import sys import typing -from contextlib import contextmanager import anyio from anyio.abc import ObjectReceiveStream, ObjectSendStream +from starlette._utils import collapse_excgroups from starlette.background import BackgroundTask from starlette.requests import ClientDisconnect, Request from starlette.responses import ContentStream, Response, StreamingResponse from starlette.types import ASGIApp, Message, Receive, Scope, Send -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import BaseExceptionGroup - RequestResponseEndpoint = typing.Callable[[Request], typing.Awaitable[Response]] DispatchFunction = typing.Callable[ [Request, RequestResponseEndpoint], typing.Awaitable[Response] @@ -20,17 +16,6 @@ T = typing.TypeVar("T") -@contextmanager -def _convert_excgroups() -> typing.Generator[None, None, None]: - try: - yield - except BaseException as exc: - while isinstance(exc, BaseExceptionGroup) and len(exc.exceptions) == 1: - exc = exc.exceptions[0] - - raise exc - - class _CachedRequest(Request): """ If the user calls Request.body() from their dispatch function @@ -201,7 +186,7 @@ async def body_stream() -> typing.AsyncGenerator[bytes, None]: response.raw_headers = message["headers"] return response - with _convert_excgroups(): + with collapse_excgroups(): async with anyio.create_task_group() as task_group: response = await self.dispatch_func(request, call_next) await response(scope, wrapped_receive, send) diff --git a/tests/middleware/test_wsgi.py b/tests/middleware/test_wsgi.py index ad3975403..fe527e373 100644 --- a/tests/middleware/test_wsgi.py +++ b/tests/middleware/test_wsgi.py @@ -2,11 +2,9 @@ import pytest +from starlette._utils import collapse_excgroups from starlette.middleware.wsgi import WSGIMiddleware, build_environ -if sys.version_info < (3, 11): # pragma: no cover - from exceptiongroup import ExceptionGroup - def hello_world(environ, start_response): status = "200 OK" @@ -69,12 +67,9 @@ def test_wsgi_exception(test_client_factory): # The HTTP protocol implementations would catch this error and return 500. app = WSGIMiddleware(raise_exception) client = test_client_factory(app) - with pytest.raises(ExceptionGroup) as exc: + with pytest.raises(RuntimeError), collapse_excgroups(): client.get("/") - assert len(exc.value.exceptions) == 1 - assert isinstance(exc.value.exceptions[0], RuntimeError) - def test_wsgi_exc_info(test_client_factory): # Note that we're testing the WSGI app directly here.