diff --git a/hathor/consensus/block_consensus.py b/hathor/consensus/block_consensus.py index 41056f11d..9ad148de4 100644 --- a/hathor/consensus/block_consensus.py +++ b/hathor/consensus/block_consensus.py @@ -432,8 +432,8 @@ def remove_first_block_markers(self, block: Block) -> None: assert block.storage is not None storage = block.storage - from hathor.transaction.storage.traversal import BFSWalk - bfs = BFSWalk(storage, is_dag_verifications=True, is_left_to_right=False) + from hathor.transaction.storage.traversal import BFSTimestampWalk + bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(block, skip_root=True): if tx.is_block: bfs.skip_neighbors(tx) @@ -470,8 +470,8 @@ def _score_block_dfs(self, block: BaseTransaction, used: set[bytes], score = sum_weights(score, x) else: - from hathor.transaction.storage.traversal import BFSWalk - bfs = BFSWalk(storage, is_dag_verifications=True, is_left_to_right=False) + from hathor.transaction.storage.traversal import BFSTimestampWalk + bfs = BFSTimestampWalk(storage, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(parent, skip_root=False): assert tx.hash is not None assert not tx.is_block diff --git a/hathor/consensus/transaction_consensus.py b/hathor/consensus/transaction_consensus.py index dd0e97808..78c4454d9 100644 --- a/hathor/consensus/transaction_consensus.py +++ b/hathor/consensus/transaction_consensus.py @@ -334,7 +334,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: """ Remove a hash from `meta.voided_by` and its descendants (both from verification DAG and funds tree). """ - from hathor.transaction.storage.traversal import BFSWalk + from hathor.transaction.storage.traversal import BFSTimestampWalk assert tx.hash is not None assert tx.storage is not None @@ -347,7 +347,7 @@ def remove_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: self.log.debug('remove_voided_by', tx=tx.hash_hex, voided_hash=voided_hash.hex()) - bfs = BFSWalk(tx.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) + bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) check_list: list[BaseTransaction] = [] for tx2 in bfs.run(tx, skip_root=False): assert tx2.storage is not None @@ -404,8 +404,9 @@ def add_voided_by(self, tx: Transaction, voided_hash: bytes) -> bool: # If tx is soft voided, we can only walk through the DAG of funds. is_dag_verifications = False - from hathor.transaction.storage.traversal import BFSWalk - bfs = BFSWalk(tx.storage, is_dag_funds=True, is_dag_verifications=is_dag_verifications, is_left_to_right=True) + from hathor.transaction.storage.traversal import BFSTimestampWalk + bfs = BFSTimestampWalk(tx.storage, is_dag_funds=True, is_dag_verifications=is_dag_verifications, + is_left_to_right=True) check_list: list[Transaction] = [] for tx2 in bfs.run(tx, skip_root=False): assert tx2.storage is not None diff --git a/hathor/indexes/mempool_tips_index.py b/hathor/indexes/mempool_tips_index.py index 784327f69..08c79dd46 100644 --- a/hathor/indexes/mempool_tips_index.py +++ b/hathor/indexes/mempool_tips_index.py @@ -185,8 +185,8 @@ def iter(self, tx_storage: 'TransactionStorage', max_timestamp: Optional[float] yield from cast(Iterator[Transaction], it) def iter_all(self, tx_storage: 'TransactionStorage') -> Iterator[Transaction]: - from hathor.transaction.storage.traversal import BFSWalk - bfs = BFSWalk(tx_storage, is_dag_verifications=True, is_left_to_right=False) + from hathor.transaction.storage.traversal import BFSTimestampWalk + bfs = BFSTimestampWalk(tx_storage, is_dag_verifications=True, is_left_to_right=False) for tx in bfs.run(self.iter(tx_storage), skip_root=False): assert isinstance(tx, Transaction) if tx.get_metadata().first_block is not None: diff --git a/hathor/transaction/base_transaction.py b/hathor/transaction/base_transaction.py index 75bb3de43..f87f1193e 100644 --- a/hathor/transaction/base_transaction.py +++ b/hathor/transaction/base_transaction.py @@ -951,8 +951,8 @@ def update_accumulated_weight(self, *, stop_value: float = inf, save_file: bool # reduce the number of visits in the BFS. We need to specially handle when a transaction is not # directly verified by a block. - from hathor.transaction.storage.traversal import BFSWalk - bfs_walk = BFSWalk(self.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) + from hathor.transaction.storage.traversal import BFSTimestampWalk + bfs_walk = BFSTimestampWalk(self.storage, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=True) for tx in bfs_walk.run(self, skip_root=True): accumulated_weight = sum_weights(accumulated_weight, tx.weight) if accumulated_weight > stop_value: diff --git a/hathor/transaction/storage/transaction_storage.py b/hathor/transaction/storage/transaction_storage.py index 126eccdf3..485483732 100644 --- a/hathor/transaction/storage/transaction_storage.py +++ b/hathor/transaction/storage/transaction_storage.py @@ -1026,10 +1026,10 @@ def iter_mempool_from_tx_tips(self) -> Iterator[Transaction]: This method requires indexes to be enabled. """ - from hathor.transaction.storage.traversal import BFSWalk + from hathor.transaction.storage.traversal import BFSTimestampWalk root = self.iter_mempool_tips_from_tx_tips() - walk = BFSWalk(self, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False) + walk = BFSTimestampWalk(self, is_dag_funds=True, is_dag_verifications=True, is_left_to_right=False) for tx in walk.run(root): tx_meta = tx.get_metadata() # XXX: skip blocks and tx-tips that have already been confirmed diff --git a/hathor/transaction/storage/traversal.py b/hathor/transaction/storage/traversal.py index aef55069e..fc6bbc110 100644 --- a/hathor/transaction/storage/traversal.py +++ b/hathor/transaction/storage/traversal.py @@ -15,12 +15,14 @@ import heapq from abc import ABC, abstractmethod +from collections import deque from itertools import chain -from typing import TYPE_CHECKING, Any, Iterable, Iterator, Optional, Union +from typing import TYPE_CHECKING, Iterable, Iterator, Optional, Union if TYPE_CHECKING: from hathor.transaction import BaseTransaction # noqa: F401 from hathor.transaction.storage import TransactionStorage # noqa: F401 + from hathor.types import VertexId class HeapItem: @@ -43,8 +45,7 @@ def __le__(self, other: 'HeapItem') -> bool: class GenericWalk(ABC): """ A helper class to walk on the DAG. """ - seen: set[bytes] - to_visit: list[Any] + seen: set['VertexId'] def __init__(self, storage: 'TransactionStorage', *, is_dag_funds: bool = False, is_dag_verifications: bool = False, is_left_to_right: bool = True): @@ -58,7 +59,6 @@ def __init__(self, storage: 'TransactionStorage', *, is_dag_funds: bool = False, """ self.storage = storage self.seen = set() - self.to_visit = [] self.is_dag_funds = is_dag_funds self.is_dag_verifications = is_dag_verifications @@ -79,26 +79,36 @@ def _pop_visit(self) -> 'BaseTransaction': """ raise NotImplementedError - def add_neighbors(self, tx: 'BaseTransaction') -> None: - """ Add neighbors of `tx` to be visited later according to the configuration. + @abstractmethod + def _is_empty(self) -> bool: + """ Return true if there aren't any txs left to be visited. """ + raise NotImplementedError + + def _get_iterator(self, tx: 'BaseTransaction', *, is_left_to_right: bool) -> Iterator['VertexId']: meta = None - it: Iterator[bytes] = chain() + it: Iterator['VertexId'] = chain() if self.is_dag_verifications: - if self.is_left_to_right: + if is_left_to_right: meta = meta or tx.get_metadata() it = chain(it, meta.children) else: it = chain(it, tx.parents) if self.is_dag_funds: - if self.is_left_to_right: + if is_left_to_right: meta = meta or tx.get_metadata() it = chain(it, *meta.spent_outputs.values()) else: it = chain(it, [txin.tx_id for txin in tx.inputs]) + return it + + def add_neighbors(self, tx: 'BaseTransaction') -> None: + """ Add neighbors of `tx` to be visited later according to the configuration. + """ + it = self._get_iterator(tx, is_left_to_right=self.is_left_to_right) for _hash in it: if _hash not in self.seen: self.seen.add(_hash) @@ -131,7 +141,7 @@ def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *, else: self.add_neighbors(root) - while self.to_visit: + while not self._is_empty(): tx = self._pop_visit() assert tx.hash is not None yield tx @@ -142,16 +152,23 @@ def run(self, root: Union['BaseTransaction', Iterable['BaseTransaction']], *, self._ignore_neighbors = None -class BFSWalk(GenericWalk): - """ A help to walk in the DAG using a BFS. +class BFSTimestampWalk(GenericWalk): + """ A help to walk in the DAG using a BFS that prioritizes by timestamp. """ - to_visit: list[HeapItem] + _to_visit: list[HeapItem] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._to_visit = [] + + def _is_empty(self) -> bool: + return not self._to_visit def _push_visit(self, tx: 'BaseTransaction') -> None: - heapq.heappush(self.to_visit, HeapItem(tx, reverse=self._reverse_heap)) + heapq.heappush(self._to_visit, HeapItem(tx, reverse=self._reverse_heap)) def _pop_visit(self) -> 'BaseTransaction': - item = heapq.heappop(self.to_visit) + item = heapq.heappop(self._to_visit) tx = item.tx # We can safely remove it because we are walking in topological order # and it won't appear again in the future because this would be a cycle. @@ -160,13 +177,39 @@ def _pop_visit(self) -> 'BaseTransaction': return tx +class BFSOrderWalk(GenericWalk): + """ A help to walk in the DAG using a BFS. + """ + _to_visit: deque['BaseTransaction'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._to_visit = deque() + + def _is_empty(self) -> bool: + return not self._to_visit + + def _push_visit(self, tx: 'BaseTransaction') -> None: + self._to_visit.append(tx) + + def _pop_visit(self) -> 'BaseTransaction': + return self._to_visit.popleft() + + class DFSWalk(GenericWalk): """ A help to walk in the DAG using a DFS. """ - to_visit: list['BaseTransaction'] + _to_visit: list['BaseTransaction'] + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._to_visit = [] + + def _is_empty(self) -> bool: + return not self._to_visit def _push_visit(self, tx: 'BaseTransaction') -> None: - self.to_visit.append(tx) + self._to_visit.append(tx) def _pop_visit(self) -> 'BaseTransaction': - return self.to_visit.pop() + return self._to_visit.pop() diff --git a/tests/tx/test_traversal.py b/tests/tx/test_traversal.py index a8a99e930..a4ca58732 100644 --- a/tests/tx/test_traversal.py +++ b/tests/tx/test_traversal.py @@ -1,6 +1,6 @@ from math import inf -from hathor.transaction.storage.traversal import BFSWalk, DFSWalk +from hathor.transaction.storage.traversal import BFSOrderWalk, BFSTimestampWalk, DFSWalk from tests import unittest from tests.utils import add_blocks_unlock_reward, add_new_blocks, add_new_transactions, add_new_tx @@ -86,9 +86,9 @@ def test_right_to_left(self): self.assertTrue(seen_v.union(seen_f).issubset(seen_vf)) -class BaseBFSWalkTestCase(_TraversalTestCase): +class BaseBFSTimestampWalkTestCase(_TraversalTestCase): def gen_walk(self, **kwargs): - return BFSWalk(self.manager.tx_storage, **kwargs) + return BFSTimestampWalk(self.manager.tx_storage, **kwargs) def _run_lr(self, walk, skip_root=True): seen = set() @@ -109,16 +109,59 @@ def _run_rl(self, walk): return seen -class SyncV1BFSWalkTestCase(unittest.SyncV1Params, BaseBFSWalkTestCase): +class SyncV1BFSTimestampWalkTestCase(unittest.SyncV1Params, BaseBFSTimestampWalkTestCase): __test__ = True -class SyncV2BFSWalkTestCase(unittest.SyncV2Params, BaseBFSWalkTestCase): +class SyncV2BFSTimestampWalkTestCase(unittest.SyncV2Params, BaseBFSTimestampWalkTestCase): + __test__ = True + + +class BaseBFSOrderWalkTestCase(_TraversalTestCase): + def gen_walk(self, **kwargs): + return BFSOrderWalk(self.manager.tx_storage, **kwargs) + + def _run_lr(self, walk, skip_root=True): + seen = set() + distance = {} + distance[self.root_tx.hash] = 0 + last_dist = 0 + for tx in walk.run(self.root_tx, skip_root=skip_root): + seen.add(tx.hash) + it = walk._get_iterator(tx, is_left_to_right=False) + dist = 1 + min(distance.get(_hash, inf) for _hash in it) + self.assertIsInstance(dist, int) + distance[tx.hash] = dist + self.assertGreaterEqual(dist, last_dist) + last_dist = dist + return seen + + def _run_rl(self, walk): + seen = set() + distance = {} + distance[self.root_tx.hash] = 0 + last_dist = 0 + for tx in walk.run(self.root_tx, skip_root=True): + seen.add(tx.hash) + it = walk._get_iterator(tx, is_left_to_right=True) + dist = 1 + min(distance.get(_hash, inf) for _hash in it) + self.assertIsInstance(dist, int) + distance[tx.hash] = dist + self.assertGreaterEqual(dist, last_dist) + last_dist = dist + return seen + + +class SyncV1BFSOrderWalkTestCase(unittest.SyncV1Params, BaseBFSOrderWalkTestCase): + __test__ = True + + +class SyncV2BFSOrderWalkTestCase(unittest.SyncV2Params, BaseBFSOrderWalkTestCase): __test__ = True # sync-bridge should behave like sync-v2 -class SyncBridgeBFSWalkTestCase(unittest.SyncBridgeParams, SyncV2BFSWalkTestCase): +class SyncBridgeBFSOrderWalkTestCase(unittest.SyncBridgeParams, SyncV2BFSOrderWalkTestCase): pass