-
Notifications
You must be signed in to change notification settings - Fork 93
Fix fused matmul check/rewrite functions #2331
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
bmehta001
merged 39 commits into
microsoft:main
from
bmehta001:bhamehta/fusedmatmul_find_ops
Jun 6, 2025
Merged
Changes from 35 commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
71c1d5c
Use producer syntax, simplify, rm unnecessary checks
bmehta001 77b39f6
Simplify assert, assigning attributes
bmehta001 af0abbd
Add test to ensure fusion rules do not rely on position of node
bmehta001 65f4637
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 8946821
Add checking for transBatch
bmehta001 0f2b287
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 cb5a192
Fix condition, formatting, and add default
bmehta001 b2e9737
Fix formatting
bmehta001 af9c064
Simplify syntax w/ functions
bmehta001 4611fdc
Condense rules using type function and classVars
bmehta001 bee83a8
Rm unused/fix comment
bmehta001 9f81fe9
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 b9abd98
Address comments
bmehta001 10162db
Fix None error + rm type-ignore
bmehta001 d06f388
Add more tests
bmehta001 a065541
Handle defaults and add docstring
bmehta001 1a90317
Add clarifying comment
bmehta001 d417c02
Fix correct default behavior for transpose
bmehta001 a23ee07
Formally drop python 3.8 support (#2354)
justinchuby d1eb856
Implement `__repr__` for MatchResult (#2353)
justinchuby 19b7f6a
Use onnx_ir as a dependency (#2324)
justinchuby 3fd79be
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms 11075ee
Fix pytest for TestCosSinCacheTransform (#2358)
justinchuby 9b81926
SDPA fusion cleanup (#2352)
gramalingam 7553ce1
Require onnx-ir 0.1.1 (#2360)
justinchuby 73432e5
Enable CSE in optimizer (#2361)
titaiwangms ccce52e
Rewrite tests and address comments
bmehta001 3654fa8
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 12b4cc0
Support common subexpression elimination pass (CSE) (#2304)
titaiwangms 2276a16
Enable CSE in optimizer (#2361)
titaiwangms 2a0a798
Revert changes
bmehta001 5bcf2b1
Fix errors/simplify
bmehta001 a94c295
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 e976fb1
Update onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py
bmehta001 2adf8ea
Iterate through IR Model instead of ModelProto
bmehta001 73889d3
Addressed comments
bmehta001 841c49b
Simplify
bmehta001 04e6955
Simplify use of constants
bmehta001 514649f
Merge branch 'main' into bhamehta/fusedmatmul_find_ops
bmehta001 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,13 +2,30 @@ | |
| # Licensed under the MIT License. | ||
| from __future__ import annotations | ||
|
|
||
| from typing import ClassVar | ||
| from typing import ClassVar, Optional, Sequence | ||
|
|
||
| import onnxscript.rewriter.pattern as orp | ||
| from onnxscript import ir | ||
| from onnxscript.rewriter import _ir_utils | ||
|
|
||
|
|
||
| def _get_node(value: ir.Value, name: str) -> ir.Node: | ||
| """Get the node from the output value.""" | ||
| node = value.producer() | ||
| assert node is not None, f"{name} node should not be None" | ||
| return node | ||
|
|
||
|
|
||
| def _get_kwargs(node: ir.Node) -> dict[str, float | int]: | ||
| """Get the kwargs from the node.""" | ||
| kwargs = {key: val.value for key, val in node.attributes.items()} | ||
justinchuby marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| return kwargs | ||
|
|
||
|
|
||
|
|
||
|
|
||
| class FusedMatMulDiv1(orp.RewriteRuleClassBase): | ||
| """Replaces ``MatMul + Div`` by FusedMatMul.""" | ||
| """Replaces ``MatMul + Div`` with MatMul.""" | ||
|
|
||
| def pattern(self, op, x, y, cst): | ||
| return op.Div(op.MatMul(x, y), cst) | ||
|
|
@@ -29,122 +46,289 @@ | |
|
|
||
|
|
||
| class FusedMatMulDiv2(orp.RewriteRuleClassBase): | ||
| """Replaces ``FusedMatMul + Div`` by FusedMatMul.""" | ||
| """Replaces ``FusedMatMul + Div`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y, cst): | ||
| return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) | ||
| return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst) | ||
|
|
||
| def check(self, context, x, y, cst) -> orp.MatchResult: | ||
| def check(self, context, x, y, cst, **_) -> orp.MatchResult: | ||
| check_result = orp.MatchResult() | ||
| if cst.const_value is None: | ||
| return check_result.fail("Divisor is not a constant value.") | ||
| if cst.const_value.numpy().size > 1: | ||
| return check_result.fail("Divisor is not a scalar value.") | ||
| return check_result | ||
|
|
||
| def rewrite(self, op, x, y, cst): | ||
| def rewrite(self, op, x, y, cst, fused: ir.Value): | ||
| value = cst.const_value.numpy() | ||
| c = float(value[0] if value.shape == (1,) else value) | ||
| node = list(x.uses())[0][0] # noqa: RUF015 | ||
|
|
||
| kwargs = {} | ||
| alpha = node.attributes.get("alpha", None) | ||
| kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c | ||
| for name in ["transA", "transB", "transBatchA", "transBatchB"]: | ||
| att = node.attributes.get(name) | ||
| if att: | ||
| kwargs[name] = att.value | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| kwargs = _get_kwargs(fused_node) | ||
| kwargs["alpha"] = kwargs.get("alpha", 1.0) / c | ||
| return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") | ||
|
|
||
|
|
||
| class _TransposeMatMulBase(orp.RewriteRuleClassBase): | ||
| _pos: ClassVar = 1 | ||
|
|
||
| def check(self, context, x, y) -> orp.MatchResult: | ||
| def check( | ||
| self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_ | ||
| ) -> orp.MatchResult: | ||
| check_result = orp.MatchResult() | ||
| perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 | ||
| expected_perm = list(range(len(perm))) | ||
| expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] | ||
| if perm != expected_perm: | ||
| return check_result.fail("Permutation values for Transpose are not correct.") | ||
| transposed_node = _get_node(transposed, "Transpose") | ||
| perm = transposed_node.attributes.get_ints("perm") | ||
| if perm: | ||
| # Check that last two dimensions are swapped | ||
| expected_perm = list(range(len(perm))) | ||
| expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] | ||
| if perm != expected_perm: | ||
| return check_result.fail("Permutation values for Transpose are not correct.") | ||
| elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( | ||
| self._pos == 2 and not _ir_utils.has_rank(y, 2) | ||
| ): | ||
| # If perm is not defined, the default transpose behavior is to swap | ||
| # all dimensions, which is correct for MatMul with rank = 2. | ||
| return check_result.fail( | ||
| "If perm is not defined, rank must be 2 for TransposeMatMul rule." | ||
| ) | ||
| if fused: | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" | ||
| if fused_node.attributes.get_int(trans_batch_property, 0): | ||
| return check_result.fail( | ||
| "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule." | ||
| ) | ||
| return check_result | ||
|
|
||
| def rewrite(self, op, x, y): | ||
| node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015 | ||
| def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): | ||
| kwargs = {} | ||
| for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: | ||
| att = node.attributes.get(name) | ||
| if att: | ||
| kwargs[name] = att.value | ||
| name = "transA" if self._pos == 1 else "transB" | ||
| kwargs[name] = 1 - kwargs.get(name, 0) | ||
| if fused: | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| kwargs = _get_kwargs(fused_node) | ||
| trans_name = "transA" if self._pos == 1 else "transB" | ||
| kwargs[trans_name] = 1 - kwargs.get(trans_name, 0) | ||
| return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") | ||
|
|
||
|
|
||
| class TransposeMatMul1(_TransposeMatMulBase): | ||
| """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" | ||
| """Replaces ``Transpose + MatMul`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.MatMul(op.Transpose(x), y) | ||
| return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y) | ||
|
|
||
|
|
||
| class TransposeFusedMatMul1(TransposeMatMul1): | ||
| """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" | ||
| """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft") | ||
| return op.FusedMatMul( | ||
| op.Transpose(x, _outputs=["transposed"]), | ||
| y, | ||
| _domain="com.microsoft", | ||
| _outputs=["fused"], | ||
| ) | ||
|
|
||
|
|
||
| class TransposeMatMul2(_TransposeMatMulBase): | ||
| """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" | ||
| """Replaces ``Transpose + MatMul`` with FusedMatMul.""" | ||
|
|
||
| _pos: ClassVar = 2 | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.MatMul(x, op.Transpose(y)) | ||
| return op.MatMul(x, op.Transpose(y, _outputs=["transposed"])) | ||
|
|
||
|
|
||
| class TransposeFusedMatMul2(TransposeMatMul2): | ||
| """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" | ||
| """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft") | ||
| return op.FusedMatMul( | ||
| x, | ||
| op.Transpose(y, _outputs=["transposed"]), | ||
| _domain="com.microsoft", | ||
| _outputs=["fused"], | ||
| ) | ||
|
|
||
|
|
||
| class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase): | ||
gramalingam marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| """Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either | ||
| when transBatchA or transBatchB in FusedMatMul is 1, or | ||
| can be inverted based on the permutation dims of the Transpose, in | ||
| contrast to the original FusedMatMul rule which assumes that | ||
| transBatchA and transBatchB are always 0 before and after rewriting. | ||
| transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position | ||
| i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1]. | ||
| transBatchA = 0, transA = 1 flips the last two dimensions | ||
| i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2]. | ||
| transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions | ||
| i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0]. | ||
| The flipping logic is based on the following cases: | ||
| Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0] | ||
| or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] | ||
| - Then transBatchA and transA can be flipped in FusedMatMul when rewriting. | ||
| Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] | ||
| or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] | ||
| - Then transBatchA can be flipped in FusedMatMul when rewriting. | ||
| Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] | ||
| - Then transA can be flipped in FusedMatMul when rewriting. | ||
| The same logic applies for transBatchB and transB, when _pos is set to 2. | ||
| The _flip_transpose_batch and _flip_transpose flags are used to control | ||
| which case is applied by the rules of inheriting classes that change these class vars. | ||
| """ | ||
|
|
||
| _pos: ClassVar = 1 | ||
| _flip_transpose_batch: ClassVar = False | ||
| _flip_transpose: ClassVar = False | ||
|
|
||
| def check( | ||
| self, context, x, y, transposed: ir.Value, fused: ir.Value, **_ | ||
| ) -> orp.MatchResult: | ||
| check_result = orp.MatchResult() | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" | ||
| trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) | ||
| transposed_node = _get_node(transposed, "Transpose") | ||
| perm = transposed_node.attributes["perm"].as_ints() | ||
| if not perm: | ||
| return check_result.fail("Permutation values for Transpose are not correct.") | ||
|
|
||
| list_perm = list(range(len(perm))) | ||
| if self._flip_transpose_batch and self._flip_transpose: | ||
| # Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0] | ||
| # or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] | ||
| # - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting. | ||
| if trans_batch == 0: | ||
| expected_perm = [*list_perm[1:], list_perm[0]] | ||
| else: | ||
| expected_perm = [list_perm[-1], *list_perm[0:-1]] | ||
| if expected_perm == perm: | ||
| return check_result | ||
| elif self._flip_transpose_batch: | ||
| # Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] | ||
| # or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] | ||
| # - Then transBatchA/B can be flipped in FusedMatMul when rewriting. | ||
| if trans_batch == 0: | ||
| expected_perm = [*list_perm[1:-1], list_perm[0], list_perm[-1]] | ||
| else: | ||
| expected_perm = [list_perm[-2], *list_perm[0:-2], list_perm[-1]] | ||
| if expected_perm == perm: | ||
| return check_result | ||
| elif self._flip_transpose: | ||
| # Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] | ||
| # - Then transA can be flipped in FusedMatMul when rewriting. | ||
| expected_perm = [list_perm[-1], *list_perm[1:-1], list_perm[0]] | ||
| if expected_perm == perm and trans_batch == 1: | ||
| return check_result | ||
|
|
||
| return check_result.fail("Permutation values for Transpose are not correct.") | ||
|
|
||
| def rewrite(self, op, x, y, fused: ir.Value, **_): | ||
| kwargs = {} | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| kwargs = _get_kwargs(fused_node) | ||
| name = "A" if self._pos == 1 else "B" | ||
| if self._flip_transpose_batch: | ||
| trans_batch_property = f"transBatch{name}" | ||
| kwargs[trans_batch_property] = 1 - kwargs.get(trans_batch_property, 0) | ||
| if self._flip_transpose: | ||
| trans_property = f"trans{name}" | ||
| kwargs[trans_property] = 1 - kwargs.get(trans_property, 0) | ||
| return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") | ||
bmehta001 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def pattern(self, op, x, y): | ||
| if self._pos == 1: | ||
| return op.FusedMatMul( | ||
| op.Transpose(x, _outputs=["transposed"]), | ||
| y, | ||
| _domain="com.microsoft", | ||
| _outputs=["fused"], | ||
| ) | ||
| else: | ||
| return op.FusedMatMul( | ||
| x, | ||
| op.Transpose(y, _outputs=["transposed"]), | ||
| _domain="com.microsoft", | ||
| _outputs=["fused"], | ||
| ) | ||
|
|
||
|
|
||
| TransposeFusedMatMulWithFlippedBatchAndTranspose1 = type( | ||
| "TransposeFusedMatMulWithFlippedBatchAndTranspose1", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_flip_transpose": True, "_flip_transpose_batch": True}, | ||
| ) | ||
| TransposeFusedMatMulWithFlippedBatchAndTranspose2 = type( | ||
| "TransposeFusedMatMulWithFlippedBatchAndTranspose2", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_pos": 2, "_flip_transpose": True, "_flip_transpose_batch": True}, | ||
| ) | ||
| TransposeFusedMatMulWithFlippedBatch1 = type( | ||
| "TransposeFusedMatMulWithFlippedBatch1", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_flip_transpose_batch": True}, | ||
| ) | ||
| TransposeFusedMatMulWithFlippedBatch2 = type( | ||
| "TransposeFusedMatMulWithFlippedBatch2", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_pos": 2, "_flip_transpose_batch": True}, | ||
| ) | ||
| TransposeFusedMatMulWithBatchAndTranspose1 = type( | ||
| "TransposeFusedMatMulWithBatchAndTranspose1", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_flip_transpose": True}, | ||
| ) | ||
| TransposeFusedMatMulWithBatchAndTranspose2 = type( | ||
| "TransposeFusedMatMulWithBatchAndTranspose2", | ||
| (_TransposeFusedMatMulBaseWithBatch,), | ||
| {"_pos": 2, "_flip_transpose": True}, | ||
| ) | ||
|
|
||
|
|
||
| class MatMulTranspose(orp.RewriteRuleClassBase): | ||
| """Replaces ``MatMul + Transpose`` by FusedMatMul.""" | ||
| """Replaces ``MatMul + Transpose`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.Transpose(op.MatMul(x, y)) | ||
| return op.Transpose(op.MatMul(x, y), _outputs=["transposed"]) | ||
|
|
||
| def check(self, context, x, y) -> orp.MatchResult: | ||
| def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: | ||
| check_result = orp.MatchResult() | ||
| matmul = list(x.uses())[0][0] # noqa: RUF015 | ||
| transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 | ||
| perm = transpose.attributes["perm"].value | ||
| expected_perm = list(range(len(perm))) | ||
| expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] | ||
| if perm != expected_perm: | ||
| return check_result.fail("Permutation values for Transpose are not correct.") | ||
| transpose_node = _get_node(transposed, "Transpose") | ||
| perm = _get_ints_or_default(transpose_node, "perm") | ||
Check failureCode scanning / lintrunner RUFF/F821 Error
Undefined name _get_ints_or_default.
See https://docs.astral.sh/ruff/rules/undefined-name |
||
| # transA/transB only work on the last two dimensions of the input, | ||
| # so we can only apply this rule if the inputs are rank 2. | ||
| if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): | ||
| if perm: | ||
| # Check that the two dimensions are swapped | ||
| if perm != [1, 0]: | ||
| return check_result.fail( | ||
| "Permutation values for Transpose are not correct." | ||
| ) | ||
| # If perm is not defined, the default transpose behavior is to swap | ||
| # all dimensions, which is correct for MatMul with rank = 2. | ||
| else: | ||
| return check_result.fail("Rank must be 2 for MatMulTranspose rule.") | ||
| return check_result | ||
|
|
||
| def rewrite(self, op, x, y): | ||
| node = list(x.uses())[0][0] # noqa: RUF015 | ||
| def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): | ||
| kwargs = {} | ||
| for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: | ||
| att = node.attributes.get(name) | ||
| if att: | ||
| kwargs[name] = att.value | ||
| if fused: | ||
| fused_node = _get_node(fused, "FusedMatMul") | ||
| kwargs = _get_kwargs(fused_node) | ||
| for name in ["transA", "transB"]: | ||
| kwargs[name] = 1 - kwargs.get(name, 0) | ||
| return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") | ||
|
|
||
|
|
||
| class FusedMatMulTranspose(MatMulTranspose): | ||
| """Replaces ``MatMul + Transpose`` by FusedMatMul.""" | ||
| """Replaces ``FusedMatMul + Transpose`` with FusedMatMul.""" | ||
|
|
||
| def pattern(self, op, x, y): | ||
| return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft")) | ||
| return op.Transpose( | ||
| op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), | ||
| _outputs=["transposed"], | ||
| ) | ||
|
|
||
|
|
||
| def fused_matmul_rule_sets() -> orp.RewriteRuleSet: | ||
|
|
@@ -165,5 +349,11 @@ | |
| TransposeFusedMatMul1.rule(), | ||
| TransposeMatMul2.rule(), | ||
| TransposeFusedMatMul2.rule(), | ||
| TransposeFusedMatMulWithFlippedBatch1.rule(), # type: ignore[attr-defined] | ||
bmehta001 marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| TransposeFusedMatMulWithFlippedBatch2.rule(), # type: ignore[attr-defined] | ||
| TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(), # type: ignore[attr-defined] | ||
| TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(), # type: ignore[attr-defined] | ||
| TransposeFusedMatMulWithBatchAndTranspose1.rule(), # type: ignore[attr-defined] | ||
| TransposeFusedMatMulWithBatchAndTranspose2.rule(), # type: ignore[attr-defined] | ||
| ] | ||
| ) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.