Skip to content

Commit

Permalink
Get first test passing
Browse files Browse the repository at this point in the history
  • Loading branch information
Rigidity committed Jun 2, 2024
1 parent 70b5215 commit 17421aa
Show file tree
Hide file tree
Showing 5 changed files with 98 additions and 65 deletions.
53 changes: 18 additions & 35 deletions chia/_tests/wallet/test_new_wallet_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
from random import Random
from typing import AsyncGenerator, Dict, List, Optional, OrderedDict, Set, Tuple

from anyio import sleep
import pytest
from chia_rs import AugSchemeMPL, Coin, CoinSpend, CoinState, FullBlock, G2Element, Program
from chia_rs import AugSchemeMPL, Coin, CoinSpend, CoinState, FullBlock, Program

from chia._tests.connection_utils import add_dummy_connection
from chia.full_node.coin_store import CoinStore
from chia.protocols import wallet_protocol
from chia.protocols.protocol_message_types import ProtocolMessageTypes
from chia.protocols.shared_protocol import Capability
from chia.server.outbound_message import Message, NodeType
from chia.server.ws_connection import WSChiaConnection
Expand All @@ -22,12 +22,10 @@
from chia.types.aliases import WalletService
from chia.types.blockchain_format.sized_bytes import bytes32
from chia.types.coin_record import CoinRecord
from chia.types.peer_info import PeerInfo
from chia.types.spend_bundle import SpendBundle
from chia.util.hash import std_hash
from chia.util.ints import uint8, uint16, uint32, uint64


IDENTITY_PUZZLE = Program.to(1)
IDENTITY_PUZZLE_HASH = IDENTITY_PUZZLE.get_tree_hash()

