diff --git a/aiodns/__init__.py b/aiodns/__init__.py index ec394b9..3380701 100644 --- a/aiodns/__init__.py +++ b/aiodns/__init__.py @@ -1,13 +1,14 @@ from __future__ import annotations import asyncio +import contextlib import functools import logging import socket import sys import warnings import weakref -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Callable, Iterable, Iterator, Sequence from types import TracebackType from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload @@ -158,7 +159,10 @@ def nameservers(self, value: Iterable[str | bytes]) -> None: def _callback( self, fut: asyncio.Future[_T], result: _T, errorno: int | None ) -> None: - if fut.cancelled(): + # The future can already be done if pycares raised synchronously + # and _capture_ares_error set the exception before c-ares delivered + # the same error through this callback. + if fut.done(): return if errorno is not None: fut.set_exception( @@ -191,7 +195,8 @@ def _query_callback( errorno: int | None, ) -> None: """Callback for query that converts results to compatible format.""" - if fut.cancelled(): + # See _callback for why we guard on done() rather than cancelled(). + if fut.done(): return if errorno is not None: fut.set_exception( @@ -217,6 +222,25 @@ def _get_query_future_callback( cb = functools.partial(self._query_callback, future, qtype) return future, cb + @contextlib.contextmanager + def _capture_ares_error(self, fut: asyncio.Future[_T]) -> Iterator[None]: + # When pycares raises synchronously (e.g. ARES_EBADNAME for a + # malformed hostname), c-ares may also invoke the callback first, + # leaving the future already done. Route the error through the + # future so callers can rely on `await` to raise. + try: + yield + except pycares.AresError as exc: + if fut.done(): + return + # pycares always raises (errno, message), but be defensive: + # an args-less AresError should still resolve the future to + # avoid an indefinite hang on `await`. + errno = exc.args[0] if exc.args else error.ARES_EFORMERR + fut.set_exception( + error.DNSError(errno, pycares.errno.strerror(errno)) + ) + @overload def query( self, host: str, qtype: Literal['A'], qclass: str | None = ... @@ -283,12 +307,13 @@ def query( raise ValueError(f'invalid query class: {qclass}') from e fut, cb = self._get_query_future_callback(qtype_int) - if qclass_int is not None: - self._channel.query( - host, qtype_int, query_class=qclass_int, callback=cb - ) - else: - self._channel.query(host, qtype_int, callback=cb) + with self._capture_ares_error(fut): + if qclass_int is not None: + self._channel.query( + host, qtype_int, query_class=qclass_int, callback=cb + ) + else: + self._channel.query(host, qtype_int, callback=cb) return fut def query_dns( @@ -308,12 +333,13 @@ def query_dns( fut: asyncio.Future[pycares.DNSResult] fut, cb = self._get_future_callback() - if qclass_int is not None: - self._channel.query( - host, qtype_int, query_class=qclass_int, callback=cb - ) - else: - self._channel.query(host, qtype_int, callback=cb) + with self._capture_ares_error(fut): + if qclass_int is not None: + self._channel.query( + host, qtype_int, query_class=qclass_int, callback=cb + ) + else: + self._channel.query(host, qtype_int, callback=cb) return fut def _gethostbyname_callback( @@ -324,7 +350,8 @@ def _gethostbyname_callback( errorno: int | None, ) -> None: """Callback for gethostbyname that converts AddrInfoResult.""" - if fut.cancelled(): + # See _callback for why we guard on done() rather than cancelled(). + if fut.done(): return if errorno is not None: fut.set_exception( @@ -365,7 +392,8 @@ def gethostbyname( ) else: cb = functools.partial(self._gethostbyname_callback, fut, host) - self._channel.getaddrinfo(host, None, family=family, callback=cb) + with self._capture_ares_error(fut): + self._channel.getaddrinfo(host, None, family=family, callback=cb) return fut def getaddrinfo( @@ -379,15 +407,16 @@ def getaddrinfo( ) -> asyncio.Future[pycares.AddrInfoResult]: fut: asyncio.Future[pycares.AddrInfoResult] fut, cb = self._get_future_callback() - self._channel.getaddrinfo( - host, - port, - family=family, - type=type, - proto=proto, - flags=flags, - callback=cb, - ) + with self._capture_ares_error(fut): + self._channel.getaddrinfo( + host, + port, + family=family, + type=type, + proto=proto, + flags=flags, + callback=cb, + ) return fut def getnameinfo( @@ -397,13 +426,15 @@ def getnameinfo( ) -> asyncio.Future[pycares.NameInfoResult]: fut: asyncio.Future[pycares.NameInfoResult] fut, cb = self._get_future_callback() - self._channel.getnameinfo(sockaddr, flags, callback=cb) + with self._capture_ares_error(fut): + self._channel.getnameinfo(sockaddr, flags, callback=cb) return fut def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]: fut: asyncio.Future[pycares.HostResult] fut, cb = self._get_future_callback() - self._channel.gethostbyaddr(name, callback=cb) + with self._capture_ares_error(fut): + self._channel.gethostbyaddr(name, callback=cb) return fut def cancel(self) -> None: diff --git a/tests/test_aiodns.py b/tests/test_aiodns.py index 914fc91..9e2e895 100755 --- a/tests/test_aiodns.py +++ b/tests/test_aiodns.py @@ -1398,5 +1398,88 @@ async def test_query_callback_error() -> None: resolver._closed = True +async def _assert_malformed_name_routes_through_future( + fut: asyncio.Future[Any], +) -> None: + assert isinstance(fut, asyncio.Future) + with pytest.raises(aiodns.error.DNSError) as exc_info: + await fut + assert exc_info.value.args[0] == aiodns.error.ARES_EBADNAME + + +@pytest.mark.asyncio +async def test_query_dns_malformed_name_routes_through_future() -> None: + """Synchronous pycares.AresError must be routed through the future. + + Regression test for https://github.com/aio-libs/aiodns/issues/231: + previously a malformed name raised AresError synchronously, leaving + the internally-created future orphaned with an unretrieved exception. + """ + async with aiodns.DNSResolver() as resolver: + await _assert_malformed_name_routes_through_future( + resolver.query_dns('example test.com', 'A') + ) + + +@pytest.mark.asyncio +async def test_query_malformed_name_routes_through_future() -> None: + """Same as above for the deprecated query() entry point.""" + async with aiodns.DNSResolver() as resolver: + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + fut = resolver.query('example test.com', 'A') + await _assert_malformed_name_routes_through_future(fut) + + +def _call_resolver_entry_point( + resolver: aiodns.DNSResolver, channel_method: str +) -> asyncio.Future[Any]: + if channel_method == 'getaddrinfo': + return resolver.getaddrinfo('host') + if channel_method == 'getnameinfo': + return resolver.getnameinfo(('127.0.0.1', 0)) + assert channel_method == 'gethostbyaddr' + return resolver.gethostbyaddr('127.0.0.1') + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + 'channel_method', ['getaddrinfo', 'getnameinfo', 'gethostbyaddr'] +) +async def test_wrapped_entry_points_route_sync_ares_error( + channel_method: str, +) -> None: + """Each wrapper routes a synchronous AresError to the returned future. + + pycares does not currently validate inputs to these entry points + synchronously, so we inject an AresError via mock to prove the + wrapper is wired up and would not regress to a sync raise. + """ + async with aiodns.DNSResolver() as resolver: + exc = pycares.AresError( + aiodns.error.ARES_EBADNAME, 'Misformatted domain name' + ) + with unittest.mock.patch.object( + resolver._channel, channel_method, side_effect=exc + ): + fut = _call_resolver_entry_point(resolver, channel_method) + await _assert_malformed_name_routes_through_future(fut) + + +@pytest.mark.asyncio +async def test_capture_ares_error_leaves_done_future_untouched() -> None: + """When the callback already finished the future, the captured + AresError must be discarded; reaching the body of the context + manager would call set_exception() on a done future and raise + InvalidStateError, so passing through cleanly is the assertion.""" + async with aiodns.DNSResolver() as resolver: + fut: asyncio.Future[None] = asyncio.get_running_loop().create_future() + fut.set_result(None) + with resolver._capture_ares_error(fut): + raise pycares.AresError( + aiodns.error.ARES_EBADNAME, 'Misformatted domain name' + ) + + if __name__ == '__main__': # pragma: no cover unittest.main(verbosity=2)