diff --git a/VERSION b/VERSION index d1d899fa33..b49b25336d 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.5.5 +0.5.6 diff --git a/onnxscript/optimizer/_constant_folding.py b/onnxscript/optimizer/_constant_folding.py index e1bff26791..dfb072417a 100644 --- a/onnxscript/optimizer/_constant_folding.py +++ b/onnxscript/optimizer/_constant_folding.py @@ -1256,10 +1256,20 @@ def replace_node( # Record the names of the values that has contributed to the replacement _record_contributing_values(node, replacement) + + # Obtain the list of non-None inputs to the node before it is cleared by + # replace_nodes_and_values to check for unused initializers later. + node_inputs = [v for v in node.inputs if v is not None] + ir.convenience.replace_nodes_and_values( root, node, [node], replacement.new_nodes, node.outputs, replacement.new_outputs ) + if isinstance(root, ir.Graph): + # The old node should now be detached from the graph + assert node.graph is None + _clear_unused_initializers(node_inputs) + self._modified = True # TODO: what about new opset_imports? @@ -1336,6 +1346,19 @@ def _sym_value_can_replace_graph_output( return True +def _clear_unused_initializers(values: Sequence[ir.Value]) -> None: + # Detach all inputs to the node, then check for unused initializers + for value in values: + if value is None or not value.is_initializer(): + continue + + if not value.uses(): + assert value.is_initializer() + assert value.graph is not None + assert value.name is not None + value.graph.initializers.pop(value.name) + + @dataclasses.dataclass class FoldConstantsResult(ir.passes.PassResult): symbolic_value_map: dict[ir.Value, SymbolicValue] diff --git a/onnxscript/optimizer/_constant_folding_test.py b/onnxscript/optimizer/_constant_folding_test.py index d9395e811c..96a143f81a 100644 --- a/onnxscript/optimizer/_constant_folding_test.py +++ b/onnxscript/optimizer/_constant_folding_test.py @@ -14,13 +14,20 @@ class FoldConstantsTest(unittest.TestCase): - def _fold(self, model: ir.Model | str, onnx_shape_inference=False, **kwargs): + def _fold( + self, + model: ir.Model | str, + onnx_shape_inference: bool = False, + dce: bool = True, + **kwargs, + ): if isinstance(model, str): model = ir.from_onnx_text(model) _constant_folding.fold_constants( model, onnx_shape_inference=onnx_shape_inference, **kwargs ) - optimizer.remove_unused_nodes(model) + if dce: + optimizer.remove_unused_nodes(model) # Ensure the model is valid after optimization onnx.checker.check_model(ir.serde.serialize_model(model)) return model @@ -50,9 +57,16 @@ def test_fold_cast_like(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 1) + optimized = self._fold(model, dce=False) self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape(self): model = """ @@ -66,9 +80,18 @@ def test_fold_shape(self): } """ - optimized = self._fold(model) - self.assertEqual(len(optimized.graph), 1) + optimized = self._fold(model, dce=False) self.assertIn("four", optimized.graph.initializers) + np.testing.assert_equal( + optimized.graph.initializers["four"].const_value, np.array(4.0) + ) + # Intermediates should be removed + self.assertNotIn("two_float", optimized.graph.initializers) + self.assertNotIn("rank", optimized.graph.initializers) + self.assertNotIn("shape", optimized.graph.initializers) + + optimized = self._fold(model, dce=True) + self.assertEqual(len(optimized.graph), 1) def test_fold_shape_slice(self): model = """