Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES/11074.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed connector not waiting for connections to close before returning from :meth:`~aiohttp.BaseConnector.close` (partial backport of :pr:`3733`) -- by :user:`atemate` and :user:`bdraco`.
1 change: 1 addition & 0 deletions CHANGES/1925.bugfix.rst
15 changes: 15 additions & 0 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
Expand All @@ -14,6 +15,7 @@
EMPTY_BODY_STATUS_CODES,
BaseTimerContext,
set_exception,
set_result,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
Expand Down Expand Up @@ -43,6 +45,7 @@ def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
self._read_timeout_handle: Optional[asyncio.TimerHandle] = None

self._timeout_ceil_threshold: Optional[float] = 5
self.closed: asyncio.Future[None] = self._loop.create_future()

@property
def upgraded(self) -> bool:
Expand Down Expand Up @@ -83,6 +86,18 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:

connection_closed_cleanly = original_connection_error is None

if connection_closed_cleanly:
set_result(self.closed, None)
else:
assert original_connection_error is not None
set_exception(
self.closed,
ClientConnectionError(
f"Connection lost: {original_connection_error !s}",
),
original_connection_error,
)

if self._payload_parser is not None:
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()
Expand Down
60 changes: 49 additions & 11 deletions aiohttp/connector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import logging
import random
import socket
import sys
Expand Down Expand Up @@ -131,6 +132,14 @@ def __del__(self) -> None:
)


async def _wait_for_close(waiters: List[Awaitable[object]]) -> None:
"""Wait for all waiters to finish closing."""
results = await asyncio.gather(*waiters, return_exceptions=True)
for res in results:
if isinstance(res, Exception):
logging.error("Error while closing connector: %r", res)


class Connection:

_source_traceback = None
Expand Down Expand Up @@ -222,10 +231,14 @@ def closed(self) -> bool:
class _TransportPlaceholder:
"""placeholder for BaseConnector.connect function"""

__slots__ = ()
__slots__ = ("closed",)

def __init__(self, closed_future: asyncio.Future[Optional[Exception]]) -> None:
"""Initialize a placeholder for a transport."""
self.closed = closed_future

def close(self) -> None:
"""Close the placeholder transport."""
"""Close the placeholder."""


class BaseConnector:
Expand Down Expand Up @@ -322,6 +335,10 @@ def __init__(

self._cleanup_closed_disabled = not enable_cleanup_closed
self._cleanup_closed_transports: List[Optional[asyncio.Transport]] = []
self._placeholder_future: asyncio.Future[Optional[Exception]] = (
loop.create_future()
)
self._placeholder_future.set_result(None)
self._cleanup_closed()

def __del__(self, _warnings: Any = warnings) -> None:
Expand Down Expand Up @@ -454,18 +471,30 @@ def _cleanup_closed(self) -> None:

def close(self) -> Awaitable[None]:
"""Close all opened transports."""
self._close()
return _DeprecationWaiter(noop())
if not (waiters := self._close()):
# If there are no connections to close, we can return a noop
# awaitable to avoid scheduling a task on the event loop.
return _DeprecationWaiter(noop())
coro = _wait_for_close(waiters)
if sys.version_info >= (3, 12):
# Optimization for Python 3.12, try to close connections
# immediately to avoid having to schedule the task on the event loop.
task = asyncio.Task(coro, loop=self._loop, eager_start=True)
else:
task = self._loop.create_task(coro)
return _DeprecationWaiter(task)

def _close(self) -> List[Awaitable[object]]:
waiters: List[Awaitable[object]] = []

def _close(self) -> None:
if self._closed:
return
return waiters

self._closed = True

try:
if self._loop.is_closed():
return
return waiters

# cancel cleanup task
if self._cleanup_handle:
Expand All @@ -476,16 +505,20 @@ def _close(self) -> None:
self._cleanup_closed_handle.cancel()

for data in self._conns.values():
for proto, t0 in data:
for proto, _ in data:
proto.close()
waiters.append(proto.closed)

for proto in self._acquired:
proto.close()
waiters.append(proto.closed)

for transport in self._cleanup_closed_transports:
if transport is not None:
transport.abort()

return waiters

finally:
self._conns.clear()
self._acquired.clear()
Expand Down Expand Up @@ -546,7 +579,9 @@ async def connect(
if (conn := await self._get(key, traces)) is not None:
return conn

placeholder = cast(ResponseHandler, _TransportPlaceholder())
placeholder = cast(
ResponseHandler, _TransportPlaceholder(self._placeholder_future)
)
self._acquired.add(placeholder)
if self._limit_per_host:
self._acquired_per_host[key].add(placeholder)
Expand Down Expand Up @@ -898,15 +933,18 @@ def __init__(
self._resolve_host_tasks: Set["asyncio.Task[List[ResolveResult]]"] = set()
self._socket_factory = socket_factory

def close(self) -> Awaitable[None]:
def _close(self) -> List[Awaitable[object]]:
"""Close all ongoing DNS calls."""
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
fut.cancel()

waiters = super()._close()

for t in self._resolve_host_tasks:
t.cancel()
waiters.append(t)

return super().close()
return waiters

@property
def family(self) -> int:
Expand Down
7 changes: 6 additions & 1 deletion tests/test_client_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,8 @@ def protocol(loop, transport):
protocol.transport = transport
protocol._drain_helper.return_value = loop.create_future()
protocol._drain_helper.return_value.set_result(None)
protocol.closed = loop.create_future()
protocol.closed.set_result(None)
return protocol


Expand Down Expand Up @@ -1404,7 +1406,10 @@ async def send(self, conn):

async def create_connection(req, traces, timeout):
assert isinstance(req, CustomRequest)
return mock.Mock()
proto = mock.Mock()
proto.closed = loop.create_future()
proto.closed.set_result(None)
return proto

connector = BaseConnector(loop=loop)
connector._create_connection = create_connection
Expand Down
9 changes: 8 additions & 1 deletion tests/test_client_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ async def make_conn():

conn = loop.run_until_complete(make_conn())
proto = mock.Mock()
proto.closed = loop.create_future()
proto.closed.set_result(None)
conn._conns["a"] = deque([(proto, 123)])
yield conn
loop.run_until_complete(conn.close())
Expand Down Expand Up @@ -429,7 +431,10 @@ async def test_reraise_os_error(create_session) -> None:

async def create_connection(req, traces, timeout):
# return self.transport, self.protocol
return mock.Mock()
proto = mock.Mock()
proto.closed = session._loop.create_future()
proto.closed.set_result(None)
return proto

session._connector._create_connection = create_connection
session._connector._release = mock.Mock()
Expand Down Expand Up @@ -464,6 +469,8 @@ async def connect(req, traces, timeout):
async def create_connection(req, traces, timeout):
# return self.transport, self.protocol
conn = mock.Mock()
conn.closed = session._loop.create_future()
conn.closed.set_result(None)
return conn

session._connector.connect = connect
Expand Down
Loading
Loading