Skip to content
Closed
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions docs/ir/ir_api/ir_convenience.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,6 @@
.. autofunction:: replace_all_uses_with
.. autofunction:: replace_nodes_and_values
.. autofunction:: create_value_mapping
.. autofunction:: insert_nodes_in_value
.. autofunction:: remove_nodes
```
182 changes: 178 additions & 4 deletions onnxscript/ir/_convenience/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
"replace_all_uses_with",
"create_value_mapping",
"replace_nodes_and_values",
"insert_nodes_in_value",
"remove_nodes",
]

from typing import Mapping, Sequence, Union
Expand Down Expand Up @@ -336,6 +338,18 @@
return values


def _update_graph_or_function_outputs(
graph_or_function: _core.Graph | _core.Function,
old_values: Sequence[_core.Value],
new_values: Sequence[_core.Value],
):
"""Update graph/function outputs"""
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]


def replace_nodes_and_values(
graph_or_function: _core.Graph | _core.Function,
/,
Expand Down Expand Up @@ -368,11 +382,171 @@
# Reconnect the users of the deleted values to use the new values
replace_all_uses_with(old_values, new_values)
# Update graph/function outputs if the node generates output
replacement_mapping = dict(zip(old_values, new_values))
for idx, graph_or_function_output in enumerate(graph_or_function.outputs):
if graph_or_function_output in replacement_mapping:
graph_or_function.outputs[idx] = replacement_mapping[graph_or_function_output]
_update_graph_or_function_outputs(graph_or_function, old_values, new_values)

# insert new nodes after the index node
graph_or_function.insert_after(insertion_point, new_nodes)
graph_or_function.remove(old_nodes, safe=True)


def _find_inputs_outputs(
nodes: Sequence[_core.Node],
) -> tuple[Sequence[_core.Value], Sequence[_core.Value]]:
"""Find the values that are considered as inputs and outputs in a sequence of nodes"""
# Search the unique inputs/outputs in new_nodes, keeping the order.
all_inputs = dict.fromkeys(sum([node.inputs for node in nodes], ()))
all_outputs = dict.fromkeys(sum([node.outputs for node in nodes], ()))

Check notice

Code scanning / lintrunner

RUFF/C419 Note

# A value is considered as input if it is not any output.

Check notice

Code scanning / lintrunner

RUFF/C419 Note

inputs = tuple(val for val in all_inputs if val not in all_outputs)
# A value is considered as output if it is not any input.
outputs = tuple(val for val in all_outputs if val not in all_inputs)
return inputs, outputs


def insert_nodes_in_value(
values: _core.Value | Sequence[_core.Value], new_nodes: Sequence[_core.Node]
) -> None:
"""Inserts a sequence of nodes into the provided value(s).
This allows to insert a list of LINKED nodes (over the same context) at
a specific point in the graph.
For example, suppose we have the following graph::
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output
We want to insert [node_M, node_N] at B value::
>>> from onnxscript import ir
>>> input = ir.Input("input")
>>> node_A = ir.node("op_A", [input])
>>> B = ir.Value(name="B")
>>> node_B = ir.node("op_B", node_A.outputs, outputs=[B])
>>> node_C = ir.node("op_C", node_B.outputs)
>>> # Create a new sequence to insert
>>> input_2 = ir.Input("input_2")
>>> node_M = ir.node("op_M", [input_2])
>>> node_N = ir.node("op_N", node_M.outputs)
>>> # Insert nodes in B
>>> insert_nodes_before_value(node_B.outputs, [node_M, node_N])
>>> len(node_B.outputs)
1
>>> node_B.outputs[0].consumers()[0].op_type
'op_M'
>>> len(node_C.inputs)
1
>>> node_C.inputs[0].producer().op_type
'op_N'
>>> node_C.inputs[0].name
'B'
When values is a sequence, the set of nodes must have the same number
of inputs and outputs, then they are zipped into pairs: first value is
replaced with the first input/output, and so on.
Args:
values: The value(s) where to insert the nodes.
new_nodes: The nodes to insert in the graph.
"""
if not isinstance(values, Sequence):
values = (values,)

# Search the unique inputs/outputs in new_nodes, keeping the order.
inputs, outputs = _find_inputs_outputs(new_nodes)

# Sanity check.
if len(values) != len(inputs):
raise ValueError(f"The number of values and inputs ({inputs}) in new_nodes must match.")
if len(values) != len(outputs):
raise ValueError(f"The number of values and outputs ({outputs}) in new_nodes must match.")

# Propagate relevant info.
for val, in_val, out_val in zip(values, inputs, outputs):
# Propagate relevant info from value to out_value.
# TODO(Rama): Perhaps this should be a separate utility function.
out_val.type = val.type
out_val.shape = val.shape
out_val.name = val.name
# Propagate relevant info from value to in_value.
# TODO(Rama): Perhaps this should be a separate utility function.
in_val.type = val.type
in_val.shape = val.shape
# Rename each value, following each input.
val.name = in_val.name

# Insert the new nodes in two steps:
# 1. Reconnect the users of values to the outputs
replace_all_uses_with(values, outputs)
# 2. Reconnect the users of inputs to values
replace_all_uses_with(inputs, values)

# Update graph if there is one:
if (graph := values[-1].graph) is not None:
# Update graph/function outputs if the node generates output
_update_graph_or_function_outputs(graph, values, outputs)

# Insert new nodes if there is a graph
graph.extend(new_nodes)
graph.sort()


def remove_nodes(nodes: Sequence[_core.Node]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is very useful, thanks

Copy link
Collaborator

@justinchuby justinchuby May 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe call it remove_connected_nodes ? Also could you move this PR to https://github.com/onnx/ir-py now that we finished migration? (sorry about the extra effort) I recommend creating two PRs for the two functions so they can be reviewed individually

"""Remove a sequence of nodes.
This allows to delete a list of LINKED nodes (over the same context).
For example, suppose we have the following graph::
input -> A := node_A(input) -> B := node_B(A) -> C := node_C(B) -> output
We want to prune [node_B]::
>>> from onnxscript import ir
>>> input = ir.Input("input")
>>> node_A = ir.node("op_A", [input])
>>> node_B = ir.node("op_B", node_A.outputs)
>>> node_C = ir.node("op_C", node_B.outputs)
>>> # Delete node_B
>>> remove_nodes([node_B])
>>> len(node_A.outputs[0].consumers())
1
>>> node_A.outputs[0].consumers()[0].op_type
'op_C'
>>> len(node_C.inputs)
1
>>> node_C.inputs[0].producer().op_type
'op_A'
>>> node_B.inputs
(None,)
>>> len(node_B.outputs)
1
>>> len(node_B.outputs[0].consumers())
0
Args:
nodes: The nodes to remove.
"""
# Search the unique inputs/outputs in new_nodes, keeping the order.
inputs, outputs = _find_inputs_outputs(nodes)

# Sanity check.
if len(inputs) != len(outputs):
raise ValueError(
f"The number of inputs ({inputs}) and outputs ({outputs}) in nodes must match."
)

# Remove nodes, in several steps:
# 1. Reconnect the users of outputs with inputs
replace_all_uses_with(outputs, inputs)
# 2. Detach nodes for their inputs
for node in nodes:
for i in range(len(node.inputs)):
node.replace_input_with(i, None)

# Update graph if there is one:
if (graph := inputs[-1].graph) is not None:
# Update graph/function outputs if the node generates output
_update_graph_or_function_outputs(graph, outputs, inputs)

# Drop nodes from graph
graph.remove(nodes, safe=True)
Loading
Loading