Expand Down Expand Up @@ -783,47 +781,32 @@ async def test_subscribed_mempool_items(
simulator, queue, peer = await connect_to_simulator(one_node, self_hostname, mempool_updates=True)
subs = simulator.full_node.subscriptions
coin_store = simulator.full_node.coin_store
genesis_challenge = simulator.full_node.constants.GENESIS_CHALLENGE

print("ZZZ")

await simulator.full_node.add_block(default_400_blocks[0])

print("YYY")

await simulator.full_node.add_block_batch(default_400_blocks[1:], PeerInfo("0.0.0.0", 0), None)

print("XXX")

ph1 = IDENTITY_PUZZLE_HASH
coin1 = Coin(bytes32(b"\0" * 32), ph1, uint64(1000))

# Add coin and subscription
await coin_store._add_coin_records([CoinRecord(coin1, uint32(1), uint32(0), False, uint64(10000))])

print("AAA")

print("BBB")

# Request coin state
resp = await simulator.request_coin_state(
wallet_protocol.RequestCoinState([coin1.name()], None, genesis_challenge, True), peer
)
assert resp is not None

print("CCC")

response = wallet_protocol.RespondCoinState.from_bytes(resp.data)
assert response.coin_ids == [coin1.name()]
subs.add_coin_subscriptions(peer.peer_node_id, [coin1.name()], 1)

# Add mempool item
solution = Program.to([])
bundle = SpendBundle([CoinSpend(coin1, IDENTITY_PUZZLE, solution)], AugSchemeMPL.aggregate([]))
result = await simulator.full_node.add_transaction(bundle, bundle.name())
print("X", result)
await simulator.full_node.add_transaction(bundle, bundle.name())

msg1 = await queue.get()
msg2 = await queue.get()

found: Optional[Message] = None

for msg in [msg1, msg2]:
if msg.type == ProtocolMessageTypes.mempool_item_added.value:
found = msg
break

print("Y", simulator.auto_farm)
print("Z", simulator.full_node.mempool_manager.get_spendbundle(bundle.name()))
await sleep(5.0)
assert found is not None

print(queue)
assert False
update = wallet_protocol.MempoolItemAdded.from_bytes(found.data)
assert update.transaction_id == bundle.name()
73 changes: 54 additions & 19 deletions chia/full_node/full_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -1602,7 +1602,7 @@ async def peak_post_processing_2(

if self.sync_store.get_sync_mode() is False:
await self.send_peak_to_timelords(block)
await self.broadcast_removed_tx(ppp_result.mempool_removals, peer)
await self.broadcast_removed_tx(ppp_result.mempool_removals)

# Tell full nodes about the new peak
msg = make_msg(
Expand Down Expand Up @@ -2373,7 +2373,7 @@ async def add_transaction(
# vector.
mempool_item = self.mempool_manager.get_mempool_item(spend_name)
assert mempool_item is not None
await self.broadcast_removed_tx(info.removals, current_peer=peer)
await self.broadcast_removed_tx(info.removals)
await self.broadcast_added_tx(mempool_item, current_peer=peer)

if self.simulator_transaction_callback is not None: # callback
Expand All @@ -2389,6 +2389,7 @@ async def broadcast_added_tx(
) -> None:
assert mempool_item.fee >= 0
assert mempool_item.cost is not None

new_tx = full_node_protocol.NewTransaction(
mempool_item.name,
mempool_item.cost,
Expand All @@ -2405,45 +2406,72 @@ async def broadcast_added_tx(
if conds is None:
return

hints_for_removals = await self.hint_store.get_hints([bytes32(spend.coin_id) for spend in conds.spends])
peers = peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals))
start_time = time.monotonic()

for peer_id in peers:
peer = self.server.all_connections.get(peer_id)
hints_for_removals = await self.hint_store.get_hints([bytes32(spend.coin_id) for spend in conds.spends])
peer_ids = peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals))

if peer is None or not peer.has_capability(Capability.MEMPOOL_UPDATES):
continue
peers = [
peer
for peer_id in peer_ids
if (peer := self.server.all_connections.get(peer_id)) is not None
and peer.has_capability(Capability.MEMPOOL_UPDATES)
]

for peer in peers:
msg = make_msg(ProtocolMessageTypes.mempool_item_added, wallet_protocol.MempoolItemAdded(mempool_item.name))
await peer.send_message(msg)

async def broadcast_removed_tx(
self, mempool_removals: List[MempoolRemoveInfo], current_peer: Optional[WSChiaConnection]
) -> None:
total_time = time.monotonic() - start_time

self.log.log(
logging.DEBUG if total_time < 0.5 else logging.WARNING,
f"Broadcasting added transaction {mempool_item.name} to {len(peers)} peers took {total_time:.4f}s",
)

async def broadcast_removed_tx(self, mempool_removals: List[MempoolRemoveInfo]) -> None:
total_removals = sum([len(r.items) for r in mempool_removals])
if total_removals == 0:
return

start_time = time.monotonic()

self.log.debug(f"Broadcasting {total_removals} removed transactions to peers")

removals_to_send: Dict[bytes32, List[bytes32]] = dict()

for removal_info in mempool_removals:
for internal_mempool_item in removal_info.items:
conds = internal_mempool_item.npc_result.conds
if conds is None:
return
continue

hints_for_removals = await self.hint_store.get_hints([bytes32(spend.coin_id) for spend in conds.spends])
peers = peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals))
peer_ids = peers_for_spend_bundle(self.subscriptions, conds, set(hints_for_removals))

transaction_id: Optional[bytes32] = None
peers = [
peer
for peer_id in peer_ids
if (peer := self.server.all_connections.get(peer_id)) is not None
and peer.has_capability(Capability.MEMPOOL_UPDATES)
]

if len(peers) > 0:
transaction_id = internal_mempool_item.spend_bundle.name()
if len(peers) == 0:
continue

for peer_id in peers:
peer = self.server.all_connections.get(peer_id)
transaction_id = internal_mempool_item.spend_bundle.name()

self.log.debug(
f"Broadcasting removed transaction {transaction_id} to "
f"wallet peers {[peer.peer_node_id for peer in peers]}"
)

for peer in peers:
if peer is None or not peer.has_capability(Capability.MEMPOOL_UPDATES):
continue

assert transaction_id is not None
removals_to_send.get(peer_id, []).append(transaction_id)
removals_to_send.setdefault(peer.peer_node_id, []).append(transaction_id)

