9
9
from contextlib import suppress
10
10
from http import HTTPStatus
11
11
from http .cookies import SimpleCookie
12
- from itertools import cycle , islice
12
+ from itertools import chain , cycle , islice
13
13
from time import monotonic
14
14
from types import TracebackType
15
15
from typing import (
50
50
)
51
51
from .client_proto import ResponseHandler
52
52
from .client_reqrep import ClientRequest , Fingerprint , _merge_ssl_params
53
- from .helpers import ceil_timeout , is_ip_address , noop , sentinel
54
- from .locks import EventResultOrError
53
+ from .helpers import (
54
+ ceil_timeout ,
55
+ is_ip_address ,
56
+ noop ,
57
+ sentinel ,
58
+ set_exception ,
59
+ set_result ,
60
+ )
55
61
from .resolver import DefaultResolver
56
62
57
63
try :
@@ -840,7 +846,9 @@ def __init__(
840
846
841
847
self ._use_dns_cache = use_dns_cache
842
848
self ._cached_hosts = _DNSCacheTable (ttl = ttl_dns_cache )
843
- self ._throttle_dns_events : Dict [Tuple [str , int ], EventResultOrError ] = {}
849
+ self ._throttle_dns_futures : Dict [
850
+ Tuple [str , int ], Set ["asyncio.Future[None]" ]
851
+ ] = {}
844
852
self ._family = family
845
853
self ._local_addr_infos = aiohappyeyeballs .addr_to_addr_infos (local_addr )
846
854
self ._happy_eyeballs_delay = happy_eyeballs_delay
@@ -849,8 +857,8 @@ def __init__(
849
857
850
858
def close (self ) -> Awaitable [None ]:
851
859
"""Close all ongoing DNS calls."""
852
- for ev in self ._throttle_dns_events .values ():
853
- ev .cancel ()
860
+ for fut in chain . from_iterable ( self ._throttle_dns_futures .values () ):
861
+ fut .cancel ()
854
862
855
863
for t in self ._resolve_host_tasks :
856
864
t .cancel ()
@@ -918,18 +926,35 @@ async def _resolve_host(
918
926
await trace .send_dns_cache_hit (host )
919
927
return result
920
928
929
+ futures : Set ["asyncio.Future[None]" ]
921
930
#
922
931
# If multiple connectors are resolving the same host, we wait
923
932
# for the first one to resolve and then use the result for all of them.
924
- # We use a throttle event to ensure that we only resolve the host once
933
+ # We use a throttle to ensure that we only resolve the host once
925
934
# and then use the result for all the waiters.
926
935
#
936
+ if key in self ._throttle_dns_futures :
937
+ # get futures early, before any await (#4014)
938
+ futures = self ._throttle_dns_futures [key ]
939
+ future : asyncio .Future [None ] = self ._loop .create_future ()
940
+ futures .add (future )
941
+ if traces :
942
+ for trace in traces :
943
+ await trace .send_dns_cache_hit (host )
944
+ try :
945
+ await future
946
+ finally :
947
+ futures .discard (future )
948
+ return self ._cached_hosts .next_addrs (key )
949
+
950
+ # update dict early, before any await (#4014)
951
+ self ._throttle_dns_futures [key ] = futures = set ()
927
952
# In this case we need to create a task to ensure that we can shield
928
953
# the task from cancellation as cancelling this lookup should not cancel
929
954
# the underlying lookup or else the cancel event will get broadcast to
930
955
# all the waiters across all connections.
931
956
#
932
- coro = self ._resolve_host_with_throttle (key , host , port , traces )
957
+ coro = self ._resolve_host_with_throttle (key , host , port , futures , traces )
933
958
loop = asyncio .get_running_loop ()
934
959
if sys .version_info >= (3 , 12 ):
935
960
# Optimization for Python 3.12, try to send immediately
@@ -957,42 +982,39 @@ async def _resolve_host_with_throttle(
957
982
key : Tuple [str , int ],
958
983
host : str ,
959
984
port : int ,
985
+ futures : Set ["asyncio.Future[None]" ],
960
986
traces : Optional [Sequence ["Trace" ]],
961
987
) -> List [ResolveResult ]:
962
- """Resolve host with a dns events throttle."""
963
- if key in self ._throttle_dns_events :
964
- # get event early, before any await (#4014)
965
- event = self ._throttle_dns_events [key ]
988
+ """Resolve host and set result for all waiters.
989
+
990
+ This method must be run in a task and shielded from cancellation
991
+ to avoid cancelling the underlying lookup.
992
+ """
993
+ if traces :
994
+ for trace in traces :
995
+ await trace .send_dns_cache_miss (host )
996
+ try :
966
997
if traces :
967
998
for trace in traces :
968
- await trace .send_dns_cache_hit (host )
969
- await event .wait ()
970
- else :
971
- # update dict early, before any await (#4014)
972
- self ._throttle_dns_events [key ] = EventResultOrError (self ._loop )
999
+ await trace .send_dns_resolvehost_start (host )
1000
+
1001
+ addrs = await self ._resolver .resolve (host , port , family = self ._family )
973
1002
if traces :
974
1003
for trace in traces :
975
- await trace .send_dns_cache_miss (host )
976
- try :
977
-
978
- if traces :
979
- for trace in traces :
980
- await trace .send_dns_resolvehost_start (host )
981
-
982
- addrs = await self ._resolver .resolve (host , port , family = self ._family )
983
- if traces :
984
- for trace in traces :
985
- await trace .send_dns_resolvehost_end (host )
1004
+ await trace .send_dns_resolvehost_end (host )
986
1005
987
- self ._cached_hosts .add (key , addrs )
988
- self ._throttle_dns_events [key ].set ()
989
- except BaseException as e :
990
- # any DNS exception, independently of the implementation
991
- # is set for the waiters to raise the same exception.
992
- self ._throttle_dns_events [key ].set (exc = e )
993
- raise
994
- finally :
995
- self ._throttle_dns_events .pop (key )
1006
+ self ._cached_hosts .add (key , addrs )
1007
+ for fut in futures :
1008
+ set_result (fut , None )
1009
+ except BaseException as e :
1010
+ # any DNS exception is set for the waiters to raise the same exception.
1011
+ # This coro is always run in task that is shielded from cancellation so
1012
+ # we should never be propagating cancellation here.
1013
+ for fut in futures :
1014
+ set_exception (fut , e )
1015
+ raise
1016
+ finally :
1017
+ self ._throttle_dns_futures .pop (key )
996
1018
997
1019
return self ._cached_hosts .next_addrs (key )
998
1020
0 commit comments