diff --git a/hathor/builder/builder.py b/hathor/builder/builder.py index ea3afd8f7..0a107d20e 100644 --- a/hathor/builder/builder.py +++ b/hathor/builder/builder.py @@ -34,6 +34,7 @@ from hathor.indexes import IndexesManager, MemoryIndexesManager, RocksDBIndexesManager from hathor.manager import HathorManager from hathor.mining.cpu_mining_service import CpuMiningService +from hathor.p2p import P2PDependencies from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer from hathor.pubsub import PubSubManager @@ -64,12 +65,10 @@ class SyncSupportLevel(IntEnum): @classmethod def add_factories( cls, - settings: HathorSettingsType, p2p_manager: ConnectionsManager, + dependencies: P2PDependencies, sync_v1_support: 'SyncSupportLevel', sync_v2_support: 'SyncSupportLevel', - vertex_parser: VertexParser, - vertex_handler: VertexHandler, ) -> None: """Adds the sync factory to the manager according to the support level.""" from hathor.p2p.sync_v1.factory import SyncV11Factory @@ -78,18 +77,12 @@ def add_factories( # sync-v1 support: if sync_v1_support > cls.UNAVAILABLE: - p2p_manager.add_sync_factory(SyncVersion.V1_1, SyncV11Factory(p2p_manager, vertex_parser=vertex_parser)) + p2p_manager.add_sync_factory(SyncVersion.V1_1, SyncV11Factory(dependencies)) if sync_v1_support is cls.ENABLED: p2p_manager.enable_sync_version(SyncVersion.V1_1) # sync-v2 support: if sync_v2_support > cls.UNAVAILABLE: - sync_v2_factory = SyncV2Factory( - settings, - p2p_manager, - vertex_parser=vertex_parser, - vertex_handler=vertex_handler, - ) - p2p_manager.add_sync_factory(SyncVersion.V2, sync_v2_factory) + p2p_manager.add_sync_factory(SyncVersion.V2, SyncV2Factory(dependencies)) if sync_v2_support is cls.ENABLED: p2p_manager.enable_sync_version(SyncVersion.V2) @@ -263,7 +256,6 @@ def build(self) -> BuildArtifacts: wallet=wallet, rng=self._rng, checkpoints=self._checkpoints, - capabilities=self._capabilities, environment_info=get_environment_info(self._cmdline, str(peer.id)), bit_signaling_service=bit_signaling_service, verification_service=verification_service, @@ -415,25 +407,31 @@ def _get_or_create_p2p_manager(self) -> ConnectionsManager: return self._p2p_manager enable_ssl = True - reactor = self._get_reactor() my_peer = self._get_peer() - self._p2p_manager = ConnectionsManager( + dependencies = P2PDependencies( + reactor=self._get_reactor(), settings=self._get_or_create_settings(), - reactor=reactor, + vertex_parser=self._get_or_create_vertex_parser(), + tx_storage=self._get_or_create_tx_storage(), + vertex_handler=self._get_or_create_vertex_handler(), + verification_service=self._get_or_create_verification_service(), + capabilities=self._get_or_create_capabilities(), + whitelist_only=False, + ) + + self._p2p_manager = ConnectionsManager( + dependencies=dependencies, my_peer=my_peer, pubsub=self._get_or_create_pubsub(), ssl=enable_ssl, - whitelist_only=False, rng=self._rng, ) SyncSupportLevel.add_factories( - self._get_or_create_settings(), self._p2p_manager, + dependencies, self._sync_v1_support, self._sync_v2_support, - self._get_or_create_vertex_parser(), - self._get_or_create_vertex_handler(), ) return self._p2p_manager @@ -642,6 +640,13 @@ def _get_or_create_poa_block_producer(self) -> PoaBlockProducer | None: return self._poa_block_producer + def _get_or_create_capabilities(self) -> list[str]: + if self._capabilities is None: + settings = self._get_or_create_settings() + self._capabilities = settings.get_default_capabilities() + + return self._capabilities + def use_memory(self) -> 'Builder': self.check_if_can_modify() self._storage_type = StorageType.MEMORY diff --git a/hathor/builder/cli_builder.py b/hathor/builder/cli_builder.py index 464d9b319..d052cbe8f 100644 --- a/hathor/builder/cli_builder.py +++ b/hathor/builder/cli_builder.py @@ -34,6 +34,7 @@ from hathor.indexes import IndexesManager, MemoryIndexesManager, RocksDBIndexesManager from hathor.manager import HathorManager from hathor.mining.cpu_mining_service import CpuMiningService +from hathor.p2p import P2PDependencies from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer from hathor.p2p.peer_endpoint import PeerEndpoint @@ -317,16 +318,7 @@ def create_manager(self, reactor: Reactor) -> HathorManager: ) cpu_mining_service = CpuMiningService() - - p2p_manager = ConnectionsManager( - settings=settings, - reactor=reactor, - my_peer=peer, - pubsub=pubsub, - ssl=True, - whitelist_only=False, - rng=Random(), - ) + capabilities = settings.get_default_capabilities() vertex_handler = VertexHandler( reactor=reactor, @@ -340,13 +332,30 @@ def create_manager(self, reactor: Reactor) -> HathorManager: log_vertex_bytes=self._args.log_vertex_bytes, ) + p2p_dependencies = P2PDependencies( + reactor=reactor, + settings=settings, + vertex_parser=vertex_parser, + tx_storage=tx_storage, + vertex_handler=vertex_handler, + verification_service=verification_service, + whitelist_only=False, + capabilities=capabilities, + ) + + p2p_manager = ConnectionsManager( + dependencies=p2p_dependencies, + my_peer=peer, + pubsub=pubsub, + ssl=True, + rng=Random(), + ) + SyncSupportLevel.add_factories( - settings, p2p_manager, + p2p_dependencies, sync_v1_support, sync_v2_support, - vertex_parser, - vertex_handler, ) from hathor.consensus.poa import PoaBlockProducer, PoaSignerFile diff --git a/hathor/cli/quick_test.py b/hathor/cli/quick_test.py index 2bf6f16fe..2bdd20711 100644 --- a/hathor/cli/quick_test.py +++ b/hathor/cli/quick_test.py @@ -12,14 +12,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from argparse import ArgumentParser -from typing import Any +from typing import TYPE_CHECKING from structlog import get_logger from hathor.cli.run_node import RunNode +if TYPE_CHECKING: + from hathor.transaction import Vertex + logger = get_logger() @@ -30,18 +35,17 @@ def __init__(self, vertex_handler, manager, n_blocks): self._manager = manager self._n_blocks = n_blocks - def on_new_vertex(self, *args: Any, **kwargs: Any) -> bool: + def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool) -> bool: from hathor.transaction import Block from hathor.transaction.base_transaction import GenericVertex msg: str | None = None - res = self._vertex_handler.on_new_vertex(*args, **kwargs) + res = self._vertex_handler.on_new_vertex(vertex=vertex, fails_silently=fails_silently) if self._n_blocks is None: should_quit = res msg = 'added a tx' else: - vertex = args[0] should_quit = False assert isinstance(vertex, GenericVertex) @@ -77,7 +81,7 @@ def prepare(self, *, register_resources: bool = True) -> None: self.log.info('patching vertex_handler.on_new_vertex to quit on success') p2p_factory = self.manager.connections.get_sync_factory(SyncVersion.V2) assert isinstance(p2p_factory, SyncV2Factory) - p2p_factory.vertex_handler = VertexHandlerWrapper( + p2p_factory.dependencies.vertex_handler = VertexHandlerWrapper( self.manager.vertex_handler, self.manager, self._args.quit_after_n_blocks, diff --git a/hathor/conf/settings.py b/hathor/conf/settings.py index db235f2b7..06418c93d 100644 --- a/hathor/conf/settings.py +++ b/hathor/conf/settings.py @@ -457,6 +457,14 @@ def from_yaml(cls, *, filepath: str) -> 'HathorSettings': validators=_VALIDATORS ) + def get_default_capabilities(self) -> list[str]: + """Return the default capabilities.""" + return [ + self.CAPABILITY_WHITELIST, + self.CAPABILITY_SYNC_VERSION, + self.CAPABILITY_GET_BEST_BLOCKCHAIN + ] + def _parse_checkpoints(checkpoints: Union[dict[int, str], list[Checkpoint]]) -> list[Checkpoint]: """Parse a dictionary of raw checkpoint data into a list of checkpoints.""" diff --git a/hathor/consensus/block_consensus.py b/hathor/consensus/block_consensus.py index 419a66268..c4a7d289b 100644 --- a/hathor/consensus/block_consensus.py +++ b/hathor/consensus/block_consensus.py @@ -432,7 +432,7 @@ def remove_first_block_markers(self, block: Block) -> None: storage = block.storage from hathor.transaction.storage.traversal import BFSTimestampWalk - bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False) + bfs = BFSTimestampWalk(storage.get_vertex, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(block, skip_root=True): if tx.is_block: bfs.skip_neighbors(tx) @@ -469,7 +469,7 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], else: from hathor.transaction.storage.traversal import BFSTimestampWalk - bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False) + bfs = BFSTimestampWalk(storage.get_vertex, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(parent, skip_root=False): assert not tx.is_block diff --git a/hathor/consensus/transaction_consensus.py b/hathor/consensus/transaction_consensus.py index 12d55b270..a5a79e5c0 100644 --- a/hathor/consensus/transaction_consensus.py +++ b/hathor/consensus/transaction_consensus.py @@ -344,7 +344,9 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: self.log.debug('remove_voided_by', tx=tx.hash_hex, voided_hash=voided_hash.hex()) - bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) + bfs = BFSTimestampWalk( + tx.storage.get_vertex, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True + ) check_list: list[BaseTransaction] = [] for tx2 in bfs.run(tx, skip_root=False): assert tx2.storage is not None @@ -400,7 +402,7 @@ def add_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: is_dag_verifications = False from hathor.transaction.storage.traversal import BFSTimestampWalk - bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=is_dag_verifications, + bfs = BFSTimestampWalk(tx.storage.get_vertex, is_dag_funds=True, is_dag_verifications=is_dag_verifications, is_left_to_right=True) check_list: list[Transaction] = [] for tx2 in bfs.run(tx, skip_root=False): diff --git a/hathor/indexes/mempool_tips_index.py b/hathor/indexes/mempool_tips_index.py index 290b4f865..7885be339 100644 --- a/hathor/indexes/mempool_tips_index.py +++ b/hathor/indexes/mempool_tips_index.py @@ -209,7 +209,7 @@ def iter(self, tx_storage: 'TransactionStorage', max_timestamp: Optional[float] def iter_all(self, tx_storage: 'TransactionStorage') -> Iterator[Transaction]: from hathor.transaction.storage.traversal import BFSTimestampWalk - bfs = BFSTimestampWalk(tx_storage, is_dag_verifications=True, is_left_to_right=False) + bfs = BFSTimestampWalk(tx_storage.get_vertex, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(self.iter(tx_storage), skip_root=False): assert isinstance(tx, Transaction) if tx.get_metadata().first_block is not None: diff --git a/hathor/manager.py b/hathor/manager.py index cc86dd9dc..aa50e3766 100644 --- a/hathor/manager.py +++ b/hathor/manager.py @@ -111,7 +111,6 @@ def __init__( vertex_parser: VertexParser, hostname: Optional[str] = None, wallet: Optional[BaseWallet] = None, - capabilities: Optional[list[str]] = None, checkpoints: Optional[list[Checkpoint]] = None, rng: Optional[Random] = None, environment_info: Optional[EnvironmentInfo] = None, @@ -230,12 +229,6 @@ def __init__( # List of whitelisted peers self.peers_whitelist: list[PeerId] = [] - # List of capabilities of the peer - if capabilities is not None: - self.capabilities = capabilities - else: - self.capabilities = self.get_default_capabilities() - # This is included in some logs to provide more context self.environment_info = environment_info @@ -246,14 +239,6 @@ def __init__( self.lc_check_sync_state.clock = self.reactor self.lc_check_sync_state_interval = self.CHECK_SYNC_STATE_INTERVAL - def get_default_capabilities(self) -> list[str]: - """Return the default capabilities for this manager.""" - return [ - self._settings.CAPABILITY_WHITELIST, - self._settings.CAPABILITY_SYNC_VERSION, - self._settings.CAPABILITY_GET_BEST_BLOCKCHAIN - ] - def start(self) -> None: """ A factory must be started only once. And it is usually automatically started. """ @@ -986,9 +971,6 @@ def on_new_tx( return success - def has_sync_version_capability(self) -> bool: - return self._settings.CAPABILITY_SYNC_VERSION in self.capabilities - def add_peer_to_whitelist(self, peer_id: PeerId) -> None: if not self._settings.ENABLE_PEER_WHITELIST: return diff --git a/hathor/metrics.py b/hathor/metrics.py index b53752342..4db9060dc 100644 --- a/hathor/metrics.py +++ b/hathor/metrics.py @@ -109,8 +109,8 @@ class Metrics: # Variables to store the last block when we updated the RocksDB storage metrics last_txstorage_data_block: Optional[int] = None - # Peers connected - connected_peers: int = 0 + # Peers ready + ready_peers: int = 0 # Peers handshaking handshaking_peers: int = 0 # Peers connecting @@ -200,7 +200,7 @@ def handle_publish(self, key: HathorEvents, args: EventArguments) -> None: ): peers_connection_metrics: PeerConnectionsMetrics = data["peers_count"] - self.connected_peers = peers_connection_metrics.connected_peers_count + self.ready_peers = peers_connection_metrics.ready_peers_count self.connecting_peers = peers_connection_metrics.connecting_peers_count self.handshaking_peers = peers_connection_metrics.handshaking_peers_count self.known_peers = peers_connection_metrics.known_peers_count @@ -247,24 +247,26 @@ def collect_peer_connection_metrics(self) -> None: """ self.peer_connection_metrics.clear() - for connection in self.connections.connections: - if not connection._peer: + for connection in self.connections.get_connected_peers(): + peer = connection.get_peer_if_set() + if not peer: # A connection without peer will not be able to communicate # So we can just discard it for the sake of the metrics continue + metrics = connection.get_metrics() metric = PeerConnectionMetrics( - connection_string=str(connection.entrypoint) if connection.entrypoint else "", + connection_string=str(connection.addr), peer_id=str(connection.peer.id), network=settings.NETWORK_NAME, - received_messages=connection.metrics.received_messages, - sent_messages=connection.metrics.sent_messages, - received_bytes=connection.metrics.received_bytes, - sent_bytes=connection.metrics.sent_bytes, - received_txs=connection.metrics.received_txs, - discarded_txs=connection.metrics.discarded_txs, - received_blocks=connection.metrics.received_blocks, - discarded_blocks=connection.metrics.discarded_blocks, + received_messages=metrics.received_messages, + sent_messages=metrics.sent_messages, + received_bytes=metrics.received_bytes, + sent_bytes=metrics.sent_bytes, + received_txs=metrics.received_txs, + discarded_txs=metrics.discarded_txs, + received_blocks=metrics.received_blocks, + discarded_blocks=metrics.discarded_blocks, ) self.peer_connection_metrics.append(metric) diff --git a/hathor/p2p/__init__.py b/hathor/p2p/__init__.py index e69de29bb..00b4ca1a0 100644 --- a/hathor/p2p/__init__.py +++ b/hathor/p2p/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hathor.p2p.dependencies.p2p_dependencies import P2PDependencies + +__all__ = [ + 'P2PDependencies', +] diff --git a/hathor/p2p/dependencies/__init__.py b/hathor/p2p/dependencies/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hathor/p2p/dependencies/p2p_dependencies.py b/hathor/p2p/dependencies/p2p_dependencies.py new file mode 100644 index 000000000..46c263ad4 --- /dev/null +++ b/hathor/p2p/dependencies/p2p_dependencies.py @@ -0,0 +1,68 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from hathor.conf.settings import HathorSettings +from hathor.p2p.dependencies.protocols import ( + P2PTransactionStorageProtocol, + P2PVerificationServiceProtocol, + P2PVertexHandlerProtocol, +) +from hathor.reactor import ReactorProtocol +from hathor.transaction.vertex_parser import VertexParser + + +class P2PDependencies: + """A simple class to unify all node dependencies that are required by P2P.""" + + __slots__ = ( + 'reactor', + 'settings', + 'vertex_parser', + 'vertex_handler', + 'verification_service', + 'tx_storage', + 'capabilities', + 'whitelist_only', + '_has_sync_version_capability', + ) + + def __init__( + self, + *, + reactor: ReactorProtocol, + settings: HathorSettings, + vertex_parser: VertexParser, + vertex_handler: P2PVertexHandlerProtocol, + verification_service: P2PVerificationServiceProtocol, + tx_storage: P2PTransactionStorageProtocol, + capabilities: list[str], + whitelist_only: bool, + ) -> None: + self.reactor = reactor + self.settings = settings + self.vertex_parser = vertex_parser + self.vertex_handler = vertex_handler + self.verification_service = verification_service + self.tx_storage = tx_storage + + # List of capabilities of the peer + self.capabilities = capabilities + + # Parameter to explicitly enable whitelist-only mode, when False it will still check the whitelist for sync-v1 + self.whitelist_only = whitelist_only + + self._has_sync_version_capability = settings.CAPABILITY_SYNC_VERSION in capabilities + + def has_sync_version_capability(self) -> bool: + return self._has_sync_version_capability diff --git a/hathor/p2p/dependencies/protocols.py b/hathor/p2p/dependencies/protocols.py new file mode 100644 index 000000000..662d61703 --- /dev/null +++ b/hathor/p2p/dependencies/protocols.py @@ -0,0 +1,46 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Protocol + +from hathor.indexes.height_index import HeightInfo +from hathor.transaction import Block, Vertex +from hathor.types import VertexId + + +class P2PVertexHandlerProtocol(Protocol): + """Abstract the VertexHandler as a Python protocol to be used in P2P classes.""" + + def on_new_vertex(self, vertex: Vertex, *, fails_silently: bool = True) -> bool: ... + + +class P2PVerificationServiceProtocol(Protocol): + """Abstract the VerificationService as a Python protocol to be used in P2P classes.""" + + def verify_basic(self, vertex: Vertex) -> None: ... + + +class P2PTransactionStorageProtocol(Protocol): + """Abstract the TransactionStorage as a Python protocol to be used in P2P classes.""" + + def get_vertex(self, vertex_id: VertexId) -> Vertex: ... + def get_block(self, block_id: VertexId) -> Block: ... + def transaction_exists(self, vertex_id: VertexId) -> bool: ... + def can_validate_full(self, vertex: Vertex) -> bool: ... + def compare_bytes_with_local_tx(self, vertex: Vertex) -> bool: ... + def get_best_block(self) -> Block: ... + def get_n_height_tips(self, n_blocks: int) -> list[HeightInfo]: ... + def get_mempool_tips(self) -> set[VertexId]: ... + def get_block_id_by_height(self, height: int) -> VertexId | None: ... + def partial_vertex_exists(self, vertex_id: VertexId) -> bool: ... diff --git a/hathor/p2p/factory.py b/hathor/p2p/factory.py index 832f2e501..9b80fc5e7 100644 --- a/hathor/p2p/factory.py +++ b/hathor/p2p/factory.py @@ -17,9 +17,10 @@ from twisted.internet import protocol from twisted.internet.interfaces import IAddress -from hathor.conf.settings import HathorSettings +from hathor.p2p import P2PDependencies from hathor.p2p.manager import ConnectionsManager from hathor.p2p.peer import PrivatePeer +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.protocol import HathorLineReceiver @@ -31,25 +32,24 @@ def __init__( my_peer: PrivatePeer, p2p_manager: ConnectionsManager, *, - settings: HathorSettings, + dependencies: P2PDependencies, use_ssl: bool, ): super().__init__() - self._settings = settings self.my_peer = my_peer self.p2p_manager = p2p_manager + self.dependencies = dependencies self.use_ssl = use_ssl def buildProtocol(self, addr: IAddress) -> HathorLineReceiver: - p = HathorLineReceiver( + return HathorLineReceiver( + addr=PeerAddress.from_address(addr), my_peer=self.my_peer, p2p_manager=self.p2p_manager, + dependencies=self.dependencies, use_ssl=self.use_ssl, inbound=self.inbound, - settings=self._settings ) - p.factory = self - return p class HathorServerFactory(_HathorLineReceiverFactory, protocol.ServerFactory): diff --git a/hathor/p2p/manager.py b/hathor/p2p/manager.py index d53c7be83..352d6f4f2 100644 --- a/hathor/p2p/manager.py +++ b/hathor/p2p/manager.py @@ -12,21 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Any, Iterable, NamedTuple, Optional from structlog import get_logger from twisted.internet import endpoints from twisted.internet.address import IPv4Address, IPv6Address from twisted.internet.defer import Deferred -from twisted.internet.interfaces import IListeningPort, IProtocol, IProtocolFactory, IStreamClientEndpoint +from twisted.internet.interfaces import IListeningPort, IProtocol, IProtocolFactory from twisted.internet.task import LoopingCall from twisted.protocols.tls import TLSMemoryBIOFactory, TLSMemoryBIOProtocol from twisted.python.failure import Failure from twisted.web.client import Agent -from hathor.conf.settings import HathorSettings +from hathor.p2p import P2PDependencies from hathor.p2p.netfilter.factory import NetfilterFactory from hathor.p2p.peer import PrivatePeer, PublicPeer, UnverifiedPeer +from hathor.p2p.peer_connections import PeerConnections from hathor.p2p.peer_discovery import PeerDiscovery from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.peer_id import PeerId @@ -38,7 +41,6 @@ from hathor.p2p.sync_version import SyncVersion from hathor.p2p.utils import parse_whitelist from hathor.pubsub import HathorEvents, PubSubManager -from hathor.reactor import ReactorProtocol as Reactor from hathor.transaction import BaseTransaction from hathor.util import Random @@ -59,15 +61,10 @@ class _SyncRotateInfo(NamedTuple): to_enable: set[PeerId] -class _ConnectingPeer(NamedTuple): - entrypoint: PeerEndpoint - endpoint_deferred: Deferred - - class PeerConnectionsMetrics(NamedTuple): connecting_peers_count: int handshaking_peers_count: int - connected_peers_count: int + ready_peers_count: int known_peers_count: int @@ -79,11 +76,6 @@ class GlobalRateLimiter: SEND_TIPS = 'NodeSyncTimestamp.send_tips' manager: Optional['HathorManager'] - connections: set[HathorProtocol] - connected_peers: dict[PeerId, HathorProtocol] - connecting_peers: dict[IStreamClientEndpoint, _ConnectingPeer] - handshaking_peers: set[HathorProtocol] - whitelist_only: bool unverified_peer_storage: UnverifiedPeerStorage verified_peer_storage: VerifiedPeerStorage _sync_factories: dict[SyncVersion, SyncAgentFactory] @@ -93,24 +85,23 @@ class GlobalRateLimiter: def __init__( self, - settings: HathorSettings, - reactor: Reactor, + dependencies: P2PDependencies, my_peer: PrivatePeer, pubsub: PubSubManager, ssl: bool, rng: Random, - whitelist_only: bool, ) -> None: self.log = logger.new() - self._settings = settings + self.dependencies = dependencies + self._settings = dependencies.settings self.rng = rng self.manager = None - self.MAX_ENABLED_SYNC = settings.MAX_ENABLED_SYNC - self.SYNC_UPDATE_INTERVAL = settings.SYNC_UPDATE_INTERVAL - self.PEER_DISCOVERY_INTERVAL = settings.PEER_DISCOVERY_INTERVAL + self.MAX_ENABLED_SYNC = self._settings.MAX_ENABLED_SYNC + self.SYNC_UPDATE_INTERVAL = self._settings.SYNC_UPDATE_INTERVAL + self.PEER_DISCOVERY_INTERVAL = self._settings.PEER_DISCOVERY_INTERVAL - self.reactor = reactor + self.reactor = dependencies.reactor self.my_peer = my_peer # List of address descriptions to listen for new connections (eg: [tcp:8000]) @@ -129,10 +120,16 @@ def __init__( from hathor.p2p.factory import HathorClientFactory, HathorServerFactory self.use_ssl = ssl self.server_factory = HathorServerFactory( - self.my_peer, p2p_manager=self, use_ssl=self.use_ssl, settings=self._settings + my_peer=self.my_peer, + p2p_manager=self, + dependencies=dependencies, + use_ssl=self.use_ssl, ) self.client_factory = HathorClientFactory( - self.my_peer, p2p_manager=self, use_ssl=self.use_ssl, settings=self._settings + my_peer=self.my_peer, + p2p_manager=self, + dependencies=dependencies, + use_ssl=self.use_ssl, ) # Global maximum number of connections. @@ -142,17 +139,7 @@ def __init__( self.rate_limiter = RateLimiter(self.reactor) self.enable_rate_limiter() - # All connections. - self.connections = set() - - # List of pending connections. - self.connecting_peers = {} - - # List of peers connected but still not ready to communicate. - self.handshaking_peers = set() - - # List of peers connected and ready to communicate. - self.connected_peers = {} + self._connections = PeerConnections() # List of peers received from the network. # We cannot trust their identity before we connect to them. @@ -187,9 +174,6 @@ def __init__( # Pubsub object to publish events self.pubsub = pubsub - # Parameter to explicitly enable whitelist-only mode, when False it will still check the whitelist for sync-v1 - self.whitelist_only = whitelist_only - # Timestamp when the last discovery ran self._last_discovery: float = 0. @@ -320,12 +304,12 @@ def stop(self) -> None: def _get_peers_count(self) -> PeerConnectionsMetrics: """Get a dict containing the count of peers in each state""" - + peer_counts = self._connections.get_peer_counts() return PeerConnectionsMetrics( - len(self.connecting_peers), - len(self.handshaking_peers), - len(self.connected_peers), - len(self.verified_peer_storage) + connecting_peers_count=peer_counts.connecting, + handshaking_peers_count=peer_counts.handshaking, + ready_peers_count=peer_counts.ready, + known_peers_count=len(self.verified_peer_storage) ) def get_sync_factory(self, sync_version: SyncVersion) -> SyncAgentFactory: @@ -338,9 +322,7 @@ def has_synced_peer(self) -> bool: """ connections = list(self.iter_ready_connections()) for conn in connections: - assert conn.state is not None - assert isinstance(conn.state, ReadyState) - if conn.state.is_synced(): + if conn.is_synced(): return True return False @@ -357,37 +339,29 @@ def send_tx_to_peers(self, tx: BaseTransaction) -> None: connections = list(self.iter_ready_connections()) self.rng.shuffle(connections) for conn in connections: - assert conn.state is not None - assert isinstance(conn.state, ReadyState) - conn.state.send_tx_to_peer(tx) + conn.send_tx_to_peer(tx) def disconnect_all_peers(self, *, force: bool = False) -> None: """Disconnect all peers.""" - for conn in self.iter_all_connections(): + for conn in self.get_connected_peers(): conn.disconnect(force=force) - def on_connection_failure(self, failure: Failure, peer: Optional[UnverifiedPeer | PublicPeer], - endpoint: IStreamClientEndpoint) -> None: - connecting_peer = self.connecting_peers[endpoint] - entrypoint = connecting_peer.entrypoint - self.log.warn('connection failure', entrypoint=str(entrypoint), failure=failure.getErrorMessage()) - self.connecting_peers.pop(endpoint) - + def on_connection_failure(self, failure: Failure, endpoint: PeerEndpoint) -> None: + self.log.warn('connection failure', endpoint=str(endpoint), failure=failure.getErrorMessage()) + self._connections.on_failed_to_connect(addr=endpoint.addr) self.pubsub.publish( HathorEvents.NETWORK_PEER_CONNECTION_FAILED, - peer=peer, peers_count=self._get_peers_count() ) def on_peer_connect(self, protocol: HathorProtocol) -> None: - """Called when a new connection is established.""" - if len(self.connections) >= self.max_connections: + """Called when a new connection is established from both inbound and outbound peers.""" + if len(self._connections.connected_peers()) >= self.max_connections: self.log.warn('reached maximum number of connections', max_connections=self.max_connections) protocol.disconnect(force=True) return - self.connections.add(protocol) - self.handshaking_peers.add(protocol) + self._connections.on_connected(protocol=protocol) self.pubsub.publish( HathorEvents.NETWORK_PEER_CONNECTED, protocol=protocol, @@ -396,12 +370,9 @@ def on_peer_connect(self, protocol: HathorProtocol) -> None: def on_peer_ready(self, protocol: HathorProtocol) -> None: """Called when a peer is ready.""" - assert protocol.peer is not None self.verified_peer_storage.add_or_replace(protocol.peer) - assert protocol.peer.id is not None - - self.handshaking_peers.remove(protocol) self.unverified_peer_storage.pop(protocol.peer.id, None) + connection_to_drop = self._connections.on_ready(addr=protocol.addr, peer_id=protocol.peer.id) # we emit the event even if it's a duplicate peer as a matching # NETWORK_PEER_DISCONNECTED will be emitted regardless @@ -411,21 +382,17 @@ def on_peer_ready(self, protocol: HathorProtocol) -> None: peers_count=self._get_peers_count() ) - if protocol.peer.id in self.connected_peers: + if connection_to_drop: # connected twice to same peer - self.log.warn('duplicate connection to peer', protocol=protocol) - conn = self.get_connection_to_drop(protocol) - self.reactor.callLater(0, self.drop_connection, conn) - if conn == protocol: - # the new connection is being dropped, so don't save it to connected_peers + self.log.warn('duplicate connection to peer', addr=str(protocol.addr), peer_id=str(protocol.peer.id)) + self.reactor.callLater(0, self.drop_connection, connection_to_drop) + if connection_to_drop == protocol: return - self.connected_peers[protocol.peer.id] = protocol - # In case it was a retry, we must reset the data only here, after it gets ready protocol.peer.info.reset_retry_timestamp() - if len(self.connected_peers) <= self.MAX_ENABLED_SYNC: + if len(self._connections.ready_peers()) <= self.MAX_ENABLED_SYNC: protocol.enable_sync() if protocol.peer.id in self.always_enable_sync: @@ -437,59 +404,59 @@ def on_peer_ready(self, protocol: HathorProtocol) -> None: def relay_peer_to_ready_connections(self, peer: PublicPeer) -> None: """Relay peer to all ready connections.""" for conn in self.iter_ready_connections(): - if conn.peer == peer: + if conn.get_peer() == peer: continue - assert isinstance(conn.state, ReadyState) - conn.state.send_peers([peer]) - - def on_peer_disconnect(self, protocol: HathorProtocol) -> None: - """Called when a peer disconnect.""" - self.connections.discard(protocol) - if protocol in self.handshaking_peers: - self.handshaking_peers.remove(protocol) - if protocol._peer is not None: - existing_protocol = self.connected_peers.pop(protocol.peer.id, None) - if existing_protocol is None: - # in this case, the connection was closed before it got to READY state - return - if existing_protocol != protocol: - # this is the case we're closing a duplicate connection. We need to set the - # existing protocol object back to connected_peers, as that connection is still ongoing. - # A check for duplicate connections is done during PEER_ID state, but there's still a - # chance it can happen if both connections start at the same time and none of them has - # reached READY state while the other is on PEER_ID state - self.connected_peers[protocol.peer.id] = existing_protocol + conn.send_peers([peer]) + + def on_handshake_disconnect(self, *, addr: PeerAddress) -> None: + """Called when a peer disconnects from a handshaking state (HELLO or PEER-ID).""" + self._connections.on_handshake_disconnect(addr=addr) + self.pubsub.publish( + HathorEvents.NETWORK_PEER_DISCONNECTED, + peers_count=self._get_peers_count() + ) + + def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: + """Called when a peer disconnects from the READY state.""" + self._connections.on_ready_disconnect(addr=addr, peer_id=peer_id) self.pubsub.publish( HathorEvents.NETWORK_PEER_DISCONNECTED, - protocol=protocol, peers_count=self._get_peers_count() ) - def iter_all_connections(self) -> Iterable[HathorProtocol]: - """Iterate over all connections.""" - for conn in self.connections: - yield conn + def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: + """Called when a peer disconnects from an unknown state (None).""" + self._connections.on_unknown_disconnect(addr=addr) + self.pubsub.publish( + HathorEvents.NETWORK_PEER_DISCONNECTED, + peers_count=self._get_peers_count() + ) + + def iter_connecting_outbound_peers(self) -> Iterable[PeerAddress]: + yield from self._connections.connecting_outbound_peers() + + def iter_handshaking_peers(self) -> Iterable[HathorProtocol]: + yield from self._connections.handshaking_peers().values() def iter_ready_connections(self) -> Iterable[HathorProtocol]: """Iterate over ready connections.""" - for conn in self.connected_peers.values(): - yield conn + yield from self._connections.ready_peers().values() - def iter_not_ready_endpoints(self) -> Iterable[PeerEndpoint]: + def iter_not_ready_endpoints(self) -> Iterable[PeerAddress]: """Iterate over not-ready connections.""" - for connecting_peer in self.connecting_peers.values(): - yield connecting_peer.entrypoint - for protocol in self.handshaking_peers: - if protocol.entrypoint is not None: - yield protocol.entrypoint - else: - self.log.warn('handshaking protocol has empty connection string', protocol=protocol) - - def is_peer_connected(self, peer_id: PeerId) -> bool: + yield from self._connections.not_ready_peers() + + def get_connected_peers(self) -> Iterable[HathorProtocol]: + yield from self._connections.connected_peers().values() + + def get_ready_peer_by_id(self, peer_id: PeerId) -> HathorProtocol | None: + return self._connections.get_ready_peer_by_id(peer_id) + + def is_peer_ready(self, peer_id: PeerId) -> bool: """ :type peer_id: string (peer.id) """ - return peer_id in self.connected_peers + return self._connections.is_peer_ready(peer_id) def on_receive_peer(self, peer: UnverifiedPeer, origin: Optional[ReadyState] = None) -> None: """ Update a peer information in our storage, and instantly attempt to connect @@ -506,7 +473,7 @@ def peers_cleanup(self) -> None: to_be_removed: list[PublicPeer] = [] for peer in self.verified_peer_storage.values(): assert peer.id is not None - if self.is_peer_connected(peer.id): + if self.is_peer_ready(peer.id): continue dt = now - peer.info.last_seen if dt > self.max_peer_unseen_dt: @@ -581,10 +548,10 @@ def connect_to_if_not_connected(self, peer: UnverifiedPeer | PublicPeer, now: in # It makes no sense to keep storing peers that have disconnected and have no entrypoints # We will never be able to connect to them anymore and they will only keep spending memory # and other resources when used in APIs, so we are removing them here - if peer.id not in self.connected_peers: + if not self.is_peer_ready(peer.id): self.verified_peer_storage.remove(peer) return - if peer.id in self.connected_peers: + if self.is_peer_ready(peer.id): return assert peer.id is not None @@ -592,82 +559,57 @@ def connect_to_if_not_connected(self, peer: UnverifiedPeer | PublicPeer, now: in addr = self.rng.choice(peer.info.entrypoints) self.connect_to(addr.with_id(peer.id), peer) - def _connect_to_callback( - self, - protocol: IProtocol, - peer: UnverifiedPeer | PublicPeer | None, - endpoint: IStreamClientEndpoint, - entrypoint: PeerEndpoint, - ) -> None: - """Called when we successfully connect to a peer.""" - if isinstance(protocol, HathorProtocol): - protocol.on_outbound_connect(entrypoint, peer) - else: - assert isinstance(protocol, TLSMemoryBIOProtocol) - assert isinstance(protocol.wrappedProtocol, HathorProtocol) - protocol.wrappedProtocol.on_outbound_connect(entrypoint, peer) - self.connecting_peers.pop(endpoint) + def _connect_to_callback(self, protocol: IProtocol, addr: PeerAddress, peer_id: PeerId | None) -> None: + """Called when we successfully connect to an outbound peer.""" + if isinstance(protocol, TLSMemoryBIOProtocol): + protocol = protocol.wrappedProtocol + assert isinstance(protocol, HathorProtocol) + assert protocol.addr == addr + protocol.on_outbound_connect(peer_id) def connect_to( self, - entrypoint: PeerEndpoint, + endpoint: PeerEndpoint, peer: UnverifiedPeer | PublicPeer | None = None, - use_ssl: bool | None = None, - ) -> None: - """ Attempt to connect to a peer, even if a connection already exists. - Usually you should call `connect_to_if_not_connected`. - - If `use_ssl` is True, then the connection will be wraped by a TLS. - """ - if entrypoint.peer_id is not None and peer is not None and entrypoint.peer_id != peer.id: - self.log.debug('skipping because the entrypoint peer_id does not match the actual peer_id', - entrypoint=str(entrypoint)) - return - - for connecting_peer in self.connecting_peers.values(): - if connecting_peer.entrypoint.addr == entrypoint.addr: - self.log.debug( - 'skipping because we are already connecting to this endpoint', - entrypoint=str(entrypoint), - ) - return - - if self.localhost_only and not entrypoint.addr.is_localhost(): - self.log.debug('skip because of simple localhost check', entrypoint=str(entrypoint)) - return + ) -> Deferred[IProtocol] | None: + """Attempt to connect to a peer though a specific endpoint.""" + if endpoint.peer_id is not None and peer is not None: + assert endpoint.peer_id == peer.id, 'the entrypoint peer_id does not match the actual peer_id' - if use_ssl is None: - use_ssl = self.use_ssl + already_exists = self._connections.on_connecting(addr=endpoint.addr) + if already_exists: + self.log.debug('skipping because we are already connected(ing) to this endpoint', endpoint=str(endpoint)) + return None - endpoint = entrypoint.addr.to_client_endpoint(self.reactor) + if self.localhost_only and not endpoint.addr.is_localhost(): + self.log.debug('skip because of simple localhost check', endpoint=str(endpoint)) + return None - factory: IProtocolFactory - if use_ssl: - factory = TLSMemoryBIOFactory(self.my_peer.certificate_options, True, self.client_factory) - else: - factory = self.client_factory + factory: IProtocolFactory = self.client_factory + if self.use_ssl: + factory = TLSMemoryBIOFactory(self.my_peer.certificate_options, True, factory) if peer is not None: now = int(self.reactor.seconds()) peer.info.increment_retry_attempt(now) - deferred = endpoint.connect(factory) - self.connecting_peers[endpoint] = _ConnectingPeer(entrypoint, deferred) + peer_id = peer.id if peer else endpoint.peer_id + deferred = endpoint.addr.to_client_endpoint(self.reactor).connect(factory) + deferred \ + .addCallback(self._connect_to_callback, endpoint.addr, peer_id) \ + .addErrback(self.on_connection_failure, endpoint) - deferred.addCallback(self._connect_to_callback, peer, endpoint, entrypoint) - deferred.addErrback(self.on_connection_failure, peer, endpoint) - self.log.info('connecting to', entrypoint=str(entrypoint), peer=str(peer)) + self.log.info('connecting to', endpoint=str(endpoint), peer_id=str(peer_id)) self.pubsub.publish( HathorEvents.NETWORK_PEER_CONNECTING, peer=peer, peers_count=self._get_peers_count() ) + return deferred - def listen(self, description: str, use_ssl: Optional[bool] = None) -> None: + def listen(self, description: str) -> None: """ Start to listen for new connection according to the description. - If `ssl` is True, then the connection will be wraped by a TLS. - :Example: `manager.listen(description='tcp:8000')` @@ -677,14 +619,9 @@ def listen(self, description: str, use_ssl: Optional[bool] = None) -> None: """ endpoint = endpoints.serverFromString(self.reactor, description) - if use_ssl is None: - use_ssl = self.use_ssl - - factory: IProtocolFactory - if use_ssl: - factory = TLSMemoryBIOFactory(self.my_peer.certificate_options, False, self.server_factory) - else: - factory = self.server_factory + factory: IProtocolFactory = self.server_factory + if self.use_ssl: + factory = TLSMemoryBIOFactory(self.my_peer.certificate_options, False, factory) factory = NetfilterFactory(self, factory) @@ -721,42 +658,17 @@ def _add_hostname_entrypoint(self, hostname: str, address: IPv4Address | IPv6Add hostname_entrypoint = PeerAddress.from_hostname_address(hostname, address) self.my_peer.info.entrypoints.append(hostname_entrypoint) - def get_connection_to_drop(self, protocol: HathorProtocol) -> HathorProtocol: - """ When there are duplicate connections, determine which one should be dropped. - - We keep the connection initiated by the peer with larger id. A simple (peer_id1 > peer_id2) - on the peer id string is used for this comparison. - """ - assert protocol.peer is not None - assert protocol.peer.id is not None - assert protocol.my_peer.id is not None - other_connection = self.connected_peers[protocol.peer.id] - if bytes(protocol.my_peer.id) > bytes(protocol.peer.id): - # connection started by me is kept - if not protocol.inbound: - # other connection is dropped - return other_connection - else: - # this was started by peer, so drop it - return protocol - else: - # connection started by peer is kept - if not protocol.inbound: - return protocol - else: - return other_connection - def drop_connection(self, protocol: HathorProtocol) -> None: """ Drop a connection """ - assert protocol.peer is not None - self.log.debug('dropping connection', peer_id=protocol.peer.id, protocol=type(protocol).__name__) + protocol_peer = protocol.get_peer() + self.log.debug('dropping connection', peer_id=protocol_peer.id, protocol=type(protocol).__name__) protocol.send_error_and_close_connection('Connection droped') def drop_connection_by_peer_id(self, peer_id: PeerId) -> None: """ Drop a connection by peer id """ - protocol = self.connected_peers.get(peer_id) + protocol = self.get_ready_peer_by_id(peer_id) if protocol: self.drop_connection(protocol) @@ -781,25 +693,24 @@ def set_always_enable_sync(self, values: list[PeerId]) -> None: self.log.info('update always_enable_sync', new=new, to_enable=to_enable, to_disable=to_disable) for peer_id in new: - if peer_id not in self.connected_peers: - continue - self.connected_peers[peer_id].enable_sync() + if peer := self.get_ready_peer_by_id(peer_id): + peer.enable_sync() for peer_id in to_disable: - if peer_id not in self.connected_peers: - continue - self.connected_peers[peer_id].disable_sync() + if peer := self.get_ready_peer_by_id(peer_id): + peer.disable_sync() self.always_enable_sync = new def _calculate_sync_rotate(self) -> _SyncRotateInfo: """Calculate new sync rotation.""" + ready_peers = self._connections.ready_peers().values() current_enabled: set[PeerId] = set() - for peer_id, conn in self.connected_peers.items(): + for conn in ready_peers: if conn.is_sync_enabled(): - current_enabled.add(peer_id) + current_enabled.add(conn.peer.id) - candidates = list(self.connected_peers.keys()) + candidates = [conn.peer.id for conn in ready_peers] self.rng.shuffle(candidates) selected_peers: set[PeerId] = set(candidates[:self.MAX_ENABLED_SYNC]) @@ -837,13 +748,31 @@ def _sync_rotate_if_needed(self, *, force: bool = False) -> None: ) for peer_id in info.to_disable: - self.connected_peers[peer_id].disable_sync() + peer = self.get_ready_peer_by_id(peer_id) + assert peer is not None + peer.disable_sync() for peer_id in info.to_enable: - self.connected_peers[peer_id].enable_sync() + peer = self.get_ready_peer_by_id(peer_id) + assert peer is not None + peer.enable_sync() def reload_entrypoints_and_connections(self) -> None: """Kill all connections and reload entrypoints from the original peer config file.""" self.log.warn('Killing all connections and resetting entrypoints...') self.disconnect_all_peers(force=True) self.my_peer.reload_entrypoints_from_source_file() + + def get_peers_whitelist(self) -> list[PeerId]: + assert self.manager is not None + return self.manager.peers_whitelist + + def get_verified_peers(self) -> Iterable[PublicPeer]: + return self.verified_peer_storage.values() + + def get_randbytes(self, n: int) -> bytes: + return self.rng.randbytes(n) + + def is_peer_whitelisted(self, peer_id: PeerId) -> bool: + assert self.manager is not None + return peer_id in self.manager.peers_whitelist diff --git a/hathor/p2p/peer.py b/hathor/p2p/peer.py index 53f43369d..1bb5f9f5b 100644 --- a/hathor/p2p/peer.py +++ b/hathor/p2p/peer.py @@ -137,31 +137,15 @@ async def validate_entrypoint(self, protocol: HathorProtocol) -> bool: # Entrypoint validation with connection string and connection host # Entrypoints have the format tcp://IP|name:port for entrypoint in self.entrypoints: - if protocol.entrypoint is not None: - # Connection string has the format tcp://IP:port - # So we must consider that the entrypoint could be in name format - if protocol.entrypoint.addr == entrypoint: - return True - # TODO: don't use `daa.TEST_MODE` for this - test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE - result = await discover_dns(entrypoint.host, test_mode) - if protocol.entrypoint.addr in [endpoint.addr for endpoint in result]: - return True - else: - # When the peer is the server part of the connection we don't have the full entrypoint description - # So we can only validate the host from the protocol - assert protocol.transport is not None - connection_remote = protocol.transport.getPeer() - connection_host = getattr(connection_remote, 'host', None) - if connection_host is None: - continue - # Connection host has only the IP - # So we must consider that the entrypoint could be in name format and we just validate the host - if connection_host == entrypoint.host: - return True - test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE - result = await discover_dns(entrypoint.host, test_mode) - if connection_host in [entrypoint.addr.host for entrypoint in result]: + # Connection string has the format tcp://IP:port + # So we must consider that the entrypoint could be in name format + if protocol.addr == entrypoint: + return True + # TODO: don't use `daa.TEST_MODE` for this + test_mode = not_none(DifficultyAdjustmentAlgorithm.singleton).TEST_MODE + result = await discover_dns(entrypoint.host, test_mode) + for endpoint in result: + if protocol.addr == endpoint.addr: return True return False diff --git a/hathor/p2p/peer_connections.py b/hathor/p2p/peer_connections.py new file mode 100644 index 000000000..d42d4acc1 --- /dev/null +++ b/hathor/p2p/peer_connections.py @@ -0,0 +1,200 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass + +from hathor.p2p.peer_endpoint import PeerAddress +from hathor.p2p.peer_id import PeerId +from hathor.p2p.protocol import HathorProtocol + + +@dataclass(slots=True, frozen=True, kw_only=True) +class PeerCounts: + """Simple wrapper for metrics.""" + connecting: int + handshaking: int + ready: int + + +class PeerConnections: + """ + This class represents all peer connections made by a ConnectionsManager. + It's also responsible for reacting for state changes on those connections. + """ + + __slots__ = ('_connecting_outbound', '_handshaking', '_ready', '_addr_by_id') + + def __init__(self) -> None: + # Peers that are in the "connecting" state, between starting a connection and Twisted calling `connectionMade`. + # This is only for outbound peers, that is, connections initiated by us. + # They're uniquely identified by the address we're connecting to. + self._connecting_outbound: set[PeerAddress] = set() + + # Peers that are handshaking, in a state after being connected and before reaching the READY state. + # They're uniquely identified by the address we're connected to. + self._handshaking: dict[PeerAddress, HathorProtocol] = {} + + # Peers that are in the READY state. + # They're uniquely identified by the address we're connected to. + # Note: there may be peers with duplicate PeerIds in this structure. + self._ready: dict[PeerAddress, HathorProtocol] = {} + + # Auxiliary structure for uniquely identifying READY peers by their PeerId. When there are peers with + # duplicate PeerIds, this identifies the connection we chose to keep. + self._addr_by_id: dict[PeerId, PeerAddress] = {} + + def connecting_outbound_peers(self) -> set[PeerAddress]: + """Get connecting outbound peers.""" + return self._connecting_outbound.copy() + + def handshaking_peers(self) -> dict[PeerAddress, HathorProtocol]: + """Get handshaking peers.""" + return self._handshaking.copy() + + def ready_peers(self) -> dict[PeerAddress, HathorProtocol]: + """Get ready peers, not including possible PeerId duplicates.""" + return { + addr: self._ready[addr] + for addr in self._addr_by_id.values() + } + + def not_ready_peers(self) -> list[PeerAddress]: + """Get not ready peers, that is, peers that are either connecting or handshaking.""" + return list(self._connecting_outbound) + list(self._handshaking) + + def connected_peers(self) -> dict[PeerAddress, HathorProtocol]: + """ + Get peers that are connected, that is, peers that are either handshaking or ready. + Does not include possible PeerId duplicates. + """ + return self.handshaking_peers() | self.ready_peers() + + def all_peers(self) -> list[PeerAddress]: + """Get all peers, ready or not. Does not include possible PeerId duplicates.""" + return self.not_ready_peers() + list(self.ready_peers()) + + def get_ready_peer_by_id(self, peer_id: PeerId) -> HathorProtocol | None: + """ + Get a ready peer by its PeerId. If there are connections with duplicate PeerIds, + we return the one that we chose to keep. + """ + addr = self._addr_by_id.get(peer_id) + return self._ready[addr] if addr else None + + def get_peer_counts(self) -> PeerCounts: + """Return the peer counts, for metrics.""" + return PeerCounts( + connecting=len(self._connecting_outbound), + handshaking=len(self._handshaking), + ready=len(self._ready), + ) + + def is_peer_ready(self, peer_id: PeerId) -> bool: + """Return whether a peer is ready, by its PeerId.""" + return peer_id in self._addr_by_id + + def on_connecting(self, *, addr: PeerAddress) -> bool: + """ + Callback for when an outbound connection is initiated. + Returns True if this address already exists, either connecting or connected, and False otherwise.""" + if addr in self.all_peers(): + return True + + self._connecting_outbound.add(addr) + return False + + def on_failed_to_connect(self, *, addr: PeerAddress) -> None: + """Callback for when an outbound connection fails before getting connected.""" + assert addr in self._connecting_outbound + assert addr not in self.connected_peers() + self._connecting_outbound.remove(addr) + + def on_connected(self, *, protocol: HathorProtocol) -> None: + """Callback for when an outbound connection gets connected.""" + assert protocol.addr not in self.connected_peers() + + if protocol.inbound: + assert protocol.addr not in self._connecting_outbound + else: + assert protocol.addr in self._connecting_outbound + self._connecting_outbound.remove(protocol.addr) + + self._handshaking[protocol.addr] = protocol + + def on_handshake_disconnect(self, *, addr: PeerAddress) -> None: + """ + Callback for when a connection is closed during a handshaking state, that is, + after getting connected and before getting READY. + """ + assert addr not in self._connecting_outbound + assert addr in self._handshaking + assert addr not in self._ready + self._handshaking.pop(addr) + + def on_ready(self, *, addr: PeerAddress, peer_id: PeerId) -> HathorProtocol | None: + """ + Callback for when a connection gets to the READY state. + If the PeerId of this connection is duplicate, return the protocol that we should disconnect. + Return None otherwise. + """ + assert addr not in self._connecting_outbound + assert addr in self._handshaking + assert addr not in self._ready + + protocol = self._handshaking.pop(addr) + self._ready[addr] = protocol # We always index it by address, even if its PeerId is duplicate. + + connection_to_drop: HathorProtocol | None = None + + # If there's an existing connection with the same PeerId, this is a duplicate connection + if old_connection := self.get_ready_peer_by_id(protocol.peer.id): + # We choose to drop either the new or the old connection. + if self._should_drop_new_connection(protocol): + # We return early when we drop the new connection, + # so we don't override the old connection in _addr_by_id with it below. + return protocol + + # When dropping the old connection, we do override it in _addr_by_id below. + connection_to_drop = old_connection + + self._addr_by_id[peer_id] = addr + return connection_to_drop + + def on_ready_disconnect(self, *, addr: PeerAddress, peer_id: PeerId) -> None: + """Callback for when a connection is closed during the READY state.""" + assert addr not in self._connecting_outbound + assert addr not in self._handshaking + assert addr in self._ready + self._ready.pop(addr) + + if self._addr_by_id[peer_id] == addr: + self._addr_by_id.pop(peer_id) + + def on_unknown_disconnect(self, *, addr: PeerAddress) -> None: + """Callback for when a connection is closed during an unknown state.""" + assert addr not in self._handshaking + assert addr not in self._ready + if addr in self._connecting_outbound: + self._connecting_outbound.remove(addr) + + @staticmethod + def _should_drop_new_connection(new_conn: HathorProtocol) -> bool: + """ + When there are connections with duplicate PeerIds, determine which one should be dropped, the old or the new. + Return True if we should drop the new connection, and False otherwise. + + The logic to determine this is `(my_peer_id > other_peer_id) XNOR new_conn.inbound`. + """ + my_peer_is_larger = bytes(new_conn.my_peer.id) > bytes(new_conn.peer.id) + return my_peer_is_larger == new_conn.inbound diff --git a/hathor/p2p/peer_discovery/bootstrap.py b/hathor/p2p/peer_discovery/bootstrap.py index 55b5e9f16..6e71f310e 100644 --- a/hathor/p2p/peer_discovery/bootstrap.py +++ b/hathor/p2p/peer_discovery/bootstrap.py @@ -15,6 +15,8 @@ from typing import Callable from structlog import get_logger +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IProtocol from typing_extensions import override from hathor.p2p.peer_endpoint import PeerEndpoint @@ -37,6 +39,6 @@ def __init__(self, entrypoints: list[PeerEndpoint]): self.entrypoints = entrypoints @override - async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], Deferred[IProtocol] | None]) -> None: for entrypoint in self.entrypoints: connect_to(entrypoint) diff --git a/hathor/p2p/peer_discovery/dns.py b/hathor/p2p/peer_discovery/dns.py index c5dfe74d6..0debde977 100644 --- a/hathor/p2p/peer_discovery/dns.py +++ b/hathor/p2p/peer_discovery/dns.py @@ -19,6 +19,7 @@ from structlog import get_logger from twisted.internet.defer import Deferred, gatherResults +from twisted.internet.interfaces import IProtocol from twisted.names.client import lookupAddress, lookupText from twisted.names.dns import Record_A, Record_TXT, RRHeader from typing_extensions import override @@ -53,7 +54,7 @@ def do_lookup_text(self, host: str) -> Deferred[LookupResult]: return lookupText(host) @override - async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], Deferred[IProtocol] | None]) -> None: """ Run DNS lookup for host and connect to it This is executed when starting the DNS Peer Discovery and first connecting to the network """ diff --git a/hathor/p2p/peer_discovery/peer_discovery.py b/hathor/p2p/peer_discovery/peer_discovery.py index 7d040fae2..49019e739 100644 --- a/hathor/p2p/peer_discovery/peer_discovery.py +++ b/hathor/p2p/peer_discovery/peer_discovery.py @@ -15,6 +15,9 @@ from abc import ABC, abstractmethod from typing import Callable +from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IProtocol + from hathor.p2p.peer_endpoint import PeerEndpoint @@ -23,7 +26,7 @@ class PeerDiscovery(ABC): """ @abstractmethod - async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], Deferred[IProtocol] | None]) -> None: """ This method must discover the peers and call `connect_to` for each of them. :param connect_to: Function which will be called for each discovered peer. diff --git a/hathor/p2p/peer_endpoint.py b/hathor/p2p/peer_endpoint.py index c7cafce20..47ff422ed 100644 --- a/hathor/p2p/peer_endpoint.py +++ b/hathor/p2p/peer_endpoint.py @@ -131,7 +131,7 @@ def from_hostname_address(cls, hostname: str, address: IPv4Address | IPv6Address @classmethod def from_address(cls, address: IAddress) -> Self: - """Create an Entrypoint from a Twisted IAddress.""" + """Create a PeerAddress from a Twisted IAddress.""" if not isinstance(address, (IPv4Address, IPv6Address)): raise NotImplementedError(f'address: {address}') return cls.parse(f'{address.type}://{address.host}:{address.port}') diff --git a/hathor/p2p/protocol.py b/hathor/p2p/protocol.py index e05e63b55..5159eb075 100644 --- a/hathor/p2p/protocol.py +++ b/hathor/p2p/protocol.py @@ -14,7 +14,7 @@ import time from enum import Enum -from typing import TYPE_CHECKING, Optional, cast +from typing import TYPE_CHECKING, Iterable, Optional, cast from structlog import get_logger from twisted.internet import defer @@ -24,16 +24,17 @@ from twisted.protocols.basic import LineReceiver from twisted.python.failure import Failure -from hathor.conf.settings import HathorSettings +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages -from hathor.p2p.peer import PrivatePeer, PublicPeer, UnverifiedPeer -from hathor.p2p.peer_endpoint import PeerEndpoint +from hathor.p2p.peer import PrivatePeer, PublicPeer +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.peer_id import PeerId from hathor.p2p.rate_limiter import RateLimiter from hathor.p2p.states import BaseState, HelloState, PeerIdState, ReadyState from hathor.p2p.sync_version import SyncVersion from hathor.p2p.utils import format_address from hathor.profiler import get_cpu_profiler +from hathor.transaction import BaseTransaction if TYPE_CHECKING: from hathor.manager import HathorManager # noqa: F401 @@ -83,7 +84,6 @@ class WarningFlags(str, Enum): state: Optional[BaseState] connection_time: float _state_instances: dict[PeerState, BaseState] - entrypoint: Optional[PeerEndpoint] warning_flags: set[str] aborting: bool diff_timestamp: Optional[int] @@ -101,13 +101,16 @@ def __init__( my_peer: PrivatePeer, p2p_manager: 'ConnectionsManager', *, - settings: HathorSettings, + dependencies: P2PDependencies, use_ssl: bool, inbound: bool, + addr: PeerAddress, ) -> None: - self._settings = settings + self.dependencies = dependencies + self._settings = dependencies.settings self.my_peer = my_peer self.connections = p2p_manager + self.addr = addr assert p2p_manager.manager is not None self.node = p2p_manager.manager @@ -147,10 +150,6 @@ def __init__( self.ratelimit: RateLimiter = RateLimiter(self.reactor) # self.ratelimit.set_limit(self.RateLimitKeys.GLOBAL, 120, 60) - # Connection string of the peer - # Used to validate if entrypoints has this string - self.entrypoint: Optional[PeerEndpoint] = None - # Peer id sent in the connection url that is expected to connect (optional) self.expected_peer_id: PeerId | None = None @@ -175,7 +174,7 @@ def change_state(self, state_enum: PeerState) -> None: """Called to change the state of the connection.""" if state_enum not in self._state_instances: state_cls = state_enum.value - instance = state_cls(self, self._settings) + instance = state_cls(self, dependencies=self.dependencies) instance.state_name = state_enum.name self._state_instances[state_enum] = instance new_state = self._state_instances[state_enum] @@ -254,14 +253,11 @@ def on_connect(self) -> None: if self.connections: self.connections.on_peer_connect(self) - def on_outbound_connect(self, entrypoint: PeerEndpoint, peer: UnverifiedPeer | PublicPeer | None) -> None: + def on_outbound_connect(self, peer_id: PeerId | None) -> None: """Called when we successfully establish an outbound connection to a peer.""" - # Save the used entrypoint in protocol so we can validate that it matches the entrypoints data - if entrypoint.peer_id is not None and peer is not None: - assert entrypoint.peer_id == peer.id - - self.expected_peer_id = peer.id if peer else entrypoint.peer_id - self.entrypoint = entrypoint + # Save the peer_id so we can validate that it matches the one we'll receive in the PEER-ID state + assert not self.inbound + self.expected_peer_id = peer_id def on_peer_ready(self) -> None: assert self.connections is not None @@ -282,11 +278,33 @@ def on_disconnect(self, reason: Failure) -> None: self._idle_timeout_call_later = None self.aborting = True self.update_log_context() - if self.state: - self.state.on_exit() + + if not self.state: + # TODO: This should never happen, it can only happen if an exception was raised in the middle of our + # connection callback (connectionMade/on_connect). In that case, we may have not initialized our state + # yet. We should improve this by making an initial non-None state. + self.log.error( + 'disconnecting protocol with no state. check for previous exceptions', + addr=str(self.addr), + peer_id=str(self.get_peer_id()), + ) + self.connections.on_unknown_disconnect(addr=self.addr) + return + self.state.on_exit() + state_name = self.state.state_name + + if self.is_state(self.PeerState.HELLO) or self.is_state(self.PeerState.PEER_ID): self.state = None - if self.connections: - self.connections.on_peer_disconnect(self) + self.connections.on_handshake_disconnect(addr=self.addr) + return + + if self.is_state(self.PeerState.READY): + self.state = None + self.connections.on_ready_disconnect(addr=self.addr, peer_id=self.peer.id) + return + + self.state = None + raise AssertionError(f'disconnected in unexpected state: {state_name or "unknown"}') def send_message(self, cmd: ProtocolMessages, payload: Optional[str] = None) -> None: """ A generic message which must be implemented to send a message @@ -394,6 +412,27 @@ def disable_sync(self) -> None: self.log.info('disable sync') self.state.sync_agent.disable_sync() + def is_synced(self) -> bool: + assert isinstance(self.state, ReadyState) + return self.state.is_synced() + + def send_tx_to_peer(self, tx: BaseTransaction) -> None: + assert isinstance(self.state, ReadyState) + return self.state.send_tx_to_peer(tx) + + def get_peer(self) -> PublicPeer: + return self.peer + + def get_peer_if_set(self) -> PublicPeer | None: + return self._peer + + def send_peers(self, peers: Iterable[PublicPeer]) -> None: + assert isinstance(self.state, ReadyState) + self.state.send_peers(peers) + + def get_metrics(self) -> 'ConnectionMetrics': + return self.metrics + class HathorLineReceiver(LineReceiver, HathorProtocol): """ Implements HathorProtocol in a LineReceiver protocol. @@ -401,6 +440,10 @@ class HathorLineReceiver(LineReceiver, HathorProtocol): """ MAX_LENGTH = 65536 + def makeConnection(self, transport: ITransport) -> None: + assert self.addr == PeerAddress.from_address(transport.getPeer()) + super().makeConnection(transport) + def connectionMade(self) -> None: super(HathorLineReceiver, self).connectionMade() self.setLineMode() diff --git a/hathor/p2p/resources/add_peers.py b/hathor/p2p/resources/add_peers.py index c8faeb5dc..a7aa3f212 100644 --- a/hathor/p2p/resources/add_peers.py +++ b/hathor/p2p/resources/add_peers.py @@ -72,7 +72,7 @@ def render_POST(self, request: Request) -> bytes: def already_connected(endpoint: PeerEndpoint) -> bool: # ignore peers that we're already trying to connect for ready_endpoint in self.manager.connections.iter_not_ready_endpoints(): - if endpoint.addr == ready_endpoint.addr: + if endpoint.addr == ready_endpoint: return True # remove peers we already know about diff --git a/hathor/p2p/resources/status.py b/hathor/p2p/resources/status.py index 68edb9f0e..896220b9a 100644 --- a/hathor/p2p/resources/status.py +++ b/hathor/p2p/resources/status.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from twisted.web.http import Request + import hathor from hathor.api_util import Resource, set_cors from hathor.cli.openapi_files.register import register_resource @@ -34,7 +36,7 @@ def __init__(self, manager): self.manager = manager self.reactor = manager.reactor - def render_GET(self, request): + def render_GET(self, request: Request) -> bytes: request.setHeader(b'content-type', b'application/json; charset=utf-8') set_cors(request, 'GET') @@ -42,17 +44,14 @@ def render_GET(self, request): connecting_peers = [] # TODO: refactor as not to use a private item - for endpoint, deferred in self.manager.connections.connecting_peers.items(): - host = getattr(endpoint, '_host', '') - port = getattr(endpoint, '_port', '') - connecting_peers.append({'deferred': str(deferred), 'address': '{}:{}'.format(host, port)}) + for address in self.manager.connections.iter_connecting_outbound_peers(): + connecting_peers.append({'address': str(address)}) handshaking_peers = [] # TODO: refactor as not to use a private item - for conn in self.manager.connections.handshaking_peers: - remote = conn.transport.getPeer() + for conn in self.manager.connections.iter_handshaking_peers(): handshaking_peers.append({ - 'address': '{}:{}'.format(remote.host, remote.port), + 'address': str(conn.addr), 'state': conn.state.state_name, 'uptime': now - conn.connection_time, 'app_version': conn.app_version, @@ -60,7 +59,6 @@ def render_GET(self, request): connected_peers = [] for conn in self.manager.connections.iter_ready_connections(): - remote = conn.transport.getPeer() status = {} status[conn.state.sync_agent.name] = conn.state.sync_agent.get_status() connected_peers.append({ @@ -68,7 +66,7 @@ def render_GET(self, request): 'app_version': conn.app_version, 'current_time': now, 'uptime': now - conn.connection_time, - 'address': '{}:{}'.format(remote.host, remote.port), + 'address': str(conn.addr), 'state': conn.state.state_name, # 'received_bytes': conn.received_bytes, 'rtt': list(conn.state.rtt_window), @@ -134,7 +132,7 @@ def render_GET(self, request): 'id': '5578ab3bcaa861fb9d07135b8b167dd230d4487b147be8fd2c94a79bd349d123', 'app_version': 'Hathor v0.14.0-beta', 'uptime': 118.37029600143433, - 'address': '192.168.1.1:54321', + 'address': 'tcp://192.168.1.1:54321', 'state': 'READY', 'last_message': 1539271481, 'plugins': { @@ -149,8 +147,7 @@ def render_GET(self, request): 'peer_best_blockchain': [_openapi_height_info] } _openapi_connecting_peer = { - 'deferred': '>', # noqa - 'address': '192.168.1.1:54321' + 'address': 'tcp://192.168.1.1:54321' } StatusResource.openapi = { @@ -201,7 +198,7 @@ def render_GET(self, request): 'connected_peers': [_openapi_connected_peer], 'handshaking_peers': [ { - 'address': '192.168.1.1:54321', + 'address': 'tcp://192.168.1.1:54321', 'state': 'HELLO', 'uptime': 0.0010249614715576172, 'app_version': 'Unknown' diff --git a/hathor/p2p/states/base.py b/hathor/p2p/states/base.py index f08401cc0..75a69140e 100644 --- a/hathor/p2p/states/base.py +++ b/hathor/p2p/states/base.py @@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred from hathor.conf.settings import HathorSettings +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages if TYPE_CHECKING: @@ -34,9 +35,10 @@ class BaseState: Callable[[str], None] | Callable[[str], Deferred[None]] | Callable[[str], Coroutine[Deferred[None], Any, None]] ] - def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings): + def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies): self.log = logger.new(**protocol.get_logger_context()) - self._settings = settings + self.dependencies = dependencies + self._settings: HathorSettings = dependencies.settings self.protocol = protocol self.cmd_map = { ProtocolMessages.ERROR: self.handle_error, diff --git a/hathor/p2p/states/hello.py b/hathor/p2p/states/hello.py index 47c9cf4e5..9c034b7cb 100644 --- a/hathor/p2p/states/hello.py +++ b/hathor/p2p/states/hello.py @@ -18,8 +18,8 @@ import hathor from hathor.conf.get_settings import get_global_settings -from hathor.conf.settings import HathorSettings from hathor.exception import HathorError +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages from hathor.p2p.states.base import BaseState from hathor.p2p.sync_version import SyncVersion @@ -33,8 +33,8 @@ class HelloState(BaseState): - def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings) -> None: - super().__init__(protocol, settings) + def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: + super().__init__(protocol, dependencies=dependencies) self.log = logger.new(**protocol.get_logger_context()) self.cmd_map.update({ ProtocolMessages.HELLO: self.handle_hello, @@ -55,11 +55,11 @@ def _get_hello_data(self) -> dict[str, Any]: 'network': self._settings.NETWORK_NAME, 'remote_address': format_address(remote), 'genesis_short_hash': get_genesis_short_hash(), - 'timestamp': protocol.node.reactor.seconds(), + 'timestamp': self.dependencies.reactor.seconds(), 'settings_dict': get_settings_hello_dict(self._settings), - 'capabilities': protocol.node.capabilities, + 'capabilities': self.dependencies.capabilities, } - if self.protocol.node.has_sync_version_capability(): + if self.dependencies.has_sync_version_capability(): data['sync_versions'] = [x.value for x in self._get_sync_versions()] return data @@ -143,7 +143,7 @@ def handle_hello(self, payload: str) -> None: protocol.send_error_and_close_connection('Different genesis.') return - dt = data['timestamp'] - protocol.node.reactor.seconds() + dt = data['timestamp'] - self.dependencies.reactor.seconds() if abs(dt) > self._settings.MAX_FUTURE_TIMESTAMP_ALLOWED / 2: protocol.send_error_and_close_connection('Nodes timestamps too far apart.') return diff --git a/hathor/p2p/states/peer_id.py b/hathor/p2p/states/peer_id.py index 77e8a051e..2ca93ea59 100644 --- a/hathor/p2p/states/peer_id.py +++ b/hathor/p2p/states/peer_id.py @@ -16,7 +16,7 @@ from structlog import get_logger -from hathor.conf.settings import HathorSettings +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PublicPeer from hathor.p2p.peer_id import PeerId @@ -30,8 +30,8 @@ class PeerIdState(BaseState): - def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings) -> None: - super().__init__(protocol, settings) + def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: + super().__init__(protocol, dependencies=dependencies) self.log = logger.new(remote=protocol.get_short_remote()) self.cmd_map.update({ ProtocolMessages.PEER_ID: self.handle_peer_id, @@ -111,7 +111,7 @@ async def handle_peer_id(self, payload: str) -> None: return if protocol.connections is not None: - if protocol.connections.is_peer_connected(peer.id): + if protocol.connections.is_peer_ready(peer.id): protocol.send_error_and_close_connection('We are already connected.') return @@ -120,9 +120,6 @@ async def handle_peer_id(self, payload: str) -> None: protocol.send_error_and_close_connection('Connection string is not in the entrypoints.') return - if protocol.entrypoint is not None and protocol.entrypoint.peer_id is not None: - assert protocol.entrypoint.peer_id == peer.id - if protocol.use_ssl: certificate_valid = peer.validate_certificate(protocol) if not certificate_valid: @@ -149,7 +146,7 @@ def _should_block_peer(self, peer_id: PeerId) -> bool: Currently this is only because the peer is not in a whitelist and whitelist blocking is active. """ - peer_is_whitelisted = peer_id in self.protocol.node.peers_whitelist + peer_is_whitelisted = self.protocol.connections.is_peer_whitelisted(peer_id) # never block whitelisted peers if peer_is_whitelisted: return False @@ -164,10 +161,8 @@ def _should_block_peer(self, peer_id: PeerId) -> bool: return True # otherwise we block non-whitelisted peers when on "whitelist-only mode" - if self.protocol.connections is not None: - protocol_is_whitelist_only = self.protocol.connections.whitelist_only - if protocol_is_whitelist_only and not peer_is_whitelisted: - return True + if self.dependencies.whitelist_only and not peer_is_whitelisted: + return True # default is not blocking, this will be sync-v2 peers not on whitelist when not on whitelist-only mode return False diff --git a/hathor/p2p/states/ready.py b/hathor/p2p/states/ready.py index 1bed1c745..fb3db3ab9 100644 --- a/hathor/p2p/states/ready.py +++ b/hathor/p2p/states/ready.py @@ -18,8 +18,8 @@ from structlog import get_logger from twisted.internet.task import LoopingCall -from hathor.conf.settings import HathorSettings from hathor.indexes.height_index import HeightInfo +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PublicPeer, UnverifiedPeer from hathor.p2p.states.base import BaseState @@ -35,12 +35,12 @@ class ReadyState(BaseState): - def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings) -> None: - super().__init__(protocol, settings) + def __init__(self, protocol: 'HathorProtocol', *, dependencies: P2PDependencies) -> None: + super().__init__(protocol, dependencies=dependencies) self.log = logger.new(**self.protocol.get_logger_context()) - self.reactor = self.protocol.node.reactor + self.reactor = self.dependencies.reactor # It triggers an event to send a ping message if necessary. self.lc_ping = LoopingCall(self.send_ping_if_necessary) @@ -85,7 +85,7 @@ def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings) -> None self.lc_get_best_blockchain: Optional[LoopingCall] = None # if the peer has the GET-BEST-BLOCKCHAIN capability - common_capabilities = protocol.capabilities & set(protocol.node.capabilities) + common_capabilities = protocol.capabilities & set(self.dependencies.capabilities) if (self._settings.CAPABILITY_GET_BEST_BLOCKCHAIN in common_capabilities): # set the loop to get the best blockchain from the peer self.lc_get_best_blockchain = LoopingCall(self.send_get_best_blockchain) @@ -106,7 +106,7 @@ def __init__(self, protocol: 'HathorProtocol', settings: HathorSettings) -> None self.log.debug(f'loading {sync_version}') sync_factory = connections.get_sync_factory(sync_version) - self.sync_agent: SyncAgent = sync_factory.create_sync_agent(self.protocol, reactor=self.reactor) + self.sync_agent: SyncAgent = sync_factory.create_sync_agent(self.protocol) self.cmd_map.update(self.sync_agent.get_cmd_dict()) def on_enter(self) -> None: @@ -155,7 +155,7 @@ def handle_get_peers(self, payload: str) -> None: """ Executed when a GET-PEERS command is received. It just responds with a list of all known peers. """ - for peer in self.protocol.connections.verified_peer_storage.values(): + for peer in self.protocol.connections.get_verified_peers(): self.send_peers([peer]) def send_peers(self, peer_list: Iterable[PublicPeer]) -> None: @@ -195,8 +195,7 @@ def send_ping(self) -> None: """ # Add a salt number to prevent peers from faking rtt. self.ping_start_time = self.reactor.seconds() - rng = self.protocol.connections.rng - self.ping_salt = rng.randbytes(self.ping_salt_size).hex() + self.ping_salt = self.protocol.connections.get_randbytes(self.ping_salt_size).hex() self.send_message(ProtocolMessages.PING, self.ping_salt) def send_pong(self, salt: str) -> None: @@ -255,7 +254,7 @@ def handle_get_best_blockchain(self, payload: str) -> None: ) return - best_blockchain = self.protocol.node.tx_storage.get_n_height_tips(n_blocks) + best_blockchain = self.dependencies.tx_storage.get_n_height_tips(n_blocks) self.send_best_blockchain(best_blockchain) def send_best_blockchain(self, best_blockchain: list[HeightInfo]) -> None: diff --git a/hathor/p2p/sync_factory.py b/hathor/p2p/sync_factory.py index f4883f21a..da32bd68b 100644 --- a/hathor/p2p/sync_factory.py +++ b/hathor/p2p/sync_factory.py @@ -16,7 +16,6 @@ from typing import TYPE_CHECKING from hathor.p2p.sync_agent import SyncAgent -from hathor.reactor import ReactorProtocol as Reactor if TYPE_CHECKING: from hathor.p2p.protocol import HathorProtocol @@ -24,5 +23,5 @@ class SyncAgentFactory(ABC): @abstractmethod - def create_sync_agent(self, protocol: 'HathorProtocol', reactor: Reactor) -> SyncAgent: - pass + def create_sync_agent(self, protocol: 'HathorProtocol') -> SyncAgent: + raise NotImplementedError diff --git a/hathor/p2p/sync_v1/agent.py b/hathor/p2p/sync_v1/agent.py index 68fe401ec..7291db0e9 100644 --- a/hathor/p2p/sync_v1/agent.py +++ b/hathor/p2p/sync_v1/agent.py @@ -22,14 +22,13 @@ from twisted.internet.defer import CancelledError, Deferred, inlineCallbacks from twisted.internet.interfaces import IDelayedCall -from hathor.conf.get_settings import get_global_settings +from hathor.p2p import P2PDependencies from hathor.p2p.messages import GetNextPayload, GetTipsPayload, NextPayload, ProtocolMessages, TipsPayload from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_v1.downloader import Downloader -from hathor.reactor import ReactorProtocol as Reactor from hathor.transaction import BaseTransaction +from hathor.transaction.storage import TransactionStorage from hathor.transaction.storage.exceptions import TransactionDoesNotExist -from hathor.transaction.vertex_parser import VertexParser from hathor.util import json_dumps, json_loads logger = get_logger() @@ -64,9 +63,8 @@ def __init__( self, protocol: 'HathorProtocol', downloader: Downloader, - reactor: Reactor, *, - vertex_parser: VertexParser, + dependencies: P2PDependencies, ) -> None: """ :param protocol: Protocol of the connection. @@ -75,13 +73,17 @@ def __init__( :param reactor: Reactor to schedule later calls. (default=twisted.internet.reactor) :type reactor: Reactor """ - self._settings = get_global_settings() - self.vertex_parser = vertex_parser + self._settings = dependencies.settings + self.dependencies = dependencies self.protocol = protocol - self.manager = protocol.node self.downloader = downloader + self.reactor = dependencies.reactor - self.reactor: Reactor = reactor + # Since Sync-v1 does not support multiprocess P2P, the dependencies.tx_storage will always be a concrete + # TransactionStorage in the same process, with no IPC. + # This reduces the number of IPC endpoints we have to implement. + assert isinstance(self.dependencies.tx_storage, TransactionStorage) + self.tx_storage = self.dependencies.tx_storage # Rate limit for this connection. assert protocol.connections is not None @@ -184,7 +186,7 @@ def is_synced(self) -> bool: See the `send_tx_to_peer_if_possible` method for the exact process and to understand why this condition has to be this way. """ - return self.manager.tx_storage.latest_timestamp - self.synced_timestamp <= self.sync_threshold + return self.tx_storage.latest_timestamp - self.synced_timestamp <= self.sync_threshold def is_errored(self) -> bool: # XXX: this sync manager does not have an error state, this method exists for API parity with sync-v2 @@ -203,7 +205,7 @@ def send_tx_to_peer_if_possible(self, tx: BaseTransaction) -> None: # parents' timestamps are below synced_timestamp, i.e., we know that the peer # has all the parents. for parent_hash in tx.parents: - parent = self.protocol.node.tx_storage.get_transaction(parent_hash) + parent = self.tx_storage.get_vertex(parent_hash) if parent.timestamp > self.synced_timestamp: return @@ -286,7 +288,7 @@ def sync_from_timestamp(self, next_timestamp: int) -> Generator[Deferred, Any, N next_offset=payload.next_offset, hashes=len(payload.hashes)) count = 0 for h in payload.hashes: - if not self.manager.tx_storage.transaction_exists(h): + if not self.tx_storage.transaction_exists(h): pending.add(self.get_data(h)) count += 1 self.log.debug('...', next_ts=next_timestamp, count=count, pending=len(pending)) @@ -321,18 +323,18 @@ def find_synced_timestamp(self) -> Generator[Deferred, Any, Optional[int]]: # Maximum of ceil(log(k)), where k is the number of items between the new one and the latest item. prev_cur = None cur = self.peer_timestamp - local_merkle_tree, _ = self.manager.tx_storage.get_merkle_tree(cur) + local_merkle_tree, _ = self.tx_storage.get_merkle_tree(cur) step = 1 while tips.merkle_tree != local_merkle_tree: - if cur <= self.manager.tx_storage.first_timestamp: + if cur <= self.tx_storage.first_timestamp: raise Exception( 'We cannot go before genesis. Peer is probably running with wrong configuration or database.' ) prev_cur = cur - assert self.manager.tx_storage.first_timestamp > 0 - cur = max(cur - step, self.manager.tx_storage.first_timestamp) + assert self.tx_storage.first_timestamp > 0 + cur = max(cur - step, self.tx_storage.first_timestamp) tips = (yield self.get_peer_tips(cur)) - local_merkle_tree, _ = self.manager.tx_storage.get_merkle_tree(cur) + local_merkle_tree, _ = self.tx_storage.get_merkle_tree(cur) step *= 2 # Here, both nodes are synced at timestamp `cur` and not synced at timestamp `prev_cur`. @@ -348,7 +350,7 @@ def find_synced_timestamp(self) -> Generator[Deferred, Any, Optional[int]]: while high - low > 1: mid = (low + high + 1) // 2 tips = (yield self.get_peer_tips(mid)) - local_merkle_tree, _ = self.manager.tx_storage.get_merkle_tree(mid) + local_merkle_tree, _ = self.tx_storage.get_merkle_tree(mid) if tips.merkle_tree == local_merkle_tree: low = mid else: @@ -442,9 +444,9 @@ def send_next(self, timestamp: int, offset: int = 0) -> None: from hathor.indexes.timestamp_index import RangeIdx count = self.MAX_HASHES - assert self.manager.tx_storage.indexes is not None + assert self.tx_storage.indexes is not None from_idx = RangeIdx(timestamp, offset) - hashes, next_idx = self.manager.tx_storage.indexes.sorted_all.get_hashes_and_next_idx(from_idx, count) + hashes, next_idx = self.tx_storage.indexes.sorted_all.get_hashes_and_next_idx(from_idx, count) if next_idx is None: # this means we've reached the end and there's nothing else to sync next_timestamp, next_offset = inf, 0 @@ -524,7 +526,7 @@ def _send_tips(self, timestamp: Optional[int] = None, include_hashes: bool = Fal """ Send a TIPS message. """ if timestamp is None: - timestamp = self.manager.tx_storage.latest_timestamp + timestamp = self.tx_storage.latest_timestamp # All tips # intervals = self.manager.tx_storage.get_all_tips(timestamp) @@ -532,7 +534,7 @@ def _send_tips(self, timestamp: Optional[int] = None, include_hashes: bool = Fal # raise Exception('No tips for timestamp {}'.format(timestamp)) # Calculate list of hashes to be sent - merkle_tree, hashes = self.manager.tx_storage.get_merkle_tree(timestamp) + merkle_tree, hashes = self.tx_storage.get_merkle_tree(timestamp) has_more = False if not include_hashes: @@ -577,7 +579,7 @@ def handle_get_data(self, payload: str) -> None: hash_hex = payload # self.log.debug('handle_get_data', payload=hash_hex) try: - tx = self.protocol.node.tx_storage.get_transaction(bytes.fromhex(hash_hex)) + tx = self.tx_storage.get_vertex(bytes.fromhex(hash_hex)) self.send_data(tx) except TransactionDoesNotExist: # In case the tx does not exist we send a NOT-FOUND message @@ -605,7 +607,7 @@ def handle_data(self, payload: str) -> None: data = base64.b64decode(payload) try: - tx = self.vertex_parser.deserialize(data) + tx = self.dependencies.vertex_parser.deserialize(data) except struct.error: # Invalid data for tx decode return @@ -614,11 +616,10 @@ def handle_data(self, payload: str) -> None: self.log.debug('tx received from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id()) - if self.protocol.node.tx_storage.get_genesis(tx.hash): + if self.tx_storage.get_genesis(tx.hash): # We just got the data of a genesis tx/block. What should we do? # Will it reduce peer reputation score? return - tx.storage = self.protocol.node.tx_storage key = self.get_data_key(tx.hash) deferred = self.deferred_by_key.pop(key, None) @@ -627,16 +628,16 @@ def handle_data(self, payload: str) -> None: assert tx.timestamp is not None self.requested_data_arrived(tx.timestamp) deferred.callback(tx) - elif self.manager.tx_storage.transaction_exists(tx.hash): + elif self.tx_storage.transaction_exists(tx.hash): # transaction already added to the storage, ignore it # XXX: maybe we could add a hash blacklist and punish peers propagating known bad txs - self.manager.tx_storage.compare_bytes_with_local_tx(tx) + self.tx_storage.compare_bytes_with_local_tx(tx) return else: self.log.info('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id()) # If we have not requested the data, it is a new transaction being propagated # in the network, thus, we propagate it as well. - success = self.manager.vertex_handler.on_new_vertex(tx) + success = self.dependencies.vertex_handler.on_new_vertex(tx) if success: self.protocol.connections.send_tx_to_peers(tx) self.update_received_stats(tx, success) @@ -682,12 +683,12 @@ def on_tx_success(self, tx: 'BaseTransaction') -> 'BaseTransaction': # the parameter of the second callback is the return of the first # so I need to return the same tx to guarantee that all peers will receive it if tx: - if self.manager.tx_storage.transaction_exists(tx.hash): - self.manager.tx_storage.compare_bytes_with_local_tx(tx) + if self.tx_storage.transaction_exists(tx.hash): + self.tx_storage.compare_bytes_with_local_tx(tx) success = True else: # Add tx to the DAG. - success = self.manager.vertex_handler.on_new_vertex(tx) + success = self.dependencies.vertex_handler.on_new_vertex(tx) if success: self.protocol.connections.send_tx_to_peers(tx) # Updating stats data diff --git a/hathor/p2p/sync_v1/downloader.py b/hathor/p2p/sync_v1/downloader.py index d8b3c12cf..8e22d9f7f 100644 --- a/hathor/p2p/sync_v1/downloader.py +++ b/hathor/p2p/sync_v1/downloader.py @@ -22,10 +22,10 @@ from twisted.python.failure import Failure from hathor.conf.get_settings import get_global_settings +from hathor.p2p import P2PDependencies from hathor.transaction.storage.exceptions import TransactionDoesNotExist if TYPE_CHECKING: - from hathor.manager import HathorManager from hathor.p2p.sync_v1.agent import NodeSyncTimestamp from hathor.transaction import BaseTransaction @@ -145,10 +145,10 @@ class Downloader: # Size of the sliding window used to download transactions. window_size: int - def __init__(self, manager: 'HathorManager', window_size: int = 100): - self._settings = get_global_settings() + def __init__(self, dependencies: P2PDependencies, window_size: int = 100): + self.dependencies = dependencies + self._settings = dependencies.settings self.log = logger.new() - self.manager = manager self.pending_transactions = {} self.waiting_deque = deque() @@ -172,7 +172,7 @@ def get_tx(self, tx_id: bytes, connection: 'NodeSyncTimestamp') -> Deferred: # If I already have this tx in the storage just return a defer already success # In the node_sync code we already handle this case but in a race condition situation # we might get here but it's not common - tx = self.manager.tx_storage.get_transaction(tx_id) + tx = self.dependencies.tx_storage.get_vertex(tx_id) self.log.debug('requesting to download a tx that is already in the storage', tx=tx_id.hex()) return defer.succeed(tx) except TransactionDoesNotExist: diff --git a/hathor/p2p/sync_v1/factory.py b/hathor/p2p/sync_v1/factory.py index 2a205d728..68fd740e1 100644 --- a/hathor/p2p/sync_v1/factory.py +++ b/hathor/p2p/sync_v1/factory.py @@ -14,34 +14,29 @@ from typing import TYPE_CHECKING, Optional -from hathor.p2p.manager import ConnectionsManager +from hathor.p2p import P2PDependencies from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_factory import SyncAgentFactory from hathor.p2p.sync_v1.agent import NodeSyncTimestamp from hathor.p2p.sync_v1.downloader import Downloader -from hathor.reactor import ReactorProtocol as Reactor -from hathor.transaction.vertex_parser import VertexParser if TYPE_CHECKING: from hathor.p2p.protocol import HathorProtocol class SyncV11Factory(SyncAgentFactory): - def __init__(self, connections: ConnectionsManager, *, vertex_parser: VertexParser): - self.connections = connections - self.vertex_parser = vertex_parser + def __init__(self, dependencies: P2PDependencies) -> None: + self.dependencies = dependencies self._downloader: Optional[Downloader] = None def get_downloader(self) -> Downloader: if self._downloader is None: - assert self.connections.manager is not None - self._downloader = Downloader(self.connections.manager) + self._downloader = Downloader(self.dependencies) return self._downloader - def create_sync_agent(self, protocol: 'HathorProtocol', reactor: Reactor) -> SyncAgent: + def create_sync_agent(self, protocol: 'HathorProtocol') -> SyncAgent: return NodeSyncTimestamp( protocol, downloader=self.get_downloader(), - reactor=reactor, - vertex_parser=self.vertex_parser + dependencies=self.dependencies, ) diff --git a/hathor/p2p/sync_v2/agent.py b/hathor/p2p/sync_v2/agent.py index 5393080b4..2a3b75f64 100644 --- a/hathor/p2p/sync_v2/agent.py +++ b/hathor/p2p/sync_v2/agent.py @@ -24,8 +24,8 @@ from twisted.internet.defer import Deferred, inlineCallbacks from twisted.internet.task import LoopingCall, deferLater -from hathor.conf.settings import HathorSettings from hathor.exception import InvalidNewTransaction +from hathor.p2p import P2PDependencies from hathor.p2p.messages import ProtocolMessages from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_v2.blockchain_streaming_client import BlockchainStreamingClient, StreamingError @@ -38,17 +38,14 @@ TransactionsStreamingServer, ) from hathor.p2p.sync_v2.transaction_streaming_client import TransactionStreamingClient -from hathor.reactor import ReactorProtocol as Reactor from hathor.transaction import BaseTransaction, Block, Transaction +from hathor.transaction.genesis import is_genesis from hathor.transaction.storage.exceptions import TransactionDoesNotExist -from hathor.transaction.vertex_parser import VertexParser from hathor.types import VertexId from hathor.util import collect_n -from hathor.vertex_handler import VertexHandler if TYPE_CHECKING: from hathor.p2p.protocol import HathorProtocol - from hathor.transaction.storage import TransactionStorage logger = get_logger() @@ -88,30 +85,23 @@ class NodeBlockSync(SyncAgent): def __init__( self, - settings: HathorSettings, protocol: 'HathorProtocol', - reactor: Reactor, *, - vertex_parser: VertexParser, - vertex_handler: VertexHandler, + dependencies: P2PDependencies, ) -> None: """ :param protocol: Protocol of the connection. :type protocol: HathorProtocol - - :param reactor: Reactor to schedule later calls. (default=twisted.internet.reactor) - :type reactor: Reactor """ - self._settings = settings - self.vertex_parser = vertex_parser - self.vertex_handler = vertex_handler + self.dependencies = dependencies + self._settings = dependencies.settings + self.vertex_parser = dependencies.vertex_parser self.protocol = protocol - self.tx_storage: 'TransactionStorage' = protocol.node.tx_storage self.state = PeerState.UNKNOWN self.DEFAULT_STREAMING_LIMIT = DEFAULT_STREAMING_LIMIT - self.reactor: Reactor = reactor + self.reactor = dependencies.reactor self._is_streaming: bool = False # Create logger with context @@ -152,7 +142,7 @@ def __init__( # Saves if I am in the middle of a mempool sync # we don't execute any sync while in the middle of it - self.mempool_manager = SyncMempoolManager(self) + self.mempool_manager = SyncMempoolManager(self, dependencies=self.dependencies) self._receiving_tips: Optional[list[VertexId]] = None self.max_receiving_tips: int = self._settings.MAX_MEMPOOL_RECEIVING_TIPS @@ -178,9 +168,7 @@ def __init__( def get_status(self) -> dict[str, Any]: """ Return the status of the sync. """ - assert self.tx_storage.indexes is not None - assert self.tx_storage.indexes.mempool_tips is not None - tips = self.tx_storage.indexes.mempool_tips.get() + tips = self.dependencies.tx_storage.get_mempool_tips() tips_limited, tips_has_more = collect_n(iter(tips), MAX_MEMPOOL_STATUS_TIPS) res = { 'is_enabled': self.is_sync_enabled(), @@ -347,7 +335,7 @@ def run_sync_mempool(self) -> Generator[Any, Any, None]: def get_my_best_block(self) -> _HeightInfo: """Return my best block info.""" - bestblock = self.tx_storage.get_best_block() + bestblock = self.dependencies.tx_storage.get_best_block() meta = bestblock.get_metadata() assert meta.validation.is_fully_connected() return _HeightInfo(height=bestblock.get_height(), id=bestblock.hash) @@ -358,7 +346,6 @@ def run_sync_blocks(self) -> Generator[Any, Any, bool]: Notice that we might already have all other peer's blocks while the other peer is still syncing. """ - assert self.tx_storage.indexes is not None self.state = PeerState.SYNCING_BLOCKS # Get my best block. @@ -381,7 +368,7 @@ def run_sync_blocks(self) -> Generator[Any, Any, bool]: # Not synced but same blockchain? if self.peer_best_block.height <= my_best_block.height: # Is peer behind me at the same blockchain? - common_block_hash = self.tx_storage.indexes.height.get(self.peer_best_block.height) + common_block_hash = self.dependencies.tx_storage.get_block_id_by_height(self.peer_best_block.height) if common_block_hash == self.peer_best_block.id: # If yes, nothing to sync from this peer. if not self.is_synced(): @@ -459,15 +446,13 @@ def send_get_tips(self) -> None: def handle_get_tips(self, _payload: str) -> None: """ Handle a GET-TIPS message. """ - assert self.tx_storage.indexes is not None - assert self.tx_storage.indexes.mempool_tips is not None if self._is_streaming: self.log.warn('can\'t send while streaming') # XXX: or can we? self.send_message(ProtocolMessages.MEMPOOL_END) return self.log.debug('handle_get_tips') # TODO Use a streaming of tips - for tx_id in self.tx_storage.indexes.mempool_tips.get(): + for tx_id in self.dependencies.tx_storage.get_mempool_tips(): self.send_tips(tx_id) self.log.debug('tips end') self.send_message(ProtocolMessages.TIPS_END) @@ -489,7 +474,7 @@ def handle_tips(self, payload: str) -> None: # filter-out txs we already have try: self._receiving_tips.extend( - VertexId(tx_id) for tx_id in data if not self.tx_storage.partial_vertex_exists(tx_id) + VertexId(tx_id) for tx_id in data if not self.dependencies.tx_storage.partial_vertex_exists(tx_id) ) except ValueError: self.protocol.send_error_and_close_connection('Invalid trasaction ID received') @@ -533,7 +518,9 @@ def start_blockchain_streaming(self, start_block: _HeightInfo, end_block: _HeightInfo) -> Deferred[StreamEnd]: """Request peer to start streaming blocks to us.""" - self._blk_streaming_client = BlockchainStreamingClient(self, start_block, end_block) + self._blk_streaming_client = BlockchainStreamingClient( + self, start_block, end_block, dependencies=self.dependencies + ) quantity = self._blk_streaming_client._blk_max_quantity self.log.info('requesting blocks streaming', start_block=start_block, @@ -597,7 +584,7 @@ def find_best_common_block(self, for info in block_info_list: try: # We must check only fully validated transactions. - blk = self.tx_storage.get_transaction(info.id) + blk = self.dependencies.tx_storage.get_vertex(info.id) except TransactionDoesNotExist: hi = info else: @@ -616,12 +603,12 @@ def on_block_complete(self, blk: Block, vertex_list: list[BaseTransaction]) -> G # Note: Any vertex and block could have already been added by another concurrent syncing peer. try: for tx in vertex_list: - if not self.tx_storage.transaction_exists(tx.hash): - self.vertex_handler.on_new_vertex(tx, fails_silently=False) + if not self.dependencies.tx_storage.transaction_exists(tx.hash): + self.dependencies.vertex_handler.on_new_vertex(tx, fails_silently=False) yield deferLater(self.reactor, 0, lambda: None) - if not self.tx_storage.transaction_exists(blk.hash): - self.vertex_handler.on_new_vertex(blk, fails_silently=False) + if not self.dependencies.tx_storage.transaction_exists(blk.hash): + self.dependencies.vertex_handler.on_new_vertex(blk, fails_silently=False) except InvalidNewTransaction: self.protocol.send_error_and_close_connection('invalid vertex received') @@ -643,7 +630,6 @@ def send_get_peer_block_hashes(self, heights: list[int]) -> None: def handle_get_peer_block_hashes(self, payload: str) -> None: """ Handle a GET-PEER-BLOCK-HASHES message. """ - assert self.tx_storage.indexes is not None heights = json.loads(payload) if len(heights) > 20: self.log.info('too many heights', heights_qty=len(heights)) @@ -651,10 +637,10 @@ def handle_get_peer_block_hashes(self, payload: str) -> None: return data = [] for h in heights: - blk_hash = self.tx_storage.indexes.height.get(h) + blk_hash = self.dependencies.tx_storage.get_block_id_by_height(h) if blk_hash is None: break - blk = self.tx_storage.get_transaction(blk_hash) + blk = self.dependencies.tx_storage.get_vertex(blk_hash) if blk.get_metadata().voided_by: break data.append((h, blk_hash.hex())) @@ -705,7 +691,7 @@ def handle_get_next_blocks(self, payload: str) -> None: def _validate_block(self, _hash: VertexId) -> Optional[Block]: """Validate block given in the GET-NEXT-BLOCKS and GET-TRANSACTIONS-BFS messages.""" try: - blk = self.tx_storage.get_transaction(_hash) + blk = self.dependencies.tx_storage.get_vertex(_hash) except TransactionDoesNotExist: self.log.debug('requested block not found', blk_id=_hash.hex()) self.send_message(ProtocolMessages.NOT_FOUND, _hash.hex()) @@ -782,7 +768,6 @@ def handle_blocks(self, payload: str) -> None: if not isinstance(blk, Block): # Not a block. Punish peer? return - blk.storage = self.tx_storage assert self._blk_streaming_client is not None self._blk_streaming_client.handle_blocks(blk) @@ -843,7 +828,7 @@ def send_get_best_block(self) -> None: def handle_get_best_block(self, _payload: str) -> None: """ Handle a GET-BEST-BLOCK message. """ - best_block = self.tx_storage.get_best_block() + best_block = self.dependencies.tx_storage.get_best_block() meta = best_block.get_metadata() assert meta.validation.is_fully_connected() payload = BestBlockPayload( @@ -865,9 +850,9 @@ def handle_best_block(self, payload: str) -> None: def start_transactions_streaming(self, partial_blocks: list[Block]) -> Deferred[StreamEnd]: """Request peer to start streaming transactions to us.""" - self._tx_streaming_client = TransactionStreamingClient(self, - partial_blocks, - limit=self.DEFAULT_STREAMING_LIMIT) + self._tx_streaming_client = TransactionStreamingClient( + self, partial_blocks, limit=self.DEFAULT_STREAMING_LIMIT, dependencies=self.dependencies + ) start_from: list[bytes] = [] first_block_hash = partial_blocks[0].hash @@ -956,7 +941,7 @@ def handle_get_transactions_bfs(self, payload: str) -> None: start_from_txs = [] for start_from_hash in data.start_from: try: - tx = self.tx_storage.get_transaction(start_from_hash) + tx = self.dependencies.tx_storage.get_vertex(start_from_hash) except TransactionDoesNotExist: # In case the tx does not exist we send a NOT-FOUND message self.log.debug('requested start_from_hash not found', start_from_hash=start_from_hash.hex()) @@ -982,11 +967,14 @@ def send_transactions_bfs(self, """ if self._tx_streaming_server is not None and self._tx_streaming_server.is_running: self.stop_tx_streaming_server(StreamEnd.PER_REQUEST) - self._tx_streaming_server = TransactionsStreamingServer(self, - start_from, - first_block, - last_block, - limit=self.DEFAULT_STREAMING_LIMIT) + self._tx_streaming_server = TransactionsStreamingServer( + self, + start_from, + first_block, + last_block, + limit=self.DEFAULT_STREAMING_LIMIT, + dependencies=self.dependencies, + ) self._tx_streaming_server.start() def send_transaction(self, tx: Transaction) -> None: @@ -1033,7 +1021,6 @@ def handle_transaction(self, payload: str) -> None: self.log.warn('not a transaction', hash=tx.hash_hex) # Not a transaction. Punish peer? return - tx.storage = self.tx_storage assert self._tx_streaming_client is not None self._tx_streaming_client.handle_transaction(tx) @@ -1047,7 +1034,7 @@ def get_tx(self, tx_id: bytes) -> Generator[Deferred, Any, BaseTransaction]: self.log.debug('tx in cache', tx=tx_id.hex()) return tx try: - tx = self.tx_storage.get_transaction(tx_id) + tx = self.dependencies.tx_storage.get_vertex(tx_id) except TransactionDoesNotExist: tx = yield self.get_data(tx_id, 'mempool') assert tx is not None @@ -1117,7 +1104,7 @@ def handle_get_data(self, payload: str) -> None: origin = data.get('origin', '') # self.log.debug('handle_get_data', payload=hash_hex) try: - tx = self.protocol.node.tx_storage.get_transaction(bytes.fromhex(txid_hex)) + tx = self.dependencies.tx_storage.get_vertex(bytes.fromhex(txid_hex)) self.send_data(tx, origin=origin) except TransactionDoesNotExist: # In case the tx does not exist we send a NOT-FOUND message @@ -1152,25 +1139,23 @@ def handle_data(self, payload: str) -> None: return assert tx is not None - if self.protocol.node.tx_storage.get_genesis(tx.hash): + if is_genesis(tx.hash, settings=self.dependencies.settings): # We just got the data of a genesis tx/block. What should we do? # Will it reduce peer reputation score? return - tx.storage = self.protocol.node.tx_storage - - if self.tx_storage.partial_vertex_exists(tx.hash): + if self.dependencies.tx_storage.partial_vertex_exists(tx.hash): # transaction already added to the storage, ignore it # XXX: maybe we could add a hash blacklist and punish peers propagating known bad txs - self.tx_storage.compare_bytes_with_local_tx(tx) + self.dependencies.tx_storage.compare_bytes_with_local_tx(tx) return else: # If we have not requested the data, it is a new transaction being propagated # in the network, thus, we propagate it as well. - if self.tx_storage.can_validate_full(tx): + if self.dependencies.tx_storage.can_validate_full(tx): self.log.debug('tx received in real time from peer', tx=tx.hash_hex, peer=self.protocol.get_peer_id()) try: - success = self.vertex_handler.on_new_vertex(tx, fails_silently=False) + success = self.dependencies.vertex_handler.on_new_vertex(tx, fails_silently=False) if success: self.protocol.connections.send_tx_to_peers(tx) except InvalidNewTransaction: diff --git a/hathor/p2p/sync_v2/blockchain_streaming_client.py b/hathor/p2p/sync_v2/blockchain_streaming_client.py index e78ec056b..3859a576d 100644 --- a/hathor/p2p/sync_v2/blockchain_streaming_client.py +++ b/hathor/p2p/sync_v2/blockchain_streaming_client.py @@ -17,6 +17,7 @@ from structlog import get_logger from twisted.internet.defer import Deferred +from hathor.p2p import P2PDependencies from hathor.p2p.sync_v2.exception import ( BlockNotConnectedToPreviousBlock, InvalidVertexError, @@ -35,11 +36,17 @@ class BlockchainStreamingClient: - def __init__(self, sync_agent: 'NodeBlockSync', start_block: '_HeightInfo', end_block: '_HeightInfo') -> None: + def __init__( + self, + sync_agent: 'NodeBlockSync', + start_block: '_HeightInfo', + end_block: '_HeightInfo', + *, + dependencies: P2PDependencies, + ) -> None: + self.dependencies = dependencies self.sync_agent = sync_agent self.protocol = self.sync_agent.protocol - self.tx_storage = self.sync_agent.tx_storage - self.vertex_handler = self.sync_agent.vertex_handler self.log = logger.new(peer=self.protocol.get_short_peer_id()) @@ -99,10 +106,9 @@ def handle_blocks(self, blk: Block) -> None: # Check for repeated blocks. is_duplicated = False - if self.tx_storage.partial_vertex_exists(blk.hash): + if self.dependencies.tx_storage.partial_vertex_exists(blk.hash): # We reached a block we already have. Skip it. self._blk_repeated += 1 - is_duplicated = True if self._blk_repeated > self.max_repeated_blocks: self.log.info('too many repeated block received', total_repeated=self._blk_repeated) self.fails(TooManyRepeatedVerticesError()) @@ -124,9 +130,9 @@ def handle_blocks(self, blk: Block) -> None: else: self.log.debug('block received', blk_id=blk.hash.hex()) - if self.tx_storage.can_validate_full(blk): + if self.dependencies.tx_storage.can_validate_full(blk): try: - self.vertex_handler.on_new_vertex(blk, fails_silently=False) + self.dependencies.vertex_handler.on_new_vertex(blk, fails_silently=False) except HathorError: self.fails(InvalidVertexError(blk.hash.hex())) return diff --git a/hathor/p2p/sync_v2/factory.py b/hathor/p2p/sync_v2/factory.py index b9be356b3..4e46c8dfb 100644 --- a/hathor/p2p/sync_v2/factory.py +++ b/hathor/p2p/sync_v2/factory.py @@ -14,38 +14,21 @@ from typing import TYPE_CHECKING -from hathor.conf.settings import HathorSettings -from hathor.p2p.manager import ConnectionsManager +from hathor.p2p import P2PDependencies from hathor.p2p.sync_agent import SyncAgent from hathor.p2p.sync_factory import SyncAgentFactory from hathor.p2p.sync_v2.agent import NodeBlockSync -from hathor.reactor import ReactorProtocol as Reactor -from hathor.transaction.vertex_parser import VertexParser -from hathor.vertex_handler import VertexHandler if TYPE_CHECKING: from hathor.p2p.protocol import HathorProtocol class SyncV2Factory(SyncAgentFactory): - def __init__( - self, - settings: HathorSettings, - connections: ConnectionsManager, - *, - vertex_parser: VertexParser, - vertex_handler: VertexHandler, - ): - self._settings = settings - self.connections = connections - self.vertex_parser = vertex_parser - self.vertex_handler = vertex_handler + def __init__(self, dependencies: P2PDependencies) -> None: + self.dependencies = dependencies - def create_sync_agent(self, protocol: 'HathorProtocol', reactor: Reactor) -> SyncAgent: + def create_sync_agent(self, protocol: 'HathorProtocol') -> SyncAgent: return NodeBlockSync( - self._settings, - protocol, - reactor=reactor, - vertex_parser=self.vertex_parser, - vertex_handler=self.vertex_handler, + protocol=protocol, + dependencies=self.dependencies, ) diff --git a/hathor/p2p/sync_v2/mempool.py b/hathor/p2p/sync_v2/mempool.py index 03651642e..c7ff2c363 100644 --- a/hathor/p2p/sync_v2/mempool.py +++ b/hathor/p2p/sync_v2/mempool.py @@ -19,6 +19,7 @@ from twisted.internet.defer import Deferred, inlineCallbacks from hathor.exception import InvalidNewTransaction +from hathor.p2p import P2PDependencies from hathor.transaction import BaseTransaction if TYPE_CHECKING: @@ -30,15 +31,14 @@ class SyncMempoolManager: """Manage the sync-v2 mempool with one peer. """ - def __init__(self, sync_agent: 'NodeBlockSync'): + def __init__(self, sync_agent: 'NodeBlockSync', *, dependencies: P2PDependencies): """Initialize the sync-v2 mempool manager.""" self.log = logger.new(peer=sync_agent.protocol.get_short_peer_id()) + self.dependencies = dependencies # Shortcuts. self.sync_agent = sync_agent - self.vertex_handler = self.sync_agent.vertex_handler - self.tx_storage = self.sync_agent.tx_storage - self.reactor = self.sync_agent.reactor + self.reactor = dependencies.reactor self._deferred: Optional[Deferred[bool]] = None @@ -90,7 +90,7 @@ def _unsafe_run(self) -> Generator[Deferred, Any, bool]: if not self.missing_tips: # No missing tips? Let's get them! tx_hashes: list[bytes] = yield self.sync_agent.get_tips() - self.missing_tips.update(h for h in tx_hashes if not self.tx_storage.transaction_exists(h)) + self.missing_tips.update(h for h in tx_hashes if not self.dependencies.tx_storage.transaction_exists(h)) while self.missing_tips: self.log.debug('We have missing tips! Let\'s start!', missing_tips=[x.hex() for x in self.missing_tips]) @@ -127,20 +127,20 @@ def _next_missing_dep(self, tx: BaseTransaction) -> Optional[bytes]: """Get the first missing dependency found of tx.""" assert not tx.is_block for txin in tx.inputs: - if not self.tx_storage.transaction_exists(txin.tx_id): + if not self.dependencies.tx_storage.transaction_exists(txin.tx_id): return txin.tx_id for parent in tx.parents: - if not self.tx_storage.transaction_exists(parent): + if not self.dependencies.tx_storage.transaction_exists(parent): return parent return None def _add_tx(self, tx: BaseTransaction) -> None: """Add tx to the DAG.""" self.missing_tips.discard(tx.hash) - if self.tx_storage.transaction_exists(tx.hash): + if self.dependencies.tx_storage.transaction_exists(tx.hash): return try: - success = self.vertex_handler.on_new_vertex(tx, fails_silently=False) + success = self.dependencies.vertex_handler.on_new_vertex(tx, fails_silently=False) if success: self.sync_agent.protocol.connections.send_tx_to_peers(tx) except InvalidNewTransaction: diff --git a/hathor/p2p/sync_v2/streamers.py b/hathor/p2p/sync_v2/streamers.py index 5b215102f..11308d2e7 100644 --- a/hathor/p2p/sync_v2/streamers.py +++ b/hathor/p2p/sync_v2/streamers.py @@ -19,6 +19,7 @@ from twisted.internet.interfaces import IConsumer, IDelayedCall, IPushProducer from zope.interface import implementer +from hathor.p2p import P2PDependencies from hathor.transaction import BaseTransaction, Block, Transaction from hathor.transaction.storage.traversal import BFSOrderWalk from hathor.util import not_none @@ -68,7 +69,6 @@ def __str__(self): class _StreamingServerBase: def __init__(self, sync_agent: 'NodeBlockSync', *, limit: int = DEFAULT_STREAMING_LIMIT): self.sync_agent = sync_agent - self.tx_storage = self.sync_agent.tx_storage self.protocol: 'HathorProtocol' = sync_agent.protocol assert self.protocol.transport is not None @@ -212,16 +212,20 @@ class TransactionsStreamingServer(_StreamingServerBase): should there be interruptions or issues. """ - def __init__(self, - sync_agent: 'NodeBlockSync', - start_from: list[BaseTransaction], - first_block: Block, - last_block: Block, - *, - limit: int = DEFAULT_STREAMING_LIMIT) -> None: + def __init__( + self, + sync_agent: 'NodeBlockSync', + start_from: list[BaseTransaction], + first_block: Block, + last_block: Block, + *, + limit: int = DEFAULT_STREAMING_LIMIT, + dependencies: P2PDependencies, + ) -> None: # XXX: is limit needed for tx streaming? Or let's always send all txs for # a block? Very unlikely we'll reach this limit super().__init__(sync_agent, limit=limit) + self.dependencies = dependencies self.first_block: Block = first_block self.last_block: Block = last_block @@ -232,7 +236,12 @@ def __init__(self, assert tx.get_metadata().first_block == self.first_block.hash self.current_block: Optional[Block] = self.first_block - self.bfs = BFSOrderWalk(self.tx_storage, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False) + self.bfs = BFSOrderWalk( + self.dependencies.tx_storage.get_vertex, + is_dag_verifications=True, + is_dag_funds=True, + is_left_to_right=False, + ) self.iter = self.get_iter() def _stop_streaming_server(self, response_code: StreamEnd) -> None: @@ -296,7 +305,7 @@ def send_next(self) -> None: # Check if tx is confirmed by the `self.current_block` or any next block. assert cur_metadata.first_block is not None assert self.current_block is not None - first_block = self.tx_storage.get_block(cur_metadata.first_block) + first_block = self.dependencies.tx_storage.get_block(cur_metadata.first_block) if not_none(first_block.static_metadata.height) < not_none(self.current_block.static_metadata.height): self.log.debug('skipping tx: out of current block') self.bfs.skip_neighbors(cur) diff --git a/hathor/p2p/sync_v2/transaction_streaming_client.py b/hathor/p2p/sync_v2/transaction_streaming_client.py index e784a41cc..d63f2fc96 100644 --- a/hathor/p2p/sync_v2/transaction_streaming_client.py +++ b/hathor/p2p/sync_v2/transaction_streaming_client.py @@ -18,6 +18,7 @@ from structlog import get_logger from twisted.internet.defer import Deferred, inlineCallbacks +from hathor.p2p import P2PDependencies from hathor.p2p.sync_v2.exception import ( InvalidVertexError, StreamingError, @@ -37,16 +38,18 @@ class TransactionStreamingClient: - def __init__(self, - sync_agent: 'NodeBlockSync', - partial_blocks: list['Block'], - *, - limit: int) -> None: + def __init__( + self, + sync_agent: 'NodeBlockSync', + partial_blocks: list['Block'], + *, + limit: int, + dependencies: P2PDependencies, + ) -> None: + self.dependencies = dependencies self.sync_agent = sync_agent self.protocol = self.sync_agent.protocol - self.tx_storage = self.sync_agent.tx_storage - self.verification_service = self.protocol.node.verification_service - self.reactor = sync_agent.reactor + self.reactor = self.dependencies.reactor self.log = logger.new(peer=self.protocol.get_short_peer_id()) @@ -153,7 +156,7 @@ def _process_transaction(self, tx: BaseTransaction) -> Generator[Any, Any, None] # Run basic verification. if not tx.is_genesis: try: - self.verification_service.verify_basic(tx) + self.dependencies.verification_service.verify_basic(tx) except TxValidationError as e: self.fails(InvalidVertexError(repr(e))) return @@ -194,7 +197,7 @@ def _process_transaction(self, tx: BaseTransaction) -> Generator[Any, Any, None] def _update_dependencies(self, tx: BaseTransaction) -> None: """Update _existing_deps and _waiting_for with the dependencies.""" for dep in tx.get_all_dependencies(): - if self.tx_storage.transaction_exists(dep) or dep in self._db: + if self.dependencies.tx_storage.transaction_exists(dep) or dep in self._db: self._existing_deps.add(dep) else: self._waiting_for.add(dep) diff --git a/hathor/prometheus.py b/hathor/prometheus.py index 5418c44d4..965d68443 100644 --- a/hathor/prometheus.py +++ b/hathor/prometheus.py @@ -183,7 +183,11 @@ def set_new_metrics(self) -> None: """ Update metric_gauges dict with new data from metrics """ for metric_name in METRIC_INFO.keys(): - self.metric_gauges[metric_name].set(getattr(self.metrics, metric_name)) + metric_attr = metric_name + if metric_attr == 'connected_peers': + # TODO: Improve this. Temporary backwards compatibility workaround after a rename in metrics. + metric_attr = 'ready_peers' + self.metric_gauges[metric_name].set(getattr(self.metrics, metric_attr)) self._set_rocksdb_tx_storage_metrics() self._set_new_peer_connection_metrics() diff --git a/hathor/reactor/memory_reactor.py b/hathor/reactor/memory_reactor.py new file mode 100644 index 000000000..2c32a706b --- /dev/null +++ b/hathor/reactor/memory_reactor.py @@ -0,0 +1,53 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping, Sequence +from typing import AnyStr + +from twisted.internet.interfaces import IProcessProtocol, IProcessTransport +from twisted.internet.task import Clock +from twisted.internet.testing import MemoryReactor as TwistedMemoryReactor + + +class MemoryReactor(TwistedMemoryReactor): + """A drop-in replacement for Twisted's own MemoryReactor that adds support for IReactorProcess.""" + + def run(self) -> None: + """ + We have to override TwistedMemoryReactor.run() because the original Twisted implementation weirdly calls stop() + inside run(), and we need the reactor running during our tests. + """ + self.running = True + + def spawnProcess( + self, + processProtocol: IProcessProtocol, + executable: bytes | str, + args: Sequence[bytes | str], + env: Mapping[AnyStr, AnyStr] | None = None, + path: bytes | str | None = None, + uid: int | None = None, + gid: int | None = None, + usePTY: bool = False, + childFDs: Mapping[int, int | str] | None = None, + ) -> IProcessTransport: + raise NotImplementedError + + +class MemoryReactorClock(MemoryReactor, Clock): + """A drop-in replacement for Twisted's own MemoryReactorClock that adds support for IReactorProcess.""" + + def __init__(self) -> None: + MemoryReactor.__init__(self) + Clock.__init__(self) diff --git a/hathor/reactor/reactor.py b/hathor/reactor/reactor.py index b92c80062..e33687af7 100644 --- a/hathor/reactor/reactor.py +++ b/hathor/reactor/reactor.py @@ -15,7 +15,7 @@ from typing import cast from structlog import get_logger -from twisted.internet.interfaces import IReactorCore, IReactorTCP, IReactorTime +from twisted.internet.interfaces import IReactorCore, IReactorProcess, IReactorSocket, IReactorTCP, IReactorTime from zope.interface.verify import verifyObject from hathor.reactor.reactor_protocol import ReactorProtocol @@ -76,6 +76,8 @@ def initialize_global_reactor(*, use_asyncio_reactor: bool = False) -> ReactorPr assert verifyObject(IReactorTime, twisted_reactor) is True assert verifyObject(IReactorCore, twisted_reactor) is True assert verifyObject(IReactorTCP, twisted_reactor) is True + assert verifyObject(IReactorProcess, twisted_reactor) is True + assert verifyObject(IReactorSocket, twisted_reactor) is True # We cast to ReactorProtocol, our own type that stubs the necessary Twisted zope interfaces, to aid typing. _reactor = cast(ReactorProtocol, twisted_reactor) diff --git a/hathor/reactor/reactor_process_protocol.py b/hathor/reactor/reactor_process_protocol.py new file mode 100644 index 000000000..0b709ee9c --- /dev/null +++ b/hathor/reactor/reactor_process_protocol.py @@ -0,0 +1,38 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections.abc import Mapping, Sequence +from typing import AnyStr, Protocol + +from twisted.internet.interfaces import IProcessProtocol, IProcessTransport, IReactorProcess +from zope.interface import implementer + + +@implementer(IReactorProcess) +class ReactorProcessProtocol(Protocol): + """A Python protocol that stubs Twisted's IReactorProcess interface.""" + + def spawnProcess( + self, + processProtocol: IProcessProtocol, + executable: bytes | str, + args: Sequence[bytes | str], + env: Mapping[AnyStr, AnyStr] | None = None, + path: bytes | str | None = None, + uid: int | None = None, + gid: int | None = None, + usePTY: bool = False, + childFDs: Mapping[int, int | str] | None = None, + ) -> IProcessTransport: + ... diff --git a/hathor/reactor/reactor_protocol.py b/hathor/reactor/reactor_protocol.py index 7c301d052..83d69a7b7 100644 --- a/hathor/reactor/reactor_protocol.py +++ b/hathor/reactor/reactor_protocol.py @@ -15,6 +15,8 @@ from typing import Protocol from hathor.reactor.reactor_core_protocol import ReactorCoreProtocol +from hathor.reactor.reactor_process_protocol import ReactorProcessProtocol +from hathor.reactor.reactor_socket_protocol import ReactorSocketProtocol from hathor.reactor.reactor_tcp_protocol import ReactorTCPProtocol from hathor.reactor.reactor_time_protocol import ReactorTimeProtocol @@ -23,9 +25,10 @@ class ReactorProtocol( ReactorCoreProtocol, ReactorTimeProtocol, ReactorTCPProtocol, + ReactorProcessProtocol, + ReactorSocketProtocol, Protocol, ): """ - A Python protocol that represents the intersection of Twisted's IReactorCore+IReactorTime+IReactorTCP interfaces. + A Python protocol that represents an intersection of the Twisted reactor interfaces that we use. """ - pass diff --git a/hathor/reactor/reactor_socket_protocol.py b/hathor/reactor/reactor_socket_protocol.py new file mode 100644 index 000000000..1ad07aa37 --- /dev/null +++ b/hathor/reactor/reactor_socket_protocol.py @@ -0,0 +1,45 @@ +# Copyright 2024 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from socket import AddressFamily +from typing import Protocol + +from twisted.internet.interfaces import IListeningPort, IReactorSocket +from twisted.internet.protocol import DatagramProtocol, ServerFactory +from zope.interface import implementer + + +@implementer(IReactorSocket) +class ReactorSocketProtocol(Protocol): + """A Python protocol that stubs Twisted's IReactorSocket interface.""" + + def adoptStreamPort( + self, + fileDescriptor: int, + addressFamily: AddressFamily, + factory: ServerFactory, + ) -> IListeningPort: + ... + + def adoptStreamConnection(self, fileDescriptor: int, addressFamily: AddressFamily, factory: ServerFactory) -> None: + ... + + def adoptDatagramPort( + self, + fileDescriptor: int, + addressFamily: AddressFamily, + protocol: DatagramProtocol, + maxPacketSize: int, + ) -> IListeningPort: + ... diff --git a/hathor/reward_lock/reward_lock.py b/hathor/reward_lock/reward_lock.py index 85b6871e8..856386e7d 100644 --- a/hathor/reward_lock/reward_lock.py +++ b/hathor/reward_lock/reward_lock.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + from typing import TYPE_CHECKING, Iterator, Optional from hathor.conf.settings import HathorSettings @@ -19,11 +21,11 @@ from hathor.util import not_none if TYPE_CHECKING: - from hathor.transaction.storage.vertex_storage_protocol import VertexStorageProtocol + from hathor.transaction.storage import TransactionStorage from hathor.transaction.transaction import RewardLockedInfo, Transaction -def iter_spent_rewards(tx: 'Transaction', storage: 'VertexStorageProtocol') -> Iterator[Block]: +def iter_spent_rewards(tx: 'Transaction', storage: TransactionStorage) -> Iterator[Block]: """Iterate over all the rewards being spent, assumes tx has been verified.""" for input_tx in tx.inputs: spent_tx = storage.get_vertex(input_tx.tx_id) @@ -41,7 +43,7 @@ def is_spent_reward_locked(settings: HathorSettings, tx: 'Transaction') -> bool: def get_spent_reward_locked_info( settings: HathorSettings, tx: 'Transaction', - storage: 'VertexStorageProtocol', + storage: TransactionStorage, ) -> Optional['RewardLockedInfo']: """Check if any input block reward is locked, returning the locked information if any, or None if they are all unlocked.""" @@ -54,7 +56,7 @@ def get_spent_reward_locked_info( return None -def get_minimum_best_height(storage: 'VertexStorageProtocol') -> int: +def get_minimum_best_height(storage: TransactionStorage) -> int: """Return the height of the current best block that shall be used for `min_height` verification.""" import math diff --git a/hathor/simulator/fake_connection.py b/hathor/simulator/fake_connection.py index b3a29afc9..aefd804b4 100644 --- a/hathor/simulator/fake_connection.py +++ b/hathor/simulator/fake_connection.py @@ -25,6 +25,7 @@ from hathor.p2p.peer import PrivatePeer from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.peer_id import PeerId +from hathor.transaction.storage import TransactionStorage if TYPE_CHECKING: from hathor.manager import HathorManager @@ -103,6 +104,14 @@ def entrypoint(self) -> PeerEndpoint: return entrypoint.with_id(self.manager1.my_peer.id) return entrypoint.with_id(self._fake_bootstrap_id) + @property + def peer_addr1(self) -> PeerAddress: + return PeerAddress.from_address(self.addr1) + + @property + def peer_addr2(self) -> PeerAddress: + return PeerAddress.from_address(self.addr2) + @property def proto1(self): return self._proto1 @@ -149,14 +158,18 @@ def is_both_synced(self, *, errmsgs: Optional[list[str]] = None) -> bool: self.log.debug('peer not synced', peer1_synced=state1_is_synced, peer2_synced=state2_is_synced) errmsgs.append('peer not synced') return False - [best_block_info1] = state1.protocol.node.tx_storage.get_n_height_tips(1) - [best_block_info2] = state2.protocol.node.tx_storage.get_n_height_tips(1) + [best_block_info1] = state1.dependencies.tx_storage.get_n_height_tips(1) + [best_block_info2] = state2.dependencies.tx_storage.get_n_height_tips(1) if best_block_info1.id != best_block_info2.id: self.log.debug('best block is different') errmsgs.append('best block is different') return False - tips1 = {i.data for i in state1.protocol.node.tx_storage.get_tx_tips()} - tips2 = {i.data for i in state2.protocol.node.tx_storage.get_tx_tips()} + tx_storage1 = state1.dependencies.tx_storage + tx_storage2 = state2.dependencies.tx_storage + assert isinstance(tx_storage1, TransactionStorage) + assert isinstance(tx_storage2, TransactionStorage) + tips1 = {i.data for i in tx_storage1.get_tx_tips()} + tips2 = {i.data for i in tx_storage2.get_tx_tips()} if tips1 != tips2: self.log.debug('tx tips are different') errmsgs.append('tx tips are different') @@ -277,12 +290,9 @@ def reconnect(self) -> None: # When _fake_bootstrap_id is set we don't pass the peer because that's how bootstrap calls connect_to() peer = self._proto1.my_peer.to_unverified_peer() if self._fake_bootstrap_id is False else None - self.manager2.connections.connect_to(self.entrypoint, peer) - - connecting_peers = list(self.manager2.connections.connecting_peers.values()) - for connecting_peer in connecting_peers: - if connecting_peer.entrypoint.addr == self.entrypoint.addr: - connecting_peer.endpoint_deferred.callback(self._proto2) + deferred = self.manager2.connections.connect_to(self.entrypoint, peer) + assert deferred is not None + deferred.callback(self._proto2) self.tr1 = HathorStringTransport(self._proto2.my_peer, peer_address=self.addr2) self.tr2 = HathorStringTransport(self._proto1.my_peer, peer_address=self.addr1) diff --git a/hathor/simulator/clock.py b/hathor/simulator/heap_clock.py similarity index 91% rename from hathor/simulator/clock.py rename to hathor/simulator/heap_clock.py index 3e0aeb4f5..92961a9c3 100644 --- a/hathor/simulator/clock.py +++ b/hathor/simulator/heap_clock.py @@ -17,9 +17,10 @@ from twisted.internet.base import DelayedCall from twisted.internet.interfaces import IDelayedCall, IReactorTime -from twisted.internet.testing import MemoryReactor from zope.interface import implementer +from hathor.reactor.memory_reactor import MemoryReactor + @implementer(IReactorTime) class HeapClock: @@ -94,10 +95,3 @@ class MemoryReactorHeapClock(MemoryReactor, HeapClock): def __init__(self): MemoryReactor.__init__(self) HeapClock.__init__(self) - - def run(self): - """ - We have to override MemoryReactor.run() because the original Twisted implementation weirdly calls stop() inside - run(), and we need the reactor running during our tests. - """ - self.running = True diff --git a/hathor/simulator/simulator.py b/hathor/simulator/simulator.py index a31862909..c74fadc3c 100644 --- a/hathor/simulator/simulator.py +++ b/hathor/simulator/simulator.py @@ -27,7 +27,7 @@ from hathor.feature_activation.feature_service import FeatureService from hathor.manager import HathorManager from hathor.p2p.peer import PrivatePeer -from hathor.simulator.clock import HeapClock, MemoryReactorHeapClock +from hathor.simulator.heap_clock import HeapClock, MemoryReactorHeapClock from hathor.simulator.miner.geometric_miner import GeometricMiner from hathor.simulator.patches import SimulatorCpuMiningService, SimulatorVertexVerifier from hathor.simulator.tx_generator import RandomTransactionGenerator diff --git a/hathor/sysctl/p2p/manager.py b/hathor/sysctl/p2p/manager.py index 9f9856a42..28d009607 100644 --- a/hathor/sysctl/p2p/manager.py +++ b/hathor/sysctl/p2p/manager.py @@ -235,7 +235,7 @@ def set_kill_connection(self, peer_id: str, force: bool = False) -> None: peer_id_obj = PeerId(peer_id) except ValueError: raise SysctlException('invalid peer-id') - conn = self.connections.connected_peers.get(peer_id_obj, None) + conn = self.connections.get_ready_peer_by_id(peer_id_obj) if conn is None: self.log.warn('Killing connection', peer_id=peer_id) raise SysctlException('peer-id is not connected') diff --git a/hathor/transaction/base_transaction.py b/hathor/transaction/base_transaction.py index 41cbab100..7b3faa091 100644 --- a/hathor/transaction/base_transaction.py +++ b/hathor/transaction/base_transaction.py @@ -672,7 +672,9 @@ def update_accumulated_weight(self, *, stop_value: float = inf, save_file: bool # directly verified by a block. from hathor.transaction.storage.traversal import BFSTimestampWalk - bfs_walk = BFSTimestampWalk(self.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) + bfs_walk = BFSTimestampWalk( + self.storage.get_vertex, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True + ) for tx in bfs_walk.run(self, skip_root=True): accumulated_weight += weight_to_work(tx.weight) if accumulated_weight > stop_value: diff --git a/hathor/transaction/block.py b/hathor/transaction/block.py index 9f5f5a06d..fc1a90510 100644 --- a/hathor/transaction/block.py +++ b/hathor/transaction/block.py @@ -360,7 +360,9 @@ def iter_transactions_in_this_block(self) -> Iterator[BaseTransaction]: """Return an iterator of the transactions that have this block as meta.first_block.""" from hathor.transaction.storage.traversal import BFSOrderWalk assert self.storage is not None - bfs = BFSOrderWalk(self.storage, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False) + bfs = BFSOrderWalk( + self.storage.get_vertex, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False + ) for tx in bfs.run(self, skip_root=True): tx_meta = tx.get_metadata() if tx_meta.first_block != self.hash: diff --git a/hathor/transaction/storage/__init__.py b/hathor/transaction/storage/__init__.py index 4fbdd6ae7..e46ff6035 100644 --- a/hathor/transaction/storage/__init__.py +++ b/hathor/transaction/storage/__init__.py @@ -15,7 +15,6 @@ from hathor.transaction.storage.cache_storage import TransactionCacheStorage from hathor.transaction.storage.memory_storage import TransactionMemoryStorage from hathor.transaction.storage.transaction_storage import TransactionStorage -from hathor.transaction.storage.vertex_storage_protocol import VertexStorageProtocol try: from hathor.transaction.storage.rocksdb_storage import TransactionRocksDBStorage @@ -27,5 +26,4 @@ 'TransactionMemoryStorage', 'TransactionCacheStorage', 'TransactionRocksDBStorage', - 'VertexStorageProtocol' ] diff --git a/hathor/transaction/storage/transaction_storage.py b/hathor/transaction/storage/transaction_storage.py index a6ee50aa9..999486d6a 100644 --- a/hathor/transaction/storage/transaction_storage.py +++ b/hathor/transaction/storage/transaction_storage.py @@ -533,12 +533,14 @@ def get_transaction(self, hash_bytes: bytes) -> BaseTransaction: self.post_get_validation(tx) return tx - def get_block_by_height(self, height: int) -> Optional[Block]: - """Return a block in the best blockchain from the height index. This is fast.""" + def get_block_id_by_height(self, height: int) -> VertexId | None: assert self.indexes is not None - ancestor_hash = self.indexes.height.get(height) + return self.indexes.height.get(height) - return None if ancestor_hash is None else self.get_block(ancestor_hash) + def get_block_by_height(self, height: int) -> Optional[Block]: + """Return a block in the best blockchain from the height index. This is fast.""" + block_id = self.get_block_id_by_height(height) + return None if block_id is None else self.get_block(block_id) def get_metadata(self, hash_bytes: bytes) -> Optional[TransactionMetadata]: """Returns the transaction metadata with hash `hash_bytes`. @@ -1012,7 +1014,7 @@ def iter_mempool_from_tx_tips(self) -> Iterator[Transaction]: from hathor.transaction.storage.traversal import BFSTimestampWalk root = self.iter_mempool_tips_from_tx_tips() - walk = BFSTimestampWalk(self, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False) + walk = BFSTimestampWalk(self.get_vertex, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False) for tx in walk.run(root): tx_meta = tx.get_metadata() # XXX: skip blocks and tx-tips that have already been confirmed @@ -1137,6 +1139,11 @@ def partial_vertex_exists(self, vertex_id: VertexId) -> bool: with self.allow_partially_validated_context(): return self.transaction_exists(vertex_id) + def get_mempool_tips(self) -> set[VertexId]: + assert self.indexes is not None + assert self.indexes.mempool_tips is not None + return self.indexes.mempool_tips.get() + class BaseTransactionStorage(TransactionStorage): indexes: Optional[IndexesManager] diff --git a/hathor/transaction/storage/traversal.py b/hathor/transaction/storage/traversal.py index 7900cb8d6..9f2e780ac 100644 --- a/hathor/transaction/storage/traversal.py +++ b/hathor/transaction/storage/traversal.py @@ -18,11 +18,10 @@ from abc import ABC, abstractmethod from collections import deque from itertools import chain -from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Union +from typing import TYPE_CHECKING, Callable, Iterable, Iterator, Optional, Union if TYPE_CHECKING: - from hathor.transaction import BaseTransaction # noqa: F401 - from hathor.transaction.storage import VertexStorageProtocol + from hathor.transaction import BaseTransaction, Vertex # noqa: F401 from hathor.types import VertexId @@ -50,7 +49,7 @@ class GenericWalk(ABC): def __init__( self, - storage: VertexStorageProtocol, + vertex_getter: Callable[[VertexId], Vertex], *, is_dag_funds: bool = False, is_dag_verifications: bool = False, @@ -64,7 +63,7 @@ def __init__( :param is_dag_verifications: Add neighbors from the DAG of verifications :param is_left_to_right: Decide which side of the DAG we will walk to """ - self.storage = storage + self.vertex_getter = vertex_getter self.seen = set() self.is_dag_funds = is_dag_funds @@ -119,7 +118,7 @@ def add_neighbors(self, tx: 'BaseTransaction') -> None: for _hash in it: if _hash not in self.seen: self.seen.add(_hash) - neighbor = self.storage.get_vertex(_hash) + neighbor = self.vertex_getter(_hash) self._push_visit(neighbor) def skip_neighbors(self, tx: 'BaseTransaction') -> None: @@ -164,14 +163,14 @@ class BFSTimestampWalk(GenericWalk): def __init__( self, - storage: VertexStorageProtocol, + vertex_getter: Callable[[VertexId], Vertex], *, is_dag_funds: bool = False, is_dag_verifications: bool = False, is_left_to_right: bool = True, ) -> None: super().__init__( - storage, + vertex_getter=vertex_getter, is_dag_funds=is_dag_funds, is_dag_verifications=is_dag_verifications, is_left_to_right=is_left_to_right @@ -200,14 +199,14 @@ class BFSOrderWalk(GenericWalk): def __init__( self, - storage: VertexStorageProtocol, + vertex_getter: Callable[[VertexId], Vertex], *, is_dag_funds: bool = False, is_dag_verifications: bool = False, is_left_to_right: bool = True, ) -> None: super().__init__( - storage, + vertex_getter=vertex_getter, is_dag_funds=is_dag_funds, is_dag_verifications=is_dag_verifications, is_left_to_right=is_left_to_right @@ -231,14 +230,14 @@ class DFSWalk(GenericWalk): def __init__( self, - storage: VertexStorageProtocol, + vertex_getter: Callable[[VertexId], Vertex], *, is_dag_funds: bool = False, is_dag_verifications: bool = False, is_left_to_right: bool = True, ) -> None: super().__init__( - storage, + vertex_getter=vertex_getter, is_dag_funds=is_dag_funds, is_dag_verifications=is_dag_verifications, is_left_to_right=is_left_to_right diff --git a/hathor/transaction/storage/vertex_storage_protocol.py b/hathor/transaction/storage/vertex_storage_protocol.py deleted file mode 100644 index a35b3cd78..000000000 --- a/hathor/transaction/storage/vertex_storage_protocol.py +++ /dev/null @@ -1,48 +0,0 @@ -# Copyright 2024 Hathor Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from abc import abstractmethod -from typing import Protocol - -from hathor.transaction import BaseTransaction, Block -from hathor.types import VertexId - - -class VertexStorageProtocol(Protocol): - """ - This Protocol currently represents a subset of TransactionStorage methods. Its main use case is for verification - methods that can receive a RocksDB storage or an ephemeral simple memory storage. - - Therefore, objects returned by this protocol may or may not have an `object.storage` pointer set. - """ - - @abstractmethod - def get_vertex(self, vertex_id: VertexId) -> BaseTransaction: - """Return a vertex from the storage.""" - raise NotImplementedError - - @abstractmethod - def get_block(self, block_id: VertexId) -> Block: - """Return a block from the storage.""" - raise NotImplementedError - - @abstractmethod - def get_parent_block(self, block: Block) -> Block: - """Get the parent block of a block.""" - raise NotImplementedError - - @abstractmethod - def get_best_block_tips(self) -> list[VertexId]: - """Return a list of blocks that are heads in a best chain.""" - raise NotImplementedError diff --git a/hathor/verification/transaction_verifier.py b/hathor/verification/transaction_verifier.py index 906df38c2..cb7893b5f 100644 --- a/hathor/verification/transaction_verifier.py +++ b/hathor/verification/transaction_verifier.py @@ -51,8 +51,6 @@ def __init__(self, *, settings: HathorSettings, daa: DifficultyAdjustmentAlgorit def verify_parents_basic(self, tx: Transaction) -> None: """Verify number and non-duplicity of parents.""" - assert tx.storage is not None - # check if parents are duplicated parents_set = set(tx.parents) if len(tx.parents) > len(parents_set): diff --git a/hathor/websocket/factory.py b/hathor/websocket/factory.py index b96dcbdff..a8f432dd5 100644 --- a/hathor/websocket/factory.py +++ b/hathor/websocket/factory.py @@ -158,7 +158,7 @@ def _send_metrics(self): 'blocks': self.metrics.blocks, 'best_block_height': self.metrics.best_block_height, 'hash_rate': self.metrics.hash_rate, - 'peers': self.metrics.connected_peers, + 'peers': self.metrics.ready_peers, 'type': 'dashboard:metrics', 'time': self.reactor.seconds(), }) diff --git a/tests/cli/test_events_simulator.py b/tests/cli/test_events_simulator.py index 2a4ee941f..826d73737 100644 --- a/tests/cli/test_events_simulator.py +++ b/tests/cli/test_events_simulator.py @@ -17,13 +17,13 @@ from hathor.cli.events_simulator.event_forwarding_websocket_factory import EventForwardingWebsocketFactory from hathor.cli.events_simulator.events_simulator import create_parser, execute from hathor.conf.get_settings import get_global_settings -from tests.test_memory_reactor_clock import TestMemoryReactorClock +from hathor.reactor.memory_reactor import MemoryReactorClock def test_events_simulator() -> None: parser = create_parser() args = parser.parse_args(['--scenario', 'ONLY_LOAD']) - reactor = TestMemoryReactorClock() + reactor = MemoryReactorClock() execute(args, reactor) reactor.advance(1) diff --git a/tests/event/websocket/test_factory.py b/tests/event/websocket/test_factory.py index 24feeab98..294206ff5 100644 --- a/tests/event/websocket/test_factory.py +++ b/tests/event/websocket/test_factory.py @@ -21,7 +21,7 @@ from hathor.event.websocket.factory import EventWebsocketFactory from hathor.event.websocket.protocol import EventWebsocketProtocol from hathor.event.websocket.response import EventResponse, InvalidRequestType -from hathor.simulator.clock import MemoryReactorHeapClock +from hathor.simulator.heap_clock import MemoryReactorHeapClock from tests.utils import EventMocker diff --git a/tests/others/test_metrics.py b/tests/others/test_metrics.py index b46f6985b..8a8dc546e 100644 --- a/tests/others/test_metrics.py +++ b/tests/others/test_metrics.py @@ -5,7 +5,7 @@ from hathor.p2p.manager import PeerConnectionsMetrics from hathor.p2p.peer import PrivatePeer -from hathor.p2p.peer_endpoint import PeerEndpoint +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.protocol import HathorProtocol from hathor.pubsub import HathorEvents from hathor.simulator.utils import add_new_blocks @@ -42,7 +42,7 @@ def test_p2p_network_events(self): # Assertion self.assertEquals(manager.metrics.connecting_peers, 3) self.assertEquals(manager.metrics.handshaking_peers, 4) - self.assertEquals(manager.metrics.connected_peers, 5) + self.assertEquals(manager.metrics.ready_peers, 5) self.assertEquals(manager.metrics.known_peers, 6) manager.metrics.stop() @@ -60,25 +60,35 @@ def test_connections_manager_integration(self): wallet = Wallet(directory=tmpdir) wallet.unlock(b'teste') manager = self.create_peer('testnet', tx_storage=tx_storage, wallet=wallet) + p2p_manager = manager.connections - manager.connections.verified_peer_storage.update({ + p2p_manager.verified_peer_storage.update({ "1": PrivatePeer.auto_generated(), "2": PrivatePeer.auto_generated(), "3": PrivatePeer.auto_generated(), }) - manager.connections.connected_peers.update({"1": Mock(), "2": Mock()}) - manager.connections.handshaking_peers.update({Mock()}) + peer1 = Mock() + peer1.addr = PeerAddress.parse('tcp://localhost:40403') + peer2 = Mock() + peer2.addr = PeerAddress.parse('tcp://localhost:40404') + peer3 = Mock() + peer3.addr = PeerAddress.parse('tcp://localhost:40405') + p2p_manager._connections.on_connected(protocol=peer1) + p2p_manager._connections.on_connected(protocol=peer2) + p2p_manager._connections.on_connected(protocol=peer3) + p2p_manager._connections.on_ready(addr=peer1.addr, peer_id=Mock()) + p2p_manager._connections.on_ready(addr=peer2.addr, peer_id=Mock()) # Execution endpoint = PeerEndpoint.parse('tcp://127.0.0.1:8005') # This will trigger sending to the pubsub one of the network events - manager.connections.connect_to(endpoint, use_ssl=True) + manager.connections.connect_to(endpoint) self.run_to_completion() # Assertion self.assertEquals(manager.metrics.known_peers, 3) - self.assertEquals(manager.metrics.connected_peers, 2) + self.assertEquals(manager.metrics.ready_peers, 2) self.assertEquals(manager.metrics.handshaking_peers, 1) self.assertEquals(manager.metrics.connecting_peers, 1) @@ -215,18 +225,20 @@ def test_peer_connections_data_collection(self): self.use_memory_storage = True manager = self.create_peer('testnet') self.assertIsInstance(manager.tx_storage, TransactionMemoryStorage) - - my_peer = manager.my_peer + port = 40403 def build_hathor_protocol(): + nonlocal port protocol = HathorProtocol( - my_peer=my_peer, + my_peer=manager.my_peer, p2p_manager=manager.connections, use_ssl=False, inbound=False, - settings=self._settings + addr=PeerAddress.parse(f'tcp://localhost:{port}'), + dependencies=Mock(), ) protocol._peer = PrivatePeer.auto_generated().to_public_peer() + port += 1 return protocol @@ -246,9 +258,12 @@ def build_hathor_protocol(): fake_peers[2].metrics.discarded_blocks = 3 fake_peers[2].metrics.discarded_txs = 3 - manager.connections.connections.add(fake_peers[0]) - manager.connections.connections.add(fake_peers[1]) - manager.connections.connections.add(fake_peers[2]) + manager.connections._connections._addr_by_id[fake_peers[0].peer.id] = fake_peers[0].addr + manager.connections._connections._addr_by_id[fake_peers[1].peer.id] = fake_peers[1].addr + manager.connections._connections._addr_by_id[fake_peers[2].peer.id] = fake_peers[2].addr + manager.connections._connections._ready[fake_peers[0].addr] = fake_peers[0] + manager.connections._connections._ready[fake_peers[1].addr] = fake_peers[1] + manager.connections._connections._ready[fake_peers[2].addr] = fake_peers[2] # Execution manager.metrics._collect_data() diff --git a/tests/p2p/test_bootstrap.py b/tests/p2p/test_bootstrap.py index 82aa932bb..16bec0d0a 100644 --- a/tests/p2p/test_bootstrap.py +++ b/tests/p2p/test_bootstrap.py @@ -1,6 +1,8 @@ from typing import Callable +from unittest.mock import Mock from twisted.internet.defer import Deferred +from twisted.internet.interfaces import IProtocol from twisted.names.dns import TXT, A, Record_A, Record_TXT, RRHeader from typing_extensions import override @@ -9,28 +11,38 @@ from hathor.p2p.peer_discovery import DNSPeerDiscovery, PeerDiscovery from hathor.p2p.peer_discovery.dns import LookupResult from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint, Protocol +from hathor.p2p.peer_id import PeerId from hathor.pubsub import PubSubManager +from hathor.reactor.memory_reactor import MemoryReactorClock from tests import unittest -from tests.test_memory_reactor_clock import TestMemoryReactorClock class MockPeerDiscovery(PeerDiscovery): - def __init__(self, mocked_host_ports: list[tuple[str, int]]): - self.mocked_host_ports = mocked_host_ports + def __init__(self, mocked_addrs: list[tuple[str, int, str | None]]): + self.mocked_addrs = mocked_addrs @override - async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], None]) -> None: - for host, port in self.mocked_host_ports: - addr = PeerAddress(Protocol.TCP, host, port) - connect_to(addr.with_id()) + async def discover_and_connect(self, connect_to: Callable[[PeerEndpoint], Deferred[IProtocol] | None]) -> None: + for host, port, peer_id_str in self.mocked_addrs: + peer_id = PeerId(peer_id_str) if peer_id_str is not None else None + connect_to(PeerAddress(Protocol.TCP, host, port).with_id(peer_id)) class MockDNSPeerDiscovery(DNSPeerDiscovery): - def __init__(self, reactor: TestMemoryReactorClock, bootstrap_txt: list[tuple[str, int]], bootstrap_a: list[str]): + def __init__( + self, + reactor: MemoryReactorClock, + bootstrap_txt: list[tuple[str, int, str | None]], + bootstrap_a: list[str], + ): super().__init__(['test.example']) self.reactor = reactor self.mocked_lookup_a = [RRHeader(type=A, payload=Record_A(address)) for address in bootstrap_a] - txt_entries = [f'tcp://{h}:{p}'.encode() for h, p in bootstrap_txt] + txt_entries = [] + for host, port, peer_id_str in bootstrap_txt: + peer_id = PeerId(peer_id_str) if peer_id_str is not None else None + addr_and_id = PeerAddress(Protocol.TCP, host, port).with_id(peer_id) + txt_entries.append(str(addr_and_id).encode()) self.mocked_lookup_txt = [RRHeader(type=TXT, payload=Record_TXT(*txt_entries))] def do_lookup_address(self, host: str) -> Deferred[LookupResult]: @@ -50,46 +62,64 @@ class BootstrapTestCase(unittest.TestCase): def test_mock_discovery(self) -> None: pubsub = PubSubManager(self.clock) peer = PrivatePeer.auto_generated() - connections = ConnectionsManager(self._settings, self.clock, peer, pubsub, True, self.rng, True) + connections = ConnectionsManager( + dependencies=Mock(), + my_peer=peer, + pubsub=pubsub, + rng=self.rng, + ssl=True, + ) host_ports1 = [ - ('foobar', 1234), - ('127.0.0.99', 9999), + ('foobar', 1234, None), + ('127.0.0.99', 9999, None), + ('192.168.0.1', 1111, 'c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696') ] host_ports2 = [ - ('baz', 456), - ('127.0.0.88', 8888), + ('baz', 456, None), + ('127.0.0.88', 8888, None), + ('192.168.0.2', 2222, 'bc5119d47bb4ea7c19100bd97fb11f36970482108bd3d45ff101ee4f6bbec872') ] connections.add_peer_discovery(MockPeerDiscovery(host_ports1)) connections.add_peer_discovery(MockPeerDiscovery(host_ports2)) connections.do_discovery() self.clock.advance(1) - connecting_entrypoints = {str(entrypoint) for entrypoint, _ in connections.connecting_peers.values()} - self.assertEqual(connecting_entrypoints, { + connecting_addrs = {str(addr) for addr in connections._connections.connecting_outbound_peers()} + self.assertEqual(connecting_addrs, { 'tcp://foobar:1234', 'tcp://127.0.0.99:9999', 'tcp://baz:456', 'tcp://127.0.0.88:8888', + 'tcp://192.168.0.1:1111', + 'tcp://192.168.0.2:2222', }) def test_dns_discovery(self) -> None: pubsub = PubSubManager(self.clock) peer = PrivatePeer.auto_generated() - connections = ConnectionsManager(self._settings, self.clock, peer, pubsub, True, self.rng, True) + connections = ConnectionsManager( + dependencies=Mock(), + my_peer=peer, + pubsub=pubsub, + rng=self.rng, + ssl=True, + ) bootstrap_a = [ '127.0.0.99', '127.0.0.88', ] bootstrap_txt = [ - ('foobar', 1234), - ('baz', 456), + ('foobar', 1234, None), + ('baz', 456, None), + ('192.168.0.1', 1111, 'c0f19299c2a4dcbb6613a14011ff07b63d6cb809e4cee25e9c1ccccdd6628696') ] connections.add_peer_discovery(MockDNSPeerDiscovery(self.clock, bootstrap_txt, bootstrap_a)) connections.do_discovery() self.clock.advance(1) - connecting_entrypoints = {str(entrypoint) for entrypoint, _ in connections.connecting_peers.values()} - self.assertEqual(connecting_entrypoints, { + connecting_addrs = {str(addr) for addr in connections._connections.connecting_outbound_peers()} + self.assertEqual(connecting_addrs, { 'tcp://127.0.0.99:40403', 'tcp://127.0.0.88:40403', 'tcp://foobar:1234', 'tcp://baz:456', + 'tcp://192.168.0.1:1111' }) diff --git a/tests/p2p/test_connections.py b/tests/p2p/test_connections.py index b27897ca4..64482a2bc 100644 --- a/tests/p2p/test_connections.py +++ b/tests/p2p/test_connections.py @@ -18,8 +18,8 @@ def test_manager_connections(self) -> None: manager: HathorManager = self.create_peer('testnet', enable_sync_v1=True, enable_sync_v2=False) endpoint = PeerEndpoint.parse('tcp://127.0.0.1:8005') - manager.connections.connect_to(endpoint, use_ssl=True) + manager.connections.connect_to(endpoint) - self.assertIn(endpoint, manager.connections.iter_not_ready_endpoints()) - self.assertNotIn(endpoint, manager.connections.iter_ready_connections()) - self.assertNotIn(endpoint, manager.connections.iter_all_connections()) + self.assertIn(endpoint.addr, manager.connections.iter_not_ready_endpoints()) + self.assertNotIn(endpoint.addr, [conn.addr for conn in manager.connections.iter_ready_connections()]) + self.assertNotIn(endpoint.addr, [conn.addr for conn in manager.connections.get_connected_peers()]) diff --git a/tests/p2p/test_get_best_blockchain.py b/tests/p2p/test_get_best_blockchain.py index ff0d95149..67ed0bc74 100644 --- a/tests/p2p/test_get_best_blockchain.py +++ b/tests/p2p/test_get_best_blockchain.py @@ -32,8 +32,8 @@ def test_get_best_blockchain(self) -> None: self.simulator.add_connection(conn12) self.simulator.run(3600) - connected_peers1 = list(manager1.connections.connected_peers.values()) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers1 = list(manager1.connections.iter_ready_connections()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers1)) self.assertEqual(1, len(connected_peers2)) @@ -94,13 +94,13 @@ def test_handle_get_best_blockchain(self) -> None: self.assertTrue(self.simulator.run(7200, trigger=trigger)) miner.stop() - connected_peers1 = list(manager1.connections.connected_peers.values()) + connected_peers1 = list(manager1.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers1)) protocol2 = connected_peers1[0] state2 = protocol2.state assert isinstance(state2, ReadyState) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] state1 = protocol1.state @@ -134,7 +134,7 @@ def test_handle_get_best_blockchain(self) -> None: self.assertFalse(conn12.tr1.disconnecting) self.assertFalse(conn12.tr2.disconnecting) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] state1 = protocol1.state @@ -153,13 +153,13 @@ def test_handle_best_blockchain(self) -> None: self.simulator.add_connection(conn12) self.simulator.run(60) - connected_peers1 = list(manager1.connections.connected_peers.values()) + connected_peers1 = list(manager1.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers1)) protocol2 = connected_peers1[0] state2 = protocol2.state assert isinstance(state2, ReadyState) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] state1 = protocol1.state @@ -204,27 +204,27 @@ def test_node_without_get_best_blockchain_capability(self) -> None: manager1 = self.create_peer() manager2 = self.create_peer() - cababilities_without_get_best_blockchain = [ + capabilities_without_get_best_blockchain = [ self._settings.CAPABILITY_WHITELIST, self._settings.CAPABILITY_SYNC_VERSION, ] - manager2.capabilities = cababilities_without_get_best_blockchain + manager2.connections.dependencies.capabilities = capabilities_without_get_best_blockchain conn12 = FakeConnection(manager1, manager2, latency=0.05) self.simulator.add_connection(conn12) self.simulator.run(60) # assert the nodes are connected - connected_peers1 = list(manager1.connections.connected_peers.values()) + connected_peers1 = list(manager1.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers1)) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) # assert the peers have the proper capabilities protocol2 = connected_peers1[0] - self.assertTrue(protocol2.capabilities.issuperset(set(cababilities_without_get_best_blockchain))) + self.assertTrue(protocol2.capabilities.issuperset(set(capabilities_without_get_best_blockchain))) protocol1 = connected_peers2[0] - default_capabilities = manager2.get_default_capabilities() + default_capabilities = self._settings.get_default_capabilities() self.assertTrue(protocol1.capabilities.issuperset(set(default_capabilities))) # assert the peers don't engage in get_best_blockchain messages @@ -313,13 +313,13 @@ def test_stop_looping_on_exit(self) -> None: self.simulator.add_connection(conn12) self.simulator.run(60) - connected_peers1 = list(manager1.connections.connected_peers.values()) + connected_peers1 = list(manager1.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers1)) protocol2 = connected_peers1[0] state2 = protocol2.state assert isinstance(state2, ReadyState) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] state1 = protocol1.state diff --git a/tests/p2p/test_peer_id.py b/tests/p2p/test_peer_id.py index 56dfaf79b..0f22ccee5 100644 --- a/tests/p2p/test_peer_id.py +++ b/tests/p2p/test_peer_id.py @@ -2,13 +2,13 @@ import shutil import tempfile from typing import cast -from unittest.mock import Mock import pytest +from twisted.internet.address import IPv4Address from twisted.internet.interfaces import ITransport from hathor.p2p.peer import InvalidPeerIdException, PrivatePeer, PublicPeer, UnverifiedPeer -from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.peer_id import PeerId from hathor.p2p.peer_storage import VerifiedPeerStorage from tests import unittest @@ -149,7 +149,7 @@ def test_retry_connection(self) -> None: def test_validate_certificate(self) -> None: builder = TestBuilder() artifacts = builder.build() - protocol = artifacts.p2p_manager.server_factory.buildProtocol(Mock()) + protocol = artifacts.p2p_manager.server_factory.buildProtocol(IPv4Address('TCP', 'localhost', 40403)) peer = PrivatePeer.auto_generated() @@ -268,8 +268,8 @@ async def test_validate_entrypoint(self) -> None: peer.info.entrypoints = [PeerAddress.parse('tcp://127.0.0.1:40403')] # we consider that we are starting the connection to the peer - protocol = manager.connections.client_factory.buildProtocol('127.0.0.1') - protocol.entrypoint = PeerEndpoint.parse('tcp://127.0.0.1:40403') + protocol = manager.connections.client_factory.buildProtocol(IPv4Address('TCP', 'localhost', 40403)) + protocol.addr = PeerAddress.parse('tcp://127.0.0.1:40403') result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) # if entrypoint is an URI @@ -277,24 +277,12 @@ async def test_validate_entrypoint(self) -> None: result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) # test invalid. DNS in test mode will resolve to '127.0.0.1:40403' - protocol.entrypoint = PeerEndpoint.parse('tcp://45.45.45.45:40403') + protocol.addr = PeerAddress.parse('tcp://45.45.45.45:40403') result = await peer.info.validate_entrypoint(protocol) self.assertFalse(result) - # now test when receiving the connection - i.e. the peer starts it - protocol.entrypoint = None - peer.info.entrypoints = [PeerAddress.parse('tcp://127.0.0.1:40403')] - - from collections import namedtuple - DummyPeer = namedtuple('DummyPeer', 'host') - - class FakeTransport: - def getPeer(self) -> DummyPeer: - return DummyPeer(host='127.0.0.1') - protocol.transport = FakeTransport() - result = await peer.info.validate_entrypoint(protocol) - self.assertTrue(result) # if entrypoint is an URI + protocol.addr = PeerAddress.parse('tcp://127.0.0.1:40403') peer.info.entrypoints = [PeerAddress.parse('tcp://uri_name:40403')] result = await peer.info.validate_entrypoint(protocol) self.assertTrue(result) diff --git a/tests/p2p/test_protocol.py b/tests/p2p/test_protocol.py index 841a45929..2ae01fc35 100644 --- a/tests/p2p/test_protocol.py +++ b/tests/p2p/test_protocol.py @@ -2,6 +2,7 @@ from typing import Optional from unittest.mock import Mock, patch +import pytest from twisted.internet import defer from twisted.internet.protocol import Protocol from twisted.python.failure import Failure @@ -10,7 +11,7 @@ from hathor.p2p.manager import ConnectionsManager from hathor.p2p.messages import ProtocolMessages from hathor.p2p.peer import PrivatePeer -from hathor.p2p.peer_endpoint import PeerAddress +from hathor.p2p.peer_endpoint import PeerAddress, PeerEndpoint from hathor.p2p.protocol import HathorLineReceiver, HathorProtocol from hathor.simulator import FakeConnection from hathor.util import json_dumps, json_loadb @@ -117,8 +118,11 @@ def test_rate_limit(self) -> None: # Test empty disconnect self.conn.proto1.state = None - self.conn.proto1.connections = None - self.conn.proto1.on_disconnect(Failure(Exception())) + with pytest.raises(AssertionError): + # TODO: This raises because we are trying to disconnect a protocol with no state, but it's not possible + # for a protocol to have no state after it's handshaking. We have to update this when we introduce the + # new non-None initial state for protocols. + self.conn.proto1.on_disconnect(Failure(Exception())) def test_invalid_size(self) -> None: self.conn.tr1.clear() @@ -201,6 +205,44 @@ def test_valid_hello(self) -> None: self.assertFalse(self.conn.tr1.disconnecting) self.assertFalse(self.conn.tr2.disconnecting) + def test_invalid_duplicate_addr(self) -> None: + """ + We try to connect to an already connected entrypoint in each state, + and it should never add the new connection to connecting_outbound_peers. + """ + # We also specifically compare localhost with 127.0.0.1, because they are considered the same. + assert self.conn.addr2.type == 'TCP' and self.conn.addr2.host == '127.0.0.1' + entrypoint = PeerEndpoint.parse(f'tcp://localhost:{self.conn.addr2.port}') + + self.manager1.connections.connect_to(entrypoint) + assert self.manager1.connections._connections.connecting_outbound_peers() == set() + assert self.manager1.connections._connections.handshaking_peers() == {self.conn.peer_addr2: self.conn.proto1} + assert self.manager1.connections._connections.ready_peers() == {} + self._check_result_only_cmd(self.conn.peek_tr1_value(), b'HELLO') + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'HELLO') + + self.conn.run_one_step() # HELLO + self.manager1.connections.connect_to(entrypoint) + assert self.manager1.connections._connections.connecting_outbound_peers() == set() + assert self.manager1.connections._connections.handshaking_peers() == {self.conn.peer_addr2: self.conn.proto1} + assert self.manager1.connections._connections.ready_peers() == {} + self._check_result_only_cmd(self.conn.peek_tr1_value(), b'PEER-ID') + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'PEER-ID') + + self.conn.run_one_step() # PEER-ID + self.manager1.connections.connect_to(entrypoint) + assert self.manager1.connections._connections.connecting_outbound_peers() == set() + assert self.manager1.connections._connections.handshaking_peers() == {self.conn.peer_addr2: self.conn.proto1} + assert self.manager1.connections._connections.ready_peers() == {} + self._check_result_only_cmd(self.conn.peek_tr1_value(), b'READY') + self._check_result_only_cmd(self.conn.peek_tr2_value(), b'READY') + + self.conn.run_one_step() # READY + self.manager1.connections.connect_to(entrypoint) + assert self.manager1.connections._connections.connecting_outbound_peers() == set() + assert self.manager1.connections._connections.handshaking_peers() == {} + assert self.manager1.connections._connections.ready_peers() == {self.conn.peer_addr2: self.conn.proto1} + def test_invalid_same_peer_id(self) -> None: manager3 = self.create_peer(self.network, peer=self.peer1) conn = FakeConnection(self.manager1, manager3) @@ -251,21 +293,24 @@ def test_invalid_same_peer_id2(self) -> None: # one of the peers will close the connection. We don't know which one, as it depends # on the peer ids - if self.conn.tr1.disconnecting or self.conn.tr2.disconnecting: - conn_dead = self.conn + if bytes(self.peer1.id) > bytes(self.peer2.id): + tr_dead = self.conn.tr1 + tr_dead_value = self.conn.peek_tr1_value() + proto_alive = conn.proto2 conn_alive = conn - elif conn.tr1.disconnecting or conn.tr2.disconnecting: - conn_dead = conn - conn_alive = self.conn else: - raise Exception('It should never happen.') - self._check_result_only_cmd(conn_dead.peek_tr1_value() + conn_dead.peek_tr2_value(), b'ERROR') + tr_dead = conn.tr2 + tr_dead_value = conn.peek_tr2_value() + proto_alive = self.conn.proto1 + conn_alive = self.conn + + self._check_result_only_cmd(tr_dead_value, b'ERROR') # at this point, the connection must be closing as the error was detected on READY state - self.assertIn(True, [conn_dead.tr1.disconnecting, conn_dead.tr2.disconnecting]) - # check connected_peers - connected_peers = list(self.manager1.connections.connected_peers.values()) - self.assertEquals(1, len(connected_peers)) - self.assertIn(connected_peers[0], [conn_alive.proto1, conn_alive.proto2]) + self.assertTrue(tr_dead.disconnecting) + # check ready_peers + ready_peers = list(self.manager1.connections.iter_ready_connections()) + self.assertEquals(1, len(ready_peers)) + self.assertEquals(ready_peers[0], proto_alive) # connection is still up self.assertIsConnected(conn_alive) @@ -345,32 +390,32 @@ def test_send_invalid_unicode(self) -> None: self.assertTrue(self.conn.tr1.disconnecting) def test_on_disconnect(self) -> None: - self.assertIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) self.conn.disconnect(Failure(Exception('testing'))) - self.assertNotIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertNotIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) def test_on_disconnect_after_hello(self) -> None: self.conn.run_one_step() # HELLO - self.assertIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) self.conn.disconnect(Failure(Exception('testing'))) - self.assertNotIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertNotIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) def test_on_disconnect_after_peer(self) -> None: self.conn.run_one_step() # HELLO - self.assertIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) # No peer id in the peer_storage (known_peers) self.assertNotIn(self.peer2.id, self.manager1.connections.verified_peer_storage) # The peer READY now depends on a message exchange from both peers, so we need one more step self.conn.run_one_step() # PEER-ID self.conn.run_one_step() # READY - self.assertIn(self.conn.proto1, self.manager1.connections.connected_peers.values()) + self.assertIn(self.conn.proto1, self.manager1.connections.iter_ready_connections()) # Peer id 2 in the peer_storage (known_peers) after connection self.assertIn(self.peer2.id, self.manager1.connections.verified_peer_storage) - self.assertNotIn(self.conn.proto1, self.manager1.connections.handshaking_peers) + self.assertNotIn(self.conn.proto1, self.manager1.connections.iter_handshaking_peers()) self.conn.disconnect(Failure(Exception('testing'))) # Peer id 2 in the peer_storage (known_peers) after disconnection but before looping call self.assertIn(self.peer2.id, self.manager1.connections.verified_peer_storage) - self.assertNotIn(self.conn.proto1, self.manager1.connections.connected_peers.values()) + self.assertNotIn(self.conn.proto1, self.manager1.connections.iter_ready_connections()) self.clock.advance(10) # Peer id 2 removed from peer_storage (known_peers) after disconnection and after looping call @@ -386,9 +431,9 @@ def test_invalid_expected_peer_id(self) -> None: p2p_manager: ConnectionsManager = self.manager2.connections # Initially, manager1 and manager2 are handshaking, from the setup - assert p2p_manager.connecting_peers == {} - assert p2p_manager.handshaking_peers == {self.conn.proto2} - assert p2p_manager.connected_peers == {} + assert p2p_manager._connections.connecting_outbound_peers() == set() + assert p2p_manager._connections.handshaking_peers() == {self.conn.peer_addr1: self.conn.proto2} + assert p2p_manager._connections.ready_peers() == {} # We change our peer id (on manager1) new_peer = PrivatePeer.auto_generated() @@ -406,9 +451,9 @@ def test_invalid_expected_peer_id_bootstrap(self) -> None: p2p_manager: ConnectionsManager = self.manager1.connections # Initially, manager1 and manager2 are handshaking, from the setup - assert p2p_manager.connecting_peers == {} - assert p2p_manager.handshaking_peers == {self.conn.proto1} - assert p2p_manager.connected_peers == {} + assert p2p_manager._connections.connecting_outbound_peers() == set() + assert p2p_manager._connections.handshaking_peers() == {self.conn.peer_addr2: self.conn.proto1} + assert p2p_manager._connections.ready_peers() == {} # We create a new manager3, and use it as a bootstrap in manager1 peer3 = PrivatePeer.auto_generated() @@ -416,9 +461,12 @@ def test_invalid_expected_peer_id_bootstrap(self) -> None: conn = FakeConnection(manager1=manager3, manager2=self.manager1, fake_bootstrap_id=peer3.id) # Now manager1 and manager3 are handshaking - assert p2p_manager.connecting_peers == {} - assert p2p_manager.handshaking_peers == {self.conn.proto1, conn.proto2} - assert p2p_manager.connected_peers == {} + assert p2p_manager._connections.connecting_outbound_peers() == set() + assert p2p_manager._connections.handshaking_peers() == { + self.conn.peer_addr2: self.conn.proto1, + conn.peer_addr1: conn.proto2, + } + assert p2p_manager._connections.ready_peers() == {} # We change our peer id (on manager3) new_peer = PrivatePeer.auto_generated() @@ -436,18 +484,21 @@ def test_valid_unset_peer_id_bootstrap(self) -> None: p2p_manager: ConnectionsManager = self.manager1.connections # Initially, manager1 and manager2 are handshaking, from the setup - assert p2p_manager.connecting_peers == {} - assert p2p_manager.handshaking_peers == {self.conn.proto1} - assert p2p_manager.connected_peers == {} + assert p2p_manager._connections.connecting_outbound_peers() == set() + assert p2p_manager._connections.handshaking_peers() == {self.conn.peer_addr2: self.conn.proto1} + assert p2p_manager._connections.ready_peers() == {} # We create a new manager3, and use it as a bootstrap in manager1, but without the peer_id manager3: HathorManager = self.create_peer(self.network) conn = FakeConnection(manager1=manager3, manager2=self.manager1, fake_bootstrap_id=None) # Now manager1 and manager3 are handshaking - assert p2p_manager.connecting_peers == {} - assert p2p_manager.handshaking_peers == {self.conn.proto1, conn.proto2} - assert p2p_manager.connected_peers == {} + assert p2p_manager._connections.connecting_outbound_peers() == set() + assert p2p_manager._connections.handshaking_peers() == { + self.conn.peer_addr2: self.conn.proto1, + conn.peer_addr1: conn.proto2, + } + assert p2p_manager._connections.ready_peers() == {} # We change our peer id (on manager3) new_peer = PrivatePeer.auto_generated() diff --git a/tests/p2p/test_sync.py b/tests/p2p/test_sync.py index fc7712495..533d14192 100644 --- a/tests/p2p/test_sync.py +++ b/tests/p2p/test_sync.py @@ -2,6 +2,7 @@ from hathor.checkpoint import Checkpoint as cp from hathor.crypto.util import decode_address +from hathor.p2p import P2PDependencies from hathor.p2p.protocol import PeerIdState from hathor.p2p.sync_version import SyncVersion from hathor.simulator import FakeConnection @@ -268,13 +269,29 @@ def test_downloader(self) -> None: downloader = conn.proto2.connections.get_sync_factory(SyncVersion.V1_1).get_downloader() - node_sync1 = NodeSyncTimestamp( - conn.proto1, downloader, reactor=conn.proto1.node.reactor, vertex_parser=self.manager1.vertex_parser + p2p_dependencies1 = P2PDependencies( + reactor=self.manager1.reactor, + settings=self._settings, + vertex_parser=self.manager1.vertex_parser, + tx_storage=self.manager1.tx_storage, + vertex_handler=self.manager1.vertex_handler, + verification_service=self.manager1.verification_service, + capabilities=[], + whitelist_only=False, ) - node_sync1.start() - node_sync2 = NodeSyncTimestamp( - conn.proto2, downloader, reactor=conn.proto2.node.reactor, vertex_parser=manager2.vertex_parser + p2p_dependencies2 = P2PDependencies( + reactor=manager2.reactor, + settings=self._settings, + vertex_parser=manager2.vertex_parser, + tx_storage=manager2.tx_storage, + vertex_handler=manager2.vertex_handler, + verification_service=manager2.verification_service, + capabilities=[], + whitelist_only=False, ) + node_sync1 = NodeSyncTimestamp(conn.proto1, downloader, dependencies=p2p_dependencies1) + node_sync1.start() + node_sync2 = NodeSyncTimestamp(conn.proto2, downloader, dependencies=p2p_dependencies2) node_sync2.start() self.assertTrue(isinstance(conn.proto1.state, PeerIdState)) diff --git a/tests/p2p/test_sync_rate_limiter.py b/tests/p2p/test_sync_rate_limiter.py index 04d091c27..e550b7da4 100644 --- a/tests/p2p/test_sync_rate_limiter.py +++ b/tests/p2p/test_sync_rate_limiter.py @@ -31,7 +31,7 @@ def test_sync_rate_limiter(self) -> None: manager2.connections.disable_rate_limiter() manager2.connections.enable_rate_limiter(8, 2) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] assert isinstance(protocol1.state, ReadyState) @@ -64,7 +64,7 @@ def test_sync_rate_limiter_disconnect(self) -> None: connections.rate_limiter.reset(connections.GlobalRateLimiter.SEND_TIPS) connections.enable_rate_limiter(1, 1) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] @@ -114,7 +114,7 @@ def test_sync_rate_limiter_delayed_calls_draining(self) -> None: connections.rate_limiter.reset(connections.GlobalRateLimiter.SEND_TIPS) connections.enable_rate_limiter(1, 1) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] @@ -154,7 +154,7 @@ def test_sync_rate_limiter_delayed_calls_stop(self) -> None: connections.rate_limiter.reset(connections.GlobalRateLimiter.SEND_TIPS) connections.enable_rate_limiter(1, 1) - connected_peers2 = list(manager2.connections.connected_peers.values()) + connected_peers2 = list(manager2.connections.iter_ready_connections()) self.assertEqual(1, len(connected_peers2)) protocol1 = connected_peers2[0] diff --git a/tests/p2p/test_sync_v2.py b/tests/p2p/test_sync_v2.py index 842f75bda..04fde283b 100644 --- a/tests/p2p/test_sync_v2.py +++ b/tests/p2p/test_sync_v2.py @@ -197,7 +197,7 @@ def test_exceeds_streaming_and_mempool_limits(self) -> None: blk = manager1.tx_storage.get_best_block() tx_parents = [manager1.tx_storage.get_transaction(x) for x in blk.parents[1:]] self.assertEqual(len(tx_parents), 2) - dfs = DFSWalk(manager1.tx_storage, is_dag_verifications=True, is_left_to_right=False) + dfs = DFSWalk(manager1.tx_storage.get_vertex, is_dag_verifications=True, is_left_to_right=False) cnt = 0 for tx in dfs.run(tx_parents): if tx.get_metadata().first_block == blk.hash: diff --git a/tests/poa/test_poa_block_producer.py b/tests/poa/test_poa_block_producer.py index 3be849470..807570581 100644 --- a/tests/poa/test_poa_block_producer.py +++ b/tests/poa/test_poa_block_producer.py @@ -22,14 +22,14 @@ from hathor.consensus.poa import PoaBlockProducer from hathor.crypto.util import get_public_key_bytes_compressed from hathor.manager import HathorManager +from hathor.reactor.memory_reactor import MemoryReactorClock from hathor.transaction.poa import PoaBlock from tests.poa.utils import get_settings, get_signer -from tests.test_memory_reactor_clock import TestMemoryReactorClock from tests.unittest import TestBuilder def _get_manager(settings: HathorSettings) -> HathorManager: - reactor = TestMemoryReactorClock() + reactor = MemoryReactorClock() reactor.advance(settings.GENESIS_BLOCK_TIMESTAMP) artifacts = TestBuilder() \ @@ -45,7 +45,7 @@ def test_poa_block_producer_one_signer() -> None: settings = get_settings(signer, time_between_blocks=10) manager = _get_manager(settings) reactor = manager.reactor - assert isinstance(reactor, TestMemoryReactorClock) + assert isinstance(reactor, MemoryReactorClock) manager = Mock(wraps=manager) producer = PoaBlockProducer(settings=settings, reactor=reactor, poa_signer=signer) producer.manager = manager @@ -103,7 +103,7 @@ def test_poa_block_producer_two_signers() -> None: settings = get_settings(signer1, signer2, time_between_blocks=10) manager = _get_manager(settings) reactor = manager.reactor - assert isinstance(reactor, TestMemoryReactorClock) + assert isinstance(reactor, MemoryReactorClock) manager = Mock(wraps=manager) producer = PoaBlockProducer(settings=settings, reactor=reactor, poa_signer=signer1) producer.manager = manager diff --git a/tests/pubsub/test_pubsub2.py b/tests/pubsub/test_pubsub2.py index d0ede02ac..d258bd76d 100644 --- a/tests/pubsub/test_pubsub2.py +++ b/tests/pubsub/test_pubsub2.py @@ -16,9 +16,9 @@ from unittest.mock import Mock, patch import pytest -from twisted.internet.testing import MemoryReactorClock from hathor.pubsub import HathorEvents, PubSubManager +from hathor.reactor.memory_reactor import MemoryReactorClock @pytest.mark.parametrize('is_in_main_thread', [False, True]) diff --git a/tests/resources/p2p/test_status.py b/tests/resources/p2p/test_status.py index 646ba6903..b73634bd4 100644 --- a/tests/resources/p2p/test_status.py +++ b/tests/resources/p2p/test_status.py @@ -1,4 +1,3 @@ -from twisted.internet import endpoints from twisted.internet.address import IPv4Address from twisted.internet.defer import inlineCallbacks @@ -68,18 +67,18 @@ def test_handshaking(self): self.assertEqual(server_data['network'], 'testnet') self.assertGreater(server_data['uptime'], 0) - handshake_peer = self.conn1.proto1.transport.getPeer() - handshake_address = '{}:{}'.format(handshake_peer.host, handshake_peer.port) - self.assertEqual(len(known_peers), 0) self.assertEqual(len(connections['connected_peers']), 0) self.assertEqual(len(connections['handshaking_peers']), 1) - self.assertEqual(connections['handshaking_peers'][0]['address'], handshake_address) + self.assertEqual(connections['handshaking_peers'][0]['address'], str(self.conn1.proto1.addr)) @inlineCallbacks def test_get_with_one_peer(self): + assert self.conn1.peek_tr1_value().startswith(b'HELLO') self.conn1.run_one_step() # HELLO + assert self.conn1.peek_tr1_value().startswith(b'PEER-ID') self.conn1.run_one_step() # PEER-ID + assert self.conn1.peek_tr1_value().startswith(b'READY') self.conn1.run_one_step() # READY self.conn1.run_one_step() # BOTH PEERS ARE READY NOW @@ -100,17 +99,14 @@ def test_get_with_one_peer(self): @inlineCallbacks def test_connecting_peers(self): - address = '192.168.1.1:54321' - endpoint = endpoints.clientFromString(self.manager.reactor, 'tcp:{}'.format(address)) - deferred = endpoint.connect - self.manager.connections.connecting_peers[endpoint] = deferred + peer_address = PeerAddress.parse('tcp://192.168.1.1:54321') + self.manager.connections._connections._connecting_outbound.add(peer_address) response = yield self.web.get("status") data = response.json_value() connecting = data['connections']['connecting_peers'] self.assertEqual(len(connecting), 1) - self.assertEqual(connecting[0]['address'], address) - self.assertIsNotNone(connecting[0]['deferred']) + self.assertEqual(connecting[0]['address'], str(peer_address)) class SyncV1StatusTest(unittest.SyncV1Params, BaseStatusTest): diff --git a/tests/sysctl/test_p2p.py b/tests/sysctl/test_p2p.py index ec0366888..a3703676a 100644 --- a/tests/sysctl/test_p2p.py +++ b/tests/sysctl/test_p2p.py @@ -2,6 +2,7 @@ import tempfile from unittest.mock import MagicMock +from hathor.p2p.peer_endpoint import PeerAddress from hathor.p2p.peer_id import PeerId from hathor.sysctl import ConnectionsManagerSysctl from hathor.sysctl.exception import SysctlException @@ -172,7 +173,9 @@ def test_kill_one_connection(self): peer_id = '0e2bd0d8cd1fb6d040801c32ec27e8986ce85eb8810b6c878dcad15bce3b5b1e' conn = MagicMock() - p2p_manager.connected_peers[PeerId(peer_id)] = conn + conn.addr = PeerAddress.parse('tcp://localhost:40403') + p2p_manager._connections.on_connected(protocol=conn) + p2p_manager._connections.on_ready(addr=conn.addr, peer_id=PeerId(peer_id)) self.assertEqual(conn.disconnect.call_count, 0) sysctl.unsafe_set('kill_connection', peer_id) self.assertEqual(conn.disconnect.call_count, 1) diff --git a/tests/test_memory_reactor_clock.py b/tests/test_memory_reactor_clock.py deleted file mode 100644 index 48e8a6d48..000000000 --- a/tests/test_memory_reactor_clock.py +++ /dev/null @@ -1,26 +0,0 @@ -# Copyright 2023 Hathor Labs -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from twisted.internet.testing import MemoryReactorClock - - -class TestMemoryReactorClock(MemoryReactorClock): - __test__ = False - - def run(self): - """ - We have to override MemoryReactor.run() because the original Twisted implementation weirdly calls stop() inside - run(), and we need the reactor running during our tests. - """ - self.running = True diff --git a/tests/tx/test_merged_mining.py b/tests/tx/test_merged_mining.py index ebf032bb1..9367b9604 100644 --- a/tests/tx/test_merged_mining.py +++ b/tests/tx/test_merged_mining.py @@ -25,7 +25,7 @@ async def test_coordinator(self): from cryptography.hazmat.primitives.asymmetric import ec from hathor.crypto.util import get_address_b58_from_public_key - from hathor.simulator.clock import MemoryReactorHeapClock + from hathor.simulator.heap_clock import MemoryReactorHeapClock super().setUp() self.manager = self.create_peer('testnet') diff --git a/tests/tx/test_stratum.py b/tests/tx/test_stratum.py index 1445684b4..128e32a11 100644 --- a/tests/tx/test_stratum.py +++ b/tests/tx/test_stratum.py @@ -7,7 +7,7 @@ import pytest from twisted.internet.testing import StringTransportWithDisconnection -from hathor.simulator.clock import MemoryReactorHeapClock +from hathor.simulator.heap_clock import MemoryReactorHeapClock from hathor.stratum import ( INVALID_PARAMS, INVALID_REQUEST, diff --git a/tests/tx/test_traversal.py b/tests/tx/test_traversal.py index 9f730c545..30216e4f4 100644 --- a/tests/tx/test_traversal.py +++ b/tests/tx/test_traversal.py @@ -89,7 +89,7 @@ def test_right_to_left(self): class BaseBFSTimestampWalkTestCase(_TraversalTestCase): def gen_walk(self, **kwargs): - return BFSTimestampWalk(self.manager.tx_storage, **kwargs) + return BFSTimestampWalk(self.manager.tx_storage.get_vertex, **kwargs) def _run_lr(self, walk, skip_root=True): seen = set() @@ -120,7 +120,7 @@ class SyncV2BFSTimestampWalkTestCase(unittest.SyncV2Params, BaseBFSTimestampWalk class BaseBFSOrderWalkTestCase(_TraversalTestCase): def gen_walk(self, **kwargs): - return BFSOrderWalk(self.manager.tx_storage, **kwargs) + return BFSOrderWalk(self.manager.tx_storage.get_vertex, **kwargs) def _run_lr(self, walk, skip_root=True): seen = set() @@ -168,7 +168,7 @@ class SyncBridgeBFSOrderWalkTestCase(unittest.SyncBridgeParams, SyncV2BFSOrderWa class BaseDFSWalkTestCase(_TraversalTestCase): def gen_walk(self, **kwargs): - return DFSWalk(self.manager.tx_storage, **kwargs) + return DFSWalk(self.manager.tx_storage.get_vertex, **kwargs) def _run_lr(self, walk, skip_root=True): seen = set() diff --git a/tests/unittest.py b/tests/unittest.py index afb11c1b0..c165b4cd1 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -23,13 +23,13 @@ from hathor.p2p.sync_version import SyncVersion from hathor.pubsub import PubSubManager from hathor.reactor import ReactorProtocol as Reactor, get_global_reactor -from hathor.simulator.clock import MemoryReactorHeapClock +from hathor.reactor.memory_reactor import MemoryReactorClock +from hathor.simulator.heap_clock import MemoryReactorHeapClock from hathor.transaction import BaseTransaction, Block, Transaction from hathor.transaction.storage.transaction_storage import TransactionStorage from hathor.types import VertexId from hathor.util import Random, not_none from hathor.wallet import BaseWallet, HDWallet, Wallet -from tests.test_memory_reactor_clock import TestMemoryReactorClock logger = get_logger() main = ut_main @@ -115,7 +115,7 @@ class TestCase(unittest.TestCase): def setUp(self) -> None: self.tmpdirs: list[str] = [] - self.clock = TestMemoryReactorClock() + self.clock = MemoryReactorClock() self.clock.advance(time.time()) self.reactor = self.clock self.log = logger.new()