diff --git a/aiodns/__init__.py b/aiodns/__init__.py index 046e202..a9cdf2e 100644 --- a/aiodns/__init__.py +++ b/aiodns/__init__.py @@ -55,7 +55,11 @@ def __init__(self, nameservers: Optional[Sequence[str]] = None, raise RuntimeError( 'aiodns needs a SelectorEventLoop on Windows. See more: https://github.com/saghul/aiodns/issues/86') kwargs.pop('sock_state_cb', None) - self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb, **kwargs) + timeout = kwargs.pop('timeout', None) + self._timeout = timeout + self._channel = pycares.Channel(sock_state_cb=self._sock_state_cb, + timeout=timeout, + **kwargs) if nameservers: self.nameservers = nameservers self._read_fds = set() # type: Set[int] @@ -119,7 +123,7 @@ def _sock_state_cb(self, fd: int, readable: bool, writable: bool) -> None: self.loop.add_writer(fd, self._handle_event, fd, WRITE) self._write_fds.add(fd) if self._timer is None: - self._timer = self.loop.call_later(1.0, self._timer_cb) + self._start_timer() else: # socket is now closed if fd in self._read_fds: @@ -146,6 +150,15 @@ def _handle_event(self, fd: int, event: Any) -> None: def _timer_cb(self) -> None: if self._read_fds or self._write_fds: self._channel.process_fd(pycares.ARES_SOCKET_BAD, pycares.ARES_SOCKET_BAD) - self._timer = self.loop.call_later(1.0, self._timer_cb) + self._start_timer() else: self._timer = None + + def _start_timer(self): + timeout = self._timeout + if timeout is None or timeout < 0 or timeout > 1: + timeout = 1 + elif timeout == 0: + timeout = 0.1 + + self._timer = self.loop.call_later(timeout, self._timer_cb) diff --git a/tests.py b/tests.py index 45c93c8..a2e6c61 100755 --- a/tests.py +++ b/tests.py @@ -99,7 +99,7 @@ def test_query_bad_class(self): self.assertRaises(ValueError, self.resolver.query, 'google.com', 'A', "INVALIDCLASS") def test_query_timeout(self): - self.resolver = aiodns.DNSResolver(timeout=0.1, loop=self.loop) + self.resolver = aiodns.DNSResolver(timeout=0.1, tries=1, loop=self.loop) self.resolver.nameservers = ['1.2.3.4'] f = self.resolver.query('google.com', 'A') started = time.monotonic()