Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.5.5
0.5.6
23 changes: 23 additions & 0 deletions onnxscript/optimizer/_constant_folding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,10 +1256,20 @@ def replace_node(

# Record the names of the values that has contributed to the replacement
_record_contributing_values(node, replacement)

# Obtain the list of non-None inputs to the node before it is cleared by
# replace_nodes_and_values to check for unused initializers later.
node_inputs = [v for v in node.inputs if v is not None]

ir.convenience.replace_nodes_and_values(
root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs
)

if isinstance(root, ir.Graph):
# The old node should now be detached from the graph
assert node.graph is None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need this assertion? This might be true in the current implementation ... but just wondering if we might ever want to replace [oldnode] by [newnode1, oldnode, newnode2], for example ... not sure of replace_nodes_and_values supports such a scenario ... but the logic here seems to be robust even in that case. But the assert would fail in that case. Anyway ... just a minor suggestion.

Also: in principle, this logic could go into replace_nodes_and_values, can't it?

Copy link
Collaborator Author

@justinchuby justinchuby Oct 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this logic could go into replace_nodes_and_values, can't it?

Good point - I agree. But also replace_nodes_and_values is more general. So checking value users may be more expensive there.

replace [oldnode] by [newnode1, oldnode, newnode2]

Not sure? replace_nodes_and_values will disconnect the old nodes (which breaks all the value connections) so I don't think it will work?

_clear_unused_initializers(node_inputs)

self._modified = True

# TODO: what about new opset_imports?
Expand Down Expand Up @@ -1336,6 +1346,19 @@ def _sym_value_can_replace_graph_output(
return True


def _clear_unused_initializers(values: Sequence[ir.Value]) -> None:
# Detach all inputs to the node, then check for unused initializers
for value in values:
if value is None or not value.is_initializer():
continue

if not value.uses():
assert value.is_initializer()
assert value.graph is not None
assert value.name is not None
value.graph.initializers.pop(value.name)


@dataclasses.dataclass
class FoldConstantsResult(ir.passes.PassResult):
symbolic_value_map: dict[ir.Value, SymbolicValue]
Expand Down
35 changes: 29 additions & 6 deletions onnxscript/optimizer/_constant_folding_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,20 @@


class FoldConstantsTest(unittest.TestCase):
def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs):
def _fold(
self,
model: ir.Model | str,
onnx_shape_inference: bool = False,
dce: bool = True,
**kwargs,
):
if isinstance(model, str):
model = ir.from_onnx_text(model)
_constant_folding.fold_constants(
model, onnx_shape_inference=onnx_shape_inference, **kwargs
)
optimizer.remove_unused_nodes(model)
if dce:
optimizer.remove_unused_nodes(model)
# Ensure the model is valid after optimization
onnx.checker.check_model(ir.serde.serialize_model(model))
return model
Expand Down Expand Up @@ -50,9 +57,16 @@ def test_fold_cast_like(self):
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
optimized = self._fold(model, dce=False)
self.assertIn("four", optimized.graph.initializers)
np.testing.assert_equal(
optimized.graph.initializers["four"].const_value, np.array(4.0)
)
# Intermediates should be removed
self.assertNotIn("two_float", optimized.graph.initializers)

optimized = self._fold(model, dce=True)
self.assertEqual(len(optimized.graph), 1)

def test_fold_shape(self):
model = """
Expand All @@ -66,9 +80,18 @@ def test_fold_shape(self):
}
"""

optimized = self._fold(model)
self.assertEqual(len(optimized.graph), 1)
optimized = self._fold(model, dce=False)
self.assertIn("four", optimized.graph.initializers)
np.testing.assert_equal(
optimized.graph.initializers["four"].const_value, np.array(4.0)
)
# Intermediates should be removed
self.assertNotIn("two_float", optimized.graph.initializers)
self.assertNotIn("rank", optimized.graph.initializers)
self.assertNotIn("shape", optimized.graph.initializers)

optimized = self._fold(model, dce=True)
self.assertEqual(len(optimized.graph), 1)

def test_fold_shape_slice(self):
model = """
Expand Down
Loading