-
Notifications
You must be signed in to change notification settings - Fork 94
[Rewriter] Implement zero bias removal for Conv operations and related rules #2555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
6cfd6fd
48037a8
fd028ee
3742e7b
8bfa65f
6fc3ca7
b93a56c
e322625
121360e
ee7dafa
8ef6c41
2b9dda4
153b4e7
ce64fb7
a94f8b9
86de85f
d62eafb
d4f73dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,133 @@ | ||
| # Copyright (c) Microsoft Corporation. | ||
|
||
| # Licensed under the MIT License. | ||
| """Remove optional bias when it is all zero from Conv and related operations.""" | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from typing import ClassVar | ||
|
|
||
| import numpy as np | ||
|
|
||
| from onnxscript import ir | ||
| from onnxscript.ir import convenience | ||
| from onnxscript.rewriter._basics import MatchResult | ||
| from onnxscript.rewriter._rewrite_rule import RewriteRuleClassBase, RewriteRuleSet | ||
|
|
||
|
|
||
| class _RemoveZeroBiasBase(RewriteRuleClassBase): | ||
| """Base class for removing zero bias from operations.""" | ||
|
|
||
| def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: | ||
| """Remove the bias input from the operation.""" | ||
| node = out.producer() | ||
|
|
||
| original_inputs = list(node.inputs) | ||
| inputs_without_bias = original_inputs[:-1] | ||
|
||
|
|
||
| return op.op( | ||
| self.op_type, | ||
| inputs=inputs_without_bias, | ||
| attributes=node.attributes, | ||
| domain=node.domain, | ||
justinchuby marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
||
| def _check_bias_is_zero(self, bias_value: ir.Value) -> MatchResult: | ||
| """Check if the bias value is present and is all zeros.""" | ||
| check_result = MatchResult() | ||
|
|
||
| # Check if bias is a constant/initializer | ||
| bias_tensor = convenience.get_const_tensor(bias_value) | ||
| if bias_tensor is None: | ||
| return check_result.fail("Bias is not a constant/initializer.") | ||
|
|
||
| # Check if bias is all zeros | ||
| bias_array = bias_tensor.numpy() | ||
| if not np.allclose(bias_array, 0.0, atol=1e-8): | ||
| return check_result.fail("Bias is not all zeros.") | ||
|
|
||
| return check_result | ||
|
|
||
| def check(self, context, x: ir.Value, w: ir.Value, b: ir.Value, **_) -> MatchResult: | ||
| """Check if the bias is present and is all zeros.""" | ||
| del context # Unused | ||
| return self._check_bias_is_zero(b) | ||
|
|
||
|
|
||
| class RemoveZeroBiasFromConv(_RemoveZeroBiasBase): | ||
| """Remove zero bias from Conv operations.""" | ||
|
|
||
| op_type: ClassVar = "Conv" | ||
|
|
||
| def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: | ||
| return op.Conv(x, w, b, _outputs=["out"]) | ||
|
|
||
|
|
||
| class RemoveZeroBiasFromConvTranspose(_RemoveZeroBiasBase): | ||
| """Remove zero bias from ConvTranspose operations.""" | ||
|
|
||
| op_type: ClassVar = "ConvTranspose" | ||
|
|
||
| def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: | ||
| return op.ConvTranspose(x, w, b, _outputs=["out"]) | ||
|
|
||
|
|
||
| class RemoveZeroBiasFromQLinearConv(_RemoveZeroBiasBase): | ||
| """Remove zero bias from QLinearConv operations.""" | ||
|
|
||
| op_type: ClassVar = "QLinearConv" | ||
|
|
||
| def pattern( | ||
| self, | ||
| op: ir.tape.Tape, | ||
| x, | ||
| x_scale, | ||
| x_zero_point, | ||
| w, | ||
| w_scale, | ||
| w_zero_point, | ||
| y_scale, | ||
| y_zero_point, | ||
| b: ir.Value, | ||
| ) -> ir.Value: | ||
| return op.QLinearConv( | ||
| x, | ||
| x_scale, | ||
| x_zero_point, | ||
| w, | ||
| w_scale, | ||
| w_zero_point, | ||
| y_scale, | ||
| y_zero_point, | ||
| b, | ||
| _outputs=["out"], | ||
| ) | ||
|
|
||
|
|
||
| class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): | ||
| """Remove zero bias from Gemm operations.""" | ||
|
|
||
| op_type: ClassVar = "Gemm" | ||
|
|
||
| def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: | ||
| return op.Gemm(a, b, c, _outputs=["out"]) | ||
|
|
||
| def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult: | ||
| """Check if the bias (c parameter) is present and is all zeros.""" | ||
| del context # Unused | ||
| return self._check_bias_is_zero(c) | ||
justinchuby marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
|
|
||
| # Create rule instances | ||
| remove_zero_bias_from_conv_rule = RemoveZeroBiasFromConv().rule() | ||
| remove_zero_bias_from_conv_transpose_rule = RemoveZeroBiasFromConvTranspose().rule() | ||
| remove_zero_bias_from_qlinear_conv_rule = RemoveZeroBiasFromQLinearConv().rule() | ||
| remove_zero_bias_from_gemm_rule = RemoveZeroBiasFromGemm().rule() | ||
|
|
||
| rules = RewriteRuleSet( | ||
| [ | ||
| remove_zero_bias_from_conv_rule, | ||
| remove_zero_bias_from_conv_transpose_rule, | ||
| remove_zero_bias_from_qlinear_conv_rule, | ||
| remove_zero_bias_from_gemm_rule, | ||
| ] | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.