Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,10 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/).

## [UNRELEASED]

### Fixed

* Explicitly close all async generators to ensure predictable behavior
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍🏼 Neat.


### Removed

* Drop support for Python 3.8
Expand Down
10 changes: 6 additions & 4 deletions httpx/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import time
import typing
import warnings
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager, contextmanager
from types import TracebackType

Expand Down Expand Up @@ -46,7 +47,7 @@
TimeoutTypes,
)
from ._urls import URL, QueryParams
from ._utils import URLPattern, get_environment_proxies
from ._utils import URLPattern, get_environment_proxies, safe_async_iterate

if typing.TYPE_CHECKING:
import ssl # pragma: no cover
Expand Down Expand Up @@ -172,9 +173,10 @@ def __init__(
self._response = response
self._start = start

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async for chunk in self._stream:
yield chunk
async def __aiter__(self) -> AsyncGenerator[bytes]:
async with safe_async_iterate(self._stream) as iterator:
async for chunk in iterator:
yield chunk

async def aclose(self) -> None:
elapsed = time.perf_counter() - self._start
Expand Down
14 changes: 8 additions & 6 deletions httpx/_content.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import inspect
import warnings
from collections.abc import AsyncGenerator
from json import dumps as json_dumps
from typing import (
Any,
Expand All @@ -10,6 +11,7 @@
Iterable,
Iterator,
Mapping,
NoReturn,
)
from urllib.parse import urlencode

Expand All @@ -23,7 +25,7 @@
ResponseContent,
SyncByteStream,
)
from ._utils import peek_filelike_length, primitive_value_to_str
from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate

__all__ = ["ByteStream"]

Expand All @@ -35,7 +37,7 @@ def __init__(self, stream: bytes) -> None:
def __iter__(self) -> Iterator[bytes]:
yield self._stream

async def __aiter__(self) -> AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
yield self._stream


Expand Down Expand Up @@ -85,8 +87,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]:
chunk = await self._stream.aread(self.CHUNK_SIZE)
else:
# Otherwise iterate.
async for part in self._stream:
yield part
async with safe_async_iterate(self._stream) as iterator:
async for part in iterator:
yield part


class UnattachedStream(AsyncByteStream, SyncByteStream):
Expand All @@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream):
def __iter__(self) -> Iterator[bytes]:
raise StreamClosed()

async def __aiter__(self) -> AsyncIterator[bytes]:
def __aiter__(self) -> NoReturn:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is the async removed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not really necessary here. __aiter__() can return any arbitrary object that supports the AsyncIterator protocol, and the implementation here just raises NotImplementedError.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tomchristie this is not an important part of this PR so if you object, I can just revert it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On another note, is either one of you at EuroPython?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On another note, is either one of you at EuroPython?

I'm not. 😞

raise StreamClosed()
yield b"" # pragma: no cover


def encode_content(
Expand Down
60 changes: 31 additions & 29 deletions httpx/_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import re
import typing
import urllib.request
from collections.abc import Mapping
from collections.abc import AsyncGenerator, Mapping
from http.cookiejar import Cookie, CookieJar

from ._content import ByteStream, UnattachedStream, encode_request, encode_response
Expand Down Expand Up @@ -46,7 +46,7 @@
SyncByteStream,
)
from ._urls import URL
from ._utils import to_bytes_or_str, to_str
from ._utils import safe_async_iterate, to_bytes_or_str, to_str

__all__ = ["Cookies", "Headers", "Request", "Response"]

Expand Down Expand Up @@ -485,7 +485,9 @@ async def aread(self) -> bytes:
"""
if not hasattr(self, "_content"):
assert isinstance(self.stream, typing.AsyncIterable)
self._content = b"".join([part async for part in self.stream])
async with safe_async_iterate(self.stream) as iterator:
self._content = b"".join([part async for part in iterator])

if not isinstance(self.stream, ByteStream):
# If a streaming request has been read entirely into memory, then
# we can replace the stream with a raw bytes implementation,
Expand Down Expand Up @@ -976,12 +978,11 @@ async def aread(self) -> bytes:
Read and return the response content.
"""
if not hasattr(self, "_content"):
self._content = b"".join([part async for part in self.aiter_bytes()])
async with safe_async_iterate(self.aiter_bytes()) as iterator:
self._content = b"".join([part async for part in iterator])
return self._content

