diff --git a/CHANGELOG.md b/CHANGELOG.md index 13bbfcdb79..865023a09f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/). ## [UNRELEASED] +### Fixed + +* Explicitly close all async generators to ensure predictable behavior + ### Removed * Drop support for Python 3.8 diff --git a/httpx/_client.py b/httpx/_client.py index 13cd933673..df46a59c91 100644 --- a/httpx/_client.py +++ b/httpx/_client.py @@ -6,6 +6,7 @@ import time import typing import warnings +from collections.abc import AsyncGenerator from contextlib import asynccontextmanager, contextmanager from types import TracebackType @@ -46,7 +47,7 @@ TimeoutTypes, ) from ._urls import URL, QueryParams -from ._utils import URLPattern, get_environment_proxies +from ._utils import URLPattern, get_environment_proxies, safe_async_iterate if typing.TYPE_CHECKING: import ssl # pragma: no cover @@ -172,9 +173,10 @@ def __init__( self._response = response self._start = start - async def __aiter__(self) -> typing.AsyncIterator[bytes]: - async for chunk in self._stream: - yield chunk + async def __aiter__(self) -> AsyncGenerator[bytes]: + async with safe_async_iterate(self._stream) as iterator: + async for chunk in iterator: + yield chunk async def aclose(self) -> None: elapsed = time.perf_counter() - self._start diff --git a/httpx/_content.py b/httpx/_content.py index 6f479a0885..f18b371f6e 100644 --- a/httpx/_content.py +++ b/httpx/_content.py @@ -2,6 +2,7 @@ import inspect import warnings +from collections.abc import AsyncGenerator from json import dumps as json_dumps from typing import ( Any, @@ -10,6 +11,7 @@ Iterable, Iterator, Mapping, + NoReturn, ) from urllib.parse import urlencode @@ -23,7 +25,7 @@ ResponseContent, SyncByteStream, ) -from ._utils import peek_filelike_length, primitive_value_to_str +from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate __all__ = ["ByteStream"] @@ -35,7 +37,7 @@ def __init__(self, stream: bytes) -> None: def __iter__(self) -> Iterator[bytes]: yield self._stream - async def __aiter__(self) -> AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: yield self._stream @@ -85,8 +87,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]: chunk = await self._stream.aread(self.CHUNK_SIZE) else: # Otherwise iterate. - async for part in self._stream: - yield part + async with safe_async_iterate(self._stream) as iterator: + async for part in iterator: + yield part class UnattachedStream(AsyncByteStream, SyncByteStream): @@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream): def __iter__(self) -> Iterator[bytes]: raise StreamClosed() - async def __aiter__(self) -> AsyncIterator[bytes]: + def __aiter__(self) -> NoReturn: raise StreamClosed() - yield b"" # pragma: no cover def encode_content( diff --git a/httpx/_models.py b/httpx/_models.py index 2cc86321a4..4b24ee5810 100644 --- a/httpx/_models.py +++ b/httpx/_models.py @@ -7,7 +7,7 @@ import re import typing import urllib.request -from collections.abc import Mapping +from collections.abc import AsyncGenerator, Mapping from http.cookiejar import Cookie, CookieJar from ._content import ByteStream, UnattachedStream, encode_request, encode_response @@ -46,7 +46,7 @@ SyncByteStream, ) from ._urls import URL -from ._utils import to_bytes_or_str, to_str +from ._utils import safe_async_iterate, to_bytes_or_str, to_str __all__ = ["Cookies", "Headers", "Request", "Response"] @@ -485,7 +485,9 @@ async def aread(self) -> bytes: """ if not hasattr(self, "_content"): assert isinstance(self.stream, typing.AsyncIterable) - self._content = b"".join([part async for part in self.stream]) + async with safe_async_iterate(self.stream) as iterator: + self._content = b"".join([part async for part in iterator]) + if not isinstance(self.stream, ByteStream): # If a streaming request has been read entirely into memory, then # we can replace the stream with a raw bytes implementation, @@ -976,12 +978,11 @@ async def aread(self) -> bytes: Read and return the response content. """ if not hasattr(self, "_content"): - self._content = b"".join([part async for part in self.aiter_bytes()]) + async with safe_async_iterate(self.aiter_bytes()) as iterator: + self._content = b"".join([part async for part in iterator]) return self._content - async def aiter_bytes( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[bytes]: + async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]: """ A byte-iterator over the decoded response content. This allows us to handle gzip, deflate, brotli, and zstd encoded responses. @@ -994,19 +995,19 @@ async def aiter_bytes( decoder = self._get_content_decoder() chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_bytes in self.aiter_raw(): - decoded = decoder.decode(raw_bytes) - for chunk in chunker.decode(decoded): - yield chunk + async with safe_async_iterate(self.aiter_raw()) as iterator: + async for raw_bytes in iterator: + decoded = decoder.decode(raw_bytes) + for chunk in chunker.decode(decoded): + yield chunk + decoded = decoder.flush() for chunk in chunker.decode(decoded): yield chunk # pragma: no cover for chunk in chunker.flush(): yield chunk - async def aiter_text( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[str]: + async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]: """ A str-iterator over the decoded response content that handles both gzip, deflate, etc but also detects the content's @@ -1015,28 +1016,28 @@ async def aiter_text( decoder = TextDecoder(encoding=self.encoding or "utf-8") chunker = TextChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for byte_content in self.aiter_bytes(): - text_content = decoder.decode(byte_content) - for chunk in chunker.decode(text_content): - yield chunk + async with safe_async_iterate(self.aiter_bytes()) as iterator: + async for byte_content in iterator: + text_content = decoder.decode(byte_content) + for chunk in chunker.decode(text_content): + yield chunk text_content = decoder.flush() for chunk in chunker.decode(text_content): yield chunk # pragma: no cover for chunk in chunker.flush(): yield chunk - async def aiter_lines(self) -> typing.AsyncIterator[str]: + async def aiter_lines(self) -> AsyncGenerator[str]: decoder = LineDecoder() with request_context(request=self._request): - async for text in self.aiter_text(): - for line in decoder.decode(text): - yield line + async with safe_async_iterate(self.aiter_text()) as iterator: + async for text in iterator: + for line in decoder.decode(text): + yield line for line in decoder.flush(): yield line - async def aiter_raw( - self, chunk_size: int | None = None - ) -> typing.AsyncIterator[bytes]: + async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]: """ A byte-iterator over the raw response content. """ @@ -1052,10 +1053,11 @@ async def aiter_raw( chunker = ByteChunker(chunk_size=chunk_size) with request_context(request=self._request): - async for raw_stream_bytes in self.stream: - self._num_bytes_downloaded += len(raw_stream_bytes) - for chunk in chunker.decode(raw_stream_bytes): - yield chunk + async with safe_async_iterate(self.stream) as iterator: + async for raw_stream_bytes in iterator: + self._num_bytes_downloaded += len(raw_stream_bytes) + for chunk in chunker.decode(raw_stream_bytes): + yield chunk for chunk in chunker.flush(): yield chunk diff --git a/httpx/_multipart.py b/httpx/_multipart.py index b4761af9b2..1e5d522ba0 100644 --- a/httpx/_multipart.py +++ b/httpx/_multipart.py @@ -5,6 +5,7 @@ import os import re import typing +from collections.abc import AsyncGenerator from pathlib import Path from ._types import ( @@ -295,6 +296,6 @@ def __iter__(self) -> typing.Iterator[bytes]: for chunk in self.iter_chunks(): yield chunk - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: for chunk in self.iter_chunks(): yield chunk diff --git a/httpx/_transports/asgi.py b/httpx/_transports/asgi.py index 2bc4efae0e..910e564469 100644 --- a/httpx/_transports/asgi.py +++ b/httpx/_transports/asgi.py @@ -1,6 +1,7 @@ from __future__ import annotations import typing +from collections.abc import AsyncGenerator from .._models import Request, Response from .._types import AsyncByteStream @@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream): def __init__(self, body: list[bytes]) -> None: self._body = body - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: yield b"".join(self._body) diff --git a/httpx/_transports/default.py b/httpx/_transports/default.py index fc8c70970a..6242a5cfdc 100644 --- a/httpx/_transports/default.py +++ b/httpx/_transports/default.py @@ -28,6 +28,7 @@ import contextlib import typing +from collections.abc import AsyncGenerator from types import TracebackType if typing.TYPE_CHECKING: @@ -55,6 +56,7 @@ from .._models import Request, Response from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream from .._urls import URL +from .._utils import safe_async_iterate from .base import AsyncBaseTransport, BaseTransport T = typing.TypeVar("T", bound="HTTPTransport") @@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream): def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None: self._httpcore_stream = httpcore_stream - async def __aiter__(self) -> typing.AsyncIterator[bytes]: + async def __aiter__(self) -> AsyncGenerator[bytes]: with map_httpcore_exceptions(): - async for part in self._httpcore_stream: - yield part + async with safe_async_iterate(self._httpcore_stream) as iterator: + async for part in iterator: + yield part async def aclose(self) -> None: if hasattr(self._httpcore_stream, "aclose"): diff --git a/httpx/_types.py b/httpx/_types.py index 704dfdffc8..44dac6cff9 100644 --- a/httpx/_types.py +++ b/httpx/_types.py @@ -94,7 +94,6 @@ def __iter__(self) -> Iterator[bytes]: raise NotImplementedError( "The '__iter__' method must be implemented." ) # pragma: no cover - yield b"" # pragma: no cover def close(self) -> None: """ @@ -104,11 +103,10 @@ def close(self) -> None: class AsyncByteStream: - async def __aiter__(self) -> AsyncIterator[bytes]: + def __aiter__(self) -> AsyncIterator[bytes]: raise NotImplementedError( "The '__aiter__' method must be implemented." ) # pragma: no cover - yield b"" # pragma: no cover async def aclose(self) -> None: pass diff --git a/httpx/_utils.py b/httpx/_utils.py index 7fe827da4d..b6acaba15c 100644 --- a/httpx/_utils.py +++ b/httpx/_utils.py @@ -4,6 +4,9 @@ import os import re import typing +from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from inspect import isasyncgen from urllib.request import getproxies from ._types import PrimitiveData @@ -11,6 +14,8 @@ if typing.TYPE_CHECKING: # pragma: no cover from ._urls import URL +T = typing.TypeVar("T") + def primitive_value_to_str(value: PrimitiveData) -> str: """ @@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool: except Exception: return False return True + + +@asynccontextmanager +async def safe_async_iterate( + iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], / +) -> AsyncGenerator[AsyncIterator[T]]: + iterator = ( + iterable_or_iterator + if isinstance(iterable_or_iterator, AsyncIterator) + else iterable_or_iterator.__aiter__() + ) + try: + yield iterator + finally: + if isasyncgen(iterator): + await iterator.aclose()