diff --git a/hathor/dag_builder/artifacts.py b/hathor/dag_builder/artifacts.py index c183f4358..afa930473 100644 --- a/hathor/dag_builder/artifacts.py +++ b/hathor/dag_builder/artifacts.py @@ -14,7 +14,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Iterator, NamedTuple, Sequence, TypeVar +from typing import TYPE_CHECKING, Callable, Iterator, NamedTuple, Sequence, TypeVar from hathor.dag_builder.types import DAGNode from hathor.manager import HathorManager @@ -42,6 +42,11 @@ def __init__(self, items: Iterator[tuple[DAGNode, BaseTransaction]]) -> None: self.list: tuple[_Pair, ...] = tuple(v) self._last_propagated: str | None = None + self._step_fns: list[Callable[[DAGNode, BaseTransaction], None]] = [] + + def register_step_fn(self, step_fn: Callable[[DAGNode, BaseTransaction], None]) -> None: + """Register a new step function to be called between vertex propagations.""" + self._step_fns.append(step_fn) def get_typed_vertex(self, name: str, type_: type[T]) -> T: """Get a vertex by name, asserting it is of the provided type.""" @@ -83,6 +88,8 @@ def propagate_with( assert manager.vertex_handler.on_new_relayed_vertex(vertex) except Exception as e: raise Exception(f'failed on_new_tx({node.name})') from e + for step_fn in self._step_fns: + step_fn(node, vertex) self._last_propagated = node.name if node.name == self._last_propagated: