diff --git a/newsfragments/1543.feature.rst b/newsfragments/1543.feature.rst new file mode 100644 index 0000000000..56a8d2ae0e --- /dev/null +++ b/newsfragments/1543.feature.rst @@ -0,0 +1,2 @@ +Implement the ``eth/64`` networking protocol according to +`EIP 2124 `_ \ No newline at end of file diff --git a/p2p/abc.py b/p2p/abc.py index 24bca4b1ed..6bb9e7cb12 100644 --- a/p2p/abc.py +++ b/p2p/abc.py @@ -483,15 +483,15 @@ class HandshakeReceiptAPI(ABC): THandshakeReceipt = TypeVar('THandshakeReceipt', bound=HandshakeReceiptAPI) -class HandshakerAPI(ABC): +class HandshakerAPI(ABC, Generic[TProtocol]): logger: ExtendedDebugLogger - protocol_class: Type[ProtocolAPI] + protocol_class: Type[TProtocol] @abstractmethod async def do_handshake(self, multiplexer: MultiplexerAPI, - protocol: ProtocolAPI) -> HandshakeReceiptAPI: + protocol: TProtocol) -> HandshakeReceiptAPI: """ Perform the actual handshake for the protocol. """ diff --git a/p2p/handshake.py b/p2p/handshake.py index dc96987b70..0f29b9d05c 100644 --- a/p2p/handshake.py +++ b/p2p/handshake.py @@ -28,6 +28,8 @@ MultiplexerAPI, NodeAPI, TransportAPI, + TProtocol, + ProtocolAPI, ) from p2p.connection import Connection from p2p.constants import DEVP2P_V5 @@ -58,7 +60,7 @@ ) -class Handshaker(HandshakerAPI): +class Handshaker(HandshakerAPI[TProtocol]): """ Base class that handles the handshake for a given protocol. The primary justification for this class's existence is to house parameters that are @@ -179,7 +181,7 @@ async def _do_p2p_handshake(transport: TransportAPI, async def negotiate_protocol_handshakes(transport: TransportAPI, p2p_handshake_params: DevP2PHandshakeParams, - protocol_handshakers: Sequence[HandshakerAPI], + protocol_handshakers: Sequence[HandshakerAPI[ProtocolAPI]], token: CancelToken, ) -> Tuple[MultiplexerAPI, DevP2PReceipt, Tuple[HandshakeReceiptAPI, ...]]: # noqa: E501 """ @@ -294,7 +296,7 @@ async def negotiate_protocol_handshakes(transport: TransportAPI, async def dial_out(remote: NodeAPI, private_key: keys.PrivateKey, p2p_handshake_params: DevP2PHandshakeParams, - protocol_handshakers: Sequence[HandshakerAPI], + protocol_handshakers: Sequence[HandshakerAPI[ProtocolAPI]], token: CancelToken) -> ConnectionAPI: """ Perform the auth and P2P handshakes with the given remote. @@ -345,7 +347,7 @@ async def receive_dial_in(reader: asyncio.StreamReader, writer: asyncio.StreamWriter, private_key: keys.PrivateKey, p2p_handshake_params: DevP2PHandshakeParams, - protocol_handshakers: Sequence[HandshakerAPI], + protocol_handshakers: Sequence[HandshakerAPI[ProtocolAPI]], token: CancelToken) -> Connection: transport = await Transport.receive_connection( reader=reader, diff --git a/p2p/multiplexer.py b/p2p/multiplexer.py index ecf3ab16af..d595adc443 100644 --- a/p2p/multiplexer.py +++ b/p2p/multiplexer.py @@ -21,6 +21,7 @@ from eth_utils import get_extended_debug_logger, ValidationError from eth_utils.toolz import cons +import rlp from p2p.abc import ( CommandAPI, @@ -36,6 +37,7 @@ CorruptTransport, UnknownProtocol, UnknownProtocolCommand, + MalformedMessage, ) from p2p.p2p_proto import BaseP2PProtocol from p2p.transport_state import TransportState @@ -84,7 +86,11 @@ async def stream_transport_messages(transport: TransportAPI, msg_proto = command_id_cache[command_id] command_type = msg_proto.get_command_type_for_command_id(command_id) - cmd = command_type.decode(msg, msg_proto.snappy_support) + + try: + cmd = command_type.decode(msg, msg_proto.snappy_support) + except rlp.exceptions.DeserializationError as err: + raise MalformedMessage(f"Failed to decode {msg} for {command_type}") from err yield msg_proto, cmd diff --git a/p2p/peer.py b/p2p/peer.py index b30812f834..61a3f91889 100644 --- a/p2p/peer.py +++ b/p2p/peer.py @@ -179,10 +179,10 @@ def setup_connection_tracker(self) -> BaseConnectionTracker: return NoopConnectionTracker() def __str__(self) -> str: - return f"{self.__class__.__name__} {self.session}" + return f"{self.__class__.__name__} {self.sub_proto} {self.session}" def __repr__(self) -> str: - return f"{self.__class__.__name__} {self.session!r}" + return f"{self.__class__.__name__} {self.sub_proto!r} {self.session!r}" # # Proxy Transport attributes @@ -474,7 +474,7 @@ def __init__(self, self.event_bus = event_bus @abstractmethod - async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: + async def get_handshakers(self) -> Tuple[HandshakerAPI[ProtocolAPI], ...]: ... async def handshake(self, remote: NodeAPI) -> BasePeer: diff --git a/p2p/tools/factories/connection.py b/p2p/tools/factories/connection.py index a482a8532a..072d5fb51f 100644 --- a/p2p/tools/factories/connection.py +++ b/p2p/tools/factories/connection.py @@ -7,7 +7,7 @@ from eth_keys import keys -from p2p.abc import ConnectionAPI, HandshakerAPI, NodeAPI +from p2p.abc import ConnectionAPI, HandshakerAPI, NodeAPI, ProtocolAPI from p2p.connection import Connection from p2p.constants import DEVP2P_V5 from p2p.handshake import ( @@ -24,8 +24,8 @@ @asynccontextmanager async def ConnectionPairFactory(*, - alice_handshakers: Tuple[HandshakerAPI, ...] = (), - bob_handshakers: Tuple[HandshakerAPI, ...] = (), + alice_handshakers: Tuple[HandshakerAPI[ProtocolAPI], ...] = (), + bob_handshakers: Tuple[HandshakerAPI[ProtocolAPI], ...] = (), alice_remote: NodeAPI = None, alice_private_key: keys.PrivateKey = None, alice_client_version: str = 'alice', diff --git a/p2p/tools/handshake.py b/p2p/tools/handshake.py index 9ecdc11500..f34cfdec97 100644 --- a/p2p/tools/handshake.py +++ b/p2p/tools/handshake.py @@ -5,7 +5,7 @@ from p2p.receipt import HandshakeReceipt -class NoopHandshaker(Handshaker): +class NoopHandshaker(Handshaker[ProtocolAPI]): def __init__(self, protocol_class: Type[ProtocolAPI]) -> None: self.protocol_class = protocol_class diff --git a/p2p/tools/paragon/peer.py b/p2p/tools/paragon/peer.py index dfa821b085..394ad710f3 100644 --- a/p2p/tools/paragon/peer.py +++ b/p2p/tools/paragon/peer.py @@ -1,6 +1,7 @@ from typing import ( Iterable, Tuple, + Any, ) from cached_property import cached_property @@ -9,7 +10,7 @@ from p2p.handshake import Handshaker from p2p.receipt import HandshakeReceipt -from p2p.abc import BehaviorAPI, ProtocolAPI +from p2p.abc import BehaviorAPI from p2p.constants import DEVP2P_V5 from p2p.peer import ( BasePeer, @@ -46,12 +47,12 @@ def __init__(self, super().__init__(client_version_string, listen_port, p2p_version) -class ParagonHandshaker(Handshaker): +class ParagonHandshaker(Handshaker[ParagonProtocol]): protocol_class = ParagonProtocol async def do_handshake(self, multiplexer: MultiplexerAPI, - protocol: ProtocolAPI) -> HandshakeReceipt: + protocol: ParagonProtocol) -> HandshakeReceipt: return HandshakeReceipt(protocol) @@ -59,7 +60,7 @@ class ParagonPeerFactory(BasePeerFactory): peer_class = ParagonPeer context: ParagonContext - async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: + async def get_handshakers(self) -> Tuple[HandshakerAPI[Any], ...]: return (ParagonHandshaker(),) diff --git a/tests/core/p2p-proto/test_eth_api.py b/tests/core/p2p-proto/test_eth_api.py index 1a0bdb78a7..830c0cd367 100644 --- a/tests/core/p2p-proto/test_eth_api.py +++ b/tests/core/p2p-proto/test_eth_api.py @@ -10,28 +10,34 @@ latest_mainnet_at, mine_block, ) +from eth.vm.forks import MuirGlacierVM, PetersburgVM from trinity._utils.assertions import assert_type_equality from trinity.db.eth1.header import AsyncHeaderDB -from trinity.protocol.eth.api import ETHAPI +from trinity.exceptions import WrongForkIDFailure +from trinity.protocol.eth.api import ETHAPI, ETHV63API from trinity.protocol.eth.commands import ( GetBlockHeaders, GetNodeData, NewBlock, Status, + StatusV63, ) -from trinity.protocol.eth.handshaker import ETHHandshakeReceipt +from trinity.protocol.eth.handshaker import ETHHandshakeReceipt, ETHV63HandshakeReceipt +from trinity.protocol.eth.proto import ETHProtocolV63, ETHProtocol from trinity.tools.factories.common import ( BlockHeadersQueryFactory, ) from trinity.tools.factories.eth import ( StatusPayloadFactory, + StatusV63PayloadFactory, ) from trinity.tools.factories import ( BlockHashFactory, ChainContextFactory, ETHPeerPairFactory, + ETHV63PeerPairFactory, ) @@ -60,8 +66,23 @@ def alice_chain(bob_chain): @pytest.fixture -async def alice_and_bob(alice_chain, bob_chain): - pair_factory = ETHPeerPairFactory( +def alice_chain_on_fork(bob_chain): + bob_genesis = bob_chain.headerdb.get_canonical_block_header_by_number(0) + + chain = build( + MiningChain, + latest_mainnet_at(0), + disable_pow_check(), + genesis(params={"timestamp": bob_genesis.timestamp}), + mine_block(), + ) + + return chain + + +@pytest.fixture(params=(ETHV63PeerPairFactory, ETHPeerPairFactory)) +async def alice_and_bob(alice_chain, bob_chain, request): + pair_factory = request.param( alice_client_version='alice', alice_peer_context=ChainContextFactory(headerdb=AsyncHeaderDB(alice_chain.headerdb.db)), bob_client_version='bob', @@ -83,14 +104,48 @@ def bob(alice_and_bob): return bob +@pytest.fixture +def protocol_specific_classes(alice): + if alice.connection.has_protocol(ETHProtocolV63): + return ETHV63API, ETHV63HandshakeReceipt, StatusV63, StatusV63PayloadFactory + elif alice.connection.has_protocol(ETHProtocol): + return ETHAPI, ETHHandshakeReceipt, Status, StatusPayloadFactory + else: + raise Exception("No ETH protocol found") + + +@pytest.fixture +def ETHAPI_class(protocol_specific_classes): + api_class, _, _, _ = protocol_specific_classes + return api_class + + +@pytest.fixture +def ETHHandshakeReceipt_class(protocol_specific_classes): + _, receipt_class, _, _ = protocol_specific_classes + return receipt_class + + +@pytest.fixture +def Status_class(protocol_specific_classes): + _, _, status_class, _ = protocol_specific_classes + return status_class + + +@pytest.fixture +def StatusPayloadFactory_class(protocol_specific_classes): + _, _, _, status_payload_factory_class = protocol_specific_classes + return status_payload_factory_class + + @pytest.mark.asyncio -async def test_eth_api_properties(alice): - assert alice.connection.has_logic(ETHAPI.name) - eth_api = alice.connection.get_logic(ETHAPI.name, ETHAPI) +async def test_eth_api_properties(alice, ETHAPI_class, ETHHandshakeReceipt_class): + assert alice.connection.has_logic(ETHAPI_class.name) + eth_api = alice.connection.get_logic(ETHAPI_class.name, ETHAPI_class) assert eth_api is alice.eth_api - eth_receipt = alice.connection.get_receipt_by_type(ETHHandshakeReceipt) + eth_receipt = alice.connection.get_receipt_by_type(ETHHandshakeReceipt_class) assert eth_api.network_id == eth_receipt.network_id assert eth_api.genesis_hash == eth_receipt.genesis_hash @@ -101,7 +156,7 @@ async def test_eth_api_properties(alice): @pytest.mark.asyncio -async def test_eth_api_head_info_updates_with_newblock(alice, bob, bob_chain): +async def test_eth_api_head_info_updates_with_newblock(alice, bob, bob_chain, ETHAPI_class): # mine two blocks on bob's chain bob_chain = build( bob_chain, @@ -118,8 +173,8 @@ async def _handle_new_block(connection, msg): bob_genesis = bob_chain.headerdb.get_canonical_block_header_by_number(0) - bob_eth_api = bob.connection.get_logic(ETHAPI.name, ETHAPI) - alice_eth_api = alice.connection.get_logic(ETHAPI.name, ETHAPI) + bob_eth_api = bob.connection.get_logic(ETHAPI_class.name, ETHAPI_class) + alice_eth_api = alice.connection.get_logic(ETHAPI_class.name, ETHAPI_class) assert alice_eth_api.head_info.head_hash == bob_genesis.hash assert alice_eth_api.head_info.head_td == bob_genesis.difficulty @@ -143,19 +198,19 @@ async def _handle_new_block(connection, msg): @pytest.mark.asyncio -async def test_eth_api_send_status(alice, bob): - payload = StatusPayloadFactory() +async def test_eth_api_send_status(alice, bob, StatusPayloadFactory_class, Status_class): + payload = StatusPayloadFactory_class() command_fut = asyncio.Future() async def _handle_cmd(connection, cmd): command_fut.set_result(cmd) - bob.connection.add_command_handler(Status, _handle_cmd) + bob.connection.add_command_handler(Status_class, _handle_cmd) alice.eth_api.send_status(payload) result = await asyncio.wait_for(command_fut, timeout=1) - assert isinstance(result, Status) + assert isinstance(result, Status_class) assert_type_equality(payload, result.payload) @@ -196,3 +251,22 @@ async def _handle_cmd(connection, cmd): result = await asyncio.wait_for(command_fut, timeout=1) assert isinstance(result, GetBlockHeaders) assert_type_equality(payload, result.payload) + + +@pytest.mark.asyncio +async def test_handshake_with_incompatible_fork_id(alice_chain, bob_chain): + + alice_chain = build( + alice_chain, + mine_block() + ) + + pair_factory = ETHPeerPairFactory( + alice_peer_context=ChainContextFactory( + headerdb=AsyncHeaderDB(alice_chain.headerdb.db), + vm_configuration=((1, PetersburgVM), (2, MuirGlacierVM)) + ), + ) + with pytest.raises(WrongForkIDFailure): + async with pair_factory as (alice, bob): + pass diff --git a/tests/core/p2p-proto/test_eth_proto.py b/tests/core/p2p-proto/test_eth_proto.py index 9fad8ac740..0f0205b6a2 100644 --- a/tests/core/p2p-proto/test_eth_proto.py +++ b/tests/core/p2p-proto/test_eth_proto.py @@ -16,6 +16,7 @@ Receipts, Status, Transactions, + StatusV63, ) from trinity.tools.factories import ( @@ -32,12 +33,14 @@ NewBlockHashFactory, NewBlockPayloadFactory, StatusPayloadFactory, + StatusV63PayloadFactory, ) @pytest.mark.parametrize( 'command_type,payload', ( + (StatusV63, StatusV63PayloadFactory()), (Status, StatusPayloadFactory()), (NewBlockHashes, tuple(NewBlockHashFactory.create_batch(2))), (Transactions, tuple(BaseTransactionFieldsFactory.create_batch(2))), diff --git a/tests/core/p2p-proto/test_peer.py b/tests/core/p2p-proto/test_peer.py index bc6a3b05bf..9f4dd3a22e 100644 --- a/tests/core/p2p-proto/test_peer.py +++ b/tests/core/p2p-proto/test_peer.py @@ -5,7 +5,7 @@ from p2p.disconnect import DisconnectReason from trinity.protocol.eth.peer import ETHPeer -from trinity.protocol.eth.proto import ETHProtocol +from trinity.protocol.eth.proto import ETHProtocol, ETHProtocolV63 from trinity.protocol.les.peer import LESPeer from trinity.protocol.les.proto import ( LESProtocolV1, @@ -21,6 +21,7 @@ from tests.core.peer_helpers import ( MockPeerPoolWithConnectedPeers, ) +from trinity.tools.factories.eth.proto import ETHV63PeerPairFactory @pytest.mark.asyncio @@ -43,6 +44,16 @@ async def test_LES_v2_peers(): assert isinstance(bob.sub_proto, LESProtocolV2) +@pytest.mark.asyncio +async def test_ETH_v63_peers(): + async with ETHV63PeerPairFactory() as (alice, bob): + assert isinstance(alice, ETHPeer) + assert isinstance(bob, ETHPeer) + + assert isinstance(alice.sub_proto, ETHProtocolV63) + assert isinstance(bob.sub_proto, ETHProtocolV63) + + @pytest.mark.asyncio async def test_ETH_peers(): async with ETHPeerPairFactory() as (alice, bob): diff --git a/tests/integration/test_trinity_cli.py b/tests/integration/test_trinity_cli.py index b5fe590a43..3132f01938 100644 --- a/tests/integration/test_trinity_cli.py +++ b/tests/integration/test_trinity_cli.py @@ -202,7 +202,7 @@ async def test_web3_commands_via_attached_console(command, attached_trinity.expect_exact("'listenAddr': '[::]") attached_trinity.expect_exact("'name': 'Trinity/") attached_trinity.expect_exact("'ports': AttributeDict({") - attached_trinity.expect_exact("'protocols': AttributeDict({'eth': AttributeDict({'version': 'eth/63'") # noqa: E501 + attached_trinity.expect_exact("'protocols': AttributeDict({'eth': AttributeDict({'version': 'eth/64'") # noqa: E501 attached_trinity.expect_exact("'difficulty': ") attached_trinity.expect_exact(f"'genesis': '{expected_genesis_hash}'") attached_trinity.expect_exact("'head': '0x") diff --git a/trinity/exceptions.py b/trinity/exceptions.py index 21eafef08e..d8d0b552eb 100644 --- a/trinity/exceptions.py +++ b/trinity/exceptions.py @@ -77,6 +77,16 @@ class WrongNetworkFailure(HandshakeFailure): register_error(WrongNetworkFailure, BLACKLIST_SECONDS_WRONG_NETWORK_OR_GENESIS) +class WrongForkIDFailure(HandshakeFailure): + """ + Disconnected from the peer because it has an incompatible ForkID + """ + pass + + +register_error(WrongForkIDFailure, BLACKLIST_SECONDS_WRONG_NETWORK_OR_GENESIS) + + class WrongGenesisFailure(HandshakeFailure): """ Disconnected from the peer because it has a different genesis than we do diff --git a/trinity/protocol/common/api.py b/trinity/protocol/common/api.py index 6affb3ea7c..c215921f50 100644 --- a/trinity/protocol/common/api.py +++ b/trinity/protocol/common/api.py @@ -3,17 +3,34 @@ from eth_typing import BlockNumber, Hash32 +from p2p.abc import ConnectionAPI from p2p.logic import Application from p2p.qualifiers import HasProtocol -from trinity.protocol.eth.api import ETHAPI -from trinity.protocol.eth.proto import ETHProtocol +from trinity.protocol.eth.api import ETHV63API, ETHAPI +from trinity.protocol.eth.proto import ETHProtocolV63, ETHProtocol from trinity.protocol.les.api import LESV1API, LESV2API from trinity.protocol.les.proto import LESProtocolV1, LESProtocolV2 from .abc import ChainInfoAPI, HeadInfoAPI -AnyETHLES = HasProtocol(ETHProtocol) | HasProtocol(LESProtocolV2) | HasProtocol(LESProtocolV1) +AnyETHLES = HasProtocol(ETHProtocol) | HasProtocol(ETHProtocolV63) | HasProtocol( + LESProtocolV2) | HasProtocol(LESProtocolV1) + + +def choose_eth_or_les_api( + connection: ConnectionAPI) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: + + if connection.has_protocol(ETHProtocol): + return connection.get_logic(ETHAPI.name, ETHAPI) + elif connection.has_protocol(ETHProtocolV63): + return connection.get_logic(ETHV63API.name, ETHV63API) + elif connection.has_protocol(LESProtocolV2): + return connection.get_logic(LESV2API.name, LESV2API) + elif connection.has_protocol(LESProtocolV1): + return connection.get_logic(LESV1API.name, LESV1API) + else: + raise Exception("Unreachable code path") class ChainInfo(Application, ChainInfoAPI): @@ -29,15 +46,8 @@ def network_id(self) -> int: def genesis_hash(self) -> Hash32: return self._get_logic().genesis_hash - def _get_logic(self) -> Union[ETHAPI, LESV1API, LESV2API]: - if self.connection.has_protocol(ETHProtocol): - return self.connection.get_logic(ETHAPI.name, ETHAPI) - elif self.connection.has_protocol(LESProtocolV2): - return self.connection.get_logic(LESV2API.name, LESV2API) - elif self.connection.has_protocol(LESProtocolV1): - return self.connection.get_logic(LESV1API.name, LESV1API) - else: - raise Exception("Unreachable code path") + def _get_logic(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: + return choose_eth_or_les_api(self.connection) class HeadInfo(Application, HeadInfoAPI): @@ -47,17 +57,8 @@ class HeadInfo(Application, HeadInfoAPI): @cached_property def _tracker(self) -> HeadInfoAPI: - if self.connection.has_protocol(ETHProtocol): - eth_logic = self.connection.get_logic(ETHAPI.name, ETHAPI) - return eth_logic.head_info - elif self.connection.has_protocol(LESProtocolV2): - les_v2_logic = self.connection.get_logic(LESV2API.name, LESV2API) - return les_v2_logic.head_info - elif self.connection.has_protocol(LESProtocolV1): - les_v1_logic = self.connection.get_logic(LESV1API.name, LESV1API) - return les_v1_logic.head_info - else: - raise Exception("Unreachable code path") + api = choose_eth_or_les_api(self.connection) + return api.head_info @property def head_td(self) -> int: diff --git a/trinity/protocol/common/peer.py b/trinity/protocol/common/peer.py index 6bb5180006..fc7a483267 100644 --- a/trinity/protocol/common/peer.py +++ b/trinity/protocol/common/peer.py @@ -62,15 +62,14 @@ from trinity.constants import TO_NETWORKING_BROADCAST_CONFIG from trinity.exceptions import BaseForkIDValidationError from trinity.protocol.common.abc import ChainInfoAPI, HeadInfoAPI -from trinity.protocol.common.api import ChainInfo, HeadInfo -from trinity.protocol.eth.api import ETHAPI +from trinity.protocol.common.api import ChainInfo, HeadInfo, choose_eth_or_les_api +from trinity.protocol.eth.api import ETHV63API, ETHAPI from trinity.protocol.eth.forkid import ( extract_fork_blocks, extract_forkid, validate_forkid, ) from trinity.protocol.les.api import LESV1API, LESV2API -from trinity.protocol.les.proto import LESProtocolV1, LESProtocolV2 from trinity.components.builtin.network_db.connection.tracker import ConnectionTrackerClient from trinity.components.builtin.network_db.eth1_peer_db.tracker import ( @@ -94,18 +93,8 @@ class BaseChainPeer(BasePeer): context: ChainContext @cached_property - def chain_api(self) -> Union[ETHAPI, LESV1API, LESV2API]: - if self.connection.has_logic(ETHAPI.name): - return self.connection.get_logic(ETHAPI.name, ETHAPI) - elif self.connection.has_logic(LESV1API.name): - if self.connection.has_protocol(LESProtocolV2): - return self.connection.get_logic(LESV2API.name, LESV2API) - elif self.connection.has_protocol(LESProtocolV1): - return self.connection.get_logic(LESV1API.name, LESV1API) - else: - raise Exception("Should be unreachable") - else: - raise Exception("Should be unreachable") + def chain_api(self) -> Union[ETHAPI, ETHV63API, LESV1API, LESV2API]: + return choose_eth_or_les_api(self.connection) @cached_property def head_info(self) -> HeadInfoAPI: diff --git a/trinity/protocol/eth/api.py b/trinity/protocol/eth/api.py index e998610740..25d5f83fbd 100644 --- a/trinity/protocol/eth/api.py +++ b/trinity/protocol/eth/api.py @@ -1,4 +1,5 @@ -from typing import Any, Sequence, Tuple, Union +from abc import abstractmethod +from typing import Any, Sequence, Tuple, Union, Generic, Type, TypeVar from cached_property import cached_property @@ -11,7 +12,7 @@ SignedTransactionAPI, ) -from p2p.abc import ConnectionAPI +from p2p.abc import ConnectionAPI, ProtocolAPI from p2p.exchange import ExchangeAPI, ExchangeLogic from p2p.logic import Application, CommandHandler from p2p.qualifiers import HasProtocol @@ -29,9 +30,9 @@ NewBlockHashes, NodeData, Receipts, - Status, + StatusV63, Transactions, -) + Status) from trinity.rlp.block_body import BlockBody from .exchanges import ( @@ -40,23 +41,28 @@ GetNodeDataExchange, GetReceiptsExchange, ) -from .handshaker import ETHHandshakeReceipt +from .handshaker import ETHV63HandshakeReceipt, ETHHandshakeReceipt, BaseETHHandshakeReceipt from .payloads import ( BlockFields, NewBlockHash, NewBlockPayload, + StatusV63Payload, StatusPayload, ) -from .proto import ETHProtocol +from .proto import ETHProtocolV63, ETHProtocol + +THandshakeReceipt = TypeVar("THandshakeReceipt", bound=BaseETHHandshakeReceipt[Any]) -class HeadInfoTracker(CommandHandler[NewBlock], HeadInfoAPI): +class BaseHeadInfoTracker(CommandHandler[NewBlock], HeadInfoAPI, Generic[THandshakeReceipt]): command_type = NewBlock _head_td: int = None _head_hash: Hash32 = None _head_number: BlockNumber = None + _receipt_type: Type[THandshakeReceipt] + async def handle(self, connection: ConnectionAPI, cmd: NewBlock) -> None: header = cmd.payload.block.header actual_td = cmd.payload.total_difficulty - header.difficulty @@ -70,8 +76,8 @@ async def handle(self, connection: ConnectionAPI, cmd: NewBlock) -> None: # HeadInfoAPI # @cached_property - def _eth_receipt(self) -> ETHHandshakeReceipt: - return self.connection.get_receipt_by_type(ETHHandshakeReceipt) + def _eth_receipt(self) -> THandshakeReceipt: + return self.connection.get_receipt_by_type(self._receipt_type) @property def head_td(self) -> int: @@ -93,11 +99,19 @@ def head_number(self) -> BlockNumber: return self._head_number -class ETHAPI(Application): - name = 'eth' - qualifier = HasProtocol(ETHProtocol) +class ETHV63HeadInfoTracker(BaseHeadInfoTracker[ETHV63HandshakeReceipt]): + + _receipt_type = ETHV63HandshakeReceipt - head_info: HeadInfoTracker + +class ETHHeadInfoTracker(BaseHeadInfoTracker[ETHHandshakeReceipt]): + + _receipt_type = ETHHandshakeReceipt + + +class BaseETHAPI(Application): + name = 'eth' + head_info_tracker_cls = BaseHeadInfoTracker[THandshakeReceipt] get_block_bodies: GetBlockBodiesExchange get_block_headers: GetBlockHeadersExchange @@ -105,7 +119,7 @@ class ETHAPI(Application): get_receipts: GetReceiptsExchange def __init__(self) -> None: - self.head_info = HeadInfoTracker() + self.head_info = self.head_info_tracker_cls() self.add_child_behavior(self.head_info.as_behavior()) # Request/Response API @@ -119,6 +133,16 @@ def __init__(self) -> None: self.add_child_behavior(ExchangeLogic(self.get_node_data).as_behavior()) self.add_child_behavior(ExchangeLogic(self.get_receipts).as_behavior()) + @property + @abstractmethod + def protocol(self) -> ProtocolAPI: + ... + + @property + @abstractmethod + def receipt(self) -> BaseETHHandshakeReceipt[Any]: + ... + @cached_property def exchanges(self) -> Tuple[ExchangeAPI[Any, Any, Any], ...]: return ( @@ -134,14 +158,6 @@ def get_extra_stats(self) -> Tuple[str, ...]: for exchange in self.exchanges ) - @cached_property - def protocol(self) -> ETHProtocol: - return self.connection.get_protocol_by_type(ETHProtocol) - - @cached_property - def receipt(self) -> ETHHandshakeReceipt: - return self.connection.get_receipt_by_type(ETHHandshakeReceipt) - @cached_property def network_id(self) -> int: return self.receipt.network_id @@ -150,9 +166,6 @@ def network_id(self) -> int: def genesis_hash(self) -> Hash32: return self.receipt.genesis_hash - def send_status(self, payload: StatusPayload) -> None: - self.protocol.send(Status(payload)) - def send_get_node_data(self, node_hashes: Sequence[Hash32]) -> None: self.protocol.send(GetNodeData(tuple(node_hashes))) @@ -204,3 +217,38 @@ def send_new_block(self, block_fields = BlockFields(block.header, block.transactions, block.uncles) payload = NewBlockPayload(block_fields, total_difficulty) self.protocol.send(NewBlock(payload)) + + +class ETHV63API(BaseETHAPI): + qualifier = HasProtocol(ETHProtocolV63) + head_info_tracker_cls = ETHV63HeadInfoTracker + + @cached_property + def protocol(self) -> ETHProtocolV63: + return self.connection.get_protocol_by_type(ETHProtocolV63) + + @cached_property + def receipt(self) -> ETHV63HandshakeReceipt: + return self.connection.get_receipt_by_type(ETHV63HandshakeReceipt) + + def send_status(self, payload: StatusV63Payload) -> None: + self.protocol.send(StatusV63(payload)) + + +class ETHAPI(BaseETHAPI): + qualifier = HasProtocol(ETHProtocol) + head_info_tracker_cls = ETHHeadInfoTracker + + @cached_property + def protocol(self) -> ETHProtocol: + return self.connection.get_protocol_by_type(ETHProtocol) + + @cached_property + def receipt(self) -> ETHHandshakeReceipt: + return self.connection.get_receipt_by_type(ETHHandshakeReceipt) + + def send_status(self, payload: StatusPayload) -> None: + self.protocol.send(Status(payload)) + + +AnyETHAPI = Union[ETHV63API, ETHAPI] diff --git a/trinity/protocol/eth/commands.py b/trinity/protocol/eth/commands.py index bee76f7ed9..278326f053 100644 --- a/trinity/protocol/eth/commands.py +++ b/trinity/protocol/eth/commands.py @@ -18,21 +18,43 @@ from trinity.protocol.common.payloads import BlockHeadersQuery from trinity.rlp.block_body import BlockBody from trinity.rlp.sedes import HashOrNumber, hash_sedes +from .forkid import ForkID from .payloads import ( - StatusPayload, + StatusV63Payload, NewBlockHash, BlockFields, NewBlockPayload, + StatusPayload, ) +STATUS_V63_STRUCTURE = sedes.List(( + sedes.big_endian_int, + sedes.big_endian_int, + sedes.big_endian_int, + hash_sedes, + hash_sedes, +)) + + +class StatusV63(BaseCommand[StatusV63Payload]): + protocol_command_id = 0 + serialization_codec: RLPCodec[StatusV63Payload] = RLPCodec( + sedes=STATUS_V63_STRUCTURE, + process_inbound_payload_fn=compose( + lambda args: StatusV63Payload(*args), + ), + ) + + STATUS_STRUCTURE = sedes.List(( sedes.big_endian_int, sedes.big_endian_int, sedes.big_endian_int, hash_sedes, hash_sedes, + ForkID )) diff --git a/trinity/protocol/eth/handshaker.py b/trinity/protocol/eth/handshaker.py index d0e5e33e31..1763e3721b 100644 --- a/trinity/protocol/eth/handshaker.py +++ b/trinity/protocol/eth/handshaker.py @@ -1,28 +1,38 @@ -from typing import cast +from typing import Union, TypeVar, Generic, Tuple from cached_property import cached_property -from eth_typing import Hash32 +from eth_typing import Hash32, BlockNumber from eth_utils import encode_hex -from p2p.abc import MultiplexerAPI, ProtocolAPI +from p2p.abc import MultiplexerAPI, ProtocolAPI, NodeAPI from p2p.exceptions import ( HandshakeFailure, ) from p2p.handshake import Handshaker from p2p.receipt import HandshakeReceipt -from trinity.exceptions import WrongGenesisFailure, WrongNetworkFailure +from trinity.exceptions import ( + WrongForkIDFailure, + WrongGenesisFailure, + WrongNetworkFailure, + BaseForkIDValidationError, +) + + +from .commands import StatusV63, Status +from .forkid import ForkID, validate_forkid +from .payloads import StatusV63Payload, StatusPayload +from .proto import ETHProtocolV63, ETHProtocol -from .commands import Status -from .payloads import StatusPayload -from .proto import ETHProtocol +THandshakeParams = TypeVar("THandshakeParams", bound=Union[StatusPayload, StatusV63Payload]) -class ETHHandshakeReceipt(HandshakeReceipt): - handshake_params: StatusPayload - def __init__(self, protocol: ETHProtocol, handshake_params: StatusPayload) -> None: +class BaseETHHandshakeReceipt(HandshakeReceipt, Generic[THandshakeParams]): + handshake_params: THandshakeParams + + def __init__(self, protocol: ProtocolAPI, handshake_params: THandshakeParams) -> None: super().__init__(protocol) self.handshake_params = handshake_params @@ -47,49 +57,110 @@ def version(self) -> int: return self.handshake_params.version -class ETHHandshaker(Handshaker): +class ETHV63HandshakeReceipt(BaseETHHandshakeReceipt[StatusV63Payload]): + pass + + +class ETHHandshakeReceipt(BaseETHHandshakeReceipt[StatusPayload]): + + @cached_property + def fork_id(self) -> ForkID: + return self.handshake_params.fork_id + + +def validate_base_receipt(remote: NodeAPI, + receipt: Union[ETHV63HandshakeReceipt, ETHHandshakeReceipt], + handshake_params: Union[StatusV63Payload, StatusPayload]) -> None: + if receipt.handshake_params.network_id != handshake_params.network_id: + raise WrongNetworkFailure( + f"{remote} network " + f"({receipt.handshake_params.network_id}) does not match ours " + f"({handshake_params.network_id}), disconnecting" + ) + + if receipt.handshake_params.genesis_hash != handshake_params.genesis_hash: + raise WrongGenesisFailure( + f"{remote} genesis " + f"({encode_hex(receipt.handshake_params.genesis_hash)}) does " + f"not match ours ({encode_hex(handshake_params.genesis_hash)}), " + f"disconnecting" + ) + + +class ETHV63Handshaker(Handshaker[ETHProtocolV63]): + protocol_class = ETHProtocolV63 + + def __init__(self, handshake_params: StatusV63Payload) -> None: + self.handshake_params = handshake_params + + async def do_handshake(self, + multiplexer: MultiplexerAPI, + protocol: ETHProtocolV63) -> ETHV63HandshakeReceipt: + """ + Perform the handshake for the sub-protocol agreed with the remote peer. + + Raise HandshakeFailure if the handshake is not successful. + """ + + protocol.send(StatusV63(self.handshake_params)) + + async for cmd in multiplexer.stream_protocol_messages(protocol): + if not isinstance(cmd, StatusV63): + raise HandshakeFailure(f"Expected a ETH Status msg, got {cmd}, disconnecting") + + receipt = ETHV63HandshakeReceipt(protocol, cmd.payload) + + validate_base_receipt(multiplexer.remote, receipt, self.handshake_params) + + break + else: + raise HandshakeFailure("Message stream exited before finishing handshake") + + return receipt + + +class ETHHandshaker(Handshaker[ETHProtocol]): protocol_class = ETHProtocol - handshake_params: StatusPayload - def __init__(self, handshake_params: StatusPayload) -> None: + def __init__(self, + handshake_params: StatusPayload, + head_number: BlockNumber, + fork_blocks: Tuple[BlockNumber, ...]) -> None: self.handshake_params = handshake_params + self.head_number = head_number + self.fork_blocks = fork_blocks async def do_handshake(self, multiplexer: MultiplexerAPI, - protocol: ProtocolAPI) -> ETHHandshakeReceipt: - """Perform the handshake for the sub-protocol agreed with the remote peer. + protocol: ETHProtocol) -> ETHHandshakeReceipt: + """ + Perform the handshake for the sub-protocol agreed with the remote peer. - Raises HandshakeFailure if the handshake is not successful. + Raise HandshakeFailure if the handshake is not successful. """ - protocol = cast(ETHProtocol, protocol) + protocol.send(Status(self.handshake_params)) async for cmd in multiplexer.stream_protocol_messages(protocol): if not isinstance(cmd, Status): raise HandshakeFailure(f"Expected a ETH Status msg, got {cmd}, disconnecting") - remote_params = StatusPayload( - version=cmd.payload.version, - network_id=cmd.payload.network_id, - total_difficulty=cmd.payload.total_difficulty, - head_hash=cmd.payload.head_hash, - genesis_hash=cmd.payload.genesis_hash, - ) - receipt = ETHHandshakeReceipt(protocol, remote_params) - - if receipt.handshake_params.network_id != self.handshake_params.network_id: - raise WrongNetworkFailure( - f"{multiplexer.remote} network " - f"({receipt.handshake_params.network_id}) does not match ours " - f"({self.handshake_params.network_id}), disconnecting" - ) + receipt = ETHHandshakeReceipt(protocol, cmd.payload) - if receipt.handshake_params.genesis_hash != self.handshake_params.genesis_hash: - raise WrongGenesisFailure( - f"{multiplexer.remote} genesis " - f"({encode_hex(receipt.handshake_params.genesis_hash)}) does " - f"not match ours ({encode_hex(self.handshake_params.genesis_hash)}), " - f"disconnecting" + validate_base_receipt(multiplexer.remote, receipt, self.handshake_params) + + try: + validate_forkid( + receipt.fork_id, + self.handshake_params.genesis_hash, + self.head_number, + self.fork_blocks, + ) + except BaseForkIDValidationError as exc: + raise WrongForkIDFailure( + f"{multiplexer.remote} forkid " + f"({receipt.handshake_params.fork_id}) is incompatible to ours ({exc})" + f"({self.handshake_params.fork_id}), disconnecting" ) break diff --git a/trinity/protocol/eth/payloads.py b/trinity/protocol/eth/payloads.py index 4080d692d5..9e3cb001aa 100644 --- a/trinity/protocol/eth/payloads.py +++ b/trinity/protocol/eth/payloads.py @@ -4,6 +4,16 @@ from eth.abc import BlockHeaderAPI, TransactionFieldsAPI +from trinity.protocol.eth.forkid import ForkID + + +class StatusV63Payload(NamedTuple): + version: int + network_id: int + total_difficulty: int + head_hash: Hash32 + genesis_hash: Hash32 + class StatusPayload(NamedTuple): version: int @@ -11,6 +21,7 @@ class StatusPayload(NamedTuple): total_difficulty: int head_hash: Hash32 genesis_hash: Hash32 + fork_id: ForkID class NewBlockHash(NamedTuple): diff --git a/trinity/protocol/eth/peer.py b/trinity/protocol/eth/peer.py index 4a7e69b26c..be7ab19bf7 100644 --- a/trinity/protocol/eth/peer.py +++ b/trinity/protocol/eth/peer.py @@ -37,8 +37,9 @@ ReceiptsBundles, NodeDataBundles, ) +from . import forkid -from .api import ETHAPI +from .api import ETHV63API, ETHAPI, AnyETHAPI from .commands import ( GetBlockHeaders, GetBlockBodies, @@ -66,24 +67,29 @@ NewBlockHashesEvent, TransactionsEvent, ) -from .payloads import StatusPayload -from .proto import ETHProtocol +from .payloads import StatusV63Payload, StatusPayload +from .proto import ETHProtocolV63, ETHProtocol from .proxy import ProxyETHAPI -from .handshaker import ETHHandshaker +from .handshaker import ETHV63Handshaker, ETHHandshaker class ETHPeer(BaseChainPeer): max_headers_fetch = MAX_HEADERS_FETCH - supported_sub_protocols = (ETHProtocol,) + supported_sub_protocols = (ETHProtocolV63, ETHProtocol) sub_proto: ETHProtocol = None def get_behaviors(self) -> Tuple[BehaviorAPI, ...]: - return super().get_behaviors() + (ETHAPI().as_behavior(),) + return super().get_behaviors() + (ETHV63API().as_behavior(), ETHAPI().as_behavior()) @cached_property - def eth_api(self) -> ETHAPI: - return self.connection.get_logic(ETHAPI.name, ETHAPI) + def eth_api(self) -> AnyETHAPI: + if self.connection.has_protocol(ETHProtocolV63): + return self.connection.get_logic(ETHV63API.name, ETHV63API) + elif self.connection.has_protocol(ETHProtocol): + return self.connection.get_logic(ETHAPI.name, ETHAPI) + else: + raise Exception("Unreachable code") def get_extra_stats(self) -> Tuple[str, ...]: basic_stats = super().get_extra_stats() @@ -94,7 +100,7 @@ def get_extra_stats(self) -> Tuple[str, ...]: class ETHProxyPeer(BaseProxyPeer): """ A ``ETHPeer`` that can be used from any process instead of the actual peer pool peer. - Any action performed on the ``BCCProxyPeer`` is delegated to the actual peer in the pool. + Any action performed on the ``ETHProxyPeer`` is delegated to the actual peer in the pool. This does not yet mimic all APIs of the real peer. """ @@ -121,7 +127,7 @@ def from_session(cls, class ETHPeerFactory(BaseChainPeerFactory): peer_class = ETHPeer - async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: + async def get_handshakers(self) -> Tuple[HandshakerAPI[Any], ...]: headerdb = self.context.headerdb wait = self.cancel_token.cancellable_wait @@ -131,15 +137,29 @@ async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: headerdb.coro_get_canonical_block_hash(BlockNumber(GENESIS_BLOCK_NUMBER)) ) + handshake_v63_params = StatusV63Payload( + head_hash=head.hash, + total_difficulty=total_difficulty, + genesis_hash=genesis_hash, + network_id=self.context.network_id, + version=ETHProtocolV63.version, + ) + + fork_blocks = forkid.extract_fork_blocks(self.context.vm_configuration) + our_forkid = forkid.make_forkid(genesis_hash, head.block_number, fork_blocks) + handshake_params = StatusPayload( head_hash=head.hash, total_difficulty=total_difficulty, genesis_hash=genesis_hash, network_id=self.context.network_id, version=ETHProtocol.version, + fork_id=our_forkid ) + return ( - ETHHandshaker(handshake_params), + ETHV63Handshaker(handshake_v63_params), + ETHHandshaker(handshake_params, head.block_number, fork_blocks), ) diff --git a/trinity/protocol/eth/proto.py b/trinity/protocol/eth/proto.py index 3afc1cd507..e4846fdaac 100644 --- a/trinity/protocol/eth/proto.py +++ b/trinity/protocol/eth/proto.py @@ -1,5 +1,7 @@ from typing import ( TYPE_CHECKING, + Union, + Type, ) from eth_utils import ( @@ -7,7 +9,6 @@ ) from p2p.protocol import BaseProtocol - from .commands import ( BlockBodies, BlockHeaders, @@ -19,17 +20,40 @@ NewBlockHashes, NodeData, Receipts, - Status, Transactions, + StatusV63, + Status, ) if TYPE_CHECKING: from .peer import ETHPeer # noqa: F401 -class ETHProtocol(BaseProtocol): +class BaseETHProtocol(BaseProtocol): name = 'eth' + status_command_type: Union[Type[StatusV63], Type[Status]] + + +class ETHProtocolV63(BaseETHProtocol): version = 63 + commands = ( + StatusV63, + NewBlockHashes, + Transactions, + GetBlockHeaders, BlockHeaders, + GetBlockBodies, BlockBodies, + NewBlock, + GetNodeData, NodeData, + GetReceipts, Receipts, + ) + command_length = 17 + + logger = get_extended_debug_logger('trinity.protocol.eth.proto.ETHProtocolV63') + status_command_type = StatusV63 + + +class ETHProtocol(BaseETHProtocol): + version = 64 commands = ( Status, NewBlockHashes, @@ -43,3 +67,4 @@ class ETHProtocol(BaseProtocol): command_length = 17 logger = get_extended_debug_logger('trinity.protocol.eth.proto.ETHProtocol') + status_command_type = Status diff --git a/trinity/protocol/les/handshaker.py b/trinity/protocol/les/handshaker.py index 4b70386291..e3f10f3c12 100644 --- a/trinity/protocol/les/handshaker.py +++ b/trinity/protocol/les/handshaker.py @@ -1,5 +1,4 @@ from typing import ( - cast, Type, Union, ) @@ -9,7 +8,7 @@ from eth_typing import BlockNumber, Hash32 from eth_utils import encode_hex -from p2p.abc import MultiplexerAPI, ProtocolAPI +from p2p.abc import MultiplexerAPI from p2p.exceptions import ( HandshakeFailure, ) @@ -57,7 +56,7 @@ def genesis_hash(self) -> Hash32: return self.handshake_params.genesis_hash -class BaseLESHandshaker(Handshaker): +class BaseLESHandshaker(Handshaker[Union[LESProtocolV1, LESProtocolV2]]): handshake_params: StatusPayload def __init__(self, handshake_params: StatusPayload) -> None: @@ -67,12 +66,12 @@ def __init__(self, handshake_params: StatusPayload) -> None: async def do_handshake(self, multiplexer: MultiplexerAPI, - protocol: ProtocolAPI) -> LESHandshakeReceipt: + protocol: Union[LESProtocolV1, LESProtocolV2]) -> LESHandshakeReceipt: """Perform the handshake for the sub-protocol agreed with the remote peer. Raises HandshakeFailure if the handshake is not successful. """ - protocol = cast(AnyLESProtocol, protocol) + protocol.send(protocol.status_command_type(self.handshake_params)) async for cmd in multiplexer.stream_protocol_messages(protocol): diff --git a/trinity/protocol/les/peer.py b/trinity/protocol/les/peer.py index 19c06c2bb3..1582fe4abd 100644 --- a/trinity/protocol/les/peer.py +++ b/trinity/protocol/les/peer.py @@ -114,7 +114,7 @@ def from_session(cls, class LESPeerFactory(BaseChainPeerFactory): peer_class = LESPeer - async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: + async def get_handshakers(self) -> Tuple[HandshakerAPI[Any], ...]: headerdb = self.context.headerdb wait = self.cancel_token.cancellable_wait diff --git a/trinity/tools/factories/__init__.py b/trinity/tools/factories/__init__.py index e2aeba8548..dc8407250c 100644 --- a/trinity/tools/factories/__init__.py +++ b/trinity/tools/factories/__init__.py @@ -16,6 +16,7 @@ from .eth.proto import ( # noqa: F401 ETHHandshakerFactory, ETHPeerPairFactory, + ETHV63PeerPairFactory, ) from .headers import BlockHeaderFactory # noqa: F401 from .receipts import ReceiptFactory # noqa: F401 diff --git a/trinity/tools/factories/eth/__init__.py b/trinity/tools/factories/eth/__init__.py index 4401160cd7..324337f0e4 100644 --- a/trinity/tools/factories/eth/__init__.py +++ b/trinity/tools/factories/eth/__init__.py @@ -1,4 +1,5 @@ from .payloads import ( # noqa: F401 + StatusV63PayloadFactory, StatusPayloadFactory, NewBlockHashFactory, NewBlockPayloadFactory, diff --git a/trinity/tools/factories/eth/payloads.py b/trinity/tools/factories/eth/payloads.py index 9940d728bc..df187f4073 100644 --- a/trinity/tools/factories/eth/payloads.py +++ b/trinity/tools/factories/eth/payloads.py @@ -1,3 +1,7 @@ +from eth_typing import BlockNumber +from eth_utils import to_bytes + +from trinity.protocol.eth.forkid import ForkID try: import factory @@ -13,18 +17,42 @@ from trinity.constants import MAINNET_NETWORK_ID from trinity.protocol.eth.payloads import ( - StatusPayload, + StatusV63Payload, NewBlockHash, NewBlockPayload, BlockFields, + StatusPayload, ) -from trinity.protocol.eth.proto import ETHProtocol +from trinity.protocol.eth.proto import ETHProtocolV63, ETHProtocol from trinity.tools.factories.block_hash import BlockHashFactory from trinity.tools.factories.headers import BlockHeaderFactory from trinity.tools.factories.transactions import BaseTransactionFieldsFactory +class StatusV63PayloadFactory(factory.Factory): + class Meta: + model = StatusV63Payload + + version = ETHProtocolV63.version + network_id = MAINNET_NETWORK_ID + total_difficulty = 1 + head_hash = factory.SubFactory(BlockHashFactory) + genesis_hash = factory.SubFactory(BlockHashFactory) + + @classmethod + def from_headerdb(cls, headerdb: HeaderDatabaseAPI, **kwargs: Any) -> StatusV63Payload: + head = headerdb.get_canonical_head() + head_score = headerdb.get_score(head.hash) + genesis = headerdb.get_canonical_block_header_by_number(GENESIS_BLOCK_NUMBER) + return cls( + head_hash=head.hash, + genesis_hash=genesis.hash, + total_difficulty=head_score, + **kwargs + ) + + class StatusPayloadFactory(factory.Factory): class Meta: model = StatusPayload @@ -34,6 +62,7 @@ class Meta: total_difficulty = 1 head_hash = factory.SubFactory(BlockHashFactory) genesis_hash = factory.SubFactory(BlockHashFactory) + fork_id = ForkID(to_bytes(hexstr='0xfc64ec04'), BlockNumber(1150000)) # unsynced @classmethod def from_headerdb(cls, headerdb: HeaderDatabaseAPI, **kwargs: Any) -> StatusPayload: @@ -44,6 +73,7 @@ def from_headerdb(cls, headerdb: HeaderDatabaseAPI, **kwargs: Any) -> StatusPayl head_hash=head.hash, genesis_hash=genesis.hash, total_difficulty=head_score, + fork_id=cls.fork_id, **kwargs ) diff --git a/trinity/tools/factories/eth/proto.py b/trinity/tools/factories/eth/proto.py index 1dc39089ca..cdf6c2bed6 100644 --- a/trinity/tools/factories/eth/proto.py +++ b/trinity/tools/factories/eth/proto.py @@ -1,3 +1,6 @@ +from p2p.abc import HandshakerAPI +from trinity.protocol.eth.proto import ETHProtocolV63 + try: import factory except ImportError: @@ -7,6 +10,7 @@ cast, AsyncContextManager, Tuple, + Any, ) from lahja import EndpointAPI @@ -38,11 +42,65 @@ class Meta: handshake_params = factory.SubFactory(StatusPayloadFactory) +class ETHV63Peer(ETHPeer): + supported_sub_protocols = (ETHProtocolV63,) # type: ignore + + +class ETHV63PeerFactory(ETHPeerFactory): + peer_class = ETHV63Peer + + async def get_handshakers(self) -> Tuple[HandshakerAPI[Any], ...]: + return tuple( + shaker for shaker in await super().get_handshakers() + # mypy doesn't know these have a `handshake_params` property + if shaker.handshake_params.version == ETHProtocolV63.version # type: ignore + ) + + +def ETHV63PeerPairFactory(*, + alice_peer_context: ChainContext = None, + alice_remote: kademlia.Node = None, + alice_private_key: keys.PrivateKey = None, + alice_client_version: str = 'alice', + bob_peer_context: ChainContext = None, + bob_remote: kademlia.Node = None, + bob_private_key: keys.PrivateKey = None, + bob_client_version: str = 'bob', + cancel_token: CancelToken = None, + event_bus: EndpointAPI = None, + ) -> AsyncContextManager[Tuple[ETHPeer, ETHPeer]]: + if alice_peer_context is None: + alice_peer_context = ChainContextFactory() + + if bob_peer_context is None: + alice_genesis = alice_peer_context.headerdb.get_canonical_block_header_by_number( + BlockNumber(GENESIS_BLOCK_NUMBER), + ) + bob_peer_context = ChainContextFactory( + headerdb__genesis_params={'timestamp': alice_genesis.timestamp}, + ) + + return cast(AsyncContextManager[Tuple[ETHPeer, ETHPeer]], PeerPairFactory( + alice_peer_context=alice_peer_context, + alice_peer_factory_class=ETHV63PeerFactory, + bob_peer_context=bob_peer_context, + bob_peer_factory_class=ETHV63PeerFactory, + alice_remote=alice_remote, + alice_private_key=alice_private_key, + alice_client_version=alice_client_version, + bob_remote=bob_remote, + bob_private_key=bob_private_key, + bob_client_version=bob_client_version, + cancel_token=cancel_token, + event_bus=event_bus, + )) + + def ETHPeerPairFactory(*, alice_peer_context: ChainContext = None, alice_remote: kademlia.Node = None, alice_private_key: keys.PrivateKey = None, - alice_client_version: str = 'bob', + alice_client_version: str = 'alice', bob_peer_context: ChainContext = None, bob_remote: kademlia.Node = None, bob_private_key: keys.PrivateKey = None, diff --git a/trinity/tools/factories/les/proto.py b/trinity/tools/factories/les/proto.py index f1abc9c229..b2d8b4bada 100644 --- a/trinity/tools/factories/les/proto.py +++ b/trinity/tools/factories/les/proto.py @@ -7,6 +7,7 @@ cast, AsyncContextManager, Tuple, + Any, ) from lahja import EndpointAPI @@ -54,7 +55,7 @@ class LESV1Peer(LESPeer): class LESV1PeerFactory(LESPeerFactory): peer_class = LESV1Peer - async def get_handshakers(self) -> Tuple[HandshakerAPI, ...]: + async def get_handshakers(self) -> Tuple[HandshakerAPI[Any], ...]: return tuple(filter( # mypy doesn't know these have a `handshake_params` property lambda handshaker: handshaker.handshake_params.version == 1, # type: ignore