From 5208bbb6c8690ddee8ba949f387f9e51b9b3d3aa Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 26 Mar 2024 11:26:16 -1000 Subject: [PATCH 1/3] Add support for getaddrinfo fixes #23 --- aiodns/__init__.py | 6 ++++++ tests.py | 20 ++++++++++++++++++++ 2 files changed, 26 insertions(+) diff --git a/aiodns/__init__.py b/aiodns/__init__.py index 2574514..85e1416 100644 --- a/aiodns/__init__.py +++ b/aiodns/__init__.py @@ -110,6 +110,12 @@ def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Futu cb = functools.partial(self._callback, fut) self._channel.gethostbyname(host, family, cb) return fut + + def getaddrinfo(self, host: str, family: socket.AddressFamily = 0, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future: + fut = asyncio.Future(loop=self.loop) # type: asyncio.Future + cb = functools.partial(self._callback, fut) + self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags) + return fut def gethostbyaddr(self, name: str) -> asyncio.Future: fut = asyncio.Future(loop=self.loop) # type: asyncio.Future diff --git a/tests.py b/tests.py index 635f7d0..7de75f6 100755 --- a/tests.py +++ b/tests.py @@ -151,6 +151,26 @@ def test_gethostbyname(self): result = self.loop.run_until_complete(f) self.assertTrue(result) + def test_getaddrinfo_address_family_0(self): + f = self.resolver.getaddrinfo('google.com') + result = self.loop.run_until_complete(f) + self.assertTrue(result) + self.assertTrue(len(result.nodes) > 1) + + def test_getaddrinfo_address_family_af_inet(self): + f = self.resolver.getaddrinfo('google.com', socket.AF_INET) + result = self.loop.run_until_complete(f) + self.assertTrue(result) + self.assertTrue(len(result.nodes) == 1) + self.assertTrue(result.nodes[0].family == socket.AF_INET) + + def test_getaddrinfo_address_family_af_inet6(self): + f = self.resolver.getaddrinfo('google.com', socket.AF_INET6) + result = self.loop.run_until_complete(f) + self.assertTrue(result) + self.assertTrue(len(result.nodes) == 1) + self.assertTrue(result.nodes[0].family == socket.AF_INET6) + @unittest.skipIf(sys.platform == 'win32', 'skipped on Windows') def test_gethostbyaddr(self): f = self.resolver.gethostbyaddr('127.0.0.1') From 8e0d9c9a6b075a01aac77f2182d5f922d9ba0397 Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 26 Mar 2024 12:12:13 -1000 Subject: [PATCH 2/3] fix typing --- aiodns/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aiodns/__init__.py b/aiodns/__init__.py index 85e1416..cb5c77c 100644 --- a/aiodns/__init__.py +++ b/aiodns/__init__.py @@ -111,7 +111,7 @@ def gethostbyname(self, host: str, family: socket.AddressFamily) -> asyncio.Futu self._channel.gethostbyname(host, family, cb) return fut - def getaddrinfo(self, host: str, family: socket.AddressFamily = 0, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future: + def getaddrinfo(self, host: str, family: socket.AddressFamily = socket.AF_UNSPEC, port: Optional[int] = None, proto: int = 0, type: int = 0, flags: int = 0) -> asyncio.Future: fut = asyncio.Future(loop=self.loop) # type: asyncio.Future cb = functools.partial(self._callback, fut) self._channel.getaddrinfo(host, port, cb, family=family, type=type, proto=proto, flags=flags) From d9254918e9b3c9745f706f849f9e0d41a359041e Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Tue, 26 Mar 2024 12:20:32 -1000 Subject: [PATCH 3/3] adjust test for macos behavior difference --- tests.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests.py b/tests.py index 7de75f6..bf1bdb9 100755 --- a/tests.py +++ b/tests.py @@ -161,15 +161,13 @@ def test_getaddrinfo_address_family_af_inet(self): f = self.resolver.getaddrinfo('google.com', socket.AF_INET) result = self.loop.run_until_complete(f) self.assertTrue(result) - self.assertTrue(len(result.nodes) == 1) - self.assertTrue(result.nodes[0].family == socket.AF_INET) + self.assertTrue(all(node.family == socket.AF_INET for node in result.nodes)) def test_getaddrinfo_address_family_af_inet6(self): f = self.resolver.getaddrinfo('google.com', socket.AF_INET6) result = self.loop.run_until_complete(f) self.assertTrue(result) - self.assertTrue(len(result.nodes) == 1) - self.assertTrue(result.nodes[0].family == socket.AF_INET6) + self.assertTrue(all(node.family == socket.AF_INET6 for node in result.nodes)) @unittest.skipIf(sys.platform == 'win32', 'skipped on Windows') def test_gethostbyaddr(self):