Skip to content
Merged
78 changes: 50 additions & 28 deletions aiodns/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from __future__ import annotations

import asyncio
import contextlib
import functools
import logging
import socket
import sys
import warnings
import weakref
from collections.abc import Callable, Iterable, Sequence
from collections.abc import Callable, Iterable, Iterator, Sequence
from types import TracebackType
from typing import TYPE_CHECKING, Any, Literal, TypeVar, overload

Expand Down Expand Up @@ -158,7 +159,7 @@ def nameservers(self, value: Iterable[str | bytes]) -> None:
def _callback(
self, fut: asyncio.Future[_T], result: _T, errorno: int | None
) -> None:
if fut.cancelled():
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand Down Expand Up @@ -191,7 +192,7 @@ def _query_callback(
errorno: int | None,
) -> None:
"""Callback for query that converts results to compatible format."""
if fut.cancelled():
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand All @@ -217,6 +218,21 @@ def _get_query_future_callback(
cb = functools.partial(self._query_callback, future, qtype)
return future, cb

@contextlib.contextmanager
def _capture_ares_error(self, fut: asyncio.Future[_T]) -> Iterator[None]:
# When pycares raises synchronously (e.g. ARES_EBADNAME for a
# malformed hostname), c-ares may also invoke the callback first,
# leaving the future already done. Route the error through the
# future so callers can rely on `await` to raise.
try:
yield
except pycares.AresError as exc:
if fut.done():
return
errno = exc.args[0]
message = exc.args[1] if len(exc.args) > 1 else ''
Comment thread
bdraco marked this conversation as resolved.
Outdated
fut.set_exception(error.DNSError(errno, message))

@overload
def query(
self, host: str, qtype: Literal['A'], qclass: str | None = ...
Expand Down Expand Up @@ -283,12 +299,13 @@ def query(
raise ValueError(f'invalid query class: {qclass}') from e

fut, cb = self._get_query_future_callback(qtype_int)
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
with self._capture_ares_error(fut):
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
return fut

def query_dns(
Expand All @@ -308,12 +325,13 @@ def query_dns(

fut: asyncio.Future[pycares.DNSResult]
fut, cb = self._get_future_callback()
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
with self._capture_ares_error(fut):
if qclass_int is not None:
self._channel.query(
host, qtype_int, query_class=qclass_int, callback=cb
)
else:
self._channel.query(host, qtype_int, callback=cb)
return fut

def _gethostbyname_callback(
Expand All @@ -324,7 +342,7 @@ def _gethostbyname_callback(
errorno: int | None,
) -> None:
"""Callback for gethostbyname that converts AddrInfoResult."""
if fut.cancelled():
if fut.done():
return
if errorno is not None:
fut.set_exception(
Expand Down Expand Up @@ -365,7 +383,8 @@ def gethostbyname(
)
else:
cb = functools.partial(self._gethostbyname_callback, fut, host)
self._channel.getaddrinfo(host, None, family=family, callback=cb)
with self._capture_ares_error(fut):
self._channel.getaddrinfo(host, None, family=family, callback=cb)
return fut

def getaddrinfo(
Expand All @@ -379,15 +398,16 @@ def getaddrinfo(
) -> asyncio.Future[pycares.AddrInfoResult]:
fut: asyncio.Future[pycares.AddrInfoResult]
fut, cb = self._get_future_callback()
self._channel.getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
callback=cb,
)
with self._capture_ares_error(fut):
self._channel.getaddrinfo(
host,
port,
family=family,
type=type,
proto=proto,
flags=flags,
callback=cb,
)
return fut

def getnameinfo(
Expand All @@ -397,13 +417,15 @@ def getnameinfo(
) -> asyncio.Future[pycares.NameInfoResult]:
fut: asyncio.Future[pycares.NameInfoResult]
fut, cb = self._get_future_callback()
self._channel.getnameinfo(sockaddr, flags, callback=cb)
with self._capture_ares_error(fut):
self._channel.getnameinfo(sockaddr, flags, callback=cb)
return fut

def gethostbyaddr(self, name: str) -> asyncio.Future[pycares.HostResult]:
fut: asyncio.Future[pycares.HostResult]
fut, cb = self._get_future_callback()
self._channel.gethostbyaddr(name, callback=cb)
with self._capture_ares_error(fut):
self._channel.gethostbyaddr(name, callback=cb)
return fut

def cancel(self) -> None:
Expand Down
50 changes: 50 additions & 0 deletions tests/test_aiodns.py
Original file line number Diff line number Diff line change
Expand Up @@ -1398,5 +1398,55 @@ async def test_query_callback_error() -> None:
resolver._closed = True


async def _assert_malformed_name_routes_through_future(
fut: asyncio.Future[Any],
) -> None:
assert isinstance(fut, asyncio.Future)
with pytest.raises(aiodns.error.DNSError) as exc_info:
await fut
assert exc_info.value.args[0] == aiodns.error.ARES_EBADNAME


@pytest.mark.asyncio
async def test_query_dns_malformed_name_routes_through_future() -> None:
"""Synchronous pycares.AresError must be routed through the future.

Regression test for https://github.com/aio-libs/aiodns/issues/231:
previously a malformed name raised AresError synchronously, leaving
the internally-created future orphaned with an unretrieved exception.
"""
async with aiodns.DNSResolver() as resolver:
await _assert_malformed_name_routes_through_future(
resolver.query_dns('example test.com', 'A')
)


@pytest.mark.asyncio
async def test_query_malformed_name_routes_through_future() -> None:
"""Same as above for the deprecated query() entry point."""
async with aiodns.DNSResolver() as resolver:
with warnings.catch_warnings():
warnings.simplefilter('ignore', DeprecationWarning)
fut = resolver.query('example test.com', 'A')
await _assert_malformed_name_routes_through_future(fut)


@pytest.mark.asyncio
async def test_capture_ares_error_leaves_done_future_untouched() -> None:
"""When the callback already finished the future, the captured
AresError must be discarded so the original result is preserved."""
async with aiodns.DNSResolver() as resolver:
fut: asyncio.Future[None] = asyncio.get_running_loop().create_future()
fut.set_result(None)
exc = pycares.AresError(
aiodns.error.ARES_EBADNAME, 'Misformatted domain name'
)
cm = resolver._capture_ares_error(fut)
cm.__enter__()
suppressed = cm.__exit__(type(exc), exc, exc.__traceback__)
assert suppressed
Comment thread
bdraco marked this conversation as resolved.
Outdated
assert fut.result() is None


if __name__ == '__main__': # pragma: no cover
unittest.main(verbosity=2)
Loading