Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
6cfd6fd
[Rewriter] Implement zero bias removal for Conv operations and relate…
whyvineet Sep 9, 2025
48037a8
Merge branch 'microsoft:main' into remove-optional-bias
whyvineet Sep 9, 2025
fd028ee
Merge branch 'main' of github-personal:whyvineet/onnxscript into remo…
whyvineet Sep 10, 2025
3742e7b
[Rewriter] Enhance zero bias removal for Conv, ConvTranspose, Gemm, a…
whyvineet Sep 10, 2025
8bfa65f
Refactor zero bias removal tests to use helper function and improve s…
whyvineet Sep 10, 2025
6fc3ca7
Merge branch 'main' of github-personal:whyvineet/onnxscript into remo…
whyvineet Sep 10, 2025
b93a56c
Refactor test cases for zero bias removal to improve readability and …
whyvineet Sep 11, 2025
e322625
Remove duplicate import of _fuse_batchnorm in rewriter module
whyvineet Sep 11, 2025
121360e
Refactor zero bias removal logic to streamline input handling and enh…
whyvineet Sep 11, 2025
ee7dafa
Refactor Gemm operation pattern and check method to align with zero b…
whyvineet Sep 14, 2025
8ef6c41
Enhance zero bias removal logic to filter bias parameters and preserv…
whyvineet Sep 14, 2025
2b9dda4
Refactor bias removal logic to directly use operation inputs, improvi…
whyvineet Sep 14, 2025
153b4e7
Remove redundant domain attribute from operation inputs in _RemoveZer…
whyvineet Sep 15, 2025
ce64fb7
Merge branch 'main' into remove-optional-bias
justinchuby Sep 16, 2025
a94f8b9
Merge HEAD, branch 'remove-optional-bias' of github-personal:whyvinee…
whyvineet Sep 21, 2025
86de85f
Refactor IR value creation in tests to use `ir.Value` for consistency…
whyvineet Sep 21, 2025
d62eafb
Revert "Refactor IR value creation in tests to use `ir.Value` for con…
whyvineet Sep 21, 2025
d4f73dd
Enhance attribute comparison in optimization tests to handle list vs …
whyvineet Sep 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 8 additions & 10 deletions onnxscript/rewriter/rules/common/_remove_zero_bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,15 @@ def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value:
"""Remove the bias input from the operation."""
node = out.producer()

original_inputs = list(node.inputs)
inputs_without_bias = original_inputs[:-1]
# Filter out the bias parameter and keep all other inputs
inputs = []
for param_name, param_value in _.items():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand now... There was just a misunderstanding on my part!

if param_name != "b": # 'b' is the bias parameter
inputs.append(param_value)

return op.op(
self.op_type,
inputs=inputs_without_bias,
inputs=inputs,
attributes=node.attributes,
domain=node.domain,
)
Expand Down Expand Up @@ -108,13 +111,8 @@ class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase):

op_type: ClassVar = "Gemm"

def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value:
return op.Gemm(a, b, c, _outputs=["out"])

def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult:
"""Check if the bias (c parameter) is present and is all zeros."""
del context # Unused
return self._check_bias_is_zero(c)
def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value:
return op.Gemm(x, w, b, _outputs=["out"])


# Create rule instances
Expand Down
166 changes: 166 additions & 0 deletions onnxscript/rewriter/rules/common/_remove_zero_bias_test.py
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe you can pass some attributes when testing to check that every info is correctly transferred (e.g. stride, transA...)

Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""Tests for removing zero bias from Conv and related operations."""

import unittest
from typing import Optional

import onnx_ir as ir

Expand All @@ -21,11 +22,19 @@ def _apply_rule_and_check_optimization(
expected_count: int,
target_op_type: str,
expected_inputs_after: int,
expected_attributes: Optional[dict] = None,
) -> None:
"""Helper function to test bias removal rules."""
# Make a copy for comparison
original_model = ir.from_proto(ir.to_proto(model))

# Get original attributes for comparison
original_target_node = None
for node in original_model.graph:
if node.op_type == target_op_type:
original_target_node = node
break

# Apply the rule
count = rule.apply_to_model(model)

Expand All @@ -45,6 +54,29 @@ def _apply_rule_and_check_optimization(
f"got {len(target_node.inputs)}"
)

# Check that attributes are preserved if the rule was applied
if expected_count > 0 and original_target_node is not None:
# All original attributes should be preserved
for attr_name, attr_value in original_target_node.attributes.items():
assert attr_name in target_node.attributes, f"Attribute {attr_name} was lost"
original_value = attr_value.value
new_value = target_node.attributes[attr_name].value
assert new_value == original_value, (
f"Attribute {attr_name} value changed from {original_value} to {new_value}"
)

# Check specific expected attributes if provided
if expected_attributes:
for attr_name, expected_value in expected_attributes.items():
assert attr_name in target_node.attributes, (
f"Expected attribute {attr_name} not found"
)
actual_attr = target_node.attributes[attr_name]
actual_value = actual_attr.value
assert actual_value == expected_value, (
f"Expected attribute {attr_name} to be {expected_value}, got {actual_value}"
)

# Compare outputs to ensure correctness (only for supported input types)
if expected_count > 0:
try:
Expand Down Expand Up @@ -231,6 +263,140 @@ def test_remove_zero_bias_from_qlinear_conv(self):
expected_inputs_after=8,
)

def test_remove_zero_bias_from_conv_with_attributes(self):
"""Test that zero bias is removed from Conv operations and attributes are preserved."""
# Create a Conv with zero bias and various attributes using ONNX text format
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[1, 2, 6, 6] x) => (float[1, 2, 2, 2] y)
{
weight = Constant <value = float[2, 2, 3, 3] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}>()
bias = Constant <value = float[2] {0, 0}>()
y = Conv <dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [2, 2]> (x, weight, bias)
}
"""
)