async def aiter_bytes(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
async def aiter_bytes(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
"""
A byte-iterator over the decoded response content.
This allows us to handle gzip, deflate, brotli, and zstd encoded responses.
Expand All @@ -994,19 +995,19 @@ async def aiter_bytes(
decoder = self._get_content_decoder()
chunker = ByteChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for raw_bytes in self.aiter_raw():
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk
async with safe_async_iterate(self.aiter_raw()) as iterator:
async for raw_bytes in iterator:
decoded = decoder.decode(raw_bytes)
for chunk in chunker.decode(decoded):
yield chunk

decoded = decoder.flush()
for chunk in chunker.decode(decoded):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk

async def aiter_text(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[str]:
async def aiter_text(self, chunk_size: int | None = None) -> AsyncGenerator[str]:
"""
A str-iterator over the decoded response content
that handles both gzip, deflate, etc but also detects the content's
Expand All @@ -1015,28 +1016,28 @@ async def aiter_text(
decoder = TextDecoder(encoding=self.encoding or "utf-8")
chunker = TextChunker(chunk_size=chunk_size)
with request_context(request=self._request):
async for byte_content in self.aiter_bytes():
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
async with safe_async_iterate(self.aiter_bytes()) as iterator:
async for byte_content in iterator:
text_content = decoder.decode(byte_content)
for chunk in chunker.decode(text_content):
yield chunk
text_content = decoder.flush()
for chunk in chunker.decode(text_content):
yield chunk # pragma: no cover
for chunk in chunker.flush():
yield chunk

async def aiter_lines(self) -> typing.AsyncIterator[str]:
async def aiter_lines(self) -> AsyncGenerator[str]:
decoder = LineDecoder()
with request_context(request=self._request):
async for text in self.aiter_text():
for line in decoder.decode(text):
yield line
async with safe_async_iterate(self.aiter_text()) as iterator:
async for text in iterator:
for line in decoder.decode(text):
yield line
for line in decoder.flush():
yield line

async def aiter_raw(
self, chunk_size: int | None = None
) -> typing.AsyncIterator[bytes]:
async def aiter_raw(self, chunk_size: int | None = None) -> AsyncGenerator[bytes]:
"""
A byte-iterator over the raw response content.
"""
Expand All @@ -1052,10 +1053,11 @@ async def aiter_raw(
chunker = ByteChunker(chunk_size=chunk_size)

with request_context(request=self._request):
async for raw_stream_bytes in self.stream:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk
async with safe_async_iterate(self.stream) as iterator:
async for raw_stream_bytes in iterator:
self._num_bytes_downloaded += len(raw_stream_bytes)
for chunk in chunker.decode(raw_stream_bytes):
yield chunk

for chunk in chunker.flush():
yield chunk
Expand Down
3 changes: 2 additions & 1 deletion httpx/_multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import re
import typing
from collections.abc import AsyncGenerator
from pathlib import Path

from ._types import (
Expand Down Expand Up @@ -295,6 +296,6 @@ def __iter__(self) -> typing.Iterator[bytes]:
for chunk in self.iter_chunks():
yield chunk

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
for chunk in self.iter_chunks():
yield chunk
3 changes: 2 additions & 1 deletion httpx/_transports/asgi.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import typing
from collections.abc import AsyncGenerator

from .._models import Request, Response
from .._types import AsyncByteStream
Expand Down Expand Up @@ -56,7 +57,7 @@ class ASGIResponseStream(AsyncByteStream):
def __init__(self, body: list[bytes]) -> None:
self._body = body

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
yield b"".join(self._body)


Expand Down
9 changes: 6 additions & 3 deletions httpx/_transports/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import contextlib
import typing
from collections.abc import AsyncGenerator
from types import TracebackType

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -55,6 +56,7 @@
from .._models import Request, Response
from .._types import AsyncByteStream, CertTypes, ProxyTypes, SyncByteStream
from .._urls import URL
from .._utils import safe_async_iterate
from .base import AsyncBaseTransport, BaseTransport

T = typing.TypeVar("T", bound="HTTPTransport")
Expand Down Expand Up @@ -266,10 +268,11 @@ class AsyncResponseStream(AsyncByteStream):
def __init__(self, httpcore_stream: typing.AsyncIterable[bytes]) -> None:
self._httpcore_stream = httpcore_stream

async def __aiter__(self) -> typing.AsyncIterator[bytes]:
async def __aiter__(self) -> AsyncGenerator[bytes]:
with map_httpcore_exceptions():
async for part in self._httpcore_stream:
yield part
async with safe_async_iterate(self._httpcore_stream) as iterator:
async for part in iterator:
yield part

async def aclose(self) -> None:
if hasattr(self._httpcore_stream, "aclose"):
Expand Down
4 changes: 1 addition & 3 deletions httpx/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ def __iter__(self) -> Iterator[bytes]:
raise NotImplementedError(
"The '__iter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover

def close(self) -> None:
"""
Expand All @@ -104,11 +103,10 @@ def close(self) -> None:


class AsyncByteStream:
async def __aiter__(self) -> AsyncIterator[bytes]:
def __aiter__(self) -> AsyncIterator[bytes]:
raise NotImplementedError(
"The '__aiter__' method must be implemented."
) # pragma: no cover
yield b"" # pragma: no cover

async def aclose(self) -> None:
pass
21 changes: 21 additions & 0 deletions httpx/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,18 @@
import os
import re
import typing
from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator
from contextlib import asynccontextmanager
from inspect import isasyncgen
from urllib.request import getproxies

from ._types import PrimitiveData

if typing.TYPE_CHECKING: # pragma: no cover
from ._urls import URL

T = typing.TypeVar("T")


def primitive_value_to_str(value: PrimitiveData) -> str:
"""
Expand Down Expand Up @@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool:
except Exception:
return False
return True


@asynccontextmanager
async def safe_async_iterate(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I'm curious about this vs. aclosing.

https://docs.python.org/3/library/contextlib.html#contextlib.aclosing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I originally went with aclosing() but it doesn't work very well with arbitrary async iterables.

iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], /
) -> AsyncGenerator[AsyncIterator[T]]:
iterator = (
iterable_or_iterator
if isinstance(iterable_or_iterator, AsyncIterator)
else iterable_or_iterator.__aiter__()
)
try:
yield iterator
finally:
if isasyncgen(iterator):
await iterator.aclose()