Skip to content

Commit cdf3dca

Browse files
authored
[PR #9454/b20908e backport][3.10] Simplify DNS throttle implementation (#9457)
1 parent ee87a04 commit cdf3dca

File tree

5 files changed

+368
-141
lines changed

5 files changed

+368
-141
lines changed

Diff for: CHANGES/9454.misc.rst

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Simplified DNS resolution throttling code to reduce chance of race conditions -- by :user:`bdraco`.

Diff for: aiohttp/connector.py

+59-37
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from contextlib import suppress
1010
from http import HTTPStatus
1111
from http.cookies import SimpleCookie
12-
from itertools import cycle, islice
12+
from itertools import chain, cycle, islice
1313
from time import monotonic
1414
from types import TracebackType
1515
from typing import (
@@ -50,8 +50,14 @@
5050
)
5151
from .client_proto import ResponseHandler
5252
from .client_reqrep import ClientRequest, Fingerprint, _merge_ssl_params
53-
from .helpers import ceil_timeout, is_ip_address, noop, sentinel
54-
from .locks import EventResultOrError
53+
from .helpers import (
54+
ceil_timeout,
55+
is_ip_address,
56+
noop,
57+
sentinel,
58+
set_exception,
59+
set_result,
60+
)
5561
from .resolver import DefaultResolver
5662

5763
try:
@@ -840,7 +846,9 @@ def __init__(
840846

841847
self._use_dns_cache = use_dns_cache
842848
self._cached_hosts = _DNSCacheTable(ttl=ttl_dns_cache)
843-
self._throttle_dns_events: Dict[Tuple[str, int], EventResultOrError] = {}
849+
self._throttle_dns_futures: Dict[
850+
Tuple[str, int], Set["asyncio.Future[None]"]
851+
] = {}
844852
self._family = family
845853
self._local_addr_infos = aiohappyeyeballs.addr_to_addr_infos(local_addr)
846854
self._happy_eyeballs_delay = happy_eyeballs_delay
@@ -849,8 +857,8 @@ def __init__(
849857

850858
def close(self) -> Awaitable[None]:
851859
"""Close all ongoing DNS calls."""
852-
for ev in self._throttle_dns_events.values():
853-
ev.cancel()
860+
for fut in chain.from_iterable(self._throttle_dns_futures.values()):
861+
fut.cancel()
854862

855863
for t in self._resolve_host_tasks:
856864
t.cancel()
@@ -918,18 +926,35 @@ async def _resolve_host(
918926
await trace.send_dns_cache_hit(host)
919927
return result
920928

929+
futures: Set["asyncio.Future[None]"]
921930
#
922931
# If multiple connectors are resolving the same host, we wait
923932
# for the first one to resolve and then use the result for all of them.
924-
# We use a throttle event to ensure that we only resolve the host once
933+
# We use a throttle to ensure that we only resolve the host once
925934
# and then use the result for all the waiters.
926935
#
936+
if key in self._throttle_dns_futures:
937+
# get futures early, before any await (#4014)
938+
futures = self._throttle_dns_futures[key]
939+
future: asyncio.Future[None] = self._loop.create_future()
940+
futures.add(future)
941+
if traces:
942+
for trace in traces:
943+
await trace.send_dns_cache_hit(host)
944+
try:
945+
await future
946+
finally:
947+
futures.discard(future)
948+
return self._cached_hosts.next_addrs(key)
949+
950+
# update dict early, before any await (#4014)
951+
self._throttle_dns_futures[key] = futures = set()
927952
# In this case we need to create a task to ensure that we can shield
928953
# the task from cancellation as cancelling this lookup should not cancel
929954
# the underlying lookup or else the cancel event will get broadcast to
930955
# all the waiters across all connections.
931956
#
932-
coro = self._resolve_host_with_throttle(key, host, port, traces)
957+
coro = self._resolve_host_with_throttle(key, host, port, futures, traces)
933958
loop = asyncio.get_running_loop()
934959
if sys.version_info >= (3, 12):
935960
# Optimization for Python 3.12, try to send immediately
@@ -957,42 +982,39 @@ async def _resolve_host_with_throttle(
957982
key: Tuple[str, int],
958983
host: str,
959984
port: int,
985+
futures: Set["asyncio.Future[None]"],
960986
traces: Optional[Sequence["Trace"]],
961987
) -> List[ResolveResult]:
962-
"""Resolve host with a dns events throttle."""
963-
if key in self._throttle_dns_events:
964-
# get event early, before any await (#4014)
965-
event = self._throttle_dns_events[key]
988+
"""Resolve host and set result for all waiters.
989+
990+
This method must be run in a task and shielded from cancellation
991+
to avoid cancelling the underlying lookup.
992+
"""
993+
if traces:
994+
for trace in traces:
995+
await trace.send_dns_cache_miss(host)
996+
try:
966997
if traces:
967998
for trace in traces:
968-
await trace.send_dns_cache_hit(host)
969-
await event.wait()
970-
else:
971-
# update dict early, before any await (#4014)
972-
self._throttle_dns_events[key] = EventResultOrError(self._loop)
999+
await trace.send_dns_resolvehost_start(host)
1000+
1001+
addrs = await self._resolver.resolve(host, port, family=self._family)
9731002
if traces:
9741003
for trace in traces:
975-
await trace.send_dns_cache_miss(host)
976-
try:
977-
978-
if traces:
979-
for trace in traces:
980-
await trace.send_dns_resolvehost_start(host)
981-
982-
addrs = await self._resolver.resolve(host, port, family=self._family)
983-
if traces:
984-
for trace in traces:
985-
await trace.send_dns_resolvehost_end(host)
1004+
await trace.send_dns_resolvehost_end(host)
9861005

987-
self._cached_hosts.add(key, addrs)
988-
self._throttle_dns_events[key].set()
989-
except BaseException as e:
990-
# any DNS exception, independently of the implementation
991-
# is set for the waiters to raise the same exception.
992-
self._throttle_dns_events[key].set(exc=e)
993-
raise
994-
finally:
995-
self._throttle_dns_events.pop(key)
1006+
self._cached_hosts.add(key, addrs)
1007+
for fut in futures:
1008+
set_result(fut, None)
1009+
except BaseException as e:
1010+
# any DNS exception is set for the waiters to raise the same exception.
1011+
# This coro is always run in task that is shielded from cancellation so
1012+
# we should never be propagating cancellation here.
1013+
for fut in futures:
1014+
set_exception(fut, e)
1015+
raise
1016+
finally:
1017+
self._throttle_dns_futures.pop(key)
9961018

9971019
return self._cached_hosts.next_addrs(key)
9981020

Diff for: aiohttp/locks.py

-41
This file was deleted.

0 commit comments

Comments
 (0)