Skip to content

Commit

Permalink
Fix equality check on nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
aandres committed Aug 21, 2023
1 parent 34cadb6 commit fa1a09f
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 6 deletions.
12 changes: 6 additions & 6 deletions beavers/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,7 +181,7 @@ def input_nodes(self) -> typing.Tuple[Node]:
return self.nodes


NO_INPUTS = _NodeInputs.create([], {})
_NO_INPUTS = _NodeInputs.create([], {})


@dataclasses.dataclass
Expand All @@ -193,7 +193,7 @@ class _RuntimeNodeData(typing.Generic[T]):
cycle_id: int


@dataclasses.dataclass(frozen=True)
@dataclasses.dataclass(frozen=True, eq=False)
class Node(typing.Generic[T]):
"""
Represent an element in a `Dag`.
Expand All @@ -219,7 +219,7 @@ class Node(typing.Generic[T]):
def _create(
value: T = None,
function: typing.Optional[typing.Callable[[...], T]] = None,
inputs: _NodeInputs = NO_INPUTS,
inputs: _NodeInputs = _NO_INPUTS,
empty: typing.Any = _STATE_EMPTY,
notifications: int = 1,
) -> Node:
Expand Down Expand Up @@ -354,7 +354,7 @@ def const(self, value: T) -> Node[T]:
return self._add_node(
Node._create(
function=_unchanged_callback,
inputs=NO_INPUTS,
inputs=_NO_INPUTS,
value=value,
notifications=0,
)
Expand Down Expand Up @@ -386,7 +386,7 @@ def source_stream(
node = self._add_stream(
function=_SourceStreamFunction(empty, name),
empty=empty,
inputs=NO_INPUTS,
inputs=_NO_INPUTS,
)
if name:
self._sources[name] = node
Expand Down Expand Up @@ -476,7 +476,7 @@ def timer_manager(self) -> Node[TimerManager]:
Node._create(
value=function.timer_manager,
function=function,
inputs=NO_INPUTS,
inputs=_NO_INPUTS,
notifications=1,
)
)
Expand Down
17 changes: 17 additions & 0 deletions tests/test_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import operator
import time

import pandas as pd
Expand Down Expand Up @@ -574,5 +575,21 @@ def test_recalculate_clean_node():
node._recalculate(2)


def test_can_add_node_copy():
dag = Dag()
source = dag.source_stream([])
node_one = dag.stream(operator.__add__).map(source, source)
node_two = dag.stream(operator.__add__).map(source, source)
assert node_one is not node_two
assert node_one != node_two


def test_can_not_add_node_back():
dag = Dag()
source = dag.source_stream([])
with pytest.raises(ValueError, match="Node already in dag"):
dag._add_node(source)


def test_unchanged_callback():
assert _unchanged_callback() is _STATE_UNCHANGED

0 comments on commit fa1a09f

Please sign in to comment.