for peer_id, removals in removals_to_send.items():
peer = self.server.all_connections.get(peer_id)
Expand All @@ -2454,6 +2482,13 @@ async def broadcast_removed_tx(
msg = make_msg(ProtocolMessageTypes.mempool_items_removed, wallet_protocol.MempoolItemsRemoved(removals))
await peer.send_message(msg)

total_time = time.monotonic() - start_time

self.log.log(
logging.DEBUG if total_time < 0.5 else logging.WARNING,
f"Broadcasting {total_removals} removed transactions to {len(peers)} peers took {total_time:.4f}s",
)

async def _needs_compact_proof(
self, vdf_info: VDFInfo, header_block: HeaderBlock, field_vdf: CompressibleVDFField
) -> bool:
Expand Down
31 changes: 23 additions & 8 deletions chia/full_node/full_node_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,8 @@ async def mempool_updates_for_puzzle_hashes(
if Capability.MEMPOOL_UPDATES not in peer.peer_capabilities:
return

start_time = time.monotonic()

async with self.full_node.db_wrapper.reader() as conn:
transaction_ids = set(
self.full_node.mempool_manager.mempool.transaction_ids_matching_puzzle_hashes(
Expand All @@ -1924,22 +1926,35 @@ async def mempool_updates_for_puzzle_hashes(
self.full_node.mempool_manager.mempool.transaction_ids_matching_coin_ids(hinted_coin_ids)
)

if len(transaction_ids) == 0:
return
if len(transaction_ids) > 0:
message = wallet_protocol.SubscribedMempoolItems(list(transaction_ids))
await peer.send_message(make_msg(ProtocolMessageTypes.subscribed_mempool_items, message))

message = wallet_protocol.SubscribedMempoolItems(list(transaction_ids))
await peer.send_message(make_msg(ProtocolMessageTypes.subscribed_mempool_items, message))
total_time = time.monotonic() - start_time

self.log.log(
logging.DEBUG if total_time < 2.0 else logging.WARNING,
f"Sending initial mempool items to peer {peer.peer_node_id} took {total_time:.4f}s",
)

async def mempool_updates_for_coin_ids(self, peer: WSChiaConnection, coin_ids: Set[bytes32]) -> None:
if Capability.MEMPOOL_UPDATES not in peer.peer_capabilities:
return

start_time = time.monotonic()

transaction_ids = self.full_node.mempool_manager.mempool.transaction_ids_matching_coin_ids(coin_ids)
if len(transaction_ids) == 0:
return

message = wallet_protocol.SubscribedMempoolItems(list(transaction_ids))
await peer.send_message(make_msg(ProtocolMessageTypes.subscribed_mempool_items, message))
if len(transaction_ids) > 0:
message = wallet_protocol.SubscribedMempoolItems(list(transaction_ids))
await peer.send_message(make_msg(ProtocolMessageTypes.subscribed_mempool_items, message))

total_time = time.monotonic() - start_time

self.log.log(
logging.DEBUG if total_time < 2.0 else logging.WARNING,
f"Sending initial mempool items to peer {peer.peer_node_id} took {total_time:.4f}s",
)

def max_subscriptions(self, peer: WSChiaConnection) -> int:
if self.is_trusted(peer):
Expand Down
2 changes: 1 addition & 1 deletion chia/full_node/mempool.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def transaction_ids_matching_puzzle_hashes(self, puzzle_hashes: Set[bytes32], in
found_addition = False

for puzzle_hash, _amount, memo in spend.create_coin:
if puzzle_hash in puzzle_hash or (include_hints and memo is not None and memo in puzzle_hashes):
if puzzle_hash in puzzle_hashes or (include_hints and memo is not None and memo in puzzle_hashes):
transaction_ids.append(transaction_id)
found_addition = True
break
Expand Down
4 changes: 2 additions & 2 deletions chia/full_node/subscriptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,10 @@ def peers_for_spend_bundle(

for puzzle_hash, amount, memo in spend.create_coin:
coin_ids.add(Coin(spend.coin_id, puzzle_hash, uint64(amount)).name())
coin_ids.add(bytes32(puzzle_hash))
puzzle_hashes.add(bytes32(puzzle_hash))

if memo is not None and len(memo) == 32:
coin_ids.add(bytes32(memo))
puzzle_hashes.add(bytes32(memo))

peers: Set[bytes32] = set()

Expand Down

0 comments on commit 17421aa

Please sign in to comment.