diff --git a/onnxscript/rewriter/__init__.py b/onnxscript/rewriter/__init__.py index fc000dc176..75f43bf3ea 100644 --- a/onnxscript/rewriter/__init__.py +++ b/onnxscript/rewriter/__init__.py @@ -41,6 +41,7 @@ _min_max_to_clip, _no_op, _redundant_scatter_nd, + _remove_optional_bias, ) _ModelProtoOrIr = TypeVar("_ModelProtoOrIr", onnx.ModelProto, ir.Model) @@ -55,6 +56,7 @@ *_redundant_scatter_nd.rules, *_fuse_pad_into_conv.rules, *_fuse_batchnorm.rules, + *_remove_optional_bias.rules, ) diff --git a/onnxscript/rewriter/rules/common/__init__.py b/onnxscript/rewriter/rules/common/__init__.py index 14ed3587f3..76d9e4f4b0 100644 --- a/onnxscript/rewriter/rules/common/__init__.py +++ b/onnxscript/rewriter/rules/common/__init__.py @@ -34,6 +34,10 @@ "normalize_pad_format_conv_integer_rule", "normalize_pad_format_conv_rule", "one_reshape_matmul_reshape_rule", + "remove_optional_bias_from_conv_rule", + "remove_optional_bias_from_conv_transpose_rule", + "remove_optional_bias_from_gemm_rule", + "remove_optional_bias_from_qlinear_conv_rule", "reshape_reshape_rule", "slice_split_rule", "squeeze_reshape_1d_rule", @@ -121,3 +125,9 @@ no_op_dynamic_scatter_nd_rule, no_op_static_scatter_nd_rule, ) +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias.py b/onnxscript/rewriter/rules/common/_remove_optional_bias.py new file mode 100644 index 0000000000..ead8a73eab --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias.py @@ -0,0 +1,123 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. +"""Remove optional bias when it is all zero from Conv, ConvTranspose, Gemm and QLinearConv operations.""" + +from __future__ import annotations + +from typing import ClassVar + +import numpy as np + +from onnxscript import ir +from onnxscript.rewriter._basics import MatchResult +from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet + + +class _RemoveOptionalBias(RewriteRuleClassBase): + def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: + node = out.producer() + + return op.op( + self.op_type, + inputs=node.inputs[:-1], + attributes=node.attributes, + ) + + def check(self, context, b: ir.Value, **_) -> MatchResult: + """Condition to check if we need to replace the pattern. + + The pattern is applied only when the bias is all zeros. The bias should be + a constant value (i.e., provided by Constant nodes or initializers). + + Returns: + MatchResult: + Success if we need to replace the pattern, Failure otherwise. + """ + del context # Unused + check_result = MatchResult() + + # Check if bias is a constant/initializer + bias_tensor = ir.convenience.get_const_tensor(b) + if bias_tensor is None: + return check_result.fail("Bias is not a constant/initializer.") + + # Check if bias is all zeros + bias_array = bias_tensor.numpy() + if not np.equal(bias_array, 0.0).all(): + return check_result.fail("Bias is not all zeros.") + + return check_result + + +class RemoveOptionalBiasFromConv(_RemoveOptionalBias): + """Remove zero bias from Conv operation.""" + + op_type: ClassVar[str] = "Conv" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Conv(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromConvTranspose(_RemoveOptionalBias): + """Remove zero bias from ConvTranspose operation.""" + + op_type: ClassVar[str] = "ConvTranspose" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.ConvTranspose(x, w, b, _outputs=["out"]) + + +class RemoveOptionalBiasFromQLinearConv(_RemoveOptionalBias): + """Remove zero bias from QLinearConv operation.""" + + op_type: ClassVar[str] = "QLinearConv" + + def pattern( + self, + op: ir.tape.Tape, + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b: ir.Value, + ) -> ir.Value: + return op.QLinearConv( + x, + x_scale, + x_zero_point, + w, + w_scale, + w_zero_point, + y_scale, + y_zero_point, + b, + _outputs=["out"], + ) + + +class RemoveOptionalBiasFromGemm(_RemoveOptionalBias): + """Remove zero bias from Gemm operation.""" + + op_type: ClassVar[str] = "Gemm" + + def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: + return op.Gemm(x, w, b, _outputs=["out"]) + + +remove_optional_bias_from_conv_rule = RemoveOptionalBiasFromConv().rule() +remove_optional_bias_from_conv_transpose_rule = RemoveOptionalBiasFromConvTranspose().rule() +remove_optional_bias_from_qlinear_conv_rule = RemoveOptionalBiasFromQLinearConv().rule() +remove_optional_bias_from_gemm_rule = RemoveOptionalBiasFromGemm().rule() + +rules = RewriteRuleSet( + [ + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_qlinear_conv_rule, + remove_optional_bias_from_gemm_rule, + ] +) diff --git a/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py new file mode 100644 index 0000000000..4349d7aae3 --- /dev/null +++ b/onnxscript/rewriter/rules/common/_remove_optional_bias_test.py @@ -0,0 +1,237 @@ +# Copyright (c) Microsoft Corporation. +# Licensed under the MIT License. + +import unittest + +import numpy as np +import onnx +import onnx_ir as ir +from onnx_ir.passes.common import onnx_checker + +from onnxscript.rewriter import MatchingTracer, MatchStatus, RewriteRule, testing +from onnxscript.rewriter.rules.common import _remove_optional_bias +from onnxscript.rewriter.rules.common._remove_optional_bias import ( + remove_optional_bias_from_conv_rule, + remove_optional_bias_from_conv_transpose_rule, + remove_optional_bias_from_gemm_rule, + remove_optional_bias_from_qlinear_conv_rule, +) + + +class _RemoveOptionalBiasTestBase(unittest.TestCase): + @property + def rng(self): + return np.random.default_rng(20251016) + + def clone_model(self, model: ir.Model) -> ir.Model: + return ir.from_proto(ir.to_proto(model)) + + def _get_test_model( + self, + op_type: str, + input_shape: ir.Shape, + weight_shape: ir.Shape, + zero_bias: bool, + attributes=None, + ): + tape = ir.tape.Tape() + bias_shape = weight_shape[1] if op_type == "ConvTranspose" else weight_shape[0] + output_shape = ir.Shape(("?",) * input_shape.rank()) + + x = ir.val("X", shape=input_shape, type=ir.TensorType(ir.DataType.FLOAT)) + + w = tape.initializer( + ir.tensor(self.rng.uniform(-0.5, 0.5, weight_shape).astype(np.float32), name="W") + ) + + if zero_bias: + bias = np.zeros(bias_shape, dtype=np.float32) + else: + bias = self.rng.uniform(-0.5, 0.5, bias_shape).astype(np.float32) + + b = tape.initializer(ir.tensor(bias, name="B")) + y = tape.op( + op_type, + inputs=[x, w, b], + attributes=attributes, + output=ir.val("Y", shape=output_shape, type=ir.TensorType(ir.DataType.FLOAT)), + ) + + # Build the model + ir_model = ir.Model( + ir.Graph( + inputs=[x], + outputs=[y], + nodes=tape.nodes, + initializers=tape.initializers, + opset_imports={"": 20}, + name="test_model", + ), + ir_version=10, + ) + onnx_checker.CheckerPass(True)(ir_model) + return ir_model + + def run_test( + self, + base_model: ir.Model, + input_shape: tuple, + input_dtype=np.float32, + ): + updated_model = self.clone_model(base_model) + count = _remove_optional_bias.rules.apply_to_model(updated_model) + + # Check rule is applied + self.assertEqual(count, 1) + + # Check number of inputs is reduced + self.assertEqual( + len(updated_model.graph[0].inputs), len(base_model.graph[0].inputs) - 1 + ) + + # Prepare inputs + inputs = (self.rng.random(input_shape).astype(input_dtype),) + + # Check inference + testing.assert_numerically_equal(base_model, updated_model, inputs) + + # Validate serialized model + output_model_proto = ir.serde.serialize_model(updated_model) + onnx.checker.check_model(output_model_proto, full_check=True) + + def run_failed_condition_test( + self, + base_model: ir.Model, + rewrite_rule: RewriteRule, + expected_message: str, + ): + onnx_checker.CheckerPass(True)(base_model) + + updated_model = self.clone_model(base_model) + tracer = MatchingTracer() + count = rewrite_rule.apply_to_model(updated_model, tracer=tracer) + + # Check that the model is unchanged + self.assertEqual(count, 0) + + # Check that the error message is the expected one + tracer_match = tracer.best_matches_map[rewrite_rule][0] + self.assertEqual(tracer_match.status.value, MatchStatus.CONDITION_FAILED) + self.assertRegex(tracer_match.match_result.reason, expected_message) + + +class RemoveOptionalBiasGemmTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=True, + attributes={"transB": 1}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_gemm(self): + input_shape = (512, 256) + base_model = self._get_test_model( + op_type="Gemm", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((64, 256)), + zero_bias=False, + attributes={"transB": 1}, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_gemm_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=True, + attributes={"strides": (2, 2)}, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="Conv", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((16, 3, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasGonvTransposeTest(_RemoveOptionalBiasTestBase): + def test_successful_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=True, + ) + self.run_test(base_model, input_shape) + + def test_fail_remove_optional_bias_conv_transpose(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model( + op_type="ConvTranspose", + input_shape=ir.Shape(input_shape), + weight_shape=ir.Shape((3, 16, 3, 3)), + zero_bias=False, + ) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_conv_transpose_rule, "Bias is not all zeros." + ) + + +class RemoveOptionalBiasQLinearConvTest(_RemoveOptionalBiasTestBase): + def _get_test_model(self, zero_bias): + if zero_bias: + bias = np.zeros((16,), dtype=np.int32) + else: + bias = self.rng.uniform(-5, 5, (16,)).astype(np.int32) + + w = ir.tensor(self.rng.uniform(-5, 5, (16, 3, 3, 3)).astype(np.uint8), name="W") + b = ir.tensor(bias, name="B") + + model = ir.from_onnx_text( + """ + < ir_version: 10, opset_import: ["" : 20] > + test_model (uint8[N, 3, 32, 32] X) => (uint8 [N, ?, ?, ?] Y) + + { + Y = QLinearConv(X, x_scale, x_zero_point, W, w_scale, w_zero_point, y_scale, y_zero_point, B) + } + """, + initializers=[w, b], + ) + onnx_checker.CheckerPass(True)(model) + return model + + def test_successful_remove_optional_bias_qlinear_conv(self): + input_shape = (1, 3, 32, 32) + base_model = self._get_test_model(zero_bias=True) + self.run_test(base_model, input_shape, np.uint8) + + def test_fail_remove_optional_bias_qlinear_conv(self): + base_model = self._get_test_model(zero_bias=False) + self.run_failed_condition_test( + base_model, remove_optional_bias_from_qlinear_conv_rule, "Bias is not all zeros." + ) + + +if __name__ == "__main__": + unittest.main()