-
-
Notifications
You must be signed in to change notification settings - Fork 971
Ensured explicit closing of async generators #3593
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,6 +2,7 @@ | |
|
|
||
| import inspect | ||
| import warnings | ||
| from collections.abc import AsyncGenerator | ||
| from json import dumps as json_dumps | ||
| from typing import ( | ||
| Any, | ||
|
|
@@ -10,6 +11,7 @@ | |
| Iterable, | ||
| Iterator, | ||
| Mapping, | ||
| NoReturn, | ||
| ) | ||
| from urllib.parse import urlencode | ||
|
|
||
|
|
@@ -23,7 +25,7 @@ | |
| ResponseContent, | ||
| SyncByteStream, | ||
| ) | ||
| from ._utils import peek_filelike_length, primitive_value_to_str | ||
| from ._utils import peek_filelike_length, primitive_value_to_str, safe_async_iterate | ||
|
|
||
| __all__ = ["ByteStream"] | ||
|
|
||
|
|
@@ -35,7 +37,7 @@ def __init__(self, stream: bytes) -> None: | |
| def __iter__(self) -> Iterator[bytes]: | ||
| yield self._stream | ||
|
|
||
| async def __aiter__(self) -> AsyncIterator[bytes]: | ||
| async def __aiter__(self) -> AsyncGenerator[bytes]: | ||
| yield self._stream | ||
|
|
||
|
|
||
|
|
@@ -85,8 +87,9 @@ async def __aiter__(self) -> AsyncIterator[bytes]: | |
| chunk = await self._stream.aread(self.CHUNK_SIZE) | ||
| else: | ||
| # Otherwise iterate. | ||
| async for part in self._stream: | ||
| yield part | ||
| async with safe_async_iterate(self._stream) as iterator: | ||
| async for part in iterator: | ||
| yield part | ||
|
|
||
|
|
||
| class UnattachedStream(AsyncByteStream, SyncByteStream): | ||
|
|
@@ -99,9 +102,8 @@ class UnattachedStream(AsyncByteStream, SyncByteStream): | |
| def __iter__(self) -> Iterator[bytes]: | ||
| raise StreamClosed() | ||
|
|
||
| async def __aiter__(self) -> AsyncIterator[bytes]: | ||
| def __aiter__(self) -> NoReturn: | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is the async removed? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not really necessary here. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @tomchristie this is not an important part of this PR so if you object, I can just revert it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On another note, is either one of you at EuroPython? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I'm not. 😞 |
||
| raise StreamClosed() | ||
| yield b"" # pragma: no cover | ||
|
|
||
|
|
||
| def encode_content( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4,13 +4,18 @@ | |
| import os | ||
| import re | ||
| import typing | ||
| from collections.abc import AsyncGenerator, AsyncIterable, AsyncIterator | ||
| from contextlib import asynccontextmanager | ||
| from inspect import isasyncgen | ||
| from urllib.request import getproxies | ||
|
|
||
| from ._types import PrimitiveData | ||
|
|
||
| if typing.TYPE_CHECKING: # pragma: no cover | ||
| from ._urls import URL | ||
|
|
||
| T = typing.TypeVar("T") | ||
|
|
||
|
|
||
| def primitive_value_to_str(value: PrimitiveData) -> str: | ||
| """ | ||
|
|
@@ -240,3 +245,19 @@ def is_ipv6_hostname(hostname: str) -> bool: | |
| except Exception: | ||
| return False | ||
| return True | ||
|
|
||
|
|
||
| @asynccontextmanager | ||
| async def safe_async_iterate( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Interesting. I'm curious about this vs. https://docs.python.org/3/library/contextlib.html#contextlib.aclosing There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I originally went with |
||
| iterable_or_iterator: AsyncIterable[T] | AsyncIterator[T], / | ||
| ) -> AsyncGenerator[AsyncIterator[T]]: | ||
| iterator = ( | ||
| iterable_or_iterator | ||
| if isinstance(iterable_or_iterator, AsyncIterator) | ||
| else iterable_or_iterator.__aiter__() | ||
| ) | ||
| try: | ||
| yield iterator | ||
| finally: | ||
| if isasyncgen(iterator): | ||
| await iterator.aclose() | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
👍🏼 Neat.