From de75f058dd2650508e29e460d275cd35f9e03f2a Mon Sep 17 00:00:00 2001 From: Rigidity <35380458+Rigidity@users.noreply.github.com> Date: Thu, 6 Jun 2024 05:59:44 -0400 Subject: [PATCH] CHIA-194: CHIP-0026 Mempool Updates (#17980) * Initial draft of mempool updates [skip ci] * More implementation work, but capability issue [skip ci] * Temp (will probably revert) * Get first test passing * Setup spent coin tests * Add unrelated test data to make tests better * Refactor test * Finish testing * Update tests * Include removal reason and block inclusion tests * Remove unnecessary dict * Add missing message to test * Fix after rebase * Broadcast support from node for mempool updates * Remove redundant code * Use new_block * Filter and early return suggestions * Reword comments on PeakPostProcessingResult * Improve mempool fn name and doc comment * Bump wallet protocol version * Add ratelimits * Add tests for querying mempool items * Asserts and more tests * Add peers for spend bundle tests --- benchmarks/mempool.py | 6 +- chia/_tests/connection_utils.py | 21 +- .../core/full_node/test_subscriptions.py | 72 ++- chia/_tests/core/mempool/test_mempool.py | 2 +- .../core/mempool/test_mempool_item_queries.py | 191 ++++++++ .../core/mempool/test_mempool_manager.py | 2 +- .../util/build_network_protocol_files.py | 2 + chia/_tests/util/network_protocol_data.py | 21 + chia/_tests/util/protocol_messages_bytes-v1.0 | Bin 50511 -> 50560 bytes chia/_tests/util/protocol_messages_json.py | 11 + .../util/test_network_protocol_files.py | 208 ++++----- .../_tests/util/test_network_protocol_json.py | 4 + .../_tests/util/test_network_protocol_test.py | 9 +- .../_tests/wallet/test_new_wallet_protocol.py | 420 +++++++++++++++++- chia/clvm/spend_sim.py | 4 +- chia/full_node/full_node.py | 178 ++++++-- chia/full_node/full_node_api.py | 77 ++++ chia/full_node/hint_store.py | 16 + chia/full_node/mempool.py | 98 +++- chia/full_node/mempool_manager.py | 74 ++- chia/full_node/subscriptions.py | 38 ++ chia/protocols/protocol_message_types.py | 6 + chia/protocols/protocol_state_machine.py | 3 + chia/protocols/shared_protocol.py | 11 +- chia/protocols/wallet_protocol.py | 36 ++ chia/server/rate_limit_numbers.py | 4 + 26 files changed, 1324 insertions(+), 190 deletions(-) create mode 100644 chia/_tests/core/mempool/test_mempool_item_queries.py diff --git a/benchmarks/mempool.py b/benchmarks/mempool.py index e0852b4b1212..90c6c7605ae6 100644 --- a/benchmarks/mempool.py +++ b/benchmarks/mempool.py @@ -177,9 +177,9 @@ async def add_spend_bundles(spend_bundles: List[SpendBundle]) -> None: spend_bundle_id = tx.name() npc = await mempool.pre_validate_spendbundle(tx, None, spend_bundle_id) assert npc is not None - _, status, error = await mempool.add_spend_bundle(tx, npc, spend_bundle_id, height) - assert status == MempoolInclusionStatus.SUCCESS - assert error is None + info = await mempool.add_spend_bundle(tx, npc, spend_bundle_id, height) + assert info.status == MempoolInclusionStatus.SUCCESS + assert info.error is None suffix = "st" if single_threaded else "mt" diff --git a/chia/_tests/connection_utils.py b/chia/_tests/connection_utils.py index 12d6a05cdf7b..1ae27a1a0193 100644 --- a/chia/_tests/connection_utils.py +++ b/chia/_tests/connection_utils.py @@ -3,7 +3,7 @@ import asyncio import logging from pathlib import Path -from typing import Set, Tuple +from typing import List, Set, Tuple import aiohttp from cryptography import x509 @@ -39,15 +39,26 @@ async def disconnect_all_and_reconnect(server: ChiaServer, reconnect_to: ChiaSer async def add_dummy_connection( - server: ChiaServer, self_hostname: str, dummy_port: int, type: NodeType = NodeType.FULL_NODE + server: ChiaServer, + self_hostname: str, + dummy_port: int, + type: NodeType = NodeType.FULL_NODE, + *, + additional_capabilities: List[Tuple[uint16, str]] = [], ) -> Tuple[asyncio.Queue, bytes32]: - wsc, peer_id = await add_dummy_connection_wsc(server, self_hostname, dummy_port, type) + wsc, peer_id = await add_dummy_connection_wsc( + server, self_hostname, dummy_port, type, additional_capabilities=additional_capabilities + ) return wsc.incoming_queue, peer_id async def add_dummy_connection_wsc( - server: ChiaServer, self_hostname: str, dummy_port: int, type: NodeType = NodeType.FULL_NODE + server: ChiaServer, + self_hostname: str, + dummy_port: int, + type: NodeType = NodeType.FULL_NODE, + additional_capabilities: List[Tuple[uint16, str]] = [], ) -> Tuple[WSChiaConnection, bytes32]: timeout = aiohttp.ClientTimeout(total=10) session = aiohttp.ClientSession(timeout=timeout) @@ -86,7 +97,7 @@ async def add_dummy_connection_wsc( peer_id, 100, 30, - local_capabilities_for_handshake=default_capabilities[type], + local_capabilities_for_handshake=default_capabilities[type] + additional_capabilities, ) await wsc.perform_handshake(server._network_id, dummy_port, type) if wsc.incoming_message_task is not None: diff --git a/chia/_tests/core/full_node/test_subscriptions.py b/chia/_tests/core/full_node/test_subscriptions.py index 21d7c8d98476..7ecdd68d1405 100644 --- a/chia/_tests/core/full_node/test_subscriptions.py +++ b/chia/_tests/core/full_node/test_subscriptions.py @@ -1,10 +1,34 @@ from __future__ import annotations -from chia.full_node.subscriptions import PeerSubscriptions +from chia_rs import AugSchemeMPL, Coin, CoinSpend, Program +from chia_rs.sized_ints import uint32, uint64 + +from chia.consensus.default_constants import DEFAULT_CONSTANTS +from chia.full_node.bundle_tools import simple_solution_generator +from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions +from chia.full_node.subscriptions import PeerSubscriptions, peers_for_spend_bundle +from chia.types.blockchain_format.program import INFINITE_COST from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.spend_bundle import SpendBundle + +IDENTITY_PUZZLE = Program.to(1) +IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash() + +OTHER_PUZZLE = Program.to(2) +OTHER_PUZZLE_HASH = OTHER_PUZZLE.get_tree_hash() + +HINT_PUZZLE = Program.to(3) +HINT_PUZZLE_HASH = HINT_PUZZLE.get_tree_hash() + +IDENTITY_COIN = Coin(bytes32(b"0" * 32), IDENTITY_PUZZLE_HASH, uint64(1000)) +OTHER_COIN = Coin(bytes32(b"3" * 32), OTHER_PUZZLE_HASH, uint64(1000)) + +EMPTY_SIGNATURE = AugSchemeMPL.aggregate([]) peer1 = bytes32(b"1" * 32) peer2 = bytes32(b"2" * 32) +peer3 = bytes32(b"3" * 32) +peer4 = bytes32(b"4" * 32) coin1 = bytes32(b"a" * 32) coin2 = bytes32(b"b" * 32) @@ -420,3 +444,49 @@ def test_clear_subscriptions() -> None: subs.clear_puzzle_subscriptions(peer1) assert subs.peer_subscription_count(peer1) == 0 + + +def test_peers_for_spent_coin() -> None: + subs = PeerSubscriptions() + + subs.add_puzzle_subscriptions(peer1, [IDENTITY_PUZZLE_HASH], 1) + subs.add_puzzle_subscriptions(peer2, [HINT_PUZZLE_HASH], 1) + subs.add_coin_subscriptions(peer3, [IDENTITY_COIN.name()], 1) + subs.add_coin_subscriptions(peer4, [OTHER_COIN.name()], 1) + + coin_spends = [CoinSpend(IDENTITY_COIN, IDENTITY_PUZZLE, Program.to([]))] + + spend_bundle = SpendBundle(coin_spends, AugSchemeMPL.aggregate([])) + generator = simple_solution_generator(spend_bundle) + npc_result = get_name_puzzle_conditions( + generator=generator, max_cost=INFINITE_COST, mempool_mode=True, height=uint32(0), constants=DEFAULT_CONSTANTS + ) + assert npc_result.conds is not None + + peers = peers_for_spend_bundle(subs, npc_result.conds, {HINT_PUZZLE_HASH}) + assert peers == {peer1, peer2, peer3} + + +def test_peers_for_created_coin() -> None: + subs = PeerSubscriptions() + + new_coin = Coin(IDENTITY_COIN.name(), OTHER_PUZZLE_HASH, uint64(1000)) + + subs.add_puzzle_subscriptions(peer1, [OTHER_PUZZLE_HASH], 1) + subs.add_puzzle_subscriptions(peer2, [HINT_PUZZLE_HASH], 1) + subs.add_coin_subscriptions(peer3, [new_coin.name()], 1) + subs.add_coin_subscriptions(peer4, [OTHER_COIN.name()], 1) + + coin_spends = [ + CoinSpend(IDENTITY_COIN, IDENTITY_PUZZLE, Program.to([[51, OTHER_PUZZLE_HASH, 1000, [HINT_PUZZLE_HASH]]])) + ] + + spend_bundle = SpendBundle(coin_spends, AugSchemeMPL.aggregate([])) + generator = simple_solution_generator(spend_bundle) + npc_result = get_name_puzzle_conditions( + generator=generator, max_cost=INFINITE_COST, mempool_mode=True, height=uint32(0), constants=DEFAULT_CONSTANTS + ) + assert npc_result.conds is not None + + peers = peers_for_spend_bundle(subs, npc_result.conds, set()) + assert peers == {peer1, peer2, peer3} diff --git a/chia/_tests/core/mempool/test_mempool.py b/chia/_tests/core/mempool/test_mempool.py index 71bca0d4fcf1..dc752bc24834 100644 --- a/chia/_tests/core/mempool/test_mempool.py +++ b/chia/_tests/core/mempool/test_mempool.py @@ -2863,7 +2863,7 @@ def test_limit_expiring_transactions(height: bool, items: List[int], expected: L invariant_check_mempool(mempool) if increase_fee: fee_rate += 0.1 - assert ret is None + assert ret.error is None else: fee_rate -= 0.1 diff --git a/chia/_tests/core/mempool/test_mempool_item_queries.py b/chia/_tests/core/mempool/test_mempool_item_queries.py new file mode 100644 index 000000000000..bd41b3f8006f --- /dev/null +++ b/chia/_tests/core/mempool/test_mempool_item_queries.py @@ -0,0 +1,191 @@ +from __future__ import annotations + +from typing import List + +from chia_rs import AugSchemeMPL, Coin, Program +from chia_rs.sized_bytes import bytes32 +from chia_rs.sized_ints import uint32, uint64 + +from chia._tests.core.mempool.test_mempool_manager import TEST_HEIGHT, make_bundle_spends_map_and_fee +from chia.consensus.default_constants import DEFAULT_CONSTANTS +from chia.full_node.bitcoin_fee_estimator import create_bitcoin_fee_estimator +from chia.full_node.bundle_tools import simple_solution_generator +from chia.full_node.fee_estimation import MempoolInfo +from chia.full_node.mempool import Mempool +from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions +from chia.types.blockchain_format.program import INFINITE_COST +from chia.types.clvm_cost import CLVMCost +from chia.types.coin_spend import CoinSpend +from chia.types.fee_rate import FeeRate +from chia.types.mempool_item import MempoolItem +from chia.types.spend_bundle import SpendBundle + +MEMPOOL_INFO = MempoolInfo( + max_size_in_cost=CLVMCost(uint64(INFINITE_COST * 10)), + minimum_fee_per_cost_to_replace=FeeRate(uint64(5)), + max_block_clvm_cost=CLVMCost(uint64(INFINITE_COST)), +) + +IDENTITY_PUZZLE = Program.to(1) +IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash() + +OTHER_PUZZLE = Program.to(2) +OTHER_PUZZLE_HASH = OTHER_PUZZLE.get_tree_hash() + +IDENTITY_COIN_1 = Coin(bytes32(b"0" * 32), IDENTITY_PUZZLE_HASH, uint64(1000)) +IDENTITY_COIN_2 = Coin(bytes32(b"1" * 32), IDENTITY_PUZZLE_HASH, uint64(1000)) +IDENTITY_COIN_3 = Coin(bytes32(b"2" * 32), IDENTITY_PUZZLE_HASH, uint64(1000)) + +OTHER_COIN_1 = Coin(bytes32(b"3" * 32), OTHER_PUZZLE_HASH, uint64(1000)) +OTHER_COIN_2 = Coin(bytes32(b"4" * 32), OTHER_PUZZLE_HASH, uint64(1000)) +OTHER_COIN_3 = Coin(bytes32(b"5" * 32), OTHER_PUZZLE_HASH, uint64(1000)) + +EMPTY_SIGNATURE = AugSchemeMPL.aggregate([]) + + +def make_item(coin_spends: List[CoinSpend]) -> MempoolItem: + spend_bundle = SpendBundle(coin_spends, EMPTY_SIGNATURE) + generator = simple_solution_generator(spend_bundle) + npc_result = get_name_puzzle_conditions( + generator=generator, max_cost=INFINITE_COST, mempool_mode=True, height=uint32(0), constants=DEFAULT_CONSTANTS + ) + bundle_coin_spends, fee = make_bundle_spends_map_and_fee(spend_bundle, npc_result) + return MempoolItem( + spend_bundle=spend_bundle, + fee=fee, + npc_result=npc_result, + spend_bundle_name=spend_bundle.name(), + height_added_to_mempool=TEST_HEIGHT, + bundle_coin_spends=bundle_coin_spends, + ) + + +def test_empty_pool() -> None: + fee_estimator = create_bitcoin_fee_estimator(uint64(INFINITE_COST)) + mempool = Mempool(MEMPOOL_INFO, fee_estimator) + assert mempool.items_with_coin_ids({IDENTITY_COIN_1.name()}) == [] + assert mempool.items_with_puzzle_hashes({IDENTITY_PUZZLE_HASH}, False) == [] + + +def test_by_spent_coin_ids() -> None: + fee_estimator = create_bitcoin_fee_estimator(uint64(INFINITE_COST)) + mempool = Mempool(MEMPOOL_INFO, fee_estimator) + + # Add an item with both queried coins, to ensure there are no duplicates in the response. + item_1 = make_item( + [ + CoinSpend(IDENTITY_COIN_1, IDENTITY_PUZZLE, Program.to([])), + CoinSpend(IDENTITY_COIN_2, IDENTITY_PUZZLE, Program.to([])), + ] + ) + mempool.add_to_pool(item_1) + + # Another coin with the same puzzle hash shouldn't match. + other = make_item( + [ + CoinSpend(IDENTITY_COIN_3, IDENTITY_PUZZLE, Program.to([])), + ] + ) + mempool.add_to_pool(other) + + # And this coin is completely unrelated. + other = make_item([CoinSpend(OTHER_COIN_1, OTHER_PUZZLE, Program.to([[]]))]) + mempool.add_to_pool(other) + + # Only the first transaction includes these coins. + assert mempool.items_with_coin_ids({IDENTITY_COIN_1.name(), IDENTITY_COIN_2.name()}) == [item_1.spend_bundle_name] + assert mempool.items_with_coin_ids({IDENTITY_COIN_1.name()}) == [item_1.spend_bundle_name] + assert mempool.items_with_coin_ids({OTHER_COIN_2.name(), OTHER_COIN_3.name()}) == [] + + +def test_by_spend_puzzle_hashes() -> None: + fee_estimator = create_bitcoin_fee_estimator(uint64(INFINITE_COST)) + mempool = Mempool(MEMPOOL_INFO, fee_estimator) + + # Add a transaction with the queried puzzle hashes. + item_1 = make_item( + [ + CoinSpend(IDENTITY_COIN_1, IDENTITY_PUZZLE, Program.to([])), + CoinSpend(IDENTITY_COIN_2, IDENTITY_PUZZLE, Program.to([])), + ] + ) + mempool.add_to_pool(item_1) + + # Another coin with the same puzzle hash should match. + item_2 = make_item( + [ + CoinSpend(IDENTITY_COIN_3, IDENTITY_PUZZLE, Program.to([])), + ] + ) + mempool.add_to_pool(item_2) + + # But this coin has a different puzzle hash. + other = make_item([CoinSpend(OTHER_COIN_1, OTHER_PUZZLE, Program.to([[]]))]) + mempool.add_to_pool(other) + + # Only the first two transactions include the puzzle hash. + assert mempool.items_with_puzzle_hashes({IDENTITY_PUZZLE_HASH}, False) == [ + item_1.spend_bundle_name, + item_2.spend_bundle_name, + ] + + # Test the other puzzle hash as well. + assert mempool.items_with_puzzle_hashes({OTHER_PUZZLE_HASH}, False) == [ + other.spend_bundle_name, + ] + + # And an unrelated puzzle hash. + assert mempool.items_with_puzzle_hashes({bytes32(b"0" * 32)}, False) == [] + + +def test_by_created_coin_id() -> None: + fee_estimator = create_bitcoin_fee_estimator(uint64(INFINITE_COST)) + mempool = Mempool(MEMPOOL_INFO, fee_estimator) + + # Add a transaction that creates the queried coin id. + item = make_item( + [ + CoinSpend(IDENTITY_COIN_1, IDENTITY_PUZZLE, Program.to([[51, IDENTITY_PUZZLE_HASH, 1000]])), + ] + ) + mempool.add_to_pool(item) + + # Test that the transaction is found. + assert mempool.items_with_coin_ids({Coin(IDENTITY_COIN_1.name(), IDENTITY_PUZZLE_HASH, uint64(1000)).name()}) == [ + item.spend_bundle_name + ] + + +def test_by_created_puzzle_hash() -> None: + fee_estimator = create_bitcoin_fee_estimator(uint64(INFINITE_COST)) + mempool = Mempool(MEMPOOL_INFO, fee_estimator) + + # Add a transaction that creates the queried puzzle hash. + item_1 = make_item( + [ + CoinSpend( + IDENTITY_COIN_1, + IDENTITY_PUZZLE, + Program.to([[51, OTHER_PUZZLE_HASH, 400], [51, OTHER_PUZZLE_HASH, 600]]), + ), + ] + ) + mempool.add_to_pool(item_1) + + # This one is hinted. + item_2 = make_item( + [ + CoinSpend( + IDENTITY_COIN_2, + IDENTITY_PUZZLE, + Program.to([[51, IDENTITY_PUZZLE_HASH, 1000, [OTHER_PUZZLE_HASH]]]), + ), + ] + ) + mempool.add_to_pool(item_2) + + # Test that the transactions are both found. + assert mempool.items_with_puzzle_hashes({OTHER_PUZZLE_HASH}, include_hints=True) == [ + item_1.spend_bundle_name, + item_2.spend_bundle_name, + ] diff --git a/chia/_tests/core/mempool/test_mempool_manager.py b/chia/_tests/core/mempool/test_mempool_manager.py index 039af2e91020..acc2ebda8a29 100644 --- a/chia/_tests/core/mempool/test_mempool_manager.py +++ b/chia/_tests/core/mempool/test_mempool_manager.py @@ -388,7 +388,7 @@ async def add_spendbundle( npc_result = await mempool_manager.pre_validate_spendbundle(sb, None, sb_name) ret = await mempool_manager.add_spend_bundle(sb, npc_result, sb_name, TEST_HEIGHT) invariant_check_mempool(mempool_manager.mempool) - return ret + return ret.cost, ret.status, ret.error async def generate_and_add_spendbundle( diff --git a/chia/_tests/util/build_network_protocol_files.py b/chia/_tests/util/build_network_protocol_files.py index cf8d285b6a37..f83f995cd731 100644 --- a/chia/_tests/util/build_network_protocol_files.py +++ b/chia/_tests/util/build_network_protocol_files.py @@ -107,6 +107,8 @@ def visit_wallet_protocol(visitor: Callable[[Any, str], None]) -> None: visitor(request_coin_state, "request_coin_state") visitor(respond_coin_state, "respond_coin_state") visitor(reject_coin_state, "reject_coin_state") + visitor(request_cost_info, "request_cost_info") + visitor(respond_cost_info, "respond_cost_info") def visit_harvester_protocol(visitor: Callable[[Any, str], None]) -> None: diff --git a/chia/_tests/util/network_protocol_data.py b/chia/_tests/util/network_protocol_data.py index dca478b2168e..76b0c60cd738 100644 --- a/chia/_tests/util/network_protocol_data.py +++ b/chia/_tests/util/network_protocol_data.py @@ -733,6 +733,27 @@ uint8(wallet_protocol.RejectStateReason.EXCEEDED_SUBSCRIPTION_LIMIT) ) +removed_mempool_item = wallet_protocol.RemovedMempoolItem( + bytes32(bytes.fromhex("59710628755b6d7f7d0b5d84d5c980e7a1c52e55f5a43b531312402bd9045da7")), uint8(1) +) + +mempool_items_added = wallet_protocol.MempoolItemsAdded( + [bytes32(bytes.fromhex("59710628755b6d7f7d0b5d84d5c980e7a1c52e55f5a43b531312402bd9045da7"))] +) + +mempool_items_removed = wallet_protocol.MempoolItemsRemoved([removed_mempool_item]) + +request_cost_info = wallet_protocol.RequestCostInfo() + +respond_cost_info = wallet_protocol.RespondCostInfo( + max_transaction_cost=uint64(100000), + max_block_cost=uint64(1000000), + max_mempool_cost=uint64(10000000), + mempool_cost=uint64(50000), + mempool_fee=uint64(500000), + bump_fee_per_cost=uint8(10), +) + ### HARVESTER PROTOCOL pool_difficulty = harvester_protocol.PoolDifficulty( diff --git a/chia/_tests/util/protocol_messages_bytes-v1.0 b/chia/_tests/util/protocol_messages_bytes-v1.0 index cb8bf94650abd32592e15167707313d8dad9f326..30d2eacdab5d304af0175b5544bc5bdd27c83581 100644 GIT binary patch delta 61 zcmX@##oW-%ydhwQ9s>|)f(Qo2wgq6C-^l??&zRN#rWp None: assert bytes(message_74) == bytes(reject_coin_state) message_bytes, input_bytes = parse_blob(input_bytes) - message_75 = type(pool_difficulty).from_bytes(message_bytes) - assert message_75 == pool_difficulty - assert bytes(message_75) == bytes(pool_difficulty) + message_75 = type(request_cost_info).from_bytes(message_bytes) + assert message_75 == request_cost_info + assert bytes(message_75) == bytes(request_cost_info) message_bytes, input_bytes = parse_blob(input_bytes) - message_76 = type(harvester_handhsake).from_bytes(message_bytes) - assert message_76 == harvester_handhsake - assert bytes(message_76) == bytes(harvester_handhsake) + message_76 = type(respond_cost_info).from_bytes(message_bytes) + assert message_76 == respond_cost_info + assert bytes(message_76) == bytes(respond_cost_info) message_bytes, input_bytes = parse_blob(input_bytes) - message_77 = type(new_signage_point_harvester).from_bytes(message_bytes) - assert message_77 == new_signage_point_harvester - assert bytes(message_77) == bytes(new_signage_point_harvester) + message_77 = type(pool_difficulty).from_bytes(message_bytes) + assert message_77 == pool_difficulty + assert bytes(message_77) == bytes(pool_difficulty) message_bytes, input_bytes = parse_blob(input_bytes) - message_78 = type(new_proof_of_space).from_bytes(message_bytes) - assert message_78 == new_proof_of_space - assert bytes(message_78) == bytes(new_proof_of_space) + message_78 = type(harvester_handhsake).from_bytes(message_bytes) + assert message_78 == harvester_handhsake + assert bytes(message_78) == bytes(harvester_handhsake) message_bytes, input_bytes = parse_blob(input_bytes) - message_79 = type(request_signatures).from_bytes(message_bytes) - assert message_79 == request_signatures - assert bytes(message_79) == bytes(request_signatures) + message_79 = type(new_signage_point_harvester).from_bytes(message_bytes) + assert message_79 == new_signage_point_harvester + assert bytes(message_79) == bytes(new_signage_point_harvester) message_bytes, input_bytes = parse_blob(input_bytes) - message_80 = type(respond_signatures).from_bytes(message_bytes) - assert message_80 == respond_signatures - assert bytes(message_80) == bytes(respond_signatures) + message_80 = type(new_proof_of_space).from_bytes(message_bytes) + assert message_80 == new_proof_of_space + assert bytes(message_80) == bytes(new_proof_of_space) message_bytes, input_bytes = parse_blob(input_bytes) - message_81 = type(plot).from_bytes(message_bytes) - assert message_81 == plot - assert bytes(message_81) == bytes(plot) + message_81 = type(request_signatures).from_bytes(message_bytes) + assert message_81 == request_signatures + assert bytes(message_81) == bytes(request_signatures) message_bytes, input_bytes = parse_blob(input_bytes) - message_82 = type(request_plots).from_bytes(message_bytes) - assert message_82 == request_plots - assert bytes(message_82) == bytes(request_plots) + message_82 = type(respond_signatures).from_bytes(message_bytes) + assert message_82 == respond_signatures + assert bytes(message_82) == bytes(respond_signatures) message_bytes, input_bytes = parse_blob(input_bytes) - message_83 = type(respond_plots).from_bytes(message_bytes) - assert message_83 == respond_plots - assert bytes(message_83) == bytes(respond_plots) + message_83 = type(plot).from_bytes(message_bytes) + assert message_83 == plot + assert bytes(message_83) == bytes(plot) message_bytes, input_bytes = parse_blob(input_bytes) - message_84 = type(request_peers_introducer).from_bytes(message_bytes) - assert message_84 == request_peers_introducer - assert bytes(message_84) == bytes(request_peers_introducer) + message_84 = type(request_plots).from_bytes(message_bytes) + assert message_84 == request_plots + assert bytes(message_84) == bytes(request_plots) message_bytes, input_bytes = parse_blob(input_bytes) - message_85 = type(respond_peers_introducer).from_bytes(message_bytes) - assert message_85 == respond_peers_introducer - assert bytes(message_85) == bytes(respond_peers_introducer) + message_85 = type(respond_plots).from_bytes(message_bytes) + assert message_85 == respond_plots + assert bytes(message_85) == bytes(respond_plots) message_bytes, input_bytes = parse_blob(input_bytes) - message_86 = type(authentication_payload).from_bytes(message_bytes) - assert message_86 == authentication_payload - assert bytes(message_86) == bytes(authentication_payload) + message_86 = type(request_peers_introducer).from_bytes(message_bytes) + assert message_86 == request_peers_introducer + assert bytes(message_86) == bytes(request_peers_introducer) message_bytes, input_bytes = parse_blob(input_bytes) - message_87 = type(get_pool_info_response).from_bytes(message_bytes) - assert message_87 == get_pool_info_response - assert bytes(message_87) == bytes(get_pool_info_response) + message_87 = type(respond_peers_introducer).from_bytes(message_bytes) + assert message_87 == respond_peers_introducer + assert bytes(message_87) == bytes(respond_peers_introducer) message_bytes, input_bytes = parse_blob(input_bytes) - message_88 = type(post_partial_payload).from_bytes(message_bytes) - assert message_88 == post_partial_payload - assert bytes(message_88) == bytes(post_partial_payload) + message_88 = type(authentication_payload).from_bytes(message_bytes) + assert message_88 == authentication_payload + assert bytes(message_88) == bytes(authentication_payload) message_bytes, input_bytes = parse_blob(input_bytes) - message_89 = type(post_partial_request).from_bytes(message_bytes) - assert message_89 == post_partial_request - assert bytes(message_89) == bytes(post_partial_request) + message_89 = type(get_pool_info_response).from_bytes(message_bytes) + assert message_89 == get_pool_info_response + assert bytes(message_89) == bytes(get_pool_info_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_90 = type(post_partial_response).from_bytes(message_bytes) - assert message_90 == post_partial_response - assert bytes(message_90) == bytes(post_partial_response) + message_90 = type(post_partial_payload).from_bytes(message_bytes) + assert message_90 == post_partial_payload + assert bytes(message_90) == bytes(post_partial_payload) message_bytes, input_bytes = parse_blob(input_bytes) - message_91 = type(get_farmer_response).from_bytes(message_bytes) - assert message_91 == get_farmer_response - assert bytes(message_91) == bytes(get_farmer_response) + message_91 = type(post_partial_request).from_bytes(message_bytes) + assert message_91 == post_partial_request + assert bytes(message_91) == bytes(post_partial_request) message_bytes, input_bytes = parse_blob(input_bytes) - message_92 = type(post_farmer_payload).from_bytes(message_bytes) - assert message_92 == post_farmer_payload - assert bytes(message_92) == bytes(post_farmer_payload) + message_92 = type(post_partial_response).from_bytes(message_bytes) + assert message_92 == post_partial_response + assert bytes(message_92) == bytes(post_partial_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_93 = type(post_farmer_request).from_bytes(message_bytes) - assert message_93 == post_farmer_request - assert bytes(message_93) == bytes(post_farmer_request) + message_93 = type(get_farmer_response).from_bytes(message_bytes) + assert message_93 == get_farmer_response + assert bytes(message_93) == bytes(get_farmer_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_94 = type(post_farmer_response).from_bytes(message_bytes) - assert message_94 == post_farmer_response - assert bytes(message_94) == bytes(post_farmer_response) + message_94 = type(post_farmer_payload).from_bytes(message_bytes) + assert message_94 == post_farmer_payload + assert bytes(message_94) == bytes(post_farmer_payload) message_bytes, input_bytes = parse_blob(input_bytes) - message_95 = type(put_farmer_payload).from_bytes(message_bytes) - assert message_95 == put_farmer_payload - assert bytes(message_95) == bytes(put_farmer_payload) + message_95 = type(post_farmer_request).from_bytes(message_bytes) + assert message_95 == post_farmer_request + assert bytes(message_95) == bytes(post_farmer_request) message_bytes, input_bytes = parse_blob(input_bytes) - message_96 = type(put_farmer_request).from_bytes(message_bytes) - assert message_96 == put_farmer_request - assert bytes(message_96) == bytes(put_farmer_request) + message_96 = type(post_farmer_response).from_bytes(message_bytes) + assert message_96 == post_farmer_response + assert bytes(message_96) == bytes(post_farmer_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_97 = type(put_farmer_response).from_bytes(message_bytes) - assert message_97 == put_farmer_response - assert bytes(message_97) == bytes(put_farmer_response) + message_97 = type(put_farmer_payload).from_bytes(message_bytes) + assert message_97 == put_farmer_payload + assert bytes(message_97) == bytes(put_farmer_payload) message_bytes, input_bytes = parse_blob(input_bytes) - message_98 = type(error_response).from_bytes(message_bytes) - assert message_98 == error_response - assert bytes(message_98) == bytes(error_response) + message_98 = type(put_farmer_request).from_bytes(message_bytes) + assert message_98 == put_farmer_request + assert bytes(message_98) == bytes(put_farmer_request) message_bytes, input_bytes = parse_blob(input_bytes) - message_99 = type(new_peak_timelord).from_bytes(message_bytes) - assert message_99 == new_peak_timelord - assert bytes(message_99) == bytes(new_peak_timelord) + message_99 = type(put_farmer_response).from_bytes(message_bytes) + assert message_99 == put_farmer_response + assert bytes(message_99) == bytes(put_farmer_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_100 = type(new_unfinished_block_timelord).from_bytes(message_bytes) - assert message_100 == new_unfinished_block_timelord - assert bytes(message_100) == bytes(new_unfinished_block_timelord) + message_100 = type(error_response).from_bytes(message_bytes) + assert message_100 == error_response + assert bytes(message_100) == bytes(error_response) message_bytes, input_bytes = parse_blob(input_bytes) - message_101 = type(new_infusion_point_vdf).from_bytes(message_bytes) - assert message_101 == new_infusion_point_vdf - assert bytes(message_101) == bytes(new_infusion_point_vdf) + message_101 = type(new_peak_timelord).from_bytes(message_bytes) + assert message_101 == new_peak_timelord + assert bytes(message_101) == bytes(new_peak_timelord) message_bytes, input_bytes = parse_blob(input_bytes) - message_102 = type(new_signage_point_vdf).from_bytes(message_bytes) - assert message_102 == new_signage_point_vdf - assert bytes(message_102) == bytes(new_signage_point_vdf) + message_102 = type(new_unfinished_block_timelord).from_bytes(message_bytes) + assert message_102 == new_unfinished_block_timelord + assert bytes(message_102) == bytes(new_unfinished_block_timelord) message_bytes, input_bytes = parse_blob(input_bytes) - message_103 = type(new_end_of_sub_slot_bundle).from_bytes(message_bytes) - assert message_103 == new_end_of_sub_slot_bundle - assert bytes(message_103) == bytes(new_end_of_sub_slot_bundle) + message_103 = type(new_infusion_point_vdf).from_bytes(message_bytes) + assert message_103 == new_infusion_point_vdf + assert bytes(message_103) == bytes(new_infusion_point_vdf) message_bytes, input_bytes = parse_blob(input_bytes) - message_104 = type(request_compact_proof_of_time).from_bytes(message_bytes) - assert message_104 == request_compact_proof_of_time - assert bytes(message_104) == bytes(request_compact_proof_of_time) + message_104 = type(new_signage_point_vdf).from_bytes(message_bytes) + assert message_104 == new_signage_point_vdf + assert bytes(message_104) == bytes(new_signage_point_vdf) message_bytes, input_bytes = parse_blob(input_bytes) - message_105 = type(respond_compact_proof_of_time).from_bytes(message_bytes) - assert message_105 == respond_compact_proof_of_time - assert bytes(message_105) == bytes(respond_compact_proof_of_time) + message_105 = type(new_end_of_sub_slot_bundle).from_bytes(message_bytes) + assert message_105 == new_end_of_sub_slot_bundle + assert bytes(message_105) == bytes(new_end_of_sub_slot_bundle) message_bytes, input_bytes = parse_blob(input_bytes) - message_106 = type(error_without_data).from_bytes(message_bytes) - assert message_106 == error_without_data - assert bytes(message_106) == bytes(error_without_data) + message_106 = type(request_compact_proof_of_time).from_bytes(message_bytes) + assert message_106 == request_compact_proof_of_time + assert bytes(message_106) == bytes(request_compact_proof_of_time) message_bytes, input_bytes = parse_blob(input_bytes) - message_107 = type(error_with_data).from_bytes(message_bytes) - assert message_107 == error_with_data - assert bytes(message_107) == bytes(error_with_data) + message_107 = type(respond_compact_proof_of_time).from_bytes(message_bytes) + assert message_107 == respond_compact_proof_of_time + assert bytes(message_107) == bytes(respond_compact_proof_of_time) + + message_bytes, input_bytes = parse_blob(input_bytes) + message_108 = type(error_without_data).from_bytes(message_bytes) + assert message_108 == error_without_data + assert bytes(message_108) == bytes(error_without_data) + + message_bytes, input_bytes = parse_blob(input_bytes) + message_109 = type(error_with_data).from_bytes(message_bytes) + assert message_109 == error_with_data + assert bytes(message_109) == bytes(error_with_data) assert input_bytes == b"" diff --git a/chia/_tests/util/test_network_protocol_json.py b/chia/_tests/util/test_network_protocol_json.py index 7f7b2e38da0e..d68571a3bde9 100644 --- a/chia/_tests/util/test_network_protocol_json.py +++ b/chia/_tests/util/test_network_protocol_json.py @@ -180,6 +180,10 @@ def test_protocol_json() -> None: assert type(respond_coin_state).from_json_dict(respond_coin_state_json) == respond_coin_state assert str(reject_coin_state_json) == str(reject_coin_state.to_json_dict()) assert type(reject_coin_state).from_json_dict(reject_coin_state_json) == reject_coin_state + assert str(request_cost_info_json) == str(request_cost_info.to_json_dict()) + assert type(request_cost_info).from_json_dict(request_cost_info_json) == request_cost_info + assert str(respond_cost_info_json) == str(respond_cost_info.to_json_dict()) + assert type(respond_cost_info).from_json_dict(respond_cost_info_json) == respond_cost_info assert str(pool_difficulty_json) == str(pool_difficulty.to_json_dict()) assert type(pool_difficulty).from_json_dict(pool_difficulty_json) == pool_difficulty assert str(harvester_handhsake_json) == str(harvester_handhsake.to_json_dict()) diff --git a/chia/_tests/util/test_network_protocol_test.py b/chia/_tests/util/test_network_protocol_test.py index 2d721429565f..766f80efd823 100644 --- a/chia/_tests/util/test_network_protocol_test.py +++ b/chia/_tests/util/test_network_protocol_test.py @@ -41,10 +41,10 @@ def test_missing_messages_state_machine() -> None: # to the visitor in build_network_protocol_files.py and rerun it. Then # update this test assert ( - len(VALID_REPLY_MESSAGE_MAP) == 25 + len(VALID_REPLY_MESSAGE_MAP) == 26 ), "A message was added to the protocol state machine. Make sure to update the protocol message regression test to include the new message" assert ( - len(NO_REPLY_EXPECTED) == 8 + len(NO_REPLY_EXPECTED) == 10 ), "A message was added to the protocol state machine. Make sure to update the protocol message regression test to include the new message" @@ -77,6 +77,8 @@ def test_missing_messages() -> None: "CoinState", "CoinStateFilters", "CoinStateUpdate", + "MempoolItemsAdded", + "MempoolItemsRemoved", "NewPeakWallet", "PuzzleSolutionResponse", "RegisterForCoinUpdates", @@ -90,11 +92,13 @@ def test_missing_messages() -> None: "RejectPuzzleState", "RejectRemovalsRequest", "RejectStateReason", + "RemovedMempoolItem", "RequestAdditions", "RequestBlockHeader", "RequestBlockHeaders", "RequestChildren", "RequestCoinState", + "RequestCostInfo", "RequestFeeEstimates", "RequestHeaderBlocks", "RequestPuzzleSolution", @@ -108,6 +112,7 @@ def test_missing_messages() -> None: "RespondBlockHeaders", "RespondChildren", "RespondCoinState", + "RespondCostInfo", "RespondFeeEstimates", "RespondHeaderBlocks", "RespondPuzzleSolution", diff --git a/chia/_tests/wallet/test_new_wallet_protocol.py b/chia/_tests/wallet/test_new_wallet_protocol.py index d8593554da8b..c59560e24cdd 100644 --- a/chia/_tests/wallet/test_new_wallet_protocol.py +++ b/chia/_tests/wallet/test_new_wallet_protocol.py @@ -6,11 +6,15 @@ from typing import AsyncGenerator, Dict, List, Optional, OrderedDict, Set, Tuple import pytest -from chia_rs import Coin, CoinState +from chia_rs import AugSchemeMPL, Coin, CoinSpend, CoinState, Program from chia._tests.connection_utils import add_dummy_connection from chia.full_node.coin_store import CoinStore +from chia.full_node.full_node import FullNode +from chia.full_node.mempool import MempoolRemoveReason from chia.protocols import wallet_protocol +from chia.protocols.protocol_message_types import ProtocolMessageTypes +from chia.protocols.shared_protocol import Capability from chia.server.outbound_message import Message, NodeType from chia.server.ws_connection import WSChiaConnection from chia.simulator import simulator_protocol @@ -20,21 +24,34 @@ from chia.types.aliases import WalletService from chia.types.blockchain_format.sized_bytes import bytes32 from chia.types.coin_record import CoinRecord +from chia.types.mempool_inclusion_status import MempoolInclusionStatus +from chia.types.spend_bundle import SpendBundle from chia.util.hash import std_hash -from chia.util.ints import uint8, uint32, uint64 +from chia.util.ints import uint8, uint16, uint32, uint64 + +IDENTITY_PUZZLE = Program.to(1) +IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash() OneNode = Tuple[List[SimulatorFullNodeService], List[WalletService], BlockTools] +ALL_FILTER = wallet_protocol.CoinStateFilters(True, True, True, uint64(0)) + async def connect_to_simulator( - one_node: OneNode, self_hostname: str + one_node: OneNode, self_hostname: str, mempool_updates: bool = True ) -> Tuple[FullNodeSimulator, Queue[Message], WSChiaConnection]: [full_node_service], _, _ = one_node full_node_api = full_node_service._api fn_server = full_node_api.server - incoming_queue, peer_id = await add_dummy_connection(fn_server, self_hostname, 41723, NodeType.WALLET) + incoming_queue, peer_id = await add_dummy_connection( + fn_server, + self_hostname, + 41723, + NodeType.WALLET, + additional_capabilities=[(uint16(Capability.MEMPOOL_UPDATES), "1")] if mempool_updates else [], + ) peer = fn_server.all_connections[peer_id] return full_node_api, incoming_queue, peer @@ -760,3 +777,398 @@ async def run_test(include_spent: bool, include_unspent: bool, include_hinted: b for include_hinted in [True, False]: for min_amount in [0, 100000, 500000000]: await run_test(include_spent, include_unspent, include_hinted, uint64(min_amount)) + + +async def assert_mempool_added(queue: Queue[Message], transaction_ids: Set[bytes32]) -> None: + message = await queue.get() + assert message.type == ProtocolMessageTypes.mempool_items_added.value + + update = wallet_protocol.MempoolItemsAdded.from_bytes(message.data) + assert set(update.transaction_ids) == transaction_ids + + +async def assert_mempool_removed( + queue: Queue[Message], + removed_items: Set[wallet_protocol.RemovedMempoolItem], +) -> None: + message = await queue.get() + assert message.type == ProtocolMessageTypes.mempool_items_removed.value + + update = wallet_protocol.MempoolItemsRemoved.from_bytes(message.data) + assert set(update.removed_items) == removed_items + + +Mpu = Tuple[FullNodeSimulator, Queue[Message], WSChiaConnection] + + +@pytest.fixture +async def mpu_setup(one_node: OneNode, self_hostname: str) -> Mpu: + return await raw_mpu_setup(one_node, self_hostname) + + +@pytest.fixture +async def mpu_setup_no_capability(one_node: OneNode, self_hostname: str) -> Mpu: + return await raw_mpu_setup(one_node, self_hostname, no_capability=True) + + +async def raw_mpu_setup(one_node: OneNode, self_hostname: str, no_capability: bool = False) -> Mpu: + simulator, queue, peer = await connect_to_simulator(one_node, self_hostname, mempool_updates=not no_capability) + await simulator.farm_blocks_to_puzzlehash(1) + await queue.get() + + new_coins: List[Tuple[Coin, bytes32]] = [] + + for i in range(10): + puzzle = Program.to(2) + ph = puzzle.get_tree_hash() + coin = Coin(std_hash(b"unrelated coin id" + i.to_bytes(4, "big")), ph, uint64(1)) + hint = std_hash(b"unrelated hint" + i.to_bytes(4, "big")) + new_coins.append((coin, hint)) + + reward_1 = Coin(std_hash(b"reward 1"), std_hash(b"reward puzzle hash"), uint64(1000)) + reward_2 = Coin(std_hash(b"reward 2"), std_hash(b"reward puzzle hash"), uint64(2000)) + await simulator.full_node.coin_store.new_block( + uint32(2), uint64(10000), [reward_1, reward_2], [coin for coin, _ in new_coins], [] + ) + await simulator.full_node.hint_store.add_hints([(coin.name(), hint) for coin, hint in new_coins]) + + for coin, hint in new_coins: + solution = Program.to([[]]) + bundle = SpendBundle([CoinSpend(coin, puzzle, solution)], AugSchemeMPL.aggregate([])) + tx_resp = await simulator.send_transaction(wallet_protocol.SendTransaction(bundle)) + assert tx_resp is not None + + ack = wallet_protocol.TransactionAck.from_bytes(tx_resp.data) + assert ack.error is None + assert ack.status == int(MempoolInclusionStatus.SUCCESS) + + return simulator, queue, peer + + +async def make_coin(full_node: FullNode) -> Tuple[Coin, bytes32]: + ph = IDENTITY_PUZZLE_HASH + coin = Coin(bytes32(b"\0" * 32), ph, uint64(1000)) + hint = bytes32(b"\0" * 32) + + height = full_node.blockchain.get_peak_height() + assert height is not None + + reward_1 = Coin(std_hash(b"reward 1"), std_hash(b"reward puzzle hash"), uint64(3000)) + reward_2 = Coin(std_hash(b"reward 2"), std_hash(b"reward puzzle hash"), uint64(4000)) + await full_node.coin_store.new_block(uint32(height + 1), uint64(200000), [reward_1, reward_2], [coin], []) + await full_node.hint_store.add_hints([(coin.name(), hint)]) + + return coin, hint + + +async def subscribe_coin( + simulator: FullNodeSimulator, coin_id: bytes32, peer: WSChiaConnection, *, existing_coin_states: int = 1 +) -> None: + genesis = simulator.full_node.blockchain.constants.GENESIS_CHALLENGE + assert genesis is not None + + msg = await simulator.request_coin_state(wallet_protocol.RequestCoinState([coin_id], None, genesis, True), peer) + assert msg is not None + + response = wallet_protocol.RespondCoinState.from_bytes(msg.data) + assert response.coin_ids == [coin_id] + assert len(response.coin_states) == existing_coin_states + + +async def subscribe_puzzle( + simulator: FullNodeSimulator, puzzle_hash: bytes32, peer: WSChiaConnection, *, existing_coin_states: int = 1 +) -> None: + genesis = simulator.full_node.blockchain.constants.GENESIS_CHALLENGE + assert genesis is not None + + msg = await simulator.request_puzzle_state( + wallet_protocol.RequestPuzzleState([puzzle_hash], None, genesis, ALL_FILTER, True), peer + ) + assert msg is not None + + response = wallet_protocol.RespondPuzzleState.from_bytes(msg.data) + assert response.puzzle_hashes == [puzzle_hash] + assert len(response.coin_states) == existing_coin_states + + +async def spend_coin(simulator: FullNodeSimulator, coin: Coin, solution: Optional[Program] = None) -> bytes32: + bundle = SpendBundle( + [CoinSpend(coin, IDENTITY_PUZZLE, Program.to([]) if solution is None else solution)], AugSchemeMPL.aggregate([]) + ) + tx_resp = await simulator.send_transaction(wallet_protocol.SendTransaction(bundle)) + assert tx_resp is not None + + ack = wallet_protocol.TransactionAck.from_bytes(tx_resp.data) + assert ack.error is None + assert ack.status == int(MempoolInclusionStatus.SUCCESS) + + transaction_id = bundle.name() + assert ack.txid == transaction_id + + return transaction_id + + +@pytest.mark.anyio +async def test_spent_coin_id_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it + coin, _ = await make_coin(simulator.full_node) + await subscribe_coin(simulator, coin.name(), peer) + transaction_id = await spend_coin(simulator, coin) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_coin(simulator, coin.name(), peer) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_spent_puzzle_hash_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it + coin, _ = await make_coin(simulator.full_node) + await subscribe_puzzle(simulator, coin.puzzle_hash, peer) + transaction_id = await spend_coin(simulator, coin) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_puzzle(simulator, coin.puzzle_hash, peer) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_spent_hint_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it + coin, hint = await make_coin(simulator.full_node) + await subscribe_puzzle(simulator, hint, peer) + transaction_id = await spend_coin(simulator, coin) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_puzzle(simulator, hint, peer) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_created_coin_id_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it to create a child coin + coin, _ = await make_coin(simulator.full_node) + child_coin = Coin(coin.name(), std_hash(b"new puzzle hash"), coin.amount) + await subscribe_coin(simulator, child_coin.name(), peer, existing_coin_states=0) + transaction_id = await spend_coin( + simulator, coin, solution=Program.to([[51, child_coin.puzzle_hash, child_coin.amount]]) + ) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_coin(simulator, child_coin.name(), peer, existing_coin_states=0) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_created_puzzle_hash_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it to create a child coin + coin, _ = await make_coin(simulator.full_node) + child_coin = Coin(coin.name(), std_hash(b"new puzzle hash"), coin.amount) + await subscribe_puzzle(simulator, child_coin.puzzle_hash, peer, existing_coin_states=0) + transaction_id = await spend_coin( + simulator, coin, solution=Program.to([[51, child_coin.puzzle_hash, child_coin.amount]]) + ) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_puzzle(simulator, child_coin.puzzle_hash, peer, existing_coin_states=0) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_created_hint_mempool_update(mpu_setup: Mpu) -> None: + simulator, queue, peer = mpu_setup + + # Make a coin and spend it to create a child coin + coin, _ = await make_coin(simulator.full_node) + child_coin = Coin(coin.name(), std_hash(b"new puzzle hash"), coin.amount) + hint = std_hash(b"new hint") + await subscribe_puzzle(simulator, hint, peer, existing_coin_states=0) + transaction_id = await spend_coin( + simulator, coin, solution=Program.to([[51, child_coin.puzzle_hash, child_coin.amount, [hint]]]) + ) + + # We should have gotten a mempool update for this transaction + await assert_mempool_added(queue, {transaction_id}) + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # The mempool item should now be in the initial update + await subscribe_puzzle(simulator, hint, peer, existing_coin_states=0) + await assert_mempool_added(queue, {transaction_id}) + + # Farm a block and listen for the mempool removal event + await simulator.farm_blocks_to_puzzlehash(1) + await assert_mempool_removed( + queue, {wallet_protocol.RemovedMempoolItem(transaction_id, uint8(MempoolRemoveReason.BLOCK_INCLUSION.value))} + ) + + +@pytest.mark.anyio +async def test_missing_capability_coin_id(mpu_setup_no_capability: Mpu) -> None: + simulator, queue, peer = mpu_setup_no_capability + + # Make a coin and spend it + coin, _ = await make_coin(simulator.full_node) + await subscribe_coin(simulator, coin.name(), peer) + transaction_id = await spend_coin(simulator, coin) + + # There is no mempool update for this transaction since the peer doesn't have the capability + assert queue.empty() + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # There is no initial mempool update since the peer doesn't have the capability + await subscribe_coin(simulator, coin.name(), peer) + assert queue.empty() + + # Farm a block and there's still no mempool update + await simulator.farm_blocks_to_puzzlehash(1) + assert queue.empty() + + +@pytest.mark.anyio +async def test_missing_capability_puzzle_hash(mpu_setup_no_capability: Mpu) -> None: + simulator, queue, peer = mpu_setup_no_capability + + # Make a coin and spend it + coin, _ = await make_coin(simulator.full_node) + await subscribe_puzzle(simulator, coin.puzzle_hash, peer) + transaction_id = await spend_coin(simulator, coin) + + # There is no mempool update for this transaction since the peer doesn't have the capability + assert queue.empty() + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # There is no initial mempool update since the peer doesn't have the capability + await subscribe_puzzle(simulator, coin.puzzle_hash, peer) + assert queue.empty() + + # Farm a block and there's still no mempool update + await simulator.farm_blocks_to_puzzlehash(1) + assert queue.empty() + + +@pytest.mark.anyio +async def test_missing_capability_hint(mpu_setup_no_capability: Mpu) -> None: + simulator, queue, peer = mpu_setup_no_capability + + # Make a coin and spend it + coin, hint = await make_coin(simulator.full_node) + await subscribe_puzzle(simulator, hint, peer) + transaction_id = await spend_coin(simulator, coin) + + # There is no mempool update for this transaction since the peer doesn't have the capability + assert queue.empty() + + # Check the mempool to make sure the transaction is there + await simulator.wait_bundle_ids_in_mempool([transaction_id]) + assert simulator.full_node.mempool_manager.mempool.get_item_by_id(transaction_id) is not None + + # There is no initial mempool update since the peer doesn't have the capability + await subscribe_puzzle(simulator, hint, peer) + assert queue.empty() + + # Farm a block and there's still no mempool update + await simulator.farm_blocks_to_puzzlehash(1) + assert queue.empty() + + +@pytest.mark.anyio +async def test_cost_info(one_node: OneNode, self_hostname: str) -> None: + simulator, _, _ = await connect_to_simulator(one_node, self_hostname) + + msg = await simulator.request_cost_info(wallet_protocol.RequestCostInfo()) + assert msg is not None + + response = wallet_protocol.RespondCostInfo.from_bytes(msg.data) + mempool_manager = simulator.full_node.mempool_manager + assert response == wallet_protocol.RespondCostInfo( + max_transaction_cost=mempool_manager.max_tx_clvm_cost, + max_block_cost=mempool_manager.max_block_clvm_cost, + max_mempool_cost=uint64(mempool_manager.mempool_max_total_cost), + mempool_cost=uint64(mempool_manager.mempool._total_cost), + mempool_fee=uint64(mempool_manager.mempool._total_fee), + bump_fee_per_cost=uint8(mempool_manager.nonzero_fee_minimum_fpc), + ) diff --git a/chia/clvm/spend_sim.py b/chia/clvm/spend_sim.py index ea6654ac76b3..eba6e99d5804 100644 --- a/chia/clvm/spend_sim.py +++ b/chia/clvm/spend_sim.py @@ -343,10 +343,10 @@ async def push_tx(self, spend_bundle: SpendBundle) -> Tuple[MempoolInclusionStat except ValidationError as e: return MempoolInclusionStatus.FAILED, e.code assert self.service.mempool_manager.peak is not None - cost, status, error = await self.service.mempool_manager.add_spend_bundle( + info = await self.service.mempool_manager.add_spend_bundle( spend_bundle, cost_result, spend_bundle_id, self.service.mempool_manager.peak.height ) - return status, error + return info.status, info.error async def get_coin_record_by_name(self, name: bytes32) -> Optional[CoinRecord]: return await self.service.coin_store.get_coin_record(name) diff --git a/chia/full_node/full_node.py b/chia/full_node/full_node.py index c5e178f069ed..2e9a09a39cea 100644 --- a/chia/full_node/full_node.py +++ b/chia/full_node/full_node.py @@ -50,9 +50,10 @@ from chia.full_node.full_node_store import FullNodeStore, FullNodeStorePeakResult, UnfinishedBlockEntry from chia.full_node.hint_management import get_hints_and_subscription_coin_ids from chia.full_node.hint_store import HintStore -from chia.full_node.mempool_manager import MempoolManager +from chia.full_node.mempool import MempoolRemoveInfo +from chia.full_node.mempool_manager import MempoolManager, NewPeakItem from chia.full_node.signage_point import SignagePoint -from chia.full_node.subscriptions import PeerSubscriptions +from chia.full_node.subscriptions import PeerSubscriptions, peers_for_spend_bundle from chia.full_node.sync_store import Peak, SyncStore from chia.full_node.tx_processing_queue import TransactionQueue from chia.full_node.weight_proof import WeightProofHandler @@ -60,7 +61,8 @@ from chia.protocols.farmer_protocol import SignagePointSourceData, SPSubSlotSourceData, SPVDFSourceData from chia.protocols.full_node_protocol import RequestBlocks, RespondBlock, RespondBlocks, RespondSignagePoint from chia.protocols.protocol_message_types import ProtocolMessageTypes -from chia.protocols.wallet_protocol import CoinState, CoinStateUpdate +from chia.protocols.shared_protocol import Capability +from chia.protocols.wallet_protocol import CoinState, CoinStateUpdate, RemovedMempoolItem from chia.rpc.rpc_server import StateChangedProtocol from chia.server.node_discovery import FullNodePeers from chia.server.outbound_message import Message, NodeType, make_msg @@ -77,6 +79,7 @@ from chia.types.generator_types import BlockGenerator from chia.types.header_block import HeaderBlock from chia.types.mempool_inclusion_status import MempoolInclusionStatus +from chia.types.mempool_item import MempoolItem from chia.types.peer_info import PeerInfo from chia.types.spend_bundle import SpendBundle from chia.types.transaction_queue_entry import TransactionQueueEntry @@ -102,7 +105,8 @@ # This is the result of calling peak_post_processing, which is then fed into peak_post_processing_2 @dataclasses.dataclass class PeakPostProcessingResult: - mempool_peak_result: List[Tuple[SpendBundle, NPCResult, bytes32]] # The result of calling MempoolManager.new_peak + mempool_peak_result: List[NewPeakItem] # The new items from calling MempoolManager.new_peak + mempool_removals: List[MempoolRemoveInfo] # The removed mempool items from calling MempoolManager.new_peak fns_peak_result: FullNodeStorePeakResult # The result of calling FullNodeStore.new_peak hints: List[Tuple[bytes32, bytes]] # The hints added to the DB lookup_coin_ids: List[bytes32] # The coin IDs that we need to look up to notify wallets of changes @@ -319,7 +323,7 @@ async def manage(self) -> AsyncIterator[None]: ) async with self.blockchain.priority_mutex.acquire(priority=BlockchainMutexPriority.high): pending_tx = await self.mempool_manager.new_peak(self.blockchain.get_tx_peak(), None) - assert len(pending_tx) == 0 # no pending transactions when starting up + assert len(pending_tx.items) == 0 # no pending transactions when starting up full_peak: Optional[FullBlock] = await self.blockchain.get_full_peak() assert full_peak is not None @@ -1544,7 +1548,6 @@ async def peak_post_processing( # Update the mempool (returns successful pending transactions added to the mempool) spent_coins: List[bytes32] = [coin_id for coin_id, _ in state_change_summary.removals] - mempool_new_peak_result: List[Tuple[SpendBundle, NPCResult, bytes32]] mempool_new_peak_result = await self.mempool_manager.new_peak(self.blockchain.get_tx_peak(), spent_coins) # Check if we detected a spent transaction, to load up our generator cache @@ -1554,7 +1557,13 @@ async def peak_post_processing( self.log.info(f"Saving previous generator for height {block.height}") self.full_node_store.previous_generator = generator_arg - return PeakPostProcessingResult(mempool_new_peak_result, fns_peak_result, hints_to_add, lookup_coin_ids) + return PeakPostProcessingResult( + mempool_new_peak_result.items, + mempool_new_peak_result.removals, + fns_peak_result, + hints_to_add, + lookup_coin_ids, + ) async def peak_post_processing_2( self, @@ -1568,20 +1577,11 @@ async def peak_post_processing_2( with peers """ record = state_change_summary.peak - for bundle, result, spend_name in ppp_result.mempool_peak_result: - self.log.debug(f"Added transaction to mempool: {spend_name}") - mempool_item = self.mempool_manager.get_mempool_item(spend_name) + for new_peak_item in ppp_result.mempool_peak_result: + self.log.debug(f"Added transaction to mempool: {new_peak_item.transaction_id}") + mempool_item = self.mempool_manager.get_mempool_item(new_peak_item.transaction_id) assert mempool_item is not None - fees = mempool_item.fee - assert fees >= 0 - assert mempool_item.cost is not None - new_tx = full_node_protocol.NewTransaction( - spend_name, - mempool_item.cost, - fees, - ) - msg = make_msg(ProtocolMessageTypes.new_transaction, new_tx) - await self.server.send_to_all([msg], NodeType.FULL_NODE) + await self.broadcast_added_tx(mempool_item) # If there were pending end of slots that happen after this peak, broadcast them if they are added if ppp_result.fns_peak_result.added_eos is not None: @@ -1602,6 +1602,7 @@ async def peak_post_processing_2( if self.sync_store.get_sync_mode() is False: await self.send_peak_to_timelords(block) + await self.broadcast_removed_tx(ppp_result.mempool_removals) # Tell full nodes about the new peak msg = make_msg( @@ -2356,9 +2357,11 @@ async def add_transaction( return MempoolInclusionStatus.SUCCESS, None if self.mempool_manager.peak is None: return MempoolInclusionStatus.FAILED, Err.MEMPOOL_NOT_INITIALIZED - cost, status, error = await self.mempool_manager.add_spend_bundle( + info = await self.mempool_manager.add_spend_bundle( transaction, cost_result, spend_name, self.mempool_manager.peak.height ) + status = info.status + error = info.error if status == MempoolInclusionStatus.SUCCESS: self.log.debug( f"Added transaction to mempool: {spend_name} mempool size: " @@ -2370,19 +2373,8 @@ async def add_transaction( # vector. mempool_item = self.mempool_manager.get_mempool_item(spend_name) assert mempool_item is not None - fees = mempool_item.fee - assert fees >= 0 - assert cost is not None - new_tx = full_node_protocol.NewTransaction( - spend_name, - cost, - fees, - ) - msg = make_msg(ProtocolMessageTypes.new_transaction, new_tx) - if peer is None: - await self.server.send_to_all([msg], NodeType.FULL_NODE) - else: - await self.server.send_to_all([msg], NodeType.FULL_NODE, peer.peer_node_id) + await self.broadcast_removed_tx(info.removals) + await self.broadcast_added_tx(mempool_item, current_peer=peer) if self.simulator_transaction_callback is not None: # callback await self.simulator_transaction_callback(spend_name) # pylint: disable=E1102 @@ -2392,6 +2384,124 @@ async def add_transaction( self.log.debug(f"Wasn't able to add transaction with id {spend_name}, status {status} error: {error}") return status, error + async def broadcast_added_tx( + self, mempool_item: MempoolItem, current_peer: Optional[WSChiaConnection] = None + ) -> None: + assert mempool_item.fee >= 0 + assert mempool_item.cost is not None + + new_tx = full_node_protocol.NewTransaction( + mempool_item.name, + mempool_item.cost, + mempool_item.fee, + ) + msg = make_msg(ProtocolMessageTypes.new_transaction, new_tx) + if current_peer is None: + await self.server.send_to_all([msg], NodeType.FULL_NODE) + else: + await self.server.send_to_all([msg], NodeType.FULL_NODE, current_peer.peer_node_id) + + conds = mempool_item.npc_result.conds + assert conds is not None + + all_peers = { + peer_id + for peer_id, peer in self.server.all_connections.items() + if peer.has_capability(Capability.MEMPOOL_UPDATES) + } + + if len(all_peers) == 0: + return + + start_time = time.monotonic() + + hints_for_removals = await self.hint_store.get_hints([bytes32(spend.coin_id) for spend in conds.spends]) + peer_ids = all_peers.intersection(peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals))) + + for peer_id in peer_ids: + peer = self.server.all_connections.get(peer_id) + + if peer is None: + continue + + msg = make_msg( + ProtocolMessageTypes.mempool_items_added, wallet_protocol.MempoolItemsAdded([mempool_item.name]) + ) + await peer.send_message(msg) + + total_time = time.monotonic() - start_time + + self.log.log( + logging.DEBUG if total_time < 0.5 else logging.WARNING, + f"Broadcasting added transaction {mempool_item.name} to {len(peer_ids)} peers took {total_time:.4f}s", + ) + + async def broadcast_removed_tx(self, mempool_removals: List[MempoolRemoveInfo]) -> None: + total_removals = sum([len(r.items) for r in mempool_removals]) + if total_removals == 0: + return + + start_time = time.monotonic() + + self.log.debug(f"Broadcasting {total_removals} removed transactions to peers") + + all_peers = { + peer_id + for peer_id, peer in self.server.all_connections.items() + if peer.has_capability(Capability.MEMPOOL_UPDATES) + } + + if len(all_peers) == 0: + return + + removals_to_send: Dict[bytes32, List[RemovedMempoolItem]] = dict() + + for removal_info in mempool_removals: + for internal_mempool_item in removal_info.items: + conds = internal_mempool_item.npc_result.conds + assert conds is not None + + hints_for_removals = await self.hint_store.get_hints([bytes32(spend.coin_id) for spend in conds.spends]) + peer_ids = all_peers.intersection( + peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals)) + ) + + if len(peer_ids) == 0: + continue + + transaction_id = internal_mempool_item.spend_bundle.name() + + self.log.debug(f"Broadcasting removed transaction {transaction_id} to " f"wallet peers {peer_ids}") + + for peer_id in peer_ids: + peer = self.server.all_connections.get(peer_id) + + if peer is None: + continue + + removal = wallet_protocol.RemovedMempoolItem(transaction_id, uint8(removal_info.reason.value)) + removals_to_send.setdefault(peer.peer_node_id, []).append(removal) + + for peer_id, removals in removals_to_send.items(): + peer = self.server.all_connections.get(peer_id) + + if peer is None: + continue + + msg = make_msg( + ProtocolMessageTypes.mempool_items_removed, + wallet_protocol.MempoolItemsRemoved(removals), + ) + await peer.send_message(msg) + + total_time = time.monotonic() - start_time + + self.log.log( + logging.DEBUG if total_time < 0.5 else logging.WARNING, + f"Broadcasting {total_removals} removed transactions " + f"to {len(removals_to_send)} peers took {total_time:.4f}s", + ) + async def _needs_compact_proof( self, vdf_info: VDFInfo, header_block: HeaderBlock, field_vdf: CompressibleVDFField ) -> bool: diff --git a/chia/full_node/full_node_api.py b/chia/full_node/full_node_api.py index 20533b4964ea..eaf308610cbe 100644 --- a/chia/full_node/full_node_api.py +++ b/chia/full_node/full_node_api.py @@ -31,6 +31,7 @@ from chia.protocols import farmer_protocol, full_node_protocol, introducer_protocol, timelord_protocol, wallet_protocol from chia.protocols.full_node_protocol import RejectBlock, RejectBlocks from chia.protocols.protocol_message_types import ProtocolMessageTypes +from chia.protocols.shared_protocol import Capability from chia.protocols.wallet_protocol import ( CoinState, PuzzleSolutionResponse, @@ -61,11 +62,13 @@ from chia.types.transaction_queue_entry import TransactionQueueEntry from chia.types.unfinished_block import UnfinishedBlock from chia.util.api_decorators import api_request +from chia.util.db_wrapper import SQLITE_MAX_VARIABLE_NUMBER from chia.util.full_block_utils import header_block_from_block from chia.util.generator_tools import get_block_header, tx_removals_and_additions from chia.util.hash import std_hash from chia.util.ints import uint8, uint32, uint64, uint128 from chia.util.limited_semaphore import LimitedSemaphoreFullError +from chia.util.misc import to_batches if TYPE_CHECKING: from chia.full_node.full_node import FullNode @@ -1827,6 +1830,7 @@ def check_subscription_limit() -> Optional[Message]: if is_done and request.subscribe_when_finished: subs.add_puzzle_subscriptions(peer.peer_node_id, puzzle_hashes, max_subscriptions) + await self.mempool_updates_for_puzzle_hashes(peer, set(puzzle_hashes), request.filters.include_hinted) response = wallet_protocol.RespondPuzzleState(puzzle_hashes, height, header_hash, is_done, coin_states) msg = make_msg(ProtocolMessageTypes.respond_puzzle_state, response) @@ -1884,11 +1888,84 @@ def check_subscription_limit() -> Optional[Message]: if request.subscribe: subs.add_coin_subscriptions(peer.peer_node_id, coin_ids, max_subscriptions) + await self.mempool_updates_for_coin_ids(peer, set(coin_ids)) response = wallet_protocol.RespondCoinState(coin_ids, coin_states) msg = make_msg(ProtocolMessageTypes.respond_coin_state, response) return msg + @api_request(reply_types=[ProtocolMessageTypes.respond_cost_info]) + async def request_cost_info(self, _request: wallet_protocol.RequestCostInfo) -> Optional[Message]: + mempool_manager = self.full_node.mempool_manager + response = wallet_protocol.RespondCostInfo( + max_transaction_cost=mempool_manager.max_tx_clvm_cost, + max_block_cost=mempool_manager.max_block_clvm_cost, + max_mempool_cost=uint64(mempool_manager.mempool_max_total_cost), + mempool_cost=uint64(mempool_manager.mempool._total_cost), + mempool_fee=uint64(mempool_manager.mempool._total_fee), + bump_fee_per_cost=uint8(mempool_manager.nonzero_fee_minimum_fpc), + ) + msg = make_msg(ProtocolMessageTypes.respond_cost_info, response) + return msg + + async def mempool_updates_for_puzzle_hashes( + self, peer: WSChiaConnection, puzzle_hashes: Set[bytes32], include_hints: bool + ) -> None: + if Capability.MEMPOOL_UPDATES not in peer.peer_capabilities: + return + + start_time = time.monotonic() + + async with self.full_node.db_wrapper.reader() as conn: + transaction_ids = set( + self.full_node.mempool_manager.mempool.items_with_puzzle_hashes(puzzle_hashes, include_hints) + ) + + hinted_coin_ids: Set[bytes32] = set() + + for batch in to_batches(puzzle_hashes, SQLITE_MAX_VARIABLE_NUMBER): + hints_db: Tuple[bytes, ...] = tuple(batch.entries) + cursor = await conn.execute( + f"SELECT coin_id from hints INDEXED BY hint_index " + f'WHERE hint IN ({"?," * (len(batch.entries) - 1)}?)', + hints_db, + ) + for row in await cursor.fetchall(): + hinted_coin_ids.add(bytes32(row[0])) + await cursor.close() + + transaction_ids |= set(self.full_node.mempool_manager.mempool.items_with_coin_ids(hinted_coin_ids)) + + if len(transaction_ids) > 0: + message = wallet_protocol.MempoolItemsAdded(list(transaction_ids)) + await peer.send_message(make_msg(ProtocolMessageTypes.mempool_items_added, message)) + + total_time = time.monotonic() - start_time + + self.log.log( + logging.DEBUG if total_time < 2.0 else logging.WARNING, + f"Sending initial mempool items to peer {peer.peer_node_id} took {total_time:.4f}s", + ) + + async def mempool_updates_for_coin_ids(self, peer: WSChiaConnection, coin_ids: Set[bytes32]) -> None: + if Capability.MEMPOOL_UPDATES not in peer.peer_capabilities: + return + + start_time = time.monotonic() + + transaction_ids = self.full_node.mempool_manager.mempool.items_with_coin_ids(coin_ids) + + if len(transaction_ids) > 0: + message = wallet_protocol.MempoolItemsAdded(list(transaction_ids)) + await peer.send_message(make_msg(ProtocolMessageTypes.mempool_items_added, message)) + + total_time = time.monotonic() - start_time + + self.log.log( + logging.DEBUG if total_time < 2.0 else logging.WARNING, + f"Sending initial mempool items to peer {peer.peer_node_id} took {total_time:.4f}s", + ) + def max_subscriptions(self, peer: WSChiaConnection) -> int: if self.is_trusted(peer): return cast(int, self.full_node.config.get("trusted_max_subscribe_items", 2000000)) diff --git a/chia/full_node/hint_store.py b/chia/full_node/hint_store.py index 5adc45246058..01385e8a3c2c 100644 --- a/chia/full_node/hint_store.py +++ b/chia/full_node/hint_store.py @@ -56,6 +56,22 @@ async def get_coin_ids_multi(self, hints: Set[bytes], *, max_items: int = 50000) return coin_ids + async def get_hints(self, coin_ids: List[bytes32]) -> List[bytes32]: + hints: List[bytes32] = [] + + async with self.db_wrapper.reader_no_transaction() as conn: + for batch in to_batches(coin_ids, SQLITE_MAX_VARIABLE_NUMBER): + coin_ids_db: Tuple[bytes32, ...] = tuple(batch.entries) + cursor = await conn.execute( + f'SELECT hint from hints WHERE coin_id IN ({"?," * (len(batch.entries) - 1)}?)', + coin_ids_db, + ) + rows = await cursor.fetchall() + hints.extend([bytes32(row[0]) for row in rows if len(row[0]) == 32]) + await cursor.close() + + return hints + async def add_hints(self, coin_hint_list: List[Tuple[bytes32, bytes]]) -> None: if len(coin_hint_list) == 0: return None diff --git a/chia/full_node/mempool.py b/chia/full_node/mempool.py index 13ea7ef415fb..df10af5c06cc 100644 --- a/chia/full_node/mempool.py +++ b/chia/full_node/mempool.py @@ -2,9 +2,10 @@ import logging import sqlite3 +from dataclasses import dataclass from datetime import datetime from enum import Enum -from typing import Awaitable, Callable, Dict, Iterator, List, Optional, Tuple +from typing import Awaitable, Callable, Dict, Iterator, List, Optional, Set, Tuple from chia_rs import AugSchemeMPL, Coin, G2Element @@ -46,6 +47,18 @@ MEMPOOL_ITEM_FEE_LIMIT = 2**50 +@dataclass +class MempoolRemoveInfo: + items: List[InternalMempoolItem] + reason: MempoolRemoveReason + + +@dataclass +class MempoolAddInfo: + removals: List[MempoolRemoveInfo] + error: Optional[Err] + + class MempoolRemoveReason(Enum): CONFLICT = 1 BLOCK_INCLUSION = 2 @@ -155,6 +168,63 @@ def all_item_ids(self) -> List[bytes32]: cursor = self._db_conn.execute("SELECT name FROM tx") return [bytes32(row[0]) for row in cursor] + def items_with_coin_ids(self, coin_ids: Set[bytes32]) -> List[bytes32]: + """ + Returns a list of transaction ids that spend or create any coins with the provided coin ids. + This iterates over the internal items instead of using a query. + """ + + transaction_ids: List[bytes32] = [] + + for transaction_id, item in self._items.items(): + conds = item.npc_result.conds + assert conds is not None + + for spend in conds.spends: + if spend.coin_id in coin_ids: + transaction_ids.append(transaction_id) + break + + for puzzle_hash, amount, _memo in spend.create_coin: + if Coin(spend.coin_id, puzzle_hash, uint64(amount)).name() in coin_ids: + transaction_ids.append(transaction_id) + break + else: + continue + + break + + return transaction_ids + + def items_with_puzzle_hashes(self, puzzle_hashes: Set[bytes32], include_hints: bool) -> List[bytes32]: + """ + Returns a list of transaction ids that spend or create any coins + with the provided puzzle hashes (or hints, if enabled). + This iterates over the internal items instead of using a query. + """ + + transaction_ids: List[bytes32] = [] + + for transaction_id, item in self._items.items(): + conds = item.npc_result.conds + assert conds is not None + + for spend in conds.spends: + if spend.puzzle_hash in puzzle_hashes: + transaction_ids.append(transaction_id) + break + + for puzzle_hash, _amount, memo in spend.create_coin: + if puzzle_hash in puzzle_hashes or (include_hints and memo is not None and memo in puzzle_hashes): + transaction_ids.append(transaction_id) + break + else: + continue + + break + + return transaction_ids + # TODO: move "process_mempool_items()" into this class in order to do this a # bit more efficiently def items_by_feerate(self) -> Iterator[MempoolItem]: @@ -224,7 +294,7 @@ def get_min_fee_rate(self, cost: int) -> Optional[float]: ) return None - def new_tx_block(self, block_height: uint32, timestamp: uint64) -> None: + def new_tx_block(self, block_height: uint32, timestamp: uint64) -> MempoolRemoveInfo: """ Remove all items that became invalid because of this new height and timestamp. (we don't know about which coins were spent in this new block @@ -237,16 +307,17 @@ def new_tx_block(self, block_height: uint32, timestamp: uint64) -> None: ) to_remove = [bytes32(row[0]) for row in cursor] - self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED) self._block_height = block_height self._timestamp = timestamp - def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> None: + return self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED) + + def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> MempoolRemoveInfo: """ Removes an item from the mempool. """ if items == []: - return + return MempoolRemoveInfo([], reason) removed_items: List[MempoolItemInfo] = [] if reason != MempoolRemoveReason.BLOCK_INCLUSION: @@ -262,8 +333,7 @@ def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> item = MempoolItemInfo(int(row[1]), int(row[2]), internal_item.height_added_to_mempool) removed_items.append(item) - for name in items: - self._items.pop(name) + removed_internal_items = [self._items.pop(name) for name in items] for batch in to_batches(items, SQLITE_MAX_VARIABLE_NUMBER): args = ",".join(["?"] * len(batch.entries)) @@ -288,7 +358,9 @@ def remove_from_pool(self, items: List[bytes32], reason: MempoolRemoveReason) -> for iteminfo in removed_items: self.fee_estimator.remove_mempool_item(info, iteminfo) - def add_to_pool(self, item: MempoolItem) -> Optional[Err]: + return MempoolRemoveInfo(removed_internal_items, reason) + + def add_to_pool(self, item: MempoolItem) -> MempoolAddInfo: """ Adds an item to the mempool by kicking out transactions (if it doesn't fit), in order of increasing fee per cost """ @@ -297,6 +369,8 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]: assert item.npc_result.conds is not None assert item.cost <= self.mempool_info.max_block_clvm_cost + removals: List[MempoolRemoveInfo] = [] + with self._db_conn: # we have certain limits on transactions that will expire soon # (in the next 15 minutes) @@ -331,10 +405,10 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]: # we can't evict any more transactions, abort (and don't # evict what we put aside in "to_remove" list) if fee_per_cost > item.fee_per_cost: - return Err.INVALID_FEE_LOW_FEE + return MempoolAddInfo([], Err.INVALID_FEE_LOW_FEE) to_remove.append(name) - self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED) + removals.append(self.remove_from_pool(to_remove, MempoolRemoveReason.EXPIRED)) # if we don't find any entries, it's OK to add this entry @@ -353,7 +427,7 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]: ) to_remove = [bytes32(row[0]) for row in cursor] - self.remove_from_pool(to_remove, MempoolRemoveReason.POOL_FULL) + removals.append(self.remove_from_pool(to_remove, MempoolRemoveReason.POOL_FULL)) # TODO: In the future, for the "fee_per_cost" field, opt for # "GENERATED ALWAYS AS (CAST(fee AS REAL) / cost) VIRTUAL" @@ -384,7 +458,7 @@ def add_to_pool(self, item: MempoolItem) -> Optional[Err]: info = FeeMempoolInfo(self.mempool_info, self.total_mempool_cost(), self.total_mempool_fees(), datetime.now()) self.fee_estimator.add_mempool_item(info, MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool)) - return None + return MempoolAddInfo(removals, None) def at_full_capacity(self, cost: int) -> bool: """ diff --git a/chia/full_node/mempool_manager.py b/chia/full_node/mempool_manager.py index 79569bd70088..9ce593809397 100644 --- a/chia/full_node/mempool_manager.py +++ b/chia/full_node/mempool_manager.py @@ -19,7 +19,7 @@ from chia.full_node.bundle_tools import simple_solution_generator from chia.full_node.fee_estimation import FeeBlockInfo, MempoolInfo, MempoolItemInfo from chia.full_node.fee_estimator_interface import FeeEstimatorInterface -from chia.full_node.mempool import MEMPOOL_ITEM_FEE_LIMIT, Mempool, MempoolRemoveReason +from chia.full_node.mempool import MEMPOOL_ITEM_FEE_LIMIT, Mempool, MempoolRemoveInfo, MempoolRemoveReason from chia.full_node.mempool_check_conditions import get_name_puzzle_conditions, mempool_check_time_locks from chia.full_node.pending_tx_cache import ConflictTxCache, PendingTxCache from chia.types.blockchain_format.coin import Coin @@ -146,6 +146,27 @@ def compute_assert_height( return ret +@dataclass +class SpendBundleAddInfo: + cost: Optional[uint64] + status: MempoolInclusionStatus + removals: List[MempoolRemoveInfo] + error: Optional[Err] + + +@dataclass +class NewPeakInfo: + items: List[NewPeakItem] + removals: List[MempoolRemoveInfo] + + +@dataclass +class NewPeakItem: + transaction_id: bytes32 + spend_bundle: SpendBundle + npc_result: NPCResult + + class MempoolManager: pool: Executor constants: ConsensusConstants @@ -335,7 +356,7 @@ async def add_spend_bundle( spend_name: bytes32, first_added_height: uint32, get_coin_records: Optional[Callable[[Collection[bytes32]], Awaitable[List[CoinRecord]]]] = None, - ) -> Tuple[Optional[uint64], MempoolInclusionStatus, Optional[Err]]: + ) -> SpendBundleAddInfo: """ Validates and adds to mempool a new_spend with the given NPCResult, and spend_name, and the current mempool. The mempool should be locked during this call (blockchain lock). If there are mempool conflicts, the conflicting @@ -350,13 +371,14 @@ async def add_spend_bundle( Returns: Optional[uint64]: cost of the entire transaction, None iff status is FAILED MempoolInclusionStatus: SUCCESS (should add to pool), FAILED (cannot add), and PENDING (can add later) + List[MempoolRemoveInfo]: conflicting mempool items which were removed, if no Err Optional[Err]: Err is set iff status is FAILED """ # Skip if already added existing_item = self.mempool.get_item_by_id(spend_name) if existing_item is not None: - return existing_item.cost, MempoolInclusionStatus.SUCCESS, None + return SpendBundleAddInfo(existing_item.cost, MempoolInclusionStatus.SUCCESS, [], None) if get_coin_records is None: get_coin_records = self.get_coin_records @@ -370,24 +392,24 @@ async def add_spend_bundle( if err is None: # No error, immediately add to mempool, after removing conflicting TXs. assert item is not None - self.mempool.remove_from_pool(remove_items, MempoolRemoveReason.CONFLICT) - err = self.mempool.add_to_pool(item) - if err is not None: - return item.cost, MempoolInclusionStatus.FAILED, err - return item.cost, MempoolInclusionStatus.SUCCESS, None + conflict = self.mempool.remove_from_pool(remove_items, MempoolRemoveReason.CONFLICT) + info = self.mempool.add_to_pool(item) + if info.error is not None: + return SpendBundleAddInfo(item.cost, MempoolInclusionStatus.FAILED, [], info.error) + return SpendBundleAddInfo(item.cost, MempoolInclusionStatus.SUCCESS, info.removals + [conflict], None) elif err is Err.MEMPOOL_CONFLICT and item is not None: # The transaction has a conflict with another item in the # mempool, put it aside and re-try it later self._conflict_cache.add(item) - return item.cost, MempoolInclusionStatus.PENDING, err + return SpendBundleAddInfo(item.cost, MempoolInclusionStatus.PENDING, [], err) elif item is not None: # The transasction has a height assertion and is not yet valid. # remember it to try it again later self._pending_cache.add(item) - return item.cost, MempoolInclusionStatus.PENDING, err + return SpendBundleAddInfo(item.cost, MempoolInclusionStatus.PENDING, [], err) else: # Cannot add to the mempool or pending pool. - return None, MempoolInclusionStatus.FAILED, err + return SpendBundleAddInfo(None, MempoolInclusionStatus.FAILED, [], err) async def validate_spend_bundle( self, @@ -655,7 +677,7 @@ def get_mempool_item(self, bundle_hash: bytes32, include_pending: bool = False) async def new_peak( self, new_peak: Optional[BlockRecordProtocol], spent_coins: Optional[List[bytes32]] - ) -> List[Tuple[SpendBundle, NPCResult, bytes32]]: + ) -> NewPeakInfo: """ Called when a new peak is available, we try to recreate a mempool for the new tip. new_peak should always be the most recent *transaction* block of the chain. Since @@ -665,17 +687,18 @@ async def new_peak( block. """ if new_peak is None: - return [] + return NewPeakInfo([], []) # we're only interested in transaction blocks if new_peak.is_transaction_block is False: - return [] + return NewPeakInfo([], []) if self.peak == new_peak: - return [] + return NewPeakInfo([], []) assert new_peak.timestamp is not None self.fee_estimator.new_block_height(new_peak.height) included_items: List[MempoolItemInfo] = [] - self.mempool.new_tx_block(new_peak.height, new_peak.timestamp) + expired = self.mempool.new_tx_block(new_peak.height, new_peak.timestamp) + mempool_item_removals: List[MempoolRemoveInfo] = [expired] use_optimization: bool = self.peak is not None and new_peak.prev_transaction_block_hash == self.peak.header_hash self.peak = new_peak @@ -693,7 +716,9 @@ async def new_peak( included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool)) self.remove_seen(item.name) spendbundle_ids_to_remove.add(item.name) - self.mempool.remove_from_pool(list(spendbundle_ids_to_remove), MempoolRemoveReason.BLOCK_INCLUSION) + mempool_item_removals.append( + self.mempool.remove_from_pool(list(spendbundle_ids_to_remove), MempoolRemoveReason.BLOCK_INCLUSION) + ) else: log.warning( "updating the mempool using the slow-path. " @@ -727,7 +752,7 @@ async def local_get_coin_records(names: Collection[bytes32]) -> List[CoinRecord] return ret for item in old_pool.all_items(): - _, result, err = await self.add_spend_bundle( + info = await self.add_spend_bundle( item.spend_bundle, item.npc_result, item.spend_bundle_name, @@ -735,11 +760,11 @@ async def local_get_coin_records(names: Collection[bytes32]) -> List[CoinRecord] local_get_coin_records, ) # Only add to `seen` if inclusion worked, so it can be resubmitted in case of a reorg - if result == MempoolInclusionStatus.SUCCESS: + if info.status == MempoolInclusionStatus.SUCCESS: self.add_and_maybe_pop_seen(item.spend_bundle_name) # If the spend bundle was confirmed or conflicting (can no longer be in mempool), it won't be # successfully added to the new mempool. - if result == MempoolInclusionStatus.FAILED and err == Err.DOUBLE_SPEND: + if info.status == MempoolInclusionStatus.FAILED and info.error == Err.DOUBLE_SPEND: # Item was in mempool, but after the new block it's a double spend. # Item is most likely included in the block. included_items.append(MempoolItemInfo(item.cost, item.fee, item.height_added_to_mempool)) @@ -748,22 +773,23 @@ async def local_get_coin_records(names: Collection[bytes32]) -> List[CoinRecord] potential_txs.update(self._conflict_cache.drain()) txs_added = [] for item in potential_txs.values(): - cost, status, error = await self.add_spend_bundle( + info = await self.add_spend_bundle( item.spend_bundle, item.npc_result, item.spend_bundle_name, item.height_added_to_mempool, self.get_coin_records, ) - if status == MempoolInclusionStatus.SUCCESS: - txs_added.append((item.spend_bundle, item.npc_result, item.spend_bundle_name)) + if info.status == MempoolInclusionStatus.SUCCESS: + txs_added.append(NewPeakItem(item.spend_bundle_name, item.spend_bundle, item.npc_result)) + mempool_item_removals.extend(info.removals) log.info( f"Size of mempool: {self.mempool.size()} spends, " f"cost: {self.mempool.total_mempool_cost()} " f"minimum fee rate (in FPC) to get in for 5M cost tx: {self.mempool.get_min_fee_rate(5000000)}" ) self.mempool.fee_estimator.new_block(FeeBlockInfo(new_peak.height, included_items)) - return txs_added + return NewPeakInfo(txs_added, mempool_item_removals) def get_items_not_in_filter(self, mempool_filter: PyBIP158, limit: int = 100) -> List[SpendBundle]: items: List[SpendBundle] = [] diff --git a/chia/full_node/subscriptions.py b/chia/full_node/subscriptions.py index a4c7d9bc5254..46db31bb177c 100644 --- a/chia/full_node/subscriptions.py +++ b/chia/full_node/subscriptions.py @@ -4,7 +4,11 @@ from dataclasses import dataclass, field from typing import Dict, List, Set +from chia_rs import Coin + from chia.types.blockchain_format.sized_bytes import bytes32 +from chia.types.spend_bundle_conditions import SpendBundleConditions +from chia.util.ints import uint64 log = logging.getLogger(__name__) @@ -208,3 +212,37 @@ def coin_subscription_count(self) -> int: def puzzle_subscription_count(self) -> int: return self._puzzle_subscriptions.total_count() + + +def peers_for_spend_bundle( + peer_subscriptions: PeerSubscriptions, conds: SpendBundleConditions, hints_for_removals: Set[bytes32] +) -> Set[bytes32]: + """ + Returns a list of peer ids that are subscribed to any of the created or + spent coins, puzzle hashes, or hints in the spend bundle. To avoid repeated + lookups, `hints_for_removals` should be a set of all puzzle hashes that are being removed. + """ + + coin_ids: Set[bytes32] = set() + puzzle_hashes: Set[bytes32] = hints_for_removals.copy() + + for spend in conds.spends: + coin_ids.add(bytes32(spend.coin_id)) + puzzle_hashes.add(bytes32(spend.puzzle_hash)) + + for puzzle_hash, amount, memo in spend.create_coin: + coin_ids.add(Coin(spend.coin_id, puzzle_hash, uint64(amount)).name()) + puzzle_hashes.add(bytes32(puzzle_hash)) + + if memo is not None and len(memo) == 32: + puzzle_hashes.add(bytes32(memo)) + + peers: Set[bytes32] = set() + + for coin_id in coin_ids: + peers |= peer_subscriptions.peers_for_coin_id(coin_id) + + for puzzle_hash in puzzle_hashes: + peers |= peer_subscriptions.peers_for_puzzle_hash(puzzle_hash) + + return peers diff --git a/chia/protocols/protocol_message_types.py b/chia/protocols/protocol_message_types.py index 448d2521254f..56c5e729336e 100644 --- a/chia/protocols/protocol_message_types.py +++ b/chia/protocols/protocol_message_types.py @@ -130,4 +130,10 @@ class ProtocolMessageTypes(Enum): respond_coin_state = 102 reject_coin_state = 103 + # Wallet protocol mempool updates + mempool_items_added = 104 + mempool_items_removed = 105 + request_cost_info = 106 + respond_cost_info = 107 + error = 255 diff --git a/chia/protocols/protocol_state_machine.py b/chia/protocols/protocol_state_machine.py index d6c9fc72a0d4..800c8d0761f3 100644 --- a/chia/protocols/protocol_state_machine.py +++ b/chia/protocols/protocol_state_machine.py @@ -13,6 +13,8 @@ pmt.request_mempool_transactions, pmt.new_compact_vdf, pmt.coin_state_update, + pmt.mempool_items_added, + pmt.mempool_items_removed, ] """ @@ -51,6 +53,7 @@ pmt.request_remove_coin_subscriptions: [pmt.respond_remove_coin_subscriptions], pmt.request_puzzle_state: [pmt.respond_puzzle_state, pmt.reject_puzzle_state], pmt.request_coin_state: [pmt.respond_coin_state, pmt.reject_coin_state], + pmt.request_cost_info: [pmt.respond_cost_info], } diff --git a/chia/protocols/shared_protocol.py b/chia/protocols/shared_protocol.py index dac1eef6f2b7..a2ad7419a62c 100644 --- a/chia/protocols/shared_protocol.py +++ b/chia/protocols/shared_protocol.py @@ -14,7 +14,7 @@ NodeType.FARMER: "0.0.36", NodeType.TIMELORD: "0.0.36", NodeType.INTRODUCER: "0.0.36", - NodeType.WALLET: "0.0.37", + NodeType.WALLET: "0.0.38", NodeType.DATA_LAYER: "0.0.36", } @@ -40,6 +40,10 @@ class Capability(IntEnum): # capability removed but functionality is still supported NONE_RESPONSE = 4 + # Opts in to receiving mempool updates for subscribed transactions + # This is between a full node and receiving wallet + MEMPOOL_UPDATES = 5 + # These are the default capabilities used in all outgoing handshakes. # "1" means the capability is supported and enabled. @@ -48,9 +52,12 @@ class Capability(IntEnum): (uint16(Capability.BLOCK_HEADERS.value), "1"), (uint16(Capability.RATE_LIMITS_V2.value), "1"), ] +_mempool_updates = [ + (uint16(Capability.MEMPOOL_UPDATES.value), "1"), +] default_capabilities = { - NodeType.FULL_NODE: _capabilities, + NodeType.FULL_NODE: _capabilities + _mempool_updates, NodeType.HARVESTER: _capabilities, NodeType.FARMER: _capabilities, NodeType.TIMELORD: _capabilities, diff --git a/chia/protocols/wallet_protocol.py b/chia/protocols/wallet_protocol.py index dd70b0f2e6d8..22a44ca26f04 100644 --- a/chia/protocols/wallet_protocol.py +++ b/chia/protocols/wallet_protocol.py @@ -363,3 +363,39 @@ class RejectCoinState(Streamable): class RejectStateReason(IntEnum): REORG = 0 EXCEEDED_SUBSCRIPTION_LIMIT = 1 + + +@streamable +@dataclass(frozen=True) +class RemovedMempoolItem(Streamable): + transaction_id: bytes32 + reason: uint8 # MempoolRemoveReason + + +@streamable +@dataclass(frozen=True) +class MempoolItemsAdded(Streamable): + transaction_ids: List[bytes32] + + +@streamable +@dataclass(frozen=True) +class MempoolItemsRemoved(Streamable): + removed_items: List[RemovedMempoolItem] + + +@streamable +@dataclass(frozen=True) +class RequestCostInfo(Streamable): + pass + + +@streamable +@dataclass(frozen=True) +class RespondCostInfo(Streamable): + max_transaction_cost: uint64 + max_block_cost: uint64 + max_mempool_cost: uint64 + mempool_cost: uint64 + mempool_fee: uint64 + bump_fee_per_cost: uint8 diff --git a/chia/server/rate_limit_numbers.py b/chia/server/rate_limit_numbers.py index 0a9416827e3d..1bef1d0abff9 100644 --- a/chia/server/rate_limit_numbers.py +++ b/chia/server/rate_limit_numbers.py @@ -158,6 +158,10 @@ def compose_rate_limits(old_rate_limits: Dict[str, Any], new_rate_limits: Dict[s ProtocolMessageTypes.request_coin_state: RLSettings(1000, 100 * 1024 * 1024), ProtocolMessageTypes.respond_coin_state: RLSettings(1000, 100 * 1024 * 1024), ProtocolMessageTypes.reject_coin_state: RLSettings(200, 100), + ProtocolMessageTypes.mempool_items_added: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.mempool_items_removed: RLSettings(1000, 100 * 1024 * 1024), + ProtocolMessageTypes.request_cost_info: RLSettings(1000, 100), + ProtocolMessageTypes.respond_cost_info: RLSettings(1000, 1024), ProtocolMessageTypes.request_ses_hashes: RLSettings(2000, 1 * 1024 * 1024), ProtocolMessageTypes.respond_ses_hashes: RLSettings(2000, 1 * 1024 * 1024), ProtocolMessageTypes.request_children: RLSettings(2000, 1024 * 1024),