-
Notifications
You must be signed in to change notification settings - Fork 93
[Rewriter] Implement zero bias removal for Conv operations and related rules #2555
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 2 commits
6cfd6fd
48037a8
fd028ee
3742e7b
8bfa65f
6fc3ca7
b93a56c
e322625
121360e
ee7dafa
8ef6c41
2b9dda4
153b4e7
ce64fb7
a94f8b9
86de85f
d62eafb
d4f73dd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,12 +21,15 @@ def rewrite(self, op: ir.tape.Tape, out: ir.Value, **_) -> ir.Value: | |
| """Remove the bias input from the operation.""" | ||
| node = out.producer() | ||
|
|
||
| original_inputs = list(node.inputs) | ||
| inputs_without_bias = original_inputs[:-1] | ||
| # Filter out the bias parameter and keep all other inputs | ||
| inputs = [] | ||
| for param_name, param_value in _.items(): | ||
|
||
| if param_name != "b": # 'b' is the bias parameter | ||
| inputs.append(param_value) | ||
|
|
||
| return op.op( | ||
| self.op_type, | ||
| inputs=inputs_without_bias, | ||
| inputs=inputs, | ||
| attributes=node.attributes, | ||
| domain=node.domain, | ||
justinchuby marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| ) | ||
|
|
@@ -108,13 +111,8 @@ class RemoveZeroBiasFromGemm(_RemoveZeroBiasBase): | |
|
|
||
| op_type: ClassVar = "Gemm" | ||
|
|
||
| def pattern(self, op: ir.tape.Tape, a: ir.Value, b: ir.Value, c: ir.Value) -> ir.Value: | ||
| return op.Gemm(a, b, c, _outputs=["out"]) | ||
|
|
||
| def check(self, context, a: ir.Value, b: ir.Value, c: ir.Value, **_) -> MatchResult: | ||
| """Check if the bias (c parameter) is present and is all zeros.""" | ||
| del context # Unused | ||
| return self._check_bias_is_zero(c) | ||
| def pattern(self, op: ir.tape.Tape, x: ir.Value, w: ir.Value, b: ir.Value) -> ir.Value: | ||
| return op.Gemm(x, w, b, _outputs=["out"]) | ||
|
|
||
|
|
||
| # Create rule instances | ||
|
|
||
whyvineet marked this conversation as resolved.
Show resolved
Hide resolved
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe you can pass some attributes when testing to check that every info is correctly transferred (e.g. stride, transA...) |
Uh oh!
There was an error while loading. Please reload this page.