diff --git a/hathor/builder/builder.py b/hathor/builder/builder.py index 7794979fa..7bc00f13d 100644 --- a/hathor/builder/builder.py +++ b/hathor/builder/builder.py @@ -334,6 +334,10 @@ def _get_or_create_rocksdb_storage(self) -> RocksDBStorage: return self._rocksdb_storage def _get_p2p_manager(self) -> ConnectionsManager: + from hathor.p2p.sync_v1.factory import SyncV11Factory + from hathor.p2p.sync_v2.factory import SyncV2Factory + from hathor.p2p.sync_version import SyncVersion + enable_ssl = True reactor = self._get_reactor() my_peer = self._get_peer_id() @@ -348,9 +352,13 @@ def _get_p2p_manager(self) -> ConnectionsManager: ssl=enable_ssl, whitelist_only=False, rng=self._rng, - enable_sync_v1=self._enable_sync_v1, - enable_sync_v2=self._enable_sync_v2, ) + p2p_manager.add_sync_factory(SyncVersion.V1_1, SyncV11Factory(p2p_manager)) + p2p_manager.add_sync_factory(SyncVersion.V2, SyncV2Factory(p2p_manager)) + if self._enable_sync_v1: + p2p_manager.enable_sync_version(SyncVersion.V1_1) + if self._enable_sync_v2: + p2p_manager.enable_sync_version(SyncVersion.V2) return p2p_manager def _get_or_create_indexes_manager(self) -> IndexesManager: diff --git a/hathor/builder/cli_builder.py b/hathor/builder/cli_builder.py index 11a36ddda..fc897867b 100644 --- a/hathor/builder/cli_builder.py +++ b/hathor/builder/cli_builder.py @@ -65,6 +65,9 @@ def create_manager(self, reactor: Reactor) -> HathorManager: from hathor.event.websocket.factory import EventWebsocketFactory from hathor.p2p.netfilter.utils import add_peer_id_blacklist from hathor.p2p.peer_discovery import BootstrapPeerDiscovery, DNSPeerDiscovery + from hathor.p2p.sync_v1.factory import SyncV11Factory + from hathor.p2p.sync_v2.factory import SyncV2Factory + from hathor.p2p.sync_version import SyncVersion from hathor.storage import RocksDBStorage from hathor.transaction.storage import ( TransactionCacheStorage, @@ -233,9 +236,13 @@ def create_manager(self, reactor: Reactor) -> HathorManager: ssl=True, whitelist_only=False, rng=Random(), - enable_sync_v1=enable_sync_v1, - enable_sync_v2=enable_sync_v2, ) + p2p_manager.add_sync_factory(SyncVersion.V1_1, SyncV11Factory(p2p_manager)) + p2p_manager.add_sync_factory(SyncVersion.V2, SyncV2Factory(p2p_manager)) + if enable_sync_v1: + p2p_manager.enable_sync_version(SyncVersion.V1_1) + if enable_sync_v2: + p2p_manager.enable_sync_version(SyncVersion.V2) self.manager = HathorManager( reactor, diff --git a/hathor/p2p/manager.py b/hathor/p2p/manager.py index 24dc4a01f..1682cd9f9 100644 --- a/hathor/p2p/manager.py +++ b/hathor/p2p/manager.py @@ -86,6 +86,7 @@ class GlobalRateLimiter: handshaking_peers: set[HathorProtocol] whitelist_only: bool _sync_factories: dict[SyncVersion, SyncAgentFactory] + _enabled_sync_versions: set[SyncVersion] rate_limiter: RateLimiter @@ -96,15 +97,7 @@ def __init__(self, pubsub: PubSubManager, ssl: bool, rng: Random, - whitelist_only: bool, - enable_sync_v1: bool, - enable_sync_v2: bool) -> None: - from hathor.p2p.sync_v1.factory import SyncV11Factory - from hathor.p2p.sync_v2.factory import SyncV2Factory - - if not (enable_sync_v1 or enable_sync_v2): - raise TypeError(f'{type(self).__name__}() at least one sync version is required') - + whitelist_only: bool) -> None: self.log = logger.new() self.rng = rng self.manager = None @@ -184,23 +177,59 @@ def __init__(self, # Parameter to explicitly enable whitelist-only mode, when False it will still check the whitelist for sync-v1 self.whitelist_only = whitelist_only - self.enable_sync_v1 = enable_sync_v1 - self.enable_sync_v2 = enable_sync_v2 - # Timestamp when the last discovery ran self._last_discovery: float = 0. # sync-manager factories self._sync_factories = {} - if enable_sync_v1: - self._sync_factories[SyncVersion.V1_1] = SyncV11Factory(self) - if enable_sync_v2: - self._sync_factories[SyncVersion.V2] = SyncV2Factory(self) + self._enabled_sync_versions = set() + + def add_sync_factory(self, sync_version: SyncVersion, sync_factory: SyncAgentFactory) -> None: + """Add factory for the given sync version, must use a sync version that does not already exist.""" + # XXX: to allow code in `set_manager` to safely use the the available sync versions, we add this restriction: + assert self.manager is None, 'Cannot modify sync factories after a manager is set' + if sync_version in self._sync_factories: + raise ValueError('sync version already exists') + self._sync_factories[sync_version] = sync_factory + + def get_available_sync_versions(self) -> set[SyncVersion]: + """What sync versions the manager is capable of using, they are not necessarily enabled.""" + return set(self._sync_factories.keys()) + + def is_sync_version_available(self, sync_version: SyncVersion) -> bool: + """Whether the given sync version is available for use, is not necessarily enabled.""" + return sync_version in self._sync_factories + + def get_enabled_sync_versions(self) -> set[SyncVersion]: + """What sync versions are enabled for use, it is necessarily a subset of the available versions.""" + return self._enabled_sync_versions.copy() + + def is_sync_version_enabled(self, sync_version: SyncVersion) -> bool: + """Whether the given sync version is enabled for use, being enabled implies being available.""" + return sync_version in self._enabled_sync_versions + + def enable_sync_version(self, sync_version: SyncVersion) -> None: + """Enable using the given sync version on new connections, it must be available before being enabled.""" + assert sync_version in self._sync_factories + if sync_version in self._enabled_sync_versions: + self.log.info('tried to enable a sync verison that was already enabled, nothing to do') + return + self._enabled_sync_versions.add(sync_version) + + def disable_sync_version(self, sync_version: SyncVersion) -> None: + """Disable using the given sync version, it WILL NOT close connections using the given version.""" + if sync_version not in self._enabled_sync_versions: + self.log.info('tried to disable a sync verison that was already disabled, nothing to do') + return + self._enabled_sync_versions.discard(sync_version) def set_manager(self, manager: 'HathorManager') -> None: """Set the manager. This method must be called before start().""" + if len(self._enabled_sync_versions) == 0: + raise TypeError('Class built incorrectly without any enabled sync version') + self.manager = manager - if self.enable_sync_v2: + if self.is_sync_version_available(SyncVersion.V2): assert self.manager.tx_storage.indexes is not None indexes = self.manager.tx_storage.indexes self.log.debug('enable sync-v2 indexes') @@ -235,6 +264,10 @@ def enable_rate_limiter(self, max_hits: int = 16, window_seconds: float = 1) -> ) def start(self) -> None: + """Listen on the given address descriptions and start accepting and processing connections.""" + if self.manager is None: + raise TypeError('Class was built incorrectly without a HathorManager.') + self.lc_reconnect.start(5, now=False) self.lc_sync_update.start(self.lc_sync_update_interval, now=False) @@ -278,20 +311,9 @@ def _get_peers_count(self) -> PeerConnectionsMetrics: len(self.peer_storage) ) - def get_sync_versions(self) -> set[SyncVersion]: - """Set of versions that were enabled and are supported.""" - assert self.manager is not None - if self.manager.has_sync_version_capability(): - return set(self._sync_factories.keys()) - else: - assert SyncVersion.V1_1 in self._sync_factories, \ - 'sync-versions capability disabled, but sync-v1 not enabled' - # XXX: this is to make it easy to simulate old behavior if we disable the sync-version capability - return {SyncVersion.V1_1} - def get_sync_factory(self, sync_version: SyncVersion) -> SyncAgentFactory: - """Get the sync factory for a given version, support MUST be checked beforehand or it will raise an assert.""" - assert sync_version in self._sync_factories, 'get_sync_factory must be called for a supported version' + """Get the sync factory for a given version, MUST be available or it will raise an assert.""" + assert sync_version in self._sync_factories, f'sync_version {sync_version} is not available' return self._sync_factories[sync_version] def has_synced_peer(self) -> bool: diff --git a/hathor/p2p/states/hello.py b/hathor/p2p/states/hello.py index d6cc80fca..56f514dd7 100644 --- a/hathor/p2p/states/hello.py +++ b/hathor/p2p/states/hello.py @@ -64,10 +64,10 @@ def _get_hello_data(self) -> dict[str, Any]: return data def _get_sync_versions(self) -> set[SyncVersion]: - """Shortcut to ConnectionManager.get_sync_versions""" + """Shortcut to ConnectionManager.get_enabled_sync_versions""" connections_manager = self.protocol.connections assert connections_manager is not None - return connections_manager.get_sync_versions() + return connections_manager.get_enabled_sync_versions() def on_enter(self) -> None: # After a connection is made, we just send a HELLO message. diff --git a/tests/others/test_cli_builder.py b/tests/others/test_cli_builder.py index 1c9c05be9..3aabf4b3d 100644 --- a/tests/others/test_cli_builder.py +++ b/tests/others/test_cli_builder.py @@ -57,8 +57,8 @@ def test_all_default(self): self.assertIsInstance(manager.tx_storage.indexes, RocksDBIndexesManager) self.assertIsNone(manager.wallet) self.assertEqual('unittests', manager.network) - self.assertIn(SyncVersion.V1_1, manager.connections._sync_factories) - self.assertNotIn(SyncVersion.V2, manager.connections._sync_factories) + self.assertTrue(manager.connections.is_sync_version_enabled(SyncVersion.V1_1)) + self.assertFalse(manager.connections.is_sync_version_enabled(SyncVersion.V2)) self.assertFalse(self.resources_builder._built_prometheus) self.assertFalse(self.resources_builder._built_status) self.assertFalse(manager._enable_event_queue) @@ -103,13 +103,13 @@ def test_memory_storage_with_rocksdb_indexes(self): def test_sync_bridge(self): manager = self._build(['--memory-storage', '--x-sync-bridge']) - self.assertIn(SyncVersion.V1_1, manager.connections._sync_factories) - self.assertIn(SyncVersion.V2, manager.connections._sync_factories) + self.assertTrue(manager.connections.is_sync_version_enabled(SyncVersion.V1_1)) + self.assertTrue(manager.connections.is_sync_version_enabled(SyncVersion.V2)) def test_sync_v2_only(self): manager = self._build(['--memory-storage', '--x-sync-v2-only']) - self.assertNotIn(SyncVersion.V1_1, manager.connections._sync_factories) - self.assertIn(SyncVersion.V2, manager.connections._sync_factories) + self.assertFalse(manager.connections.is_sync_version_enabled(SyncVersion.V1_1)) + self.assertTrue(manager.connections.is_sync_version_enabled(SyncVersion.V2)) def test_keypair_wallet(self): manager = self._build(['--memory-storage', '--wallet', 'keypair']) diff --git a/tests/p2p/test_sync.py b/tests/p2p/test_sync.py index b42a1c808..bad0f654f 100644 --- a/tests/p2p/test_sync.py +++ b/tests/p2p/test_sync.py @@ -268,7 +268,7 @@ def test_downloader(self): self.assertTrue(isinstance(conn.proto1.state, PeerIdState)) self.assertTrue(isinstance(conn.proto2.state, PeerIdState)) - downloader = conn.proto2.connections._sync_factories[SyncVersion.V1_1].get_downloader() + downloader = conn.proto2.connections.get_sync_factory(SyncVersion.V1_1).get_downloader() node_sync1 = NodeSyncTimestamp(conn.proto1, downloader, reactor=conn.proto1.node.reactor) node_sync1.start() @@ -361,7 +361,7 @@ def _downloader_bug_setup(self): # create the peer that will experience the bug self.manager_bug = self.create_peer(self.network) - self.downloader = self.manager_bug.connections._sync_factories[SyncVersion.V1_1].get_downloader() + self.downloader = self.manager_bug.connections.get_sync_factory(SyncVersion.V1_1).get_downloader() self.downloader.window_size = 1 self.conn1 = FakeConnection(self.manager_bug, self.manager1) self.conn2 = FakeConnection(self.manager_bug, self.manager2) diff --git a/tests/p2p/test_whitelist.py b/tests/p2p/test_whitelist.py index 7f1b28759..7d408e71b 100644 --- a/tests/p2p/test_whitelist.py +++ b/tests/p2p/test_whitelist.py @@ -14,10 +14,10 @@ def test_sync_v11_whitelist_no_no(self): network = 'testnet' manager1 = self.create_peer(network) - self.assertEqual(set(manager1.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager1.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) manager2 = self.create_peer(network) - self.assertEqual(set(manager2.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager2.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) conn = FakeConnection(manager1, manager2) self.assertFalse(conn.tr1.disconnecting) @@ -36,10 +36,10 @@ def test_sync_v11_whitelist_yes_no(self): network = 'testnet' manager1 = self.create_peer(network) - self.assertEqual(set(manager1.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager1.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) manager2 = self.create_peer(network) - self.assertEqual(set(manager2.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager2.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) manager1.peers_whitelist.append(manager2.my_peer.id) @@ -60,10 +60,10 @@ def test_sync_v11_whitelist_yes_yes(self): network = 'testnet' manager1 = self.create_peer(network) - self.assertEqual(set(manager1.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager1.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) manager2 = self.create_peer(network) - self.assertEqual(set(manager2.connections._sync_factories.keys()), {SyncVersion.V1_1}) + self.assertEqual(manager2.connections.get_enabled_sync_versions(), {SyncVersion.V1_1}) manager1.peers_whitelist.append(manager2.my_peer.id) manager2.peers_whitelist.append(manager1.my_peer.id) diff --git a/tests/unittest.py b/tests/unittest.py index ab64814bb..cb97cb3fb 100644 --- a/tests/unittest.py +++ b/tests/unittest.py @@ -248,14 +248,8 @@ def create_peer(self, network, peer_id=None, wallet=None, tx_storage=None, unloc manager = self.create_peer_from_builder(builder, start_manager=start_manager) # XXX: just making sure that tests set this up correctly - if enable_sync_v2: - assert SyncVersion.V2 in manager.connections._sync_factories - else: - assert SyncVersion.V2 not in manager.connections._sync_factories - if enable_sync_v1: - assert SyncVersion.V1_1 in manager.connections._sync_factories - else: - assert SyncVersion.V1_1 not in manager.connections._sync_factories + assert manager.connections.is_sync_version_enabled(SyncVersion.V2) == enable_sync_v2 + assert manager.connections.is_sync_version_enabled(SyncVersion.V1_1) == enable_sync_v1 return manager