diff --git a/onnxscript/ir/_core.py b/onnxscript/ir/_core.py index faffde7483..14d07cb9f4 100644 --- a/onnxscript/ir/_core.py +++ b/onnxscript/ir/_core.py @@ -30,6 +30,7 @@ Hashable, Iterable, Iterator, + NamedTuple, OrderedDict, Sequence, SupportsInt, @@ -1055,6 +1056,18 @@ def _quoted(string: str) -> str: return f'"{string}"' +class Usage(NamedTuple): + """A usage of a value in a node. + + Attributes: + node: The node that uses the value. + idx: The input index of the value in the node. + """ + + node: Node + idx: int + + class Node(_protocols.NodeProtocol, _display.PrettyPrintable): """IR Node. @@ -1293,6 +1306,25 @@ def inputs(self, _: Any) -> None: "Directly mutating the input sequence is unsupported. Please use Node.replace_input_with() instead." ) + def predecessors(self) -> Sequence[Node]: + """Return the predecessor nodes of the node, deduplicated, in a deterministic order.""" + # Use the ordered nature of a dictionary to deduplicate the nodes + predecessors: dict[Node, None] = {} + for value in self.inputs: + if value is not None and (producer := value.producer()) is not None: + predecessors[producer] = None + return tuple(predecessors) + + def successors(self) -> Sequence[Node]: + """Return the successor nodes of the node, deduplicated, in a deterministic order.""" + # Use the ordered nature of a dictionary to deduplicate the nodes + successors: dict[Node, None] = {} + for value in self.outputs: + assert value is not None, "Bug: Output values are not expected to be None" + for usage in value.uses(): + successors[usage.node] = None + return tuple(successors) + def replace_input_with(self, index: int, value: Value | None) -> None: """Replace an input with a new value.""" if index < 0 or index >= len(self.inputs): @@ -1564,7 +1596,7 @@ def __init__( # Use a collection of (Node, int) to store uses. This is needed # because a single use can use the same value multiple times. # Use a dictionary to preserve insertion order so that the visiting order is deterministic - self._uses: dict[tuple[Node, int], None] = {} + self._uses: dict[Usage, None] = {} self.doc_string = doc_string def __repr__(self) -> str: @@ -1595,31 +1627,39 @@ def producer(self) -> Node | None: """ return self._producer + def consumers(self) -> Sequence[Node]: + """Return the nodes (deduplicated) that consume this value.""" + return tuple({usage.node: None for usage in self._uses}) + def index(self) -> int | None: """The index of the output of the defining node.""" return self._index - def uses(self) -> Collection[tuple[Node, int]]: + def uses(self) -> Collection[Usage]: """Return a set of uses of the value. The set contains tuples of ``(Node, index)`` where the index is the index of the input of the node. For example, if ``node.inputs[1] == value``, then the use is ``(node, 1)``. """ - return self._uses.keys() + # Create a tuple for the collection so that iteration on will will not + # be affected when the usage changes during graph mutation. + # This adds a small overhead but is better a user experience than + # having users call tuple(). + return tuple(self._uses) def _add_usage(self, use: Node, index: int) -> None: """Add a usage of this value. This is an internal method. It should only be called by the Node class. """ - self._uses[(use, index)] = None + self._uses[Usage(use, index)] = None def _remove_usage(self, use: Node, index: int) -> None: """Remove a node from the uses of this value. This is an internal method. It should only be called by the Node class. """ - self._uses.pop((use, index)) + self._uses.pop(Usage(use, index)) @property def name(self) -> str | None: diff --git a/onnxscript/ir/_core_test.py b/onnxscript/ir/_core_test.py index 8662a8c01b..9b6cc94f6f 100644 --- a/onnxscript/ir/_core_test.py +++ b/onnxscript/ir/_core_test.py @@ -717,6 +717,13 @@ def test_is_dynamic_on_empty_shape(self): class ValueTest(unittest.TestCase): + def setUp(self) -> None: + self.v0 = _core.Value(name="v0") + self.v1 = _core.Value(name="v1") + self.node = _core.Node( + "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=2 + ) + def test_initialize(self): _ = _core.Value() @@ -732,14 +739,30 @@ def test_meta(self): value.metadata_props["test"] = "any string" self.assertEqual(value.metadata_props["test"], "any string") + def test_producer(self): + self.assertEqual(self.v0.producer(), None) + self.assertEqual(self.v1.producer(), None) + self.assertEqual(self.node.outputs[0].producer(), self.node) + self.assertEqual(self.node.outputs[1].producer(), self.node) + + def test_consumers(self): + self.assertEqual(self.v0.consumers(), (self.node,)) + self.assertEqual(self.v1.consumers(), (self.node,)) + self.assertEqual(self.node.outputs[0].consumers(), ()) + self.assertEqual(self.node.outputs[1].consumers(), ()) + # TODO(justinchuby): Test all methods class NodeTest(unittest.TestCase): def setUp(self) -> None: - self.v0 = _core.Value() - self.v1 = _core.Value() - self.node = _core.Node("test", "TestOp", inputs=(self.v0, self.v1), num_outputs=3) + self.v0 = _core.Value(name="v0") + self.v1 = _core.Value(name="v1") + self.node = _core.Node( + "test", "TestOp", inputs=(self.v0, self.v1, self.v1), num_outputs=3 + ) + self.node_a = _core.Node("test", "TestOpA", inputs=[self.node.outputs[0]]) + self.node_b = _core.Node("test", "TestOpB", inputs=self.node.outputs) def test_it_is_hashable(self): self.assertIsInstance(hash(self.node), int) @@ -748,7 +771,7 @@ def test_it_is_hashable(self): def test_init_with_values(self): self.assertEqual(self.node.domain, "test") self.assertEqual(self.node.op_type, "TestOp") - self.assertEqual(self.node.inputs, (self.v0, self.v1)) + self.assertEqual(self.node.inputs, (self.v0, self.v1, self.v1)) self.assertEqual(len(self.node.outputs), 3) self.assertEqual(self.node.attributes, {}) @@ -807,6 +830,23 @@ def test_it_is_added_to_a_graph_if_specified(self): ) self.assertIn(self.node, graph) + def test_predecessors(self): + self.assertEqual(self.node.predecessors(), ()) + self.assertEqual(self.node_a.predecessors(), (self.node,)) + self.assertEqual(self.node_b.predecessors(), (self.node,)) + + def test_predecessors_are_unique(self): + # node_b has three inputs from node, but only one predecessor + self.assertEqual(self.node_b.predecessors(), (self.node,)) + + def test_successors(self): + self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) + self.assertEqual(self.node_a.successors(), ()) + self.assertEqual(self.node_b.successors(), ()) + + def test_successors_are_unique(self): + self.assertEqual(self.node.successors(), (self.node_a, self.node_b)) + # TODO(justinchuby): Test all methods