diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py index cc10297afe..c9c2480428 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py @@ -5,10 +5,25 @@ from typing import ClassVar import onnxscript.rewriter.pattern as orp +from onnxscript import ir +from onnxscript.rewriter import _ir_utils + + +def _get_node(value: ir.Value, name: str) -> ir.Node: + """Get the node from the output value.""" + node = value.producer() + assert node is not None, f"{name} node should not be None" + return node + + +def _get_kwargs(node: ir.Node) -> dict[str, float | int]: + """Get the kwargs from the node.""" + kwargs = {key: val.value for key, val in node.attributes.items()} + return kwargs class FusedMatMulDiv1(orp.RewriteRuleClassBase): - """Replaces ``MatMul + Div`` by FusedMatMul.""" + """Replaces ``MatMul + Div`` with MatMul.""" def pattern(self, op, x, y, cst): return op.Div(op.MatMul(x, y), cst) @@ -29,12 +44,12 @@ def rewrite(self, op, x, y, cst): class FusedMatMulDiv2(orp.RewriteRuleClassBase): - """Replaces ``FusedMatMul + Div`` by FusedMatMul.""" + """Replaces ``FusedMatMul + Div`` with FusedMatMul.""" def pattern(self, op, x, y, cst): - return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst) + return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst) - def check(self, context, x, y, cst) -> orp.MatchResult: + def check(self, context, x, y, cst, **_) -> orp.MatchResult: check_result = orp.MatchResult() if cst.const_value is None: return check_result.fail("Divisor is not a constant value.") @@ -42,109 +57,273 @@ def check(self, context, x, y, cst) -> orp.MatchResult: return check_result.fail("Divisor is not a scalar value.") return check_result - def rewrite(self, op, x, y, cst): + def rewrite(self, op, x, y, cst, fused: ir.Value): value = cst.const_value.numpy() c = float(value[0] if value.shape == (1,) else value) - node = list(x.uses())[0][0] # noqa: RUF015 - - kwargs = {} - alpha = node.attributes.get("alpha", None) - kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c - for name in ["transA", "transB", "transBatchA", "transBatchB"]: - att = node.attributes.get(name) - if att: - kwargs[name] = att.value + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + kwargs["alpha"] = kwargs.get("alpha", 1.0) / c return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class _TransposeMatMulBase(orp.RewriteRuleClassBase): _pos: ClassVar = 1 - def check(self, context, x, y) -> orp.MatchResult: + def check( + self, context, x, y, transposed: ir.Value, fused: ir.Value | None = None, **_ + ) -> orp.MatchResult: check_result = orp.MatchResult() - perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015 - expected_perm = list(range(len(perm))) - expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: - return check_result.fail("Permutation values for Transpose are not correct.") + transposed_node = _get_node(transposed, "Transpose") + perm = transposed_node.attributes.get_ints("perm") + if perm: + # Check that last two dimensions are swapped + expected_perm = list(range(len(perm))) + expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] + if perm != expected_perm: + return check_result.fail("Permutation values for Transpose are not correct.") + elif (self._pos == 1 and not _ir_utils.has_rank(x, 2)) or ( + self._pos == 2 and not _ir_utils.has_rank(y, 2) + ): + # If perm is not defined, the default transpose behavior is to swap + # all dimensions, which is correct for MatMul with rank = 2. + return check_result.fail( + "If perm is not defined, rank must be 2 for TransposeMatMul rule." + ) + if fused: + fused_node = _get_node(fused, "FusedMatMul") + trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" + if fused_node.attributes.get_int(trans_batch_property, 0): + return check_result.fail( + "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule." + ) return check_result - def rewrite(self, op, x, y): - node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015 + def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): kwargs = {} - for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: - att = node.attributes.get(name) - if att: - kwargs[name] = att.value - name = "transA" if self._pos == 1 else "transB" - kwargs[name] = 1 - kwargs.get(name, 0) + if fused: + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + trans_name = "transA" if self._pos == 1 else "transB" + kwargs[trans_name] = 1 - kwargs.get(trans_name, 0) return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") class TransposeMatMul1(_TransposeMatMulBase): - """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + """Replaces ``Transpose + MatMul`` with FusedMatMul.""" def pattern(self, op, x, y): - return op.MatMul(op.Transpose(x), y) + return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y) class TransposeFusedMatMul1(TransposeMatMul1): - """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" def pattern(self, op, x, y): - return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft") + return op.FusedMatMul( + op.Transpose(x, _outputs=["transposed"]), + y, + _domain="com.microsoft", + _outputs=["fused"], + ) class TransposeMatMul2(_TransposeMatMulBase): - """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + """Replaces ``Transpose + MatMul`` with FusedMatMul.""" _pos: ClassVar = 2 def pattern(self, op, x, y): - return op.MatMul(x, op.Transpose(y)) + return op.MatMul(x, op.Transpose(y, _outputs=["transposed"])) class TransposeFusedMatMul2(TransposeMatMul2): - """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul.""" + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul.""" def pattern(self, op, x, y): - return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft") + return op.FusedMatMul( + x, + op.Transpose(y, _outputs=["transposed"]), + _domain="com.microsoft", + _outputs=["fused"], + ) + + +class _TransposeFusedMatMulBaseWithBatch(orp.RewriteRuleClassBase): + """Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either + when transBatchA or transBatchB in FusedMatMul is 1, or + can be inverted based on the permutation dims of the Transpose, in + contrast to the original FusedMatMul rule which assumes that + transBatchA and transBatchB are always 0 before and after rewriting. + + transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position + i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1]. + transBatchA = 0, transA = 1 flips the last two dimensions + i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2]. + transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions + i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0]. + + The flipping logic is based on the following cases: + Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0] + or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] + - Then transBatchA and transA can be flipped in FusedMatMul when rewriting. + Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] + or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] + - Then transBatchA can be flipped in FusedMatMul when rewriting. + Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] + - Then transA can be flipped in FusedMatMul when rewriting. + The same logic applies for transBatchB and transB, when _pos is set to 2. + The _flip_transpose_batch and _flip_transpose flags are used to control + which case is applied by the rules of inheriting classes that change these class vars. + """ + + _pos: ClassVar = 1 + _flip_transpose_batch: ClassVar = False + _flip_transpose: ClassVar = False + + def check( + self, context, x, y, transposed: ir.Value, fused: ir.Value, **_ + ) -> orp.MatchResult: + check_result = orp.MatchResult() + fused_node = _get_node(fused, "FusedMatMul") + trans_batch_property = "transBatchA" if self._pos == 1 else "transBatchB" + trans_batch = fused_node.attributes.get_int(trans_batch_property, 0) + transposed_node = _get_node(transposed, "Transpose") + perm = transposed_node.attributes["perm"].as_ints() + if not perm: + return check_result.fail("Permutation values for Transpose are not correct.") + + list_perm = list(range(len(perm))) + if self._flip_transpose_batch and self._flip_transpose: + # Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0] + # or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2] + # - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting. + if trans_batch == 0: + expected_perm = [*list_perm[1:], list_perm[0]] + else: + expected_perm = [list_perm[-1], *list_perm[0:-1]] + if expected_perm == perm: + return check_result + elif self._flip_transpose_batch: + # Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1] + # or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1] + # - Then transBatchA/B can be flipped in FusedMatMul when rewriting. + if trans_batch == 0: + expected_perm = [*list_perm[1:-1], list_perm[0], list_perm[-1]] + else: + expected_perm = [list_perm[-2], *list_perm[0:-2], list_perm[-1]] + if expected_perm == perm: + return check_result + elif self._flip_transpose: + # Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0] + # - Then transA can be flipped in FusedMatMul when rewriting. + expected_perm = [list_perm[-1], *list_perm[1:-1], list_perm[0]] + if expected_perm == perm and trans_batch == 1: + return check_result + + return check_result.fail("Permutation values for Transpose are not correct.") + + def rewrite(self, op, x, y, fused: ir.Value, **_): + kwargs = {} + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) + name = "A" if self._pos == 1 else "B" + if self._flip_transpose_batch: + trans_batch_property = f"transBatch{name}" + kwargs[trans_batch_property] = 1 - kwargs.get(trans_batch_property, 0) + if self._flip_transpose: + trans_property = f"trans{name}" + kwargs[trans_property] = 1 - kwargs.get(trans_property, 0) + return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft") + + def pattern(self, op, x, y): + if self._pos == 1: + return op.FusedMatMul( + op.Transpose(x, _outputs=["transposed"]), + y, + _domain="com.microsoft", + _outputs=["fused"], + ) + else: + return op.FusedMatMul( + x, + op.Transpose(y, _outputs=["transposed"]), + _domain="com.microsoft", + _outputs=["fused"], + ) + + +class TransposeFusedMatMulWithFlippedBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose = True + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose = True + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatch1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithFlippedBatch2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose_batch = True + + +class TransposeFusedMatMulWithBatchAndTranspose1(_TransposeFusedMatMulBaseWithBatch): + _flip_transpose = True + + +class TransposeFusedMatMulWithBatchAndTranspose2(_TransposeFusedMatMulBaseWithBatch): + _pos = 2 + _flip_transpose = True class MatMulTranspose(orp.RewriteRuleClassBase): - """Replaces ``MatMul + Transpose`` by FusedMatMul.""" + """Replaces ``MatMul + Transpose`` with FusedMatMul.""" def pattern(self, op, x, y): - return op.Transpose(op.MatMul(x, y)) + return op.Transpose(op.MatMul(x, y), _outputs=["transposed"]) - def check(self, context, x, y) -> orp.MatchResult: + def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult: check_result = orp.MatchResult() - matmul = list(x.uses())[0][0] # noqa: RUF015 - transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015 - perm = transpose.attributes["perm"].value - expected_perm = list(range(len(perm))) - expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2] - if perm != expected_perm: - return check_result.fail("Permutation values for Transpose are not correct.") + transpose_node = _get_node(transposed, "Transpose") + perm = transpose_node.attributes.get_ints("perm") + # transA/transB only work on the last two dimensions of the input, + # so we can only apply this rule if the inputs are rank 2. + if _ir_utils.has_rank(x, 2) and _ir_utils.has_rank(y, 2): + if perm: + # Check that the two dimensions are swapped + if perm != [1, 0]: + return check_result.fail( + "Permutation values for Transpose are not correct." + ) + # If perm is not defined, the default transpose behavior is to swap + # all dimensions, which is correct for MatMul with rank = 2. + else: + return check_result.fail("Rank must be 2 for MatMulTranspose rule.") return check_result - def rewrite(self, op, x, y): - node = list(x.uses())[0][0] # noqa: RUF015 + def rewrite(self, op, x, y, fused: ir.Value | None = None, **_): kwargs = {} - for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]: - att = node.attributes.get(name) - if att: - kwargs[name] = att.value + if fused: + fused_node = _get_node(fused, "FusedMatMul") + kwargs = _get_kwargs(fused_node) for name in ["transA", "transB"]: kwargs[name] = 1 - kwargs.get(name, 0) return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft") class FusedMatMulTranspose(MatMulTranspose): - """Replaces ``MatMul + Transpose`` by FusedMatMul.""" + """Replaces ``FusedMatMul + Transpose`` with FusedMatMul.""" def pattern(self, op, x, y): - return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft")) + return op.Transpose( + op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), + _outputs=["transposed"], + ) def fused_matmul_rule_sets() -> orp.RewriteRuleSet: @@ -165,5 +344,11 @@ def fused_matmul_rule_sets() -> orp.RewriteRuleSet: TransposeFusedMatMul1.rule(), TransposeMatMul2.rule(), TransposeFusedMatMul2.rule(), + TransposeFusedMatMulWithFlippedBatch1.rule(), + TransposeFusedMatMulWithFlippedBatch2.rule(), + TransposeFusedMatMulWithFlippedBatchAndTranspose1.rule(), + TransposeFusedMatMulWithFlippedBatchAndTranspose2.rule(), + TransposeFusedMatMulWithBatchAndTranspose1.rule(), + TransposeFusedMatMulWithBatchAndTranspose2.rule(), ] ) diff --git a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py index 04210e8537..6bd4b7fe81 100644 --- a/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py +++ b/onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets_test.py @@ -3,17 +3,21 @@ from __future__ import annotations import unittest -from typing import Any +from typing import Any, Tuple import numpy as np import onnx import onnx.reference import onnx.reference.op_run +import parameterized +import onnxscript.ir.passes.common as common_passes import onnxscript.rewriter.ort_fusions.fused_matmul_rule_sets as fused_matmul_rule_sets -from onnxscript import ir +from onnxscript import FLOAT, ir, script +from onnxscript.onnx_opset import opset18 as op +from onnxscript.values import Opset -FLOAT = onnx.TensorProto.FLOAT +ms_op = Opset("com.microsoft", 1) class FusedMatMul(onnx.reference.op_run.OpRun): @@ -29,8 +33,23 @@ def _run( transBatchA: int = 0, transBatchB: int = 0, ): - assert transBatchA == 0, f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}" - assert transBatchB == 0, f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}" + if transBatchA != 0 or transBatchB != 0: + assert len(A.shape) >= 3 and len(B.shape) >= 3, ( + f"Batch dimensions must be at least 3 for A: {A.shape} and B: {B.shape}" + ) + assert len(A.shape) == len(B.shape), ( + f"Batch dimensions must match for A: {A.shape} and B: {B.shape}" + ) + if transBatchA: + perm = list(range(len(A.shape))) + dim = len(perm) + perm = [*perm[1 : dim - 1], perm[0], perm[dim - 1]] + A = np.transpose(A, perm) + if transBatchB: + perm = list(range(len(B.shape))) + dim = len(perm) + perm = [*perm[1 : dim - 1], perm[0], perm[dim - 1]] + B = np.transpose(B, perm) if transA: perm = list(range(len(A.shape))) dim = len(perm) @@ -45,7 +64,193 @@ def _run( return (np.matmul(A, B) * a,) -class OrtRuleSetsTest(unittest.TestCase): +@script() +def _fused_matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.6 + ab = ms_op.FusedMatMul(A, B, alpha=0.4, transA=1) + out = op.Div(ab, C) + return out + + +@script() +def _matmul_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.8 + ab = op.MatMul(A, B) + out = op.Div(ab, C) + return out + + +@script() +def _matmul_div_div(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + C = 0.6 + ab = op.MatMul(A, B) + abd = op.Div(ab, C) + out = op.Div(abd, C) + return out + + +@script() +def _fused_matmul_transpose(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + ab = ms_op.FusedMatMul(A, B, alpha=0.5) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _matmul_transpose(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + ab = op.MatMul(A, B) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _transpose_matmul_1(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + At = op.Transpose(A, perm=[1, 0]) + out = op.MatMul(At, B) + return out + + +@script() +def _transpose_fused_matmul_1(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + At = op.Transpose(A, perm=[1, 0]) + out = ms_op.FusedMatMul(At, B) + return out + + +@script() +def _transpose_matmul_2(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + Bt = op.Transpose(B, perm=[1, 0]) + out = op.MatMul(A, Bt) + return out + + +@script() +def _transpose_fused_matmul_2(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> FLOAT[4, 4]: + Bt = op.Transpose(B, perm=[1, 0]) + out = ms_op.FusedMatMul(A, Bt) + return out + + +@script() +def _should_not_match(A: FLOAT[4, 4], B: FLOAT[4, 4]) -> Tuple[FLOAT[4, 4], FLOAT[4, 4]]: + At = op.Transpose(A, perm=[1, 0]) + ab = op.MatMul(At, B) + C = op.Transpose(At, perm=[1, 0]) + return ab, C + + +# Add unit tests to check that fusion rewrite can work even if MatMul is not the first node. +@script() +def _fused_matmul_with_identity_before_matmul(A: FLOAT[4, 4]) -> FLOAT[4, 4]: + B = op.Identity(A) + ab = op.MatMul(A, B) + out = op.Transpose(ab, perm=[1, 0]) + return out + + +@script() +def _fused_matmul_with_identity_before_transpose(A: FLOAT[4, 4]) -> FLOAT[4, 4]: + B = op.Identity(A) + ab = op.Transpose(A, perm=[1, 0]) + out = op.MatMul(ab, B) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_0_and_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[1, 2, 3, 0]) + out = ms_op.FusedMatMul(Xt, Y, alpha=0.5, transA=0, transB=0, transBatchA=0, transBatchB=0) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_1_and_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[3, 0, 1, 2]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_0( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[1, 2, 0, 3]) + out = ms_op.FusedMatMul(Xt, Y) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchA_1( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[2, 0, 1, 3]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transA( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Xt = op.Transpose(X, perm=[3, 1, 2, 0]) + out = ms_op.FusedMatMul(Xt, Y, transBatchA=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_0_and_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[1, 2, 3, 0]) + out = ms_op.FusedMatMul(X, Yt) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_1_and_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[3, 0, 1, 2]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_0( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[1, 2, 0, 3]) + out = ms_op.FusedMatMul(X, Yt) + return out + + +@script() +def _transpose_fused_matmul_flip_transBatchB_1( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[2, 0, 1, 3]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +@script() +def _transpose_fused_matmul_flip_transB( + X: FLOAT[4, 4, 4, 4], Y: FLOAT[4, 4, 4, 4] +) -> FLOAT[4, 4, 4, 4]: + Yt = op.Transpose(Y, perm=[3, 1, 2, 0]) + out = ms_op.FusedMatMul(X, Yt, transBatchB=1) + return out + + +class TestFusedMatmulRules(unittest.TestCase): + def _apply_fusion_rules(self, ir_model: ir.Model): + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: feeds: dict[str, Any] = {} for i in model.graph.input: @@ -57,7 +262,10 @@ def _get_random_inputs(self, model: onnx.ModelProto) -> dict[str, Any]: (d.dim_value if d.dim_value > 0 else i + 2) for i, d in enumerate(ish) ) if i.type.tensor_type.elem_type == onnx.TensorProto.FLOAT: - feeds[i.name] = np.random.randn(*shape).astype(np.float32) + if shape: + feeds[i.name] = np.random.randn(*shape).astype(np.float32) + else: + feeds[i.name] = np.random.randn(1).astype(np.float32) else: raise AssertionError(f"Not implemented for input {i}") return feeds @@ -80,283 +288,160 @@ def _check_model( for a, b in zip(expected, got): np.testing.assert_allclose(a, b, atol=atol, rtol=rtol) - @classmethod - def _fused_matmul_div_models(cls): - models = [ - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node( - "FusedMatMul", - ["X", "Y"], - ["xyc"], - transA=1, - transB=0, - alpha=0.4, - transBatchA=0, - transBatchB=0, - domain="com.microsoft", - ), - onnx.helper.make_node("Div", ["xyc", "D"], ["Z"]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [6, "a"]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - [ - onnx.numpy_helper.from_array( - np.array([0.8], dtype=np.float32), name="D" - ), - ], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_div", + _fused_matmul_div, + [FLOAT[6, "a"], FLOAT[6, "b"]], + [FLOAT[None, None]], ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), - onnx.helper.make_node("Div", ["xy", "C"], ["Z"]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, ["a", 6]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - [ - onnx.numpy_helper.from_array( - np.array([0.6], dtype=np.float32), name="C" - ) - ], - ), - opset_imports=[onnx.helper.make_opsetid("", 18)], + ( + "matmul_div", + _matmul_div, + [FLOAT["a", 6], FLOAT[6, "b"]], + [FLOAT[None, None]], ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), - onnx.helper.make_node("Div", ["xy", "C"], ["xyc"]), - onnx.helper.make_node("Div", ["xyc", "D"], ["Z"]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, ["a", 6]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [6, "b"]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - [ - onnx.numpy_helper.from_array( - np.array([0.6], dtype=np.float32), name="C" - ), - onnx.numpy_helper.from_array( - np.array([0.8], dtype=np.float32), name="D" - ), - ], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - ], + ( + "matmul_div_div", + _matmul_div_div, + [FLOAT["a", 6], FLOAT[6, "b"]], + [FLOAT[None, None]], ), ] - return models - - def test_ort_rule_set_fused_matmul_div(self): - for model_proto in self._fused_matmul_div_models(): - ir_model = ir.serde.deserialize_model(model_proto) - rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() - rule_set.apply_to_model(ir_model) - rewritten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual(["FusedMatMul"], [n.op_type for n in rewritten_model.graph.node]) - self._check_model(model_proto, rewritten_model, atol=1e-6) - - @classmethod - def _transposed_fused_matmul_div_models(cls): - models = [ - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node( - "FusedMatMul", - ["X", "Y"], - ["xy"], - domain="com.microsoft", - alpha=0.5, - ), - onnx.helper.make_node("Transpose", ["xy"], ["Z"], perm=[1, 0]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ) + def test_fused_matmul_div_models(self, name, script_func, input_types, output_types): + model_proto = script_func.to_model_proto( + input_types=input_types, + output_types=output_types, + ) + ir_model = ir.serde.deserialize_model(model_proto) + rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() + rule_set.apply_to_model(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual(["Constant", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_transpose", + _fused_matmul_transpose, ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("MatMul", ["X", "Y"], ["xy"]), - onnx.helper.make_node("Transpose", ["xy"], ["Z"], perm=[1, 0]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ( + "matmul_transpose", + _matmul_transpose, ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), - onnx.helper.make_node("MatMul", ["Xt", "Y"], ["Z"]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ( + "transpose_matmul_1", + _transpose_matmul_1, ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), - onnx.helper.make_node( - "FusedMatMul", - ["Xt", "Y"], - ["Z"], - domain="com.microsoft", - alpha=0.5, - ), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ( + "transpose_fused_matmul_1", + _transpose_fused_matmul_1, ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["Y"], ["Yt"], perm=[1, 0]), - onnx.helper.make_node("MatMul", ["X", "Yt"], ["Z"]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ("transpose_matmul_2", _transpose_matmul_2), + ( + "transpose_fused_matmul_2", + _transpose_fused_matmul_2, ), - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["Y"], ["Yt"], perm=[1, 0]), - onnx.helper.make_node( - "FusedMatMul", - ["X", "Yt"], - ["Z"], - domain="com.microsoft", - alpha=0.5, - ), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None])], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + ] + ) + def test_fused_matmul_with_transpose(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4], FLOAT[4, 4]], output_types=[FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand([("should_not_match", _should_not_match)]) + def test_should_not_match(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4], FLOAT[4, 4]], output_types=[FLOAT[4, 4], FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual( + ["Transpose", "MatMul", "Transpose"], + [n.op_type for n in ir_model.graph], + ) + self._check_model(model_proto, rewritten_model, atol=1e-6) + + @parameterized.parameterized.expand( + [ + ( + "fused_matmul_with_identity_before_matmul", + _fused_matmul_with_identity_before_matmul, + ), + ( + "fused_matmul_with_identity_before_transpose", + _fused_matmul_with_identity_before_transpose, ), ] - return models + ) + def test_fused_matmul_with_other_node_in_middle(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4]], output_types=[FLOAT[4, 4]] + ) + ir_model = ir.serde.deserialize_model(model_proto) + common_passes.ShapeInferencePass()(ir_model) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual(["Identity", "FusedMatMul"], [n.op_type for n in ir_model.graph]) + self._check_model(model_proto, rewritten_model, atol=1e-6) - def test_ort_rule_set_transpose_fused_matmul_div(self): - rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() - for model_proto in self._transposed_fused_matmul_div_models(): - ir_model = ir.serde.deserialize_model(model_proto) - rule_set.apply_to_model(ir_model) - rewritten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual(["FusedMatMul"], [n.op_type for n in rewritten_model.graph.node]) - self._check_model(model_proto, rewritten_model, atol=1e-6) - - @classmethod - def _should_not_match(cls): - models = [ - onnx.helper.make_model( - onnx.helper.make_graph( - [ - onnx.helper.make_node("Transpose", ["X"], ["Xt"], perm=[1, 0]), - onnx.helper.make_node("MatMul", ["Xt", "Y"], ["Z"]), - onnx.helper.make_node("Transpose", ["Xt"], ["W"], perm=[1, 0]), - ], - "name", - [ - onnx.helper.make_tensor_value_info("X", FLOAT, [4, 4]), - onnx.helper.make_tensor_value_info("Y", FLOAT, [4, 4]), - ], - [ - onnx.helper.make_tensor_value_info("Z", FLOAT, [None, None]), - onnx.helper.make_tensor_value_info("W", FLOAT, [None, None]), - ], - ), - opset_imports=[ - onnx.helper.make_opsetid("", 18), - onnx.helper.make_opsetid("com.microsoft", 1), - ], + @parameterized.parameterized.expand( + [ + ( + "transpose_fused_matmul_flip_transBatchA_0_and_transA", + _transpose_fused_matmul_flip_transBatchA_0_and_transA, ), + ( + "transpose_fused_matmul_flip_transBatchA_1_and_transA", + _transpose_fused_matmul_flip_transBatchA_1_and_transA, + ), + ( + "transpose_fused_matmul_flip_transBatchA_0", + _transpose_fused_matmul_flip_transBatchA_0, + ), + ( + "transpose_fused_matmul_flip_transBatchA_1", + _transpose_fused_matmul_flip_transBatchA_1, + ), + ("transpose_fused_matmul_flip_transA", _transpose_fused_matmul_flip_transA), + ( + "transpose_fused_matmul_flip_transBatchB_0_and_transB", + _transpose_fused_matmul_flip_transBatchB_0_and_transB, + ), + ( + "transpose_fused_matmul_flip_transBatchB_1_and_transB", + _transpose_fused_matmul_flip_transBatchB_1_and_transB, + ), + ( + "transpose_fused_matmul_flip_transBatchB_0", + _transpose_fused_matmul_flip_transBatchB_0, + ), + ( + "transpose_fused_matmul_flip_transBatchB_1", + _transpose_fused_matmul_flip_transBatchB_1, + ), + ("transpose_fused_matmul_flip_transB", _transpose_fused_matmul_flip_transB), ] - return models - - def test_should_not_match(self): - for model_proto in self._should_not_match(): - ir_model = ir.serde.deserialize_model(model_proto) - rule_set = fused_matmul_rule_sets.fused_matmul_rule_sets() - rule_set.apply_to_model(ir_model) - rewritten_model = ir.serde.serialize_model(ir_model) - - self.assertEqual( - ["Transpose", "MatMul", "Transpose"], - [n.op_type for n in rewritten_model.graph.node], - ) - self._check_model(model_proto, rewritten_model, atol=1e-6) + ) + def test_transpose_fused_matmul_with_batch(self, _, script_func): + model_proto = script_func.to_model_proto( + input_types=[FLOAT[4, 4, 4, 4], FLOAT[4, 4, 4, 4]], + output_types=[FLOAT[4, 4, 4, 4]], + ) + ir_model = ir.serde.deserialize_model(model_proto) + self._apply_fusion_rules(ir_model) + rewritten_model = ir.serde.serialize_model(ir_model) + self.assertEqual(["FusedMatMul"], [n.op_type for n in ir_model.graph]) + self._check_model(model_proto, rewritten_model, atol=1e-6) if __name__ == "__main__":