expected_attributes = {
"dilations": [1, 1],
"group": 1,
"kernel_shape": [3, 3],
"pads": [0, 0, 0, 0],
"strides": [2, 2],
}

_apply_rule_and_check_optimization(
model,
remove_zero_bias_from_conv_rule,
expected_count=1,
target_op_type="Conv",
expected_inputs_after=2,
expected_attributes=expected_attributes,
)

def test_remove_zero_bias_from_conv_transpose_with_attributes(self):
"""Test that zero bias is removed from ConvTranspose operations and attributes are preserved."""
# Create a ConvTranspose with zero bias and various attributes using ONNX text format
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[1, 2, 2, 2] x) => (float[1, 2, 6, 6] y)
{
weight = Constant <value = float[2, 2, 3, 3] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}>()
bias = Constant <value = float[2] {0, 0}>()
y = ConvTranspose <dilations = [1, 1], group = 1, kernel_shape = [3, 3], output_padding = [0, 0], pads = [0, 0, 0, 0], strides = [2, 2]> (x, weight, bias)
}
"""
)

expected_attributes = {
"dilations": [1, 1],
"group": 1,
"kernel_shape": [3, 3],
"output_padding": [0, 0],
"pads": [0, 0, 0, 0],
"strides": [2, 2],
}

_apply_rule_and_check_optimization(
model,
remove_zero_bias_from_conv_transpose_rule,
expected_count=1,
target_op_type="ConvTranspose",
expected_inputs_after=2,
expected_attributes=expected_attributes,
)

def test_remove_zero_bias_from_gemm_with_attributes(self):
"""Test that zero bias is removed from Gemm operations and attributes are preserved."""
# Create a Gemm with zero bias and various attributes using ONNX text format
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (float[2, 3] a) => (float[2, 4] y)
{
b = Constant <value = float[4, 3] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}>()
c = Constant <value = float[4] {0, 0, 0, 0}>()
y = Gemm <alpha = 2.0, beta = 1.0, transA = 0, transB = 1> (a, b, c)
}
"""
)

expected_attributes = {
"alpha": 2.0,
"beta": 1.0,
"transA": 0,
"transB": 1,
}

_apply_rule_and_check_optimization(
model,
remove_zero_bias_from_gemm_rule,
expected_count=1,
target_op_type="Gemm",
expected_inputs_after=2,
expected_attributes=expected_attributes,
)

def test_remove_zero_bias_from_qlinear_conv_with_attributes(self):
"""Test that zero bias is removed from QLinearConv operations and attributes are preserved."""
# Create a QLinearConv with zero bias and various attributes using ONNX text format
model = ir.from_onnx_text(
"""
<ir_version: 7, opset_import: [ "" : 17]>
agraph (uint8[1, 2, 6, 6] x) => (uint8[1, 2, 2, 2] y)
{
x_scale = Constant <value = float {0.1}>()
x_zero_point = Constant <value = uint8 {128}>()
weight = Constant <value = uint8[2, 2, 3, 3] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36}>()
w_scale = Constant <value = float {0.05}>()
w_zero_point = Constant <value = uint8 {64}>()
y_scale = Constant <value = float {0.2}>()
y_zero_point = Constant <value = uint8 {192}>()
bias = Constant <value = int32[2] {0, 0}>()
y = QLinearConv <dilations = [1, 1], group = 1, kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [2, 2]> (x, x_scale, x_zero_point, weight, w_scale, w_zero_point, y_scale, y_zero_point, bias)
}
"""
)

expected_attributes = {
"dilations": [1, 1],
"group": 1,
"kernel_shape": [3, 3],
"pads": [0, 0, 0, 0],
"strides": [2, 2],
}

_apply_rule_and_check_optimization(
model,
remove_zero_bias_from_qlinear_conv_rule,
expected_count=1,
target_op_type="QLinearConv",
expected_inputs_after=8,
expected_attributes=expected_attributes,
)

def test_qlinear_conv_with_non_zero_bias_unchanged(self):
"""Test that QLinearConv with non-zero bias is not modified."""
# Create a QLinearConv with non-zero bias using ONNX text format
Expand Down