Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f6563ad
draft tests
titaiwangms May 13, 2025
2f55000
add more tests
titaiwangms May 13, 2025
d567ba1
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 13, 2025
d1082d1
Update onnxscript/ir/passes/common/common_subexpression_elimination.py
titaiwangms May 13, 2025
ea873e4
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 16, 2025
017ef27
inplace
titaiwangms May 16, 2025
2a370e4
add recursive function but one test is still faling
titaiwangms May 16, 2025
d490072
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
706b86a
revert subgraph cse support
titaiwangms May 27, 2025
dcbc08d
add another test for subgraph
titaiwangms May 27, 2025
55d32c7
add the pass to optimization
titaiwangms May 27, 2025
c5cab5b
make repeated contained attributes hashable
titaiwangms May 27, 2025
be2c008
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 27, 2025
da05efb
delete previous_node and only delete the node
titaiwangms May 28, 2025
ce2bc54
Merge branch 'main' into titaiwang/cse_pass
titaiwangms May 29, 2025
1d4fd53
create and use a stateless function
titaiwangms May 29, 2025
5cfd94e
keep the names of graph output
titaiwangms May 29, 2025
44f6042
address reviews
titaiwangms May 29, 2025
ab212d6
resolve conflict
titaiwangms May 30, 2025
9c2d134
revert
titaiwangms May 30, 2025
6a43bfb
fix lint
titaiwangms May 30, 2025
3b1b19f
separate import common_subexpression_elimination
titaiwangms May 30, 2025
9fd8948
remove cse from optimizer
titaiwangms May 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions onnxscript/ir/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2798,7 +2798,7 @@ def __init__(
model_version: int | None = None,
doc_string: str | None = None,
functions: Sequence[Function] = (),
meta_data_props: dict[str, str] | None = None,
metadata_props: dict[str, str] | None = None,
) -> None:
self.graph: Graph = graph
self.ir_version = ir_version
Expand All @@ -2809,7 +2809,7 @@ def __init__(
self.doc_string = doc_string
self._functions = {func.identifier(): func for func in functions}
self._metadata: _metadata.MetadataStore | None = None
self._metadata_props: dict[str, str] | None = meta_data_props
self._metadata_props: dict[str, str] | None = metadata_props

@property
def functions(self) -> dict[_protocols.OperatorIdentifier, Function]:
Expand Down
4 changes: 4 additions & 0 deletions onnxscript/ir/passes/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"AddInitializersToInputsPass",
"CheckerPass",
"ClearMetadataAndDocStringPass",
"CommonSubexpressionEliminationPass",
"InlinePass",
"LiftConstantsToInitializersPass",
"LiftSubgraphInitializersToMainGraphPass",
Expand All @@ -19,6 +20,9 @@
from onnxscript.ir.passes.common.clear_metadata_and_docstring import (
ClearMetadataAndDocStringPass,
)
from onnxscript.ir.passes.common.common_subexpression_elimination import (
CommonSubexpressionEliminationPass,
)
from onnxscript.ir.passes.common.constant_manipulation import (
AddInitializersToInputsPass,
LiftConstantsToInitializersPass,
Expand Down
107 changes: 107 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Eliminate common subexpression in ONNX graphs."""

from __future__ import annotations

__all__ = [
"CommonSubexpressionEliminationPass",
]

import logging

from onnxscript import ir

logger = logging.getLogger(__name__)


class CommonSubexpressionEliminationPass(ir.passes.InPlacePass):
"""Eliminate common subexpression in ONNX graphs."""

def call(self, model: ir.Model) -> ir.passes.PassResult:
"""Return the same ir.Model but with CSE applied to the graph."""
modified = False
graph = model.graph

modified = _eliminate_common_subexpression(graph, modified)

return ir.passes.PassResult(
model,
modified=modified,
)


def _eliminate_common_subexpression(graph: ir.Graph, modified: bool) -> bool:
"""Eliminate common subexpression in ONNX graphs."""

# node to node identifier, length of outputs, inputs, and attributes
existing_node_info_to_the_node: dict[
tuple[
ir.OperatorIdentifier,
int, # len(outputs)
tuple[int, ...], # input ids
tuple[tuple[str, object], ...], # attributes
],
ir.Node,
] = {}
previous_node = None

for node in graph:
# Skip control flow ops like Loop and If.
control_flow_op: bool = False
# Use equality to check if the node is a common subexpression.
attributes = {}
for k, v in node.attributes.items():
# TODO(exporter team): CSE subgraphs.
# NOTE: control flow ops like Loop and If won't be CSEd
# because attribute: graph won't match.
if isinstance(v, ir.Graph):
control_flow_op = True
logger.debug("Skipping control flow op %s", node)

Check warning on line 60 in onnxscript/ir/passes/common/common_subexpression_elimination.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination.py#L59-L60

Added lines #L59 - L60 were not covered by tests
# The attribute value could be directly taken from the original
# protobuf, so we need to make a copy of it.
value = v.value
if v.type in (
ir.AttributeType.INTS,
ir.AttributeType.FLOATS,
ir.AttributeType.STRINGS,
):
# For INT, FLOAT and STRING attributes, we convert them to tuples
# to ensure they are hashable.
value = tuple(value)
attributes[k] = value

if control_flow_op:
# If the node is a control flow op, we skip it.
previous_node = node
continue

Check warning on line 77 in onnxscript/ir/passes/common/common_subexpression_elimination.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination.py#L76-L77

Added lines #L76 - L77 were not covered by tests

node_info = (
node.op_identifier(),
len(node.outputs),
tuple(id(input) for input in node.inputs),
tuple(sorted(attributes.items())),
)

# Check if the node is a common subexpression.
if node_info in existing_node_info_to_the_node:
# If it is, this node has an existing node with the same
# operator, number of outputs, inputs, and attributes.
# We replace the node with the existing node.
modified = True
existing_node = existing_node_info_to_the_node[node_info]
ir.convenience.replace_nodes_and_values(
graph,
insertion_point=previous_node or node,
old_nodes=[node],
new_nodes=[existing_node],
old_values=node.outputs,
new_values=existing_node.outputs,
)
previous_node = existing_node
logger.debug("Reusing node %s", existing_node)
else:
# If it is not, add to the mapping.
existing_node_info_to_the_node[node_info] = node
previous_node = node
return modified
245 changes: 245 additions & 0 deletions onnxscript/ir/passes/common/common_subexpression_elimination_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from __future__ import annotations

import unittest

import numpy as np
import onnxruntime as ort

from onnxscript import FLOAT, ir, script
from onnxscript import opset18 as op
from onnxscript.ir.passes.common import common_subexpression_elimination


class TestCommonSubexpressionEliminationPass(unittest.TestCase):
def check_graph(self, model: ir.Model, inputs: list[ir.Value], delta_nodes: list[int]):
"""Check if the model applied the CSE pass correctly.
Args:
model (ir.Model): The model to check.
inputs (list[ir.Value]): The inputs to the model.
delta_nodes (list[int]): The expected change in the number of nodes in the model.
The length of this list should match the number of graphs
in the model. (to support subgraphs in the future)
Raises:
AssertionError: If the model does not match the expected number of nodes or outputs.
"""
assert len(list(model.graphs())) == len(delta_nodes)
# Log all results from the original model.
# 1. model graph node counts
original_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])
model_proto = ir.serde.serialize_model(model)

# 2. model outputs
ort_inputs = {
k.name: np.random.rand(*v.shape).astype(np.float32)
for k, v in zip(model.graph.inputs, inputs)
}
original_model_session = ort.InferenceSession(model_proto.SerializeToString())
original_model_results = original_model_session.run(None, ort_inputs)

result = common_subexpression_elimination.CommonSubexpressionEliminationPass()(model)

result_graphs_node_count = np.array([graph.num_nodes() for graph in model.graphs()])
# Check if the number of nodes in the model is correct
self.assertTrue(
np.array_equal(
original_graphs_node_count, np.add(result_graphs_node_count, delta_nodes)
)
)
self.assertEqual(
result.modified, any(original_graphs_node_count > result_graphs_node_count)
)

result_proto = ir.serde.serialize_model(result.model)
result_session = ort.InferenceSession(result_proto.SerializeToString())
result_results = result_session.run(None, ort_inputs)

# Check if the models produce the same output
# with the same inputs
for idx, original_model_result in enumerate(original_model_results):
np.testing.assert_allclose(
original_model_result, result_results[idx], rtol=1e-5, atol=1e-5
)

def test_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.
def test_simple(self):
def f(x):
a = x.cos()
b = x.cos()
c = a + a
d = b + b
return c + d
x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.Cos(x)
b = op.Cos(x)
c = a + a
d = b + b
return c + d

Check warning on line 88 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L84-L88

Added lines #L84 - L88 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)

self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[2])

def test_more_operations_in_two_branches_with_the_same_operations_is_csed(self):
"""Test if two branches with the same operations are CSEd.
def test_simple(self):
def f(x):
a = x.cos().sin()
b = x.cos().sin()
c = a + a
d = b + b
return c + d
x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[1]) -> FLOAT[1]:
a = op.Sin(op.Cos(x))
b = op.Sin(op.Cos(x))
c = a + a
d = b + b
return c + d

Check warning on line 115 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L111-L115

Added lines #L111 - L115 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(1)], delta_nodes=[3])

def test_multiple_same_ops_with_attributes_are_csed(self):
"""Test if multiple same ops are CSEd.
def f(x):
a = x.sum()
b = x.sum()
c = x.sum()
d = x.sum()
return a + b + c + d
x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.ReduceSum(x, keepdims=False)
b = op.ReduceSum(x, keepdims=False)
c = op.ReduceSum(x, keepdims=False)
d = op.ReduceSum(x, keepdims=False)
return a + b + c + d

Check warning on line 141 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L137-L141

Added lines #L137 - L141 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[3])

def test_the_ops_with_the_same_inputs_but_different_attributes_are_not_csed(self):
"""Test if the ops with the same inputs but different attributes are not CSEd.
def f(x):
a = x.sum()
b = x.sum(keepdims=True)
c = x.sum()
d = x.sum(keepdims=True)
return a + b + c + d
x = torch.randn(2, 2)
"""

@script()
def test_model(x: FLOAT[2, 2]) -> FLOAT[2, 2]:
a = op.ReduceSum(x, keepdims=False)
b = op.ReduceSum(x, keepdims=True)
return a + b

Check warning on line 165 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L163-L165

Added lines #L163 - L165 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(model, [np.random.rand(2, 2)], delta_nodes=[0])

def test_control_flow_if_ops_are_not_csed_as_graph_attr_is_not_matched(self):
"""Test if control flow ops are not CSEd.
def f(a, b):
rank = a.rank()
if rank == 2:
result1 = a - b
else:
result1 = a + b
if rank == 2:
result2 = a - b
else:
result2 = a + b
return result1 + result2
x = torch.randn(2, 2)
"""

@script()
def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]:
rank = op.Size(op.Shape(a))

Check warning on line 192 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L192

Added line #L192 was not covered by tests
if rank == 2:
result1 = a - b

Check warning on line 194 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L194

Added line #L194 was not covered by tests
else:
result1 = a + b

Check warning on line 196 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L196

Added line #L196 was not covered by tests
if rank == 2:
result2 = a - b

Check warning on line 198 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L198

Added line #L198 was not covered by tests
else:
result2 = a + b
return result1 + result2

Check warning on line 201 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L200-L201

Added lines #L200 - L201 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[0, 0, 0, 0, 0]
)

def test_the_nodes_following_control_flow_ops_are_csed(self):
"""Test if the nodes following control flow ops are CSEd.
def f(a, b):
rank = a.rank()
if rank == 2:
x = a - b
else:
x = a + b
a = x.cos().sin()
b = x.cos().sin()
c = a + a
d = b + b
return c + d
x = torch.randn(2, 2)
"""

@script()
def test_model(a: FLOAT[2, 2], b: FLOAT[2, 2]) -> FLOAT[2, 2]:
rank = op.Size(op.Shape(a))

Check warning on line 230 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L230

Added line #L230 was not covered by tests
if rank == 2:
x = a - b

Check warning on line 232 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L232

Added line #L232 was not covered by tests
else:
x = a + b
a = op.Sin(op.Cos(x))
b = op.Sin(op.Cos(x))
c = a + a
d = b + b
return c + d

Check warning on line 239 in onnxscript/ir/passes/common/common_subexpression_elimination_test.py

View check run for this annotation

Codecov / codecov/patch

onnxscript/ir/passes/common/common_subexpression_elimination_test.py#L234-L239

Added lines #L234 - L239 were not covered by tests

model_proto = test_model.to_model_proto()
model = ir.serde.deserialize_model(model_proto)
self.check_graph(
model, [np.random.rand(2, 2), np.random.rand(2, 2)], delta_nodes=[3, 0, 0]
)
2 changes: 1 addition & 1 deletion onnxscript/ir/serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,7 +512,7 @@ def deserialize_model(proto: onnx.ModelProto) -> _core.Model:
model_version=_get_field(proto, "model_version"),
doc_string=_get_field(proto, "doc_string"),
functions=functions,
meta_data_props=deserialize_metadata_props(proto.metadata_props),
metadata_props=deserialize_metadata_props(proto.metadata_props),
)

# Handle experimental value info for functions created by the dynamo exporter in IR version 9
Expand Down
Loading
Loading