diff --git a/beavers/engine.py b/beavers/engine.py index 5aeb2e2..527e096 100644 --- a/beavers/engine.py +++ b/beavers/engine.py @@ -181,7 +181,7 @@ def input_nodes(self) -> typing.Tuple[Node]: return self.nodes -NO_INPUTS = _NodeInputs.create([], {}) +_NO_INPUTS = _NodeInputs.create([], {}) @dataclasses.dataclass @@ -193,7 +193,7 @@ class _RuntimeNodeData(typing.Generic[T]): cycle_id: int -@dataclasses.dataclass(frozen=True) +@dataclasses.dataclass(frozen=True, eq=False) class Node(typing.Generic[T]): """ Represent an element in a `Dag`. @@ -219,7 +219,7 @@ class Node(typing.Generic[T]): def _create( value: T = None, function: typing.Optional[typing.Callable[[...], T]] = None, - inputs: _NodeInputs = NO_INPUTS, + inputs: _NodeInputs = _NO_INPUTS, empty: typing.Any = _STATE_EMPTY, notifications: int = 1, ) -> Node: @@ -354,7 +354,7 @@ def const(self, value: T) -> Node[T]: return self._add_node( Node._create( function=_unchanged_callback, - inputs=NO_INPUTS, + inputs=_NO_INPUTS, value=value, notifications=0, ) @@ -386,7 +386,7 @@ def source_stream( node = self._add_stream( function=_SourceStreamFunction(empty, name), empty=empty, - inputs=NO_INPUTS, + inputs=_NO_INPUTS, ) if name: self._sources[name] = node @@ -476,7 +476,7 @@ def timer_manager(self) -> Node[TimerManager]: Node._create( value=function.timer_manager, function=function, - inputs=NO_INPUTS, + inputs=_NO_INPUTS, notifications=1, ) ) diff --git a/tests/test_engine.py b/tests/test_engine.py index 678ee31..1ac4bce 100644 --- a/tests/test_engine.py +++ b/tests/test_engine.py @@ -1,4 +1,5 @@ import asyncio +import operator import time import pandas as pd @@ -574,5 +575,21 @@ def test_recalculate_clean_node(): node._recalculate(2) +def test_can_add_node_copy(): + dag = Dag() + source = dag.source_stream([]) + node_one = dag.stream(operator.__add__).map(source, source) + node_two = dag.stream(operator.__add__).map(source, source) + assert node_one is not node_two + assert node_one != node_two + + +def test_can_not_add_node_back(): + dag = Dag() + source = dag.source_stream([]) + with pytest.raises(ValueError, match="Node already in dag"): + dag._add_node(source) + + def test_unchanged_callback(): assert _unchanged_callback() is _STATE_UNCHANGED