From f34c2de6bec0600e6658aecd1ff5606f1a97e7e8 Mon Sep 17 00:00:00 2001 From: Gabriel Levcovitz Date: Thu, 8 Jan 2026 14:02:55 -0300 Subject: [PATCH] refactor(traversal): make neighbors traversal explicitly required --- hathor/consensus/block_consensus.py | 12 +++--- hathor/consensus/consensus.py | 6 ++- hathor/consensus/transaction_consensus.py | 4 +- hathor/indexes/mempool_tips_index.py | 5 ++- hathor/p2p/sync_v2/streamers.py | 7 ++-- hathor/transaction/base_transaction.py | 1 + hathor/transaction/block.py | 3 +- hathor/transaction/storage/traversal.py | 49 +++++++++++++++++------ hathor_tests/tx/test_traversal.py | 6 +++ 9 files changed, 66 insertions(+), 27 deletions(-) diff --git a/hathor/consensus/block_consensus.py b/hathor/consensus/block_consensus.py index 568406103..4e55d82ce 100644 --- a/hathor/consensus/block_consensus.py +++ b/hathor/consensus/block_consensus.py @@ -734,12 +734,12 @@ def remove_first_block_markers(self, block: Block) -> None: bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False) for tx in bfs.run(block, skip_root=True): if tx.is_block: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue meta = tx.get_metadata() if meta.first_block != block.hash: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue if tx.is_nano_contract(): @@ -757,6 +757,7 @@ def remove_first_block_markers(self, block: Block) -> None: meta.voided_by = None meta.first_block = None self.context.save(tx) + bfs.add_neighbors() def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], mark_as_best_chain: bool, newest_timestamp: int) -> int: @@ -785,11 +786,11 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], for tx in bfs.run(parent, skip_root=False): assert tx.hash is not None if tx.is_block: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue if tx.hash in used: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue used.add(tx.hash) @@ -797,7 +798,7 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], if meta.first_block: first_block = storage.get_transaction(meta.first_block) if first_block.timestamp <= newest_timestamp: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue if mark_as_best_chain: @@ -806,6 +807,7 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], self.context.save(tx) score += weight_to_work(tx.weight) + bfs.add_neighbors() # Always save the score when it is calculated. meta = block.get_metadata() diff --git a/hathor/consensus/consensus.py b/hathor/consensus/consensus.py index 5d7077b45..5eb9bf7b6 100644 --- a/hathor/consensus/consensus.py +++ b/hathor/consensus/consensus.py @@ -360,16 +360,17 @@ def _compute_vertices_that_became_invalid( # Run a right-to-left BFS starting from the mempool tips. for tx in find_invalid_bfs.run(mempool_tips, skip_root=False): if not isinstance(tx, Transaction): - find_invalid_bfs.skip_neighbors(tx) + find_invalid_bfs.skip_neighbors() continue if tx.get_metadata().first_block is not None: - find_invalid_bfs.skip_neighbors(tx) + find_invalid_bfs.skip_neighbors() continue # At this point, it's a mempool tx, so we have to re-verify it. if not all(rule(tx) for rule in mempool_rules): invalid_txs.add(tx) + find_invalid_bfs.add_neighbors() # From the invalid txs, mark all vertices to the right as invalid. This includes both txs and blocks. to_remove: list[BaseTransaction] = [] @@ -379,6 +380,7 @@ def _compute_vertices_that_became_invalid( for vertex in find_to_remove_bfs.run(invalid_txs, skip_root=False): vertex.set_validation(ValidationState.INVALID) to_remove.append(vertex) + find_to_remove_bfs.add_neighbors() to_remove.reverse() return to_remove diff --git a/hathor/consensus/transaction_consensus.py b/hathor/consensus/transaction_consensus.py index 6187c256c..855464680 100644 --- a/hathor/consensus/transaction_consensus.py +++ b/hathor/consensus/transaction_consensus.py @@ -391,7 +391,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: meta2 = tx2.get_metadata() if not (meta2.voided_by and voided_hash in meta2.voided_by): - bfs.skip_neighbors(tx2) + bfs.skip_neighbors() continue if meta2.voided_by: meta2.voided_by.discard(voided_hash) @@ -402,6 +402,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: tx.storage.add_to_indexes(tx2) self.context.save(tx2) self.assert_valid_consensus(tx2) + bfs.add_neighbors() from hathor.transaction import Transaction for tx2 in check_list: @@ -504,6 +505,7 @@ def add_voided_by(self, tx: Transaction, voided_hash: bytes, *, is_dag_verificat self.context.save(tx2) tx2.storage.del_from_indexes(tx2, relax_assert=True) self.assert_valid_consensus(tx2) + bfs.add_neighbors() for tx2 in check_list: self.check_conflicts(tx2) diff --git a/hathor/indexes/mempool_tips_index.py b/hathor/indexes/mempool_tips_index.py index fb3b8e5b2..1d605e4c2 100644 --- a/hathor/indexes/mempool_tips_index.py +++ b/hathor/indexes/mempool_tips_index.py @@ -202,12 +202,13 @@ def iter_all(self, tx_storage: 'TransactionStorage') -> Iterator[Transaction]: bfs = BFSTimestampWalk(tx_storage, is_dag_verifications=True, is_dag_funds=True, is_left_to_right=False) for tx in bfs.run(self.iter(tx_storage), skip_root=False): if not isinstance(tx, Transaction): - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue if tx.get_metadata().first_block is not None: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() else: yield tx + bfs.add_neighbors() def get(self) -> set[bytes]: return set(iter(self._index)) diff --git a/hathor/p2p/sync_v2/streamers.py b/hathor/p2p/sync_v2/streamers.py index b7c4f7363..9dfa2a220 100644 --- a/hathor/p2p/sync_v2/streamers.py +++ b/hathor/p2p/sync_v2/streamers.py @@ -282,7 +282,7 @@ def send_next(self) -> None: # Skip blocks. if cur.is_block: - self.bfs.skip_neighbors(cur) + self.bfs.skip_neighbors() return assert isinstance(cur, Transaction) @@ -291,6 +291,7 @@ def send_next(self) -> None: if cur_metadata.first_block is None: self.log.debug('reached a tx that is not confirmed, stopping streaming') self.sync_agent.stop_tx_streaming_server(StreamEnd.TX_NOT_CONFIRMED) + self.bfs.add_neighbors() return # Check if tx is confirmed by the `self.current_block` or any next block. @@ -299,7 +300,7 @@ def send_next(self) -> None: first_block = self.tx_storage.get_block(cur_metadata.first_block) if not_none(first_block.static_metadata.height) < not_none(self.current_block.static_metadata.height): self.log.debug('skipping tx: out of current block') - self.bfs.skip_neighbors(cur) + self.bfs.skip_neighbors() return self.log.debug('send next transaction', tx_id=cur.hash.hex()) @@ -309,4 +310,4 @@ def send_next(self) -> None: if self.counter >= self.limit: self.log.debug('limit exceeded, stopping streaming') self.sync_agent.stop_tx_streaming_server(StreamEnd.LIMIT_EXCEEDED) - return + self.bfs.add_neighbors() diff --git a/hathor/transaction/base_transaction.py b/hathor/transaction/base_transaction.py index 165ffc200..37392b12b 100644 --- a/hathor/transaction/base_transaction.py +++ b/hathor/transaction/base_transaction.py @@ -744,6 +744,7 @@ def update_accumulated_weight( work += weight_to_work(tx.weight) if stop_value is not None and work > stop_value: break + bfs_walk.add_neighbors() metadata.accumulated_weight = work if save_file: diff --git a/hathor/transaction/block.py b/hathor/transaction/block.py index 64a3aabab..99f8d3abd 100644 --- a/hathor/transaction/block.py +++ b/hathor/transaction/block.py @@ -369,10 +369,11 @@ def iter_transactions_in_this_block(self) -> Iterator[Transaction]: for tx in bfs.run(self, skip_root=True): tx_meta = tx.get_metadata() if tx_meta.first_block != self.hash: - bfs.skip_neighbors(tx) + bfs.skip_neighbors() continue assert isinstance(tx, Transaction) yield tx + bfs.add_neighbors() @override def init_static_metadata_from_storage(self, settings: HathorSettings, storage: 'TransactionStorage') -> None: diff --git a/hathor/transaction/storage/traversal.py b/hathor/transaction/storage/traversal.py index b12c81de5..1022ac4d0 100644 --- a/hathor/transaction/storage/traversal.py +++ b/hathor/transaction/storage/traversal.py @@ -17,8 +17,11 @@ import heapq from abc import ABC, abstractmethod from collections import deque +from enum import StrEnum, auto from itertools import chain -from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Union +from typing import TYPE_CHECKING, Iterable, Iterator, Union + +from typing_extensions import assert_never if TYPE_CHECKING: from hathor.transaction import BaseTransaction # noqa: F401 @@ -43,6 +46,11 @@ def __le__(self, other: 'HeapItem') -> bool: return self.key <= other.key +class _WalkOp(StrEnum): + ADD_NEIGHBORS = auto() + SKIP_NEIGHBORS = auto() + + class GenericWalk(ABC): """ A helper class to walk on the DAG. """ @@ -72,7 +80,7 @@ def __init__( self.is_left_to_right = is_left_to_right self._reverse_heap: bool = not self.is_left_to_right - self._ignore_neighbors: Optional['BaseTransaction'] = None + self._walk_op: _WalkOp | None = None @abstractmethod def _push_visit(self, tx: 'BaseTransaction') -> None: @@ -111,7 +119,7 @@ def _get_iterator(self, tx: 'BaseTransaction', *, is_left_to_right: bool) -> Ite return it - def add_neighbors(self, tx: 'BaseTransaction') -> None: + def _add_neighbors(self, tx: 'BaseTransaction') -> None: """ Add neighbors of `tx` to be visited later according to the configuration. """ it = self._get_iterator(tx, is_left_to_right=self.is_left_to_right) @@ -121,11 +129,21 @@ def add_neighbors(self, tx: 'BaseTransaction') -> None: neighbor = self.storage.get_vertex(_hash) self._push_visit(neighbor) - def skip_neighbors(self, tx: 'BaseTransaction') -> None: - """ Mark `tx` to have its neighbors skipped, i.e., they will not be added to be - visited later. `tx` must be equal to the current yielded transaction. + def _set_walk_op(self, op: _WalkOp) -> None: + assert self._walk_op is None, 'walk op is already set' + self._walk_op = op + + def add_neighbors(self) -> None: + """ Mark current item to have its neighbors added, i.e., they will be added to be + visited later. """ - self._ignore_neighbors = tx + self._set_walk_op(_WalkOp.ADD_NEIGHBORS) + + def skip_neighbors(self) -> None: + """ Mark current item to have its neighbors skipped, i.e., they will not be added to be + visited later. + """ + self._set_walk_op(_WalkOp.SKIP_NEIGHBORS) def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *, skip_root: bool = False) -> Iterator['BaseTransaction']: @@ -144,16 +162,21 @@ def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *, if not skip_root: self._push_visit(root) else: - self.add_neighbors(root) + self._add_neighbors(root) while not self._is_empty(): tx = self._pop_visit() yield tx - if not self._ignore_neighbors: - self.add_neighbors(tx) - else: - assert self._ignore_neighbors == tx - self._ignore_neighbors = None + match self._walk_op: + case None: + raise ValueError('you must explicitly add or skip neighbors') + case _WalkOp.ADD_NEIGHBORS: + self._add_neighbors(tx) + self._walk_op = None + case _WalkOp.SKIP_NEIGHBORS: + self._walk_op = None + case _: + assert_never(self._walk_op) class BFSTimestampWalk(GenericWalk): diff --git a/hathor_tests/tx/test_traversal.py b/hathor_tests/tx/test_traversal.py index 6041082f2..d8a538e78 100644 --- a/hathor_tests/tx/test_traversal.py +++ b/hathor_tests/tx/test_traversal.py @@ -100,6 +100,7 @@ def _run_lr(self, walk, skip_root=True): seen.add(tx.hash) self.assertGreaterEqual(tx.timestamp, last_timestamp) last_timestamp = tx.timestamp + walk.add_neighbors() return seen def _run_rl(self, walk): @@ -109,6 +110,7 @@ def _run_rl(self, walk): seen.add(tx.hash) self.assertLessEqual(tx.timestamp, last_timestamp) last_timestamp = tx.timestamp + walk.add_neighbors() return seen @@ -131,6 +133,7 @@ def _run_lr(self, walk, skip_root=True): distance[tx.hash] = dist self.assertGreaterEqual(dist, last_dist) last_dist = dist + walk.add_neighbors() return seen def _run_rl(self, walk): @@ -146,6 +149,7 @@ def _run_rl(self, walk): distance[tx.hash] = dist self.assertGreaterEqual(dist, last_dist) last_dist = dist + walk.add_neighbors() return seen @@ -159,10 +163,12 @@ def _run_lr(self, walk, skip_root=True): seen = set() for tx in walk.run(self.root_tx, skip_root=skip_root): seen.add(tx.hash) + walk.add_neighbors() return seen def _run_rl(self, walk): seen = set() for tx in walk.run(self.root_tx, skip_root=True): seen.add(tx.hash) + walk.add_neighbors() return seen