diff --git a/hathor/indexes/mempool_tips_index.py b/hathor/indexes/mempool_tips_index.py index 210844a70..fb3b8e5b2 100644 --- a/hathor/indexes/mempool_tips_index.py +++ b/hathor/indexes/mempool_tips_index.py @@ -202,6 +202,7 @@ def iter_all(self, tx_storage: 'TransactionStorage') -> Iterator[Transaction]: bfs = BFSTimestampWalk(tx_storage, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False) for tx in bfs.run(self.iter(tx_storage), skip_root=False): if not isinstance(tx, Transaction): + bfs.skip_neighbors(tx) continue if tx.get_metadata().first_block is not None: bfs.skip_neighbors(tx) diff --git a/hathor/manager.py b/hathor/manager.py index 9b9bdd9d9..e48249c2d 100644 --- a/hathor/manager.py +++ b/hathor/manager.py @@ -579,7 +579,7 @@ def generate_parent_txs(self, timestamp: Optional[float]) -> 'ParentTxs': best_block = self.tx_storage.get_best_block() assert timestamp >= best_block.timestamp - def get_tx_parents(tx: BaseTransaction) -> list[Transaction]: + def get_tx_parents(tx: BaseTransaction, *, with_inputs: bool = False) -> list[Transaction]: if tx.is_genesis: genesis_txs = [self._settings.GENESIS_TX1_HASH, self._settings.GENESIS_TX2_HASH] if tx.is_transaction: @@ -590,34 +590,38 @@ def get_tx_parents(tx: BaseTransaction) -> list[Transaction]: parents = tx.get_tx_parents() assert len(parents) == 2 - return list(parents) - unconfirmed_tips = [tx for tx in self.tx_storage.iter_mempool_tips() if tx.timestamp < timestamp] - unconfirmed_extras = sorted( - (tx for tx in self.tx_storage.iter_mempool() if tx.timestamp < timestamp and tx not in unconfirmed_tips), - key=lambda tx: tx.timestamp, - ) + txs = list(parents) + if with_inputs: + input_tx_ids = set(i.tx_id for i in tx.inputs) + inputs = (self.tx_storage.get_transaction(tx_id) for tx_id in input_tx_ids) + input_txs = (tx for tx in inputs if isinstance(tx, Transaction)) + txs.extend(input_txs) - # mix the blocks tx-parents, with their own tx-parents to avoid carrying one of the genesis tx over - best_block_tx_parents = get_tx_parents(best_block) - tx1_tx_grandparents = get_tx_parents(best_block_tx_parents[0]) - tx2_tx_grandparents = get_tx_parents(best_block_tx_parents[1]) - confirmed_tips = sorted( - set(best_block_tx_parents) | set(tx1_tx_grandparents) | set(tx2_tx_grandparents), - key=lambda tx: tx.timestamp, - ) + return txs + unconfirmed_tips = [tx for tx in self.tx_storage.iter_mempool_tips() if tx.timestamp < timestamp] match unconfirmed_tips: case []: + # mix the blocks tx-parents, with their own tx-parents to avoid carrying one of the genesis tx over + best_block_tx_parents = get_tx_parents(best_block) + tx1_tx_grandparents = get_tx_parents(best_block_tx_parents[0], with_inputs=True) + tx2_tx_grandparents = get_tx_parents(best_block_tx_parents[1], with_inputs=True) + confirmed_tips = sorted( + set(best_block_tx_parents) | set(tx1_tx_grandparents) | set(tx2_tx_grandparents), + key=lambda tx: tx.timestamp, + ) self.log.debug('generate_parent_txs: empty mempool, repeat parents') return ParentTxs.from_txs(can_include=confirmed_tips[-2:], must_include=()) case [tip_tx]: - if unconfirmed_extras: - self.log.debug('generate_parent_txs: one tx tip and at least one other mempool tx') - return ParentTxs.from_txs(can_include=unconfirmed_extras[-1:], must_include=(tip_tx,)) - else: - self.log.debug('generate_parent_txs: one tx in mempool, fill with one repeated parent') - return ParentTxs.from_txs(can_include=confirmed_tips[-1:], must_include=(tip_tx,)) + best_block_tx_parents = get_tx_parents(best_block) + repeated_parents = get_tx_parents(tip_tx, with_inputs=True) + confirmed_tips = sorted( + set(best_block_tx_parents) | set(repeated_parents), + key=lambda tx: tx.timestamp, + ) + self.log.debug('generate_parent_txs: one tx in mempool, fill with one repeated parent') + return ParentTxs.from_txs(can_include=confirmed_tips[-1:], must_include=(tip_tx,)) case _: self.log.debug('generate_parent_txs: multiple unconfirmed mempool tips') return ParentTxs.from_txs(can_include=unconfirmed_tips, must_include=()) diff --git a/hathor_tests/tx/test_mempool_iter_all.py b/hathor_tests/tx/test_mempool_iter_all.py new file mode 100644 index 000000000..85ad15df8 --- /dev/null +++ b/hathor_tests/tx/test_mempool_iter_all.py @@ -0,0 +1,61 @@ +# Copyright 2025 Hathor Labs +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from unittest.mock import patch + +from hathor.simulator.utils import add_new_blocks, gen_new_tx +from hathor_tests import unittest + + +class MempoolIterAllTraversalTestCase(unittest.TestCase): + """Regression helpers for ByteCollectionMempoolTipsIndex.iter_all.""" + + def setUp(self) -> None: + super().setUp() + self.manager = self.create_peer('testnet', unlock_wallet=True) + + def test_iter_mempool_walks_block_chain_via_inputs(self) -> None: + # Mine enough blocks so at least one reward is spendable by the wallet. + num_blocks = self._settings.REWARD_SPEND_MIN_BLOCKS + 2 + add_new_blocks(self.manager, num_blocks, advance_clock=1) + self.run_to_completion() + + address = self.get_address(0) + assert address is not None + tx = gen_new_tx(self.manager, address, value=10) + self.manager.propagate_tx(tx) + self.run_to_completion() + + # Capture which vertices iter_mempool touches while walking dependencies. + with patch.object(self.manager.tx_storage, 'get_vertex', + wraps=self.manager.tx_storage.get_vertex) as get_vertex: + mempool = list(self.manager.tx_storage.iter_mempool()) + + self.assertEqual({tx.hash}, {t.hash for t in mempool}) + + tx_storage = self.manager.tx_storage + expected_blocks = { + txin.tx_id + for txin in tx.inputs + if tx_storage.get_transaction(txin.tx_id).is_block + } + visited_blocks = { + call.args[0] + for call in get_vertex.call_args_list + if tx_storage.get_transaction(call.args[0]).is_block + } + + # iter_mempool should only touch the blocks whose outputs are being spent in the mempool. + self.assertTrue(expected_blocks, 'at least one block reward should be spent') + self.assertEqual(expected_blocks, visited_blocks)