From 25d43d87d0be48601162f012e15181cf3f633c74 Mon Sep 17 00:00:00 2001 From: Ti-Tai Wang Date: Tue, 3 Jun 2025 22:29:29 +0000 Subject: [PATCH] enable cse in optimizer --- onnxscript/ir/passes/common/__init__.py | 5 +- .../common_subexpression_elimination.py | 153 --------- .../common_subexpression_elimination_test.py | 303 ------------------ onnxscript/optimizer/_optimizer.py | 1 + 4 files changed, 2 insertions(+), 460 deletions(-) delete mode 100644 onnxscript/ir/passes/common/common_subexpression_elimination.py delete mode 100644 onnxscript/ir/passes/common/common_subexpression_elimination_test.py diff --git a/onnxscript/ir/passes/common/__init__.py b/onnxscript/ir/passes/common/__init__.py index 3f6f55ee1d..5a5ddbe52f 100644 --- a/onnxscript/ir/passes/common/__init__.py +++ b/onnxscript/ir/passes/common/__init__.py @@ -21,6 +21,7 @@ AddInitializersToInputsPass, CheckerPass, ClearMetadataAndDocStringPass, + CommonSubexpressionEliminationPass, InlinePass, LiftConstantsToInitializersPass, LiftSubgraphInitializersToMainGraphPass, @@ -31,7 +32,3 @@ ShapeInferencePass, TopologicalSortPass, ) - -from onnxscript.ir.passes.common.common_subexpression_elimination import ( - CommonSubexpressionEliminationPass, -) diff --git a/onnxscript/ir/passes/common/common_subexpression_elimination.py b/onnxscript/ir/passes/common/common_subexpression_elimination.py deleted file mode 100644 index 4fce1250a0..0000000000 --- a/onnxscript/ir/passes/common/common_subexpression_elimination.py +++ /dev/null @@ -1,153 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -"""Eliminate common subexpression in ONNX graphs.""" - -from __future__ import annotations - -__all__ = [ - "CommonSubexpressionEliminationPass", -] - -import logging -from typing import Sequence - -from onnxscript import ir - -logger = logging.getLogger(__name__) - - -class CommonSubexpressionEliminationPass(ir.passes.InPlacePass): - """Eliminate common subexpression in ONNX graphs.""" - - def call(self, model: ir.Model) -> ir.passes.PassResult: - """Return the same ir.Model but with CSE applied to the graph.""" - modified = False - graph = model.graph - - modified = _eliminate_common_subexpression(graph, modified) - - return ir.passes.PassResult( - model, - modified=modified, - ) - - -def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool: - """Eliminate common subexpression in ONNX graphs.""" - - # node to node identifier, length of outputs, inputs, and attributes - existing_node_info_to_the_node: dict[ - tuple[ - ir.OperatorIdentifier, - int, # len(outputs) - tuple[int, ...], # input ids - tuple[tuple[str, object], ...], # attributes - ], - ir.Node, - ] = {} - - for node in graph: - # Skip control flow ops like Loop and If. - control_flow_op: bool = False - # Use equality to check if the node is a common subexpression. - attributes = {} - for k, v in node.attributes.items(): - # TODO(exporter team): CSE subgraphs. - # NOTE: control flow ops like Loop and If won't be CSEd - # because attribute: graph won't match. - if v.type in (ir.AttributeType.GRAPH, ir.AttributeType.GRAPHS): - control_flow_op = True - logger.debug("Skipping control flow op %s", node) - # The attribute value could be directly taken from the original - # protobuf, so we need to make a copy of it. - value = v.value - if v.type in ( - ir.AttributeType.INTS, - ir.AttributeType.FLOATS, - ir.AttributeType.STRINGS, - ): - # For INT, FLOAT and STRING attributes, we convert them to tuples - # to ensure they are hashable. - value = tuple(value) - attributes[k] = value - - if control_flow_op: - # If the node is a control flow op, we skip it. - continue - - node_info = ( - node.op_identifier(), - len(node.outputs), - tuple(id(input) for input in node.inputs), - tuple(sorted(attributes.items())), - ) - # Check if the node is a common subexpression. - if node_info in existing_node_info_to_the_node: - # If it is, this node has an existing node with the same - # operator, number of outputs, inputs, and attributes. - # We replace the node with the existing node. - modified = True - existing_node = existing_node_info_to_the_node[node_info] - _remove_node_and_replace_values( - graph, - remove_node=node, - remove_values=node.outputs, - new_values=existing_node.outputs, - ) - logger.debug("Reusing node %s", existing_node) - else: - # If it is not, add to the mapping. - existing_node_info_to_the_node[node_info] = node - return modified - - -def _remove_node_and_replace_values( - graph: ir.Graph, - /, - remove_node: ir.Node, - remove_values: Sequence[ir.Value], - new_values: Sequence[ir.Value], -) -> None: - """Replaces nodes and values in the graph or function. - - Args: - graph: The graph to replace nodes and values in. - remove_node: The node to remove. - remove_values: The values to replace. - new_values: The values to replace with. - """ - # Reconnect the users of the deleted values to use the new values - ir.convenience.replace_all_uses_with(remove_values, new_values) - # Update graph/function outputs if the node generates output - if any(remove_value.is_graph_output() for remove_value in remove_values): - replacement_mapping = dict(zip(remove_values, new_values)) - for idx, graph_output in enumerate(graph.outputs): - if graph_output in replacement_mapping: - new_value = replacement_mapping[graph_output] - if new_value.is_graph_output(): - # If the new value is also a graph output, we need to - # create a Identity node to preserve the remove_value. - identity_node = ir.node( - "Identity", - inputs=[new_value], - outputs=[ - ir.Value( - name=graph_output.name, - type=graph_output.type, - shape=graph_output.shape, - ) - ], - ) - # reuse the name of the graph output - graph.outputs[idx] = identity_node.outputs[0] - graph.insert_before( - remove_node, - identity_node, - ) - else: - # if new_value is not graph output, we just - # update it to use old_value name. - new_value.name = graph_output.name - graph.outputs[idx] = new_value - - graph.remove(remove_node, safe=True) diff --git a/onnxscript/ir/passes/common/common_subexpression_elimination_test.py b/onnxscript/ir/passes/common/common_subexpression_elimination_test.py deleted file mode 100644 index 461af36fc8..0000000000 --- a/onnxscript/ir/passes/common/common_subexpression_elimination_test.py +++ /dev/null @@ -1,303 +0,0 @@ -# Copyright (c) Microsoft Corporation. -# Licensed under the MIT License. -from __future__ import annotations - -import unittest - -import numpy as np -import onnxruntime as ort - -from onnxscript import FLOAT, ir, script -from onnxscript import opset18 as op -from onnxscript.ir.passes.common import common_subexpression_elimination - - -class TestCommonSubexpressionEliminationPass(unittest.TestCase): - def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]): - """Check if the model applied the CSE pass correctly. - - Args: - model: The model to check. - inputs: The inputs to the model. - delta_nodes: The expected change in the number of nodes in the model. - The length of this list should match the number of graphs - in the model. (to support subgraphs in the future) - - Raises: - AssertionError: If the model does not match the expected number of nodes or outputs. - - """ - assert len(list(model.graphs())) == len(delta_nodes) - # Log all results from the original model. - # 1. model graph node counts - original_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()]) - model_proto = ir.serde.serialize_model(model) - - # 2. model outputs - ort_inputs = { - k.name: np.random.rand(*v.shape).astype(np.float32) - for k, v in zip(model.graph.inputs, inputs) - } - original_model_session = ort.InferenceSession(model_proto.SerializeToString()) - original_model_results = original_model_session.run(None, ort_inputs) - - result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) - - result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()]) - # Check if the number of nodes in the model is correct - self.assertTrue( - np.array_equal( - original_graphs_node_count, np.add(result_graphs_node_count, delta_nodes) - ) - ) - self.assertEqual( - result.modified, any(original_graphs_node_count > result_graphs_node_count) - ) - - result_proto = ir.serde.serialize_model(result.model) - result_session = ort.InferenceSession(result_proto.SerializeToString()) - result_results = result_session.run(None, ort_inputs) - - # Check if the models produce the same output - # with the same inputs - for idx, original_model_result in enumerate(original_model_results): - np.testing.assert_allclose( - original_model_result, result_results[idx], rtol=1e-5, atol=1e-5 - ) - - def test_duplicate_operations_are_csed(self): - """Test if the same operations are CSEd. - - def test_simple(self): - def f(x): - a = x.cos() - b = x.cos() - c = a + a - d = b + b - return c + d - - x = torch.randn(2, 2) - """ - - @script() - def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: - a = op.Cos(x) - b = op.Cos(x) - c = a + a - d = b + b - return c + d - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - - self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2]) - - def test_more_operations_in_duplicated_operations_is_csed(self): - """Test if the same operations are CSEd. - - def test_simple(self): - def f(x): - a = x.cos().sin() - b = x.cos().sin() - c = a + a - d = b + b - return c + d - - x = torch.randn(2, 2) - """ - - @script() - def test_model(x: FLOAT[1]) -> FLOAT[1]: - a = op.Sin(op.Cos(x)) - b = op.Sin(op.Cos(x)) - c = a + a - d = b + b - return c + d - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - self.check_graph(model, [np.random.rand(1)], delta_nodes=[3]) - - def test_multiple_same_ops_with_attributes_are_csed(self): - """Test if multiple same ops are CSEd. - - def f(x): - a = x.sum() - b = x.sum() - c = x.sum() - d = x.sum() - return a + b + c + d - - x = torch.randn(2, 2) - - """ - - @script() - def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: - a = op.ReduceSum(x, keepdims=False) - b = op.ReduceSum(x, keepdims=False) - c = op.ReduceSum(x, keepdims=False) - d = op.ReduceSum(x, keepdims=False) - return a + b + c + d - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[3]) - - def test_the_ops_with_the_same_inputs_but_different_attributes_are_not_csed(self): - """Test if the ops with the same inputs but different attributes are not CSEd. - - def f(x): - a = x.sum() - b = x.sum(keepdims=True) - c = x.sum() - d = x.sum(keepdims=True) - return a + b + c + d - - x = torch.randn(2, 2) - - """ - - @script() - def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]: - a = op.ReduceSum(x, keepdims=False) - b = op.ReduceSum(x, keepdims=True) - return a + b - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0]) - - def test_control_flow_if_ops_are_not_csed_as_graph_attr_is_not_matched(self): - """Test if control flow ops are not CSEd. - - def f(a, b): - rank = a.rank() - if rank == 2: - result1 = a - b - else: - result1 = a + b - if rank == 2: - result2 = a - b - else: - result2 = a + b - return result1 + result2 - - x = torch.randn(2, 2) - - """ - - @script() - def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]: - rank = op.Size(op.Shape(a)) - if rank == 2: - result1 = a - b - else: - result1 = a + b - if rank == 2: - result2 = a - b - else: - result2 = a + b - return result1 + result2 - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - self.check_graph( - model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0] - ) - - def test_the_nodes_following_control_flow_ops_are_csed(self): - """Test if the nodes following control flow ops are CSEd. - - def f(a, b): - rank = a.rank() - if rank == 2: - x = a - b - else: - x = a + b - a = x.cos().sin() - b = x.cos().sin() - c = a + a - d = b + b - return c + d - - x = torch.randn(2, 2) - - """ - - @script() - def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]: - rank = op.Size(op.Shape(a)) - if rank == 2: - x = a - b - else: - x = a + b - a = op.Sin(op.Cos(x)) - b = op.Sin(op.Cos(x)) - c = a + a - d = b + b - return c + d - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - self.check_graph( - model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0] - ) - - def test_graph_output_value_replacement_preserves_name(self): - @script() - def test_model(x: FLOAT[2, 2]) -> (FLOAT[2, 2], FLOAT[2, 2]): - a = op.Cos(x) - b = op.Cos(x) - return a + b, b - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - # Set custom output names - output_name_0 = "my_output_0" - output_name_1 = "my_output_1" - model.graph.outputs[0].name = output_name_0 - model.graph.outputs[1].name = output_name_1 - original_output_value_0 = model.graph.outputs[0] - original_output_value_1 = model.graph.outputs[1] - - # Run CSE pass - result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) - new_output_value_0 = result.model.graph.outputs[0] - new_output_value_1 = result.model.graph.outputs[1] - - # The Value objects should be replaced (different id) - self.assertIs(original_output_value_0, new_output_value_0) - self.assertIsNot(original_output_value_1, new_output_value_1) - # But the names should be preserved - self.assertEqual(new_output_value_0.name, output_name_0) - self.assertEqual(new_output_value_1.name, output_name_1) - - def test_identity_inserted_when_both_outputs_are_graph_outputs(self): - @script() - def test_model(x: FLOAT[2, 2]) -> (FLOAT[2, 2], FLOAT[2, 2]): - a = op.Cos(x) - b = op.Cos(x) - return a, b - - model_proto = test_model.to_model_proto() - model = ir.serde.deserialize_model(model_proto) - # Set custom output names - output_name_0 = "output0" - output_name_1 = "output1" - model.graph.outputs[0].name = output_name_0 - model.graph.outputs[1].name = output_name_1 - - # Run CSE pass - result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model) - new_graph = result.model.graph - - # There should be an Identity node in the graph - identity_nodes = [node for node in new_graph if node.op_type == "Identity"] - self.assertTrue( - identity_nodes, "No Identity node inserted for duplicated graph outputs." - ) - - # The outputs should still have the correct names - self.assertEqual(new_graph.outputs[0].name, output_name_0) - self.assertEqual(new_graph.outputs[1].name, output_name_1) diff --git a/onnxscript/optimizer/_optimizer.py b/onnxscript/optimizer/_optimizer.py index 40787c6e74..6044f35424 100644 --- a/onnxscript/optimizer/_optimizer.py +++ b/onnxscript/optimizer/_optimizer.py @@ -51,6 +51,7 @@ def optimize_ir( early_stop=stop_if_no_change, ), onnxscript.ir.passes.common.RemoveUnusedNodesPass(), + onnxscript.ir.passes.common.CommonSubexpressionEliminationPass(), onnxscript.ir.passes.common.LiftConstantsToInitializersPass(), onnxscript.ir.passes.common.LiftSubgraphInitializersToMainGraphPass(), ]