diff --git a/hathor/cli/quick_test.py b/hathor/cli/quick_test.py index 2bf6f16fe..0701ab7a9 100644 --- a/hathor/cli/quick_test.py +++ b/hathor/cli/quick_test.py @@ -12,14 +12,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import os from argparse import ArgumentParser -from typing import Any +from typing import TYPE_CHECKING, Any, Generator from structlog import get_logger +from twisted.internet.defer import inlineCallbacks from hathor.cli.run_node import RunNode +if TYPE_CHECKING: + from hathor.transaction import BaseTransaction, Block, Transaction + logger = get_logger() @@ -28,35 +34,24 @@ def __init__(self, vertex_handler, manager, n_blocks): self.log = logger.new() self._vertex_handler = vertex_handler self._manager = manager - self._n_blocks = n_blocks - - def on_new_vertex(self, *args: Any, **kwargs: Any) -> bool: - from hathor.transaction import Block - from hathor.transaction.base_transaction import GenericVertex - - msg: str | None = None - res = self._vertex_handler.on_new_vertex(*args, **kwargs) - - if self._n_blocks is None: - should_quit = res - msg = 'added a tx' - else: - vertex = args[0] - should_quit = False - assert isinstance(vertex, GenericVertex) - - if isinstance(vertex, Block): - should_quit = vertex.get_height() >= self._n_blocks - msg = f'reached height {vertex.get_height()}' - - if should_quit: - assert msg is not None - self.log.info(f'successfully {msg}, exit now') + self._n_blocks = n_blocks or 0 + + @inlineCallbacks + def on_new_block(self, block: Block, *args: Any, **kwargs: Any) -> Generator[Any, Any, bool]: + res = yield self._vertex_handler.on_new_block(block, *args, **kwargs) + if block.get_height() >= self._n_blocks: + self.log.info(f'successfully reached height {block.get_height()}, exit now') self._manager.connections.disconnect_all_peers(force=True) self._manager.reactor.fireSystemEvent('shutdown') os._exit(0) return res + def on_new_mempool_transaction(self, tx: Transaction) -> bool: + return self._vertex_handler.on_new_mempool_transaction(tx) + + def on_new_relayed_vertex(self, vertex: BaseTransaction, *args: Any, **kwargs: Any) -> bool: + return self._vertex_handler.on_new_mempool_transaction(vertex, *args, **kwargs) + class QuickTest(RunNode): @classmethod