Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 59 additions & 28 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import asyncio
import contextlib
import functools
import logging
import socket
import sys
import warnings
import weakref
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload

Expand Down Expand Up @@ -158,7 +159,10 @@ def nameservers(self, value: Iterable[str | bytes]) -> None:
def _callback(
self, fut: asyncio.Future[_T], result: _T, errorno: int | None
) -> None:
if fut.cancelled():
# The future can already be done if pycares raised synchronously
# and _capture_ares_error set the exception before c-ares delivered
# the same error through this callback.
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand Down Expand Up @@ -191,7 +195,8 @@ def _query_callback(
errorno: int | None,
) -> None:
"""Callback for query that converts results to compatible format."""
if fut.cancelled():
# See _callback for why we guard on done() rather than cancelled().
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand All @@ -217,6 +222,25 @@ def _get_query_future_callback(
cb = functools.partial(self._query_callback, future, qtype)
return future, cb

@contextlib.contextmanager
def _capture_ares_error(self, fut: asyncio.Future[_T]) -> Iterator[None]:
# When pycares raises synchronously (e.g. ARES_EBADNAME for a
# malformed hostname), c-ares may also invoke the callback first,
# leaving the future already done. Route the error through the
# future so callers can rely on `await` to raise.
try:
yield
except pycares.AresError as exc:
if fut.done():
return
# pycares always raises (errno, message), but be defensive:
# an args-less AresError should still resolve the future to
# avoid an indefinite hang on `await`.
errno = exc.args[0] if exc.args else error.ARES_EFORMERR
fut.set_exception(
error.DNSError(errno, pycares.errno.strerror(errno))
)

@overload
def query(
self, host: str, qtype: Literal['A'], qclass: str | None = ...
Expand Down Expand Up @@ -283,12 +307,13 @@ def query(
raise ValueError(f'invalid query class: {qclass}') from e

fut, cb = self._get_query_future_callback(qtype_int)
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
with self._capture_ares_error(fut):
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
return fut

def query_dns(
Expand All @@ -308,12 +333,13 @@ def query_dns(

fut: asyncio.Future[pycares.DNSResult]
fut, cb = self._get_future_callback()
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
with self._capture_ares_error(fut):
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
return fut

def _gethostbyname_callback(
Expand All @@ -324,7 +350,8 @@ def _gethostbyname_callback(
errorno: int | None,
) -> None:
"""Callback for gethostbyname that converts AddrInfoResult."""
if fut.cancelled():
# See _callback for why we guard on done() rather than cancelled().
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand Down Expand Up @@ -365,7 +392,8 @@ def gethostbyname(
)
else:
cb = functools.partial(self._gethostbyname_callback, fut, host)
self._channel.getaddrinfo(host, None, family=family, callback=cb)
with self._capture_ares_error(fut):
self._channel.getaddrinfo(host, None, family=family, callback=cb)
return fut

def getaddrinfo(
Expand All @@ -379,15 +407,16 @@ def getaddrinfo(
) -> asyncio.Future[pycares.AddrInfoResult]:
fut: asyncio.Future[pycares.AddrInfoResult]
fut, cb = self._get_future_callback()
self._channel.getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
callback=cb,
)
with self._capture_ares_error(fut):
self._channel.getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
callback=cb,
)
return fut

def getnameinfo(
Expand All @@ -397,13 +426,15 @@ def getnameinfo(
) -> asyncio.Future[pycares.NameInfoResult]:
fut: asyncio.Future[pycares.NameInfoResult]
fut, cb = self._get_future_callback()
self._channel.getnameinfo(sockaddr, flags, callback=cb)
with self._capture_ares_error(fut):
self._channel.getnameinfo(sockaddr, flags, callback=cb)
return fut

def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]:
fut: asyncio.Future[pycares.HostResult]
fut, cb = self._get_future_callback()
self._channel.gethostbyaddr(name, callback=cb)
with self._capture_ares_error(fut):
self._channel.gethostbyaddr(name, callback=cb)
return fut

def cancel(self) -> None:
Expand Down
83 changes: 83 additions & 0 deletions tests/test_aiodns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,5 +1398,88 @@ async def test_query_callback_error() -> None:
resolver._closed = True


async def _assert_malformed_name_routes_through_future(
fut: asyncio.Future[Any],
) -> None:
assert isinstance(fut, asyncio.Future)
with pytest.raises(aiodns.error.DNSError) as exc_info:
await fut
assert exc_info.value.args[0] == aiodns.error.ARES_EBADNAME


@pytest.mark.asyncio
async def test_query_dns_malformed_name_routes_through_future() -> None:
"""Synchronous pycares.AresError must be routed through the future.

Regression test for https://github.com/aio-libs/aiodns/issues/231:
previously a malformed name raised AresError synchronously, leaving
the internally-created future orphaned with an unretrieved exception.
"""
async with aiodns.DNSResolver() as resolver:
await _assert_malformed_name_routes_through_future(
resolver.query_dns('example test.com', 'A')
)


@pytest.mark.asyncio
async def test_query_malformed_name_routes_through_future() -> None:
"""Same as above for the deprecated query() entry point."""
async with aiodns.DNSResolver() as resolver:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
fut = resolver.query('example test.com', 'A')
await _assert_malformed_name_routes_through_future(fut)


def _call_resolver_entry_point(
resolver: aiodns.DNSResolver, channel_method: str
) -> asyncio.Future[Any]:
if channel_method == 'getaddrinfo':
return resolver.getaddrinfo('host')
if channel_method == 'getnameinfo':
return resolver.getnameinfo(('127.0.0.1', 0))
assert channel_method == 'gethostbyaddr'
return resolver.gethostbyaddr('127.0.0.1')


@pytest.mark.asyncio
@pytest.mark.parametrize(
'channel_method', ['getaddrinfo', 'getnameinfo', 'gethostbyaddr']
)
async def test_wrapped_entry_points_route_sync_ares_error(
channel_method: str,
) -> None:
"""Each wrapper routes a synchronous AresError to the returned future.

pycares does not currently validate inputs to these entry points
synchronously, so we inject an AresError via mock to prove the
wrapper is wired up and would not regress to a sync raise.
"""
async with aiodns.DNSResolver() as resolver:
exc = pycares.AresError(
aiodns.error.ARES_EBADNAME, 'Misformatted domain name'
)
with unittest.mock.patch.object(
resolver._channel, channel_method, side_effect=exc
):
fut = _call_resolver_entry_point(resolver, channel_method)
await _assert_malformed_name_routes_through_future(fut)


@pytest.mark.asyncio
async def test_capture_ares_error_leaves_done_future_untouched() -> None:
"""When the callback already finished the future, the captured
AresError must be discarded; reaching the body of the context
manager would call set_exception() on a done future and raise
InvalidStateError, so passing through cleanly is the assertion."""
async with aiodns.DNSResolver() as resolver:
fut: asyncio.Future[None] = asyncio.get_running_loop().create_future()
fut.set_result(None)
with resolver._capture_ares_error(fut):
raise pycares.AresError(
aiodns.error.ARES_EBADNAME, 'Misformatted domain name'
)


if __name__ == '__main__': # pragma: no cover
unittest.main(verbosity=2)
Loading