Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hathor/p2p/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,8 @@ def do_discovery(self) -> None:
Do a discovery and connect on all discovery strategies.
"""
for peer_discovery in self.peer_discoveries:
peer_discovery.discover_and_connect(self.connect_to)
coro = peer_discovery.discover_and_connect(self.connect_to)
Deferred.fromCoroutine(coro)

def disable_rate_limiter(self) -> None:
"""Disable global rate limiter."""
Expand Down
20 changes: 10 additions & 10 deletions hathor/p2p/peer_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

import socket
from abc import ABC, abstractmethod
from typing import Any, Callable, Generator
from typing import Callable

from structlog import get_logger
from twisted.internet import defer
from twisted.internet.defer import inlineCallbacks
from twisted.names.client import lookupAddress, lookupText
from twisted.names.dns import Record_A, Record_TXT, RRHeader
from typing_extensions import override

logger = get_logger()

Expand All @@ -30,7 +30,7 @@ class PeerDiscovery(ABC):
"""

@abstractmethod
def discover_and_connect(self, connect_to: Callable[[str], None]) -> Any:
async def discover_and_connect(self, connect_to: Callable[[str], None]) -> None:
""" This method must discover the peers and call `connect_to` for each of them.

:param connect_to: Function which will be called for each discovered peer.
Expand All @@ -51,7 +51,8 @@ def __init__(self, descriptions: list[str]):
self.log = logger.new()
self.descriptions = descriptions

def discover_and_connect(self, connect_to: Callable[[str], None]) -> Any:
@override
async def discover_and_connect(self, connect_to: Callable[[str], None]) -> None:
for description in self.descriptions:
connect_to(description)

Expand All @@ -70,18 +71,17 @@ def __init__(self, hosts: list[str], default_port: int = 40403, test_mode: int =
self.default_port = default_port
self.test_mode = test_mode

@inlineCallbacks
def discover_and_connect(self, connect_to: Callable[[str], None]) -> Generator[Any, Any, None]:
@override
async def discover_and_connect(self, connect_to: Callable[[str], None]) -> None:
""" Run DNS lookup for host and connect to it
This is executed when starting the DNS Peer Discovery and first connecting to the network
"""
for host in self.hosts:
url_list = yield self.dns_seed_lookup(host)
url_list = await self.dns_seed_lookup(host)
for url in url_list:
connect_to(url)

@inlineCallbacks
def dns_seed_lookup(self, host: str) -> Generator[Any, Any, list[str]]:
async def dns_seed_lookup(self, host: str) -> list[str]:
""" Run a DNS lookup for TXT, A, and AAAA records and return a list of connection strings.
"""
if self.test_mode:
Expand All @@ -97,7 +97,7 @@ def dns_seed_lookup(self, host: str) -> Generator[Any, Any, list[str]]:
d2.addErrback(self.errback),

d = defer.gatherResults([d1, d2])
results = yield d
results = await d
unique_urls: set[str] = set()
for urls in results:
unique_urls.update(urls)
Expand Down
10 changes: 4 additions & 6 deletions hathor/p2p/peer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@
import hashlib
from enum import Enum
from math import inf
from typing import TYPE_CHECKING, Any, Generator, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast

from cryptography import x509
from cryptography.exceptions import InvalidSignature
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes, serialization
from cryptography.hazmat.primitives.asymmetric import padding, rsa
from OpenSSL.crypto import X509, PKey
from twisted.internet.defer import inlineCallbacks
from twisted.internet.interfaces import ISSLTransport
from twisted.internet.ssl import Certificate, CertificateOptions, TLSVersion, trustRootFromCertificates

Expand Down Expand Up @@ -324,8 +323,7 @@ def _get_certificate_options(self) -> CertificateOptions:
)
return certificate_options

@inlineCallbacks
def validate_entrypoint(self, protocol: 'HathorProtocol') -> Generator[Any, Any, bool]:
async def validate_entrypoint(self, protocol: 'HathorProtocol') -> bool:
""" Validates if connection entrypoint is one of the peer entrypoints
"""
found_entrypoint = False
Expand All @@ -349,7 +347,7 @@ def validate_entrypoint(self, protocol: 'HathorProtocol') -> Generator[Any, Any,
host = connection_string_to_host(entrypoint)
# TODO: don't use `daa.TEST_MODE` for this
test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE
result = yield discover_dns(host, test_mode)
result = await discover_dns(host, test_mode)
if protocol.connection_string in result:
# Found the entrypoint
found_entrypoint = True
Expand All @@ -369,7 +367,7 @@ def validate_entrypoint(self, protocol: 'HathorProtocol') -> Generator[Any, Any,
found_entrypoint = True
break
test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE
result = yield discover_dns(host, test_mode)
result = await discover_dns(host, test_mode)
if connection_host in [connection_string_to_host(x) for x in result]:
# Found the entrypoint
found_entrypoint = True
Expand Down
5 changes: 3 additions & 2 deletions hathor/p2p/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import time
from enum import Enum
from typing import TYPE_CHECKING, Any, Generator, Optional, cast
from typing import TYPE_CHECKING, Any, Coroutine, Generator, Optional, cast

from structlog import get_logger
from twisted.internet.defer import Deferred
Expand Down Expand Up @@ -311,7 +311,8 @@ def recv_message(self, cmd: ProtocolMessages, payload: str) -> Optional[Deferred
fn = self.state.cmd_map.get(cmd)
if fn is not None:
try:
return fn(payload)
result = fn(payload)
return Deferred.fromCoroutine(result) if isinstance(result, Coroutine) else result
except Exception:
self.log.warn('recv_message processing error', exc_info=True)
raise
Expand Down
8 changes: 6 additions & 2 deletions hathor/p2p/states/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Callable, Optional, Union
from collections.abc import Coroutine
from typing import TYPE_CHECKING, Any, Callable, Optional

from structlog import get_logger
from twisted.internet.defer import Deferred
Expand All @@ -27,7 +28,10 @@

class BaseState:
protocol: 'HathorProtocol'
cmd_map: dict[ProtocolMessages, Union[Callable[[str], None], Callable[[str], Deferred[None]]]]
cmd_map: dict[
ProtocolMessages,
Callable[[str], None] | Callable[[str], Deferred[None]] | Callable[[str], Coroutine[Deferred[None], Any, None]]
]

def __init__(self, protocol: 'HathorProtocol'):
self.log = logger.new(**protocol.get_logger_context())
Expand Down
8 changes: 3 additions & 5 deletions hathor/p2p/states/peer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, Any, Generator
from typing import TYPE_CHECKING

from structlog import get_logger
from twisted.internet.defer import inlineCallbacks

from hathor.conf import HathorSettings
from hathor.p2p.messages import ProtocolMessages
Expand Down Expand Up @@ -77,8 +76,7 @@ def send_peer_id(self) -> None:
}
self.send_message(ProtocolMessages.PEER_ID, json_dumps(hello))

@inlineCallbacks
def handle_peer_id(self, payload: str) -> Generator[Any, Any, None]:
async def handle_peer_id(self, payload: str) -> None:
""" Executed when a PEER-ID is received. It basically checks
the identity of the peer. Only after this step, the peer connection
is considered established and ready to communicate.
Expand Down Expand Up @@ -117,7 +115,7 @@ def handle_peer_id(self, payload: str) -> Generator[Any, Any, None]:
protocol.send_error_and_close_connection('We are already connected.')
return

entrypoint_valid = yield peer.validate_entrypoint(protocol)
entrypoint_valid = await peer.validate_entrypoint(protocol)
if not entrypoint_valid:
protocol.send_error_and_close_connection('Connection string is not in the entrypoints.')
return
Expand Down
8 changes: 3 additions & 5 deletions hathor/p2p/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

import datetime
import re
from typing import Any, Generator, Optional
from typing import Any, Optional
from urllib.parse import parse_qs, urlparse

import requests
Expand All @@ -25,7 +25,6 @@
from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509 import Certificate
from cryptography.x509.oid import NameOID
from twisted.internet.defer import inlineCallbacks
from twisted.internet.interfaces import IAddress

from hathor.conf.get_settings import get_settings
Expand Down Expand Up @@ -100,15 +99,14 @@ def connection_string_to_host(connection_string: str) -> str:
return urlparse(connection_string).netloc.split(':')[0]


@inlineCallbacks
def discover_dns(host: str, test_mode: int = 0) -> Generator[Any, Any, list[str]]:
async def discover_dns(host: str, test_mode: int = 0) -> list[str]:
""" Start a DNS peer discovery object and execute a search for the host

Returns the DNS string from the requested host
E.g., localhost -> tcp://127.0.0.1:40403
"""
discovery = DNSPeerDiscovery([], test_mode=test_mode)
result = yield discovery.dns_seed_lookup(host)
result = await discovery.dns_seed_lookup(host)
return result


Expand Down
15 changes: 6 additions & 9 deletions tests/p2p/test_peer_id.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
import shutil
import tempfile

from twisted.internet.defer import inlineCallbacks

from hathor.conf import HathorSettings
from hathor.p2p.peer_id import InvalidPeerIdException, PeerId
from hathor.p2p.peer_storage import PeerStorage
Expand Down Expand Up @@ -212,24 +210,23 @@ def test_retry_logic(self):
class BasePeerIdTest(unittest.TestCase):
__test__ = False

@inlineCallbacks
def test_validate_entrypoint(self):
async def test_validate_entrypoint(self):
manager = self.create_peer('testnet', unlock_wallet=False)
peer_id = manager.my_peer
peer_id.entrypoints = ['tcp://127.0.0.1:40403']

# we consider that we are starting the connection to the peer
protocol = manager.connections.client_factory.buildProtocol('127.0.0.1')
protocol.connection_string = 'tcp://127.0.0.1:40403'
result = yield peer_id.validate_entrypoint(protocol)
result = await peer_id.validate_entrypoint(protocol)
self.assertTrue(result)
# if entrypoint is an URI
peer_id.entrypoints = ['uri_name']
result = yield peer_id.validate_entrypoint(protocol)
result = await peer_id.validate_entrypoint(protocol)
self.assertTrue(result)
# test invalid. DNS in test mode will resolve to '127.0.0.1:40403'
protocol.connection_string = 'tcp://45.45.45.45:40403'
result = yield peer_id.validate_entrypoint(protocol)
result = await peer_id.validate_entrypoint(protocol)
self.assertFalse(result)

# now test when receiving the connection - i.e. the peer starts it
Expand All @@ -242,11 +239,11 @@ def getPeer(self):
Peer = namedtuple('Peer', 'host')
return Peer(host='127.0.0.1')
protocol.transport = FakeTransport()
result = yield peer_id.validate_entrypoint(protocol)
result = await peer_id.validate_entrypoint(protocol)
self.assertTrue(result)
# if entrypoint is an URI
peer_id.entrypoints = ['uri_name']
result = yield peer_id.validate_entrypoint(protocol)
result = await peer_id.validate_entrypoint(protocol)
self.assertTrue(result)


Expand Down