diff --git a/chia/_tests/wallet/test_wallet_node.py b/chia/_tests/wallet/test_wallet_node.py index c8d32cedae5b..f9287ab2acac 100644 --- a/chia/_tests/wallet/test_wallet_node.py +++ b/chia/_tests/wallet/test_wallet_node.py @@ -5,7 +5,7 @@ import time import types from pathlib import Path -from typing import List, Optional +from typing import Any, List, Optional import pytest @@ -17,6 +17,7 @@ from chia.protocols.wallet_protocol import CoinState from chia.server.outbound_message import Message, make_msg from chia.simulator.block_tools import test_constants +from chia.types.blockchain_format.coin import Coin from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.full_block import FullBlock from chia.types.mempool_inclusion_status import MempoolInclusionStatus @@ -28,6 +29,7 @@ from chia.util.keychain import Keychain, KeyData, generate_mnemonic from chia.util.misc import to_batches from chia.wallet.util.tx_config import DEFAULT_TX_CONFIG +from chia.wallet.util.wallet_sync_utils import PeerRequestException from chia.wallet.wallet_node import Balance, WalletNode @@ -578,3 +580,43 @@ def check_wallet_cache_empty() -> bool: # Disconnect from the peer to make sure their entry in the cache is also deleted await simulator_and_wallet[1][0][0]._server.get_connections()[0].close(120) await time_out_assert(5, check_wallet_cache_empty, True) + + +@pytest.mark.limit_consensus_modes(reason="consensus rules irrelevant") +@pytest.mark.anyio +async def test_wallet_node_bad_coin_state_ignore( + self_hostname: str, simulator_and_wallet: OldSimulatorsAndWallets, monkeypatch: pytest.MonkeyPatch +) -> None: + [full_node_api], [(wallet_node, wallet_server)], _ = simulator_and_wallet + + await wallet_server.start_client(PeerInfo(self_hostname, full_node_api.server.get_port()), None) + + @api_request() + async def register_interest_in_coin( + self: Self, request: wallet_protocol.RegisterForCoinUpdates, *, test: bool = False + ) -> Optional[Message]: + return make_msg( + ProtocolMessageTypes.respond_to_coin_update, + wallet_protocol.RespondToCoinUpdates( + [], uint32(0), [CoinState(Coin(bytes32([0] * 32), bytes32([0] * 32), uint64(0)), uint32(0), uint32(0))] + ), + ) + + async def validate_received_state_from_peer(*args: Any) -> bool: + # It's an interesting case here where we don't hit this unless something is broken + return True # pragma: no cover + + assert full_node_api.full_node._server is not None + monkeypatch.setattr( + full_node_api.full_node._server.get_connections()[0].api, + "register_interest_in_coin", + types.MethodType(register_interest_in_coin, full_node_api.full_node._server.get_connections()[0].api), + ) + monkeypatch.setattr( + wallet_node, + "validate_received_state_from_peer", + types.MethodType(validate_received_state_from_peer, wallet_node), + ) + + with pytest.raises(PeerRequestException): + await wallet_node.get_coin_state([], wallet_node.get_full_node_peer()) diff --git a/chia/wallet/wallet_node.py b/chia/wallet/wallet_node.py index 23db895ae7a7..a7a2f94fc1e5 100644 --- a/chia/wallet/wallet_node.py +++ b/chia/wallet/wallet_node.py @@ -1634,6 +1634,10 @@ async def get_coin_state( if not self.is_trusted(peer): valid_list = [] for coin in coin_state.coin_states: + if coin.coin.name() not in coin_names: + await peer.close(9999) + self.log.warning(f"Peer {peer.peer_node_id} sent us an unrequested coin state. Banning.") + raise PeerRequestException(f"Peer sent us unrequested coin state {coin}") valid = await self.validate_received_state_from_peer( coin, peer, self.get_cache_for_peer(peer), fork_height )