From db475d90aed1e0ce5ad9152f169c4f9886fb3afc Mon Sep 17 00:00:00 2001 From: Johansmm Date: Tue, 13 May 2025 23:41:35 +0200 Subject: [PATCH 1/4] [IR] extract update_graph_outputs in a helper (#2294) --- onnxscript/ir/_convenience/__init__.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index 839c5d330b..87a64bf5e4 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -336,6 +336,18 @@ def create_value_mapping(graph: _core.Graph) -> dict[str, _core.Value]: return values +def _update_graph_or_function_outputs( + graph_or_function: _core.Graph | _core.Function, + old_values: Sequence[_core.Value], + new_values: Sequence[_core.Value], +): + """Update graph/function outputs""" + replacement_mapping = dict(zip(old_values, new_values)) + for idx, graph_or_function_output in enumerate(graph_or_function.outputs): + if graph_or_function_output in replacement_mapping: + graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + + def replace_nodes_and_values( graph_or_function: _core.Graph | _core.Function, /, @@ -368,10 +380,7 @@ def replace_nodes_and_values( # Reconnect the users of the deleted values to use the new values replace_all_uses_with(old_values, new_values) # Update graph/function outputs if the node generates output - replacement_mapping = dict(zip(old_values, new_values)) - for idx, graph_or_function_output in enumerate(graph_or_function.outputs): - if graph_or_function_output in replacement_mapping: - graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output] + _update_graph_or_function_outputs(graph_or_function, old_values, new_values) # insert new nodes after the index node graph_or_function.insert_after(insertion_point, new_nodes) From d24025542a6f4e548d080bbced0816f01f0291f8 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Tue, 13 May 2025 00:58:54 +0200 Subject: [PATCH 2/4] [IR] introduce insert_nodes_before_value (#2294) Convenience function to insert a set of nodes in value(s). --- onnxscript/ir/_convenience/__init__.py | 102 ++++++++++++++++ onnxscript/ir/_convenience/_init_test.py | 146 +++++++++++++++++++++++ onnxscript/ir/convenience.py | 2 + 3 files changed, 250 insertions(+) create mode 100644 onnxscript/ir/_convenience/_init_test.py diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index 87a64bf5e4..c7aac90123 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -14,6 +14,7 @@ "replace_all_uses_with", "create_value_mapping", "replace_nodes_and_values", + "insert_nodes_in_value", ] from typing import Mapping, Sequence, Union @@ -385,3 +386,104 @@ def replace_nodes_and_values( # insert new nodes after the index node graph_or_function.insert_after(insertion_point, new_nodes) graph_or_function.remove(old_nodes, safe=True) + + +def _find_inputs_outputs( + nodes: Sequence[_core.Node], +) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]: + """Find the values that are considered as inputs and outputs in a sequence of nodes""" + # Search the unique inputs/outputs in new_nodes, keeping the order. + all_inputs = dict.fromkeys(sum([node.inputs for node in nodes], ())) + all_outputs = dict.fromkeys(sum([node.outputs for node in nodes], ())) + # A value is considered as input if it is not any output. + inputs = tuple(val for val in all_inputs if val not in all_outputs) + # A value is considered as output if it is not any input. + outputs = tuple(val for val in all_outputs if val not in all_inputs) + return inputs, outputs + + +def insert_nodes_in_value( + values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node] +) -> None: + """Inserts a sequence of nodes into the provided value(s). + + This allows to insert a list of LINKED nodes (over the same context) at + a specific point in the graph. + + For example, suppose we have the following graph:: + + input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output + + We want to insert [node_M, node_N] at B value:: + + >>> from onnxscript import ir + >>> input = ir.Input("input") + >>> node_A = ir.node("op_A", [input]) + >>> B = ir.Value(name="B") + >>> node_B = ir.node("op_B", node_A.outputs, outputs=[B]) + >>> node_C = ir.node("op_C", node_B.outputs) + >>> # Create a new sequence to insert + >>> input_2 = ir.Input("input_2") + >>> node_M = ir.node("op_M", [input_2]) + >>> node_N = ir.node("op_N", node_M.outputs) + >>> # Insert nodes in B + >>> insert_nodes_before_value(node_B.outputs, [node_M, node_N]) + >>> len(node_B.outputs) + 1 + >>> node_B.outputs[0].consumers()[0].op_type + 'op_M' + >>> len(node_C.inputs) + 1 + >>> node_C.inputs[0].producer().op_type + 'op_N' + >>> node_C.inputs[0].name + 'B' + + When values is a sequence, the set of nodes must have the same number + of inputs and outputs, then they are zipped into pairs: first value is + replaced with the first input/output, and so on. + + Args: + values: The value(s) where to insert the nodes. + new_nodes: The nodes to insert in the graph. + """ + if not isinstance(values, Sequence): + values = (values,) + + # Search the unique inputs/outputs in new_nodes, keeping the order. + inputs, outputs = _find_inputs_outputs(new_nodes) + + # Sanity check. + if len(values) != len(inputs): + raise ValueError(f"The number of values and inputs ({inputs}) in new_nodes must match.") + if len(values) != len(outputs): + raise ValueError(f"The number of values and outputs ({outputs}) in new_nodes must match.") + + # Propagate relevant info. + for val, in_val, out_val in zip(values, inputs, outputs): + # Propagate relevant info from value to out_value. + # TODO(Rama): Perhaps this should be a separate utility function. + out_val.type = val.type + out_val.shape = val.shape + out_val.name = val.name + # Propagate relevant info from value to in_value. + # TODO(Rama): Perhaps this should be a separate utility function. + in_val.type = val.type + in_val.shape = val.shape + # Rename each value, following each input. + val.name = in_val.name + + # Insert the new nodes in two steps: + # 1. Reconnect the users of values to the outputs + replace_all_uses_with(values, outputs) + # 2. Reconnect the users of inputs to values + replace_all_uses_with(inputs, values) + + # Update graph if there is one: + if (graph := values[-1].graph) is not None: + # Update graph/function outputs if the node generates output + _update_graph_or_function_outputs(graph, values, outputs) + + # Insert new nodes if there is a graph + graph.extend(new_nodes) + graph.sort() diff --git a/onnxscript/ir/_convenience/_init_test.py b/onnxscript/ir/_convenience/_init_test.py new file mode 100644 index 0000000000..64fb80890c --- /dev/null +++ b/onnxscript/ir/_convenience/_init_test.py @@ -0,0 +1,146 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Unit tests for the _constructors module.""" + +import onnx + +import unittest + +from onnxscript import ir +from onnxscript.ir._convenience import insert_nodes_in_value + + +def _create_model(model_text: str) -> ir.Model: + model = onnx.parser.parse_model(model_text) + return ir.serde.deserialize_model(model) + + +class ConvenienceTest(unittest.TestCase): + def test_insert_nodes_in_value(self): + # Main graph + input = ir.Input("input") + node_A = ir.node("op_A", [input]) + node_B = ir.node("op_B", node_A.outputs, outputs=[ir.Value(name="B")]) + node_C = ir.node("op_C", node_B.outputs) + + # New sequence to insert + input_2 = ir.Input("input_2") + node_M = ir.node("op_M", [input_2]) + node_N = ir.node("op_N", node_M.outputs) + + # Insert nodes in B + insert_nodes_in_value(node_B.outputs[0], [node_M, node_N]) + self.assertEqual(len(node_B.outputs), 1) + self.assertEqual(node_B.outputs[0].consumers()[0].op_type, "op_M") + self.assertEqual(len(node_C.inputs), 1) + self.assertEqual(node_C.inputs[0].producer().op_type, "op_N") + self.assertEqual(node_C.inputs[0].name, "B") + + def test_insert_nodes_in_value_in_graph(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + + # Sequence to insert. + # Note inputs = [i1, i2] and outputs = [b.outputs[1], c.outputs[0]]. + i1, i2 = ir.Input("i1"), ir.Input("i2") + a = ir.node("op_1", [i1, i2]) + b = ir.node("op_2", [a.outputs[0], i1], num_outputs=2) + c = ir.node("op_3", [i2, b.outputs[0]]) + + # Insert nodes in SplitNode.outputs + target_node = ir_model.graph[1] + insert_nodes_in_value(target_node.outputs, [a, b, c]) + + # Check target_node outputs have been renamed + new_i1, new_i2 = target_node.outputs + self.assertEqual(new_i1.name, "i1") + self.assertEqual(new_i2.name, "i2") + + # Check i1 and i2 have new users + self.assertEqual(tuple(node.op_type for node in new_i1.consumers()), ("op_1", "op_2")) + self.assertEqual(tuple(node.op_type for node in new_i2.consumers()), ("op_1", "op_3")) + + # Check outputs have been correctly renamed as previous values + self.assertEqual(b.outputs[1].name, "a") + self.assertEqual(c.outputs[0].name, "b") + + # Check nodes have been inserted in the graph + self.assertEqual(len(ir_model.graph), 6) + + def test_insert_nodes_in_input(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_x") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph[1].inputs[0], [node]) + self.assertEqual(node.outputs[0].name, "x") + + # Check input has been renamed + self.assertEqual(ir_model.graph.inputs[0].name, "new_x") + + # Finally, check new graph is valid + proto = ir.to_proto(ir_model) + onnx.checker.check_model(proto, full_check=True) + + def test_insert_nodes_in_output(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + z = Add(x, two) + } + """ + ) + + # Sequence to insert. + x = ir.Input("new_z") + node = ir.node("Mul", [x, x]) + + # Insert nodes in graph.inputs + insert_nodes_in_value(ir_model.graph.outputs[0], [node]) + self.assertEqual(ir_model.graph[1].outputs[0].name, "new_z") + + # Check output name is preserved + self.assertEqual(ir_model.graph.outputs[0].name, "z") + + def test_value_error_for_wrong_number_of_points(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + node = ir.node("op_M", [ir.Input("new_x"), ir.Input("new_y")]) + with self.assertRaisesRegex(ValueError, "The number of values and inputs"): + insert_nodes_in_value(ir_model.graph[0].outputs, [node]) + + with self.assertRaisesRegex(ValueError, "The number of values and outputs"): + insert_nodes_in_value(ir_model.graph[1].outputs, [node]) + + +if __name__ == "__main__": + unittest.main() diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index 480ff603b0..eb9b458925 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -10,6 +10,7 @@ "replace_all_uses_with", "replace_nodes_and_values", "create_value_mapping", + "insert_nodes_in_value", ] from onnxscript.ir._convenience import ( @@ -18,6 +19,7 @@ create_value_mapping, replace_all_uses_with, replace_nodes_and_values, + insert_nodes_in_value, ) # NOTE: Do not implement any other functions in this module. From 142e47fbe14abb16b4d4ffcb56c56fa9f0dcedc8 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Tue, 13 May 2025 23:36:20 +0200 Subject: [PATCH 3/4] [IR] introduce remove_nodes (#2294) Convenience function to remove a set of nodes. --- onnxscript/ir/_convenience/__init__.py | 63 ++++++++++++++ onnxscript/ir/_convenience/_init_test.py | 104 ++++++++++++++++++++++- onnxscript/ir/convenience.py | 2 + 3 files changed, 168 insertions(+), 1 deletion(-) diff --git a/onnxscript/ir/_convenience/__init__.py b/onnxscript/ir/_convenience/__init__.py index c7aac90123..47ef65a147 100644 --- a/onnxscript/ir/_convenience/__init__.py +++ b/onnxscript/ir/_convenience/__init__.py @@ -15,6 +15,7 @@ "create_value_mapping", "replace_nodes_and_values", "insert_nodes_in_value", + "remove_nodes", ] from typing import Mapping, Sequence, Union @@ -487,3 +488,65 @@ def insert_nodes_in_value( # Insert new nodes if there is a graph graph.extend(new_nodes) graph.sort() + + +def remove_nodes(nodes: Sequence[_core.Node]) -> None: + """Remove a sequence of nodes. + + This allows to delete a list of LINKED nodes (over the same context). + + For example, suppose we have the following graph:: + + input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output + + We want to prune [node_B]:: + + >>> from onnxscript import ir + >>> input = ir.Input("input") + >>> node_A = ir.node("op_A", [input]) + >>> node_B = ir.node("op_B", node_A.outputs) + >>> node_C = ir.node("op_C", node_B.outputs) + >>> # Delete node_B + >>> remove_nodes([node_B]) + >>> len(node_A.outputs[0].consumers()) + 1 + >>> node_A.outputs[0].consumers()[0].op_type + 'op_C' + >>> len(node_C.inputs) + 1 + >>> node_C.inputs[0].producer().op_type + 'op_A' + >>> node_B.inputs + (None,) + >>> len(node_B.outputs) + 1 + >>> len(node_B.outputs[0].consumers()) + 0 + + Args: + nodes: The nodes to remove. + """ + # Search the unique inputs/outputs in new_nodes, keeping the order. + inputs, outputs = _find_inputs_outputs(nodes) + + # Sanity check. + if len(inputs) != len(outputs): + raise ValueError( + f"The number of inputs ({inputs}) and outputs ({outputs}) in nodes must match." + ) + + # Remove nodes, in several steps: + # 1. Reconnect the users of outputs with inputs + replace_all_uses_with(outputs, inputs) + # 2. Detach nodes for their inputs + for node in nodes: + for i in range(len(node.inputs)): + node.replace_input_with(i, None) + + # Update graph if there is one: + if (graph := inputs[-1].graph) is not None: + # Update graph/function outputs if the node generates output + _update_graph_or_function_outputs(graph, outputs, inputs) + + # Drop nodes from graph + graph.remove(nodes, safe=True) diff --git a/onnxscript/ir/_convenience/_init_test.py b/onnxscript/ir/_convenience/_init_test.py index 64fb80890c..0dfe97cbc0 100644 --- a/onnxscript/ir/_convenience/_init_test.py +++ b/onnxscript/ir/_convenience/_init_test.py @@ -7,7 +7,7 @@ import unittest from onnxscript import ir -from onnxscript.ir._convenience import insert_nodes_in_value +from onnxscript.ir._convenience import insert_nodes_in_value, remove_nodes def _create_model(model_text: str) -> ir.Model: @@ -141,6 +141,108 @@ def test_value_error_for_wrong_number_of_points(self): with self.assertRaisesRegex(ValueError, "The number of values and outputs"): insert_nodes_in_value(ir_model.graph[1].outputs, [node]) + def test_remove_nodes(self): + # Main graph + input = ir.Input("input") + node_A = ir.node("op_A", [input]) + node_B = ir.node("op_B", node_A.outputs) + node_C = ir.node("op_C", node_B.outputs) + + # Delete node_B + remove_nodes([node_B]) + self.assertEqual(len(node_A.outputs[0].consumers()), 1) + self.assertEqual(node_A.outputs[0].consumers()[0].op_type, "op_C") + self.assertEqual(len(node_C.inputs), 1) + self.assertEqual(node_C.inputs[0].producer().op_type, "op_A") + + self.assertEqual((len(node_B.inputs), len(node_B.outputs)), (1, 1)) + self.assertEqual(node_B.inputs, (None,)) + self.assertEqual(len(node_B.outputs[0].consumers()), 0) + + def test_remove_nodes_in_graph(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = MergeAndSplit(x, two) + z = MergeNode(a, b, two) + } + """ + ) + # Sanity check previous to delete nodes + x, two = ir_model.graph.inputs[0], ir_model.graph[0].outputs[0] + self.assertEqual(len(x.consumers()), 1) + self.assertEqual(len(two.consumers()), 2) + + # Delete 'MergeAndSplit' + target_node = ir_model.graph[1] + remove_nodes([target_node]) + + # Check 'MergeNode' has new inputs + a, b, _ = ir_model.graph[-1].inputs + self.assertEqual(a.name, "x") + self.assertEqual(b.name, "two") + + # Check x/two consumers have been updated + self.assertEqual(len(x.consumers()), 1) + self.assertEqual(len(two.consumers()), 1) + + # Check nodes have been deleted in the graph + self.assertEqual(len(ir_model.graph), 2) + + def test_remove_nodes_in_input(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + y = Sigmoid(x) + z = Mul(y, y) + } + """ + ) + # Remove the node linked to the input + remove_nodes([ir_model.graph[0]]) + self.assertEqual(len(ir_model.graph), 1) + self.assertEqual(ir_model.graph[0].op_type, "Mul") + self.assertEqual(ir_model.graph[0].inputs[0].name, "x") + self.assertEqual(ir_model.graph[0].inputs[1].name, "x") + self.assertEqual(ir_model.graph[0].outputs[0].name, "z") + + def test_remove_nodes_in_output(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + y = Mul(x, x) + z = Sigmoid(y) + } + """ + ) + # Remove the node linked to the input + remove_nodes([ir_model.graph[-1]]) + self.assertEqual(len(ir_model.graph), 1) + self.assertEqual(ir_model.graph[0].op_type, "Mul") + self.assertEqual(ir_model.graph[0].outputs[0].name, "y") + self.assertEqual(ir_model.graph.outputs[0].name, "y") + + def test_remove_nodes_error_for_wrong_number_of_inputs_and_outputs(self): + ir_model = _create_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant() + a, b = SplitNode(x) + z = MergeNode(a, b, two) + } + """ + ) + with self.assertRaisesRegex(ValueError, "The number of inputs"): + remove_nodes([ir_model.graph[0]]) + + with self.assertRaisesRegex(ValueError, "The number of inputs"): + remove_nodes([ir_model.graph[1]]) + if __name__ == "__main__": unittest.main() diff --git a/onnxscript/ir/convenience.py b/onnxscript/ir/convenience.py index eb9b458925..28468e6294 100644 --- a/onnxscript/ir/convenience.py +++ b/onnxscript/ir/convenience.py @@ -11,6 +11,7 @@ "replace_nodes_and_values", "create_value_mapping", "insert_nodes_in_value", + "remove_nodes", ] from onnxscript.ir._convenience import ( @@ -20,6 +21,7 @@ replace_all_uses_with, replace_nodes_and_values, insert_nodes_in_value, + remove_nodes, ) # NOTE: Do not implement any other functions in this module. From e89df33e728d02664ffd0e8f1460614c86c3eec5 Mon Sep 17 00:00:00 2001 From: Johansmm Date: Tue, 13 May 2025 23:45:34 +0200 Subject: [PATCH 4/4] [IR] include insert and remove nodes in doc (#2294) --- docs/ir/ir_api/ir_convenience.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/ir/ir_api/ir_convenience.md b/docs/ir/ir_api/ir_convenience.md index 77f09bfe81..ad654cd412 100644 --- a/docs/ir/ir_api/ir_convenience.md +++ b/docs/ir/ir_api/ir_convenience.md @@ -12,4 +12,6 @@ .. autofunction:: replace_all_uses_with .. autofunction:: replace_nodes_and_values .. autofunction:: create_value_mapping +.. autofunction:: insert_nodes_in_value +.. autofunction:: remove_nodes ```