Skip to content

Commit 71c1d5c

Browse files
committed
Use producer syntax, simplify, rm unnecessary checks
1 parent 8540282 commit 71c1d5c

File tree

1 file changed

+53
-38
lines changed

1 file changed

+53
-38
lines changed

onnxscript/rewriter/ort_fusions/fused_matmul_rule_sets.py

Lines changed: 53 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
from typing import ClassVar
66

77
import onnxscript.rewriter.pattern as orp
8+
from onnxscript import ir
89

910

1011
class FusedMatMulDiv1(orp.RewriteRuleClassBase):
11-
"""Replaces ``MatMul + Div`` by FusedMatMul."""
12+
"""Replaces ``MatMul + Div`` with FusedMatMul."""
1213

1314
def pattern(self, op, x, y, cst):
1415
return op.Div(op.MatMul(x, y), cst)
@@ -29,27 +30,26 @@ def rewrite(self, op, x, y, cst):
2930

3031

3132
class FusedMatMulDiv2(orp.RewriteRuleClassBase):
32-
"""Replaces ``FusedMatMul + Div`` by FusedMatMul."""
33+
"""Replaces ``FusedMatMul + Div`` with FusedMatMul."""
3334

3435
def pattern(self, op, x, y, cst):
35-
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft"), cst)
36+
return op.Div(op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]), cst)
3637

37-
def check(self, context, x, y, cst) -> orp.MatchResult:
38+
def check(self, context, x, y, cst, fused: ir.Value) -> orp.MatchResult:
3839
check_result = orp.MatchResult()
3940
if cst.const_value is None:
4041
return check_result.fail("Divisor is not a constant value.")
4142
if cst.const_value.numpy().size > 1:
4243
return check_result.fail("Divisor is not a scalar value.")
4344
return check_result
4445

45-
def rewrite(self, op, x, y, cst):
46+
def rewrite(self, op, x, y, cst, fused: ir.Value):
4647
value = cst.const_value.numpy()
4748
c = float(value[0] if value.shape == (1,) else value)
48-
node = list(x.uses())[0][0] # noqa: RUF015
49+
node: ir.Node = fused.producer() # type: ignore[assignment]
4950

5051
kwargs = {}
51-
alpha = node.attributes.get("alpha", None)
52-
kwargs["alpha"] = alpha.value / c if alpha else 1.0 / c
52+
kwargs["alpha"] = node.attributes["alpha"].as_float() / c
5353
for name in ["transA", "transB", "transBatchA", "transBatchB"]:
5454
att = node.attributes.get(name)
5555
if att:
@@ -60,91 +60,106 @@ def rewrite(self, op, x, y, cst):
6060
class _TransposeMatMulBase(orp.RewriteRuleClassBase):
6161
_pos: ClassVar = 1
6262

63-
def check(self, context, x, y) -> orp.MatchResult:
63+
def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
6464
check_result = orp.MatchResult()
65-
perm = list((x if self._pos == 1 else y).uses())[0][0].attributes["perm"].value # noqa: RUF015
65+
node: ir.Node = transposed.producer() # type: ignore[assignment]
66+
perm = node.attributes["perm"].as_ints()
6667
expected_perm = list(range(len(perm)))
6768
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
6869
if perm != expected_perm:
6970
return check_result.fail("Permutation values for Transpose are not correct.")
7071
return check_result
7172

72-
def rewrite(self, op, x, y):
73-
node = list((x if self._pos == 2 else y).uses())[0][0] # noqa: RUF015
73+
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
7474
kwargs = {}
75-
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
76-
att = node.attributes.get(name)
77-
if att:
78-
kwargs[name] = att.value
75+
if fused:
76+
node: ir.Node = fused.producer() # type: ignore[assignment]
77+
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
78+
att = node.attributes.get(name)
79+
if att:
80+
kwargs[name] = att.value
7981
name = "transA" if self._pos == 1 else "transB"
8082
kwargs[name] = 1 - kwargs.get(name, 0)
8183
return op.FusedMatMul(x, y, **kwargs, _domain="com.microsoft")
8284

8385

8486
class TransposeMatMul1(_TransposeMatMulBase):
85-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
87+
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""
8688

8789
def pattern(self, op, x, y):
88-
return op.MatMul(op.Transpose(x), y)
90+
return op.MatMul(op.Transpose(x, _outputs=["transposed"]), y)
8991

9092

9193
class TransposeFusedMatMul1(TransposeMatMul1):
92-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
94+
"""Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""
9395

9496
def pattern(self, op, x, y):
95-
return op.FusedMatMul(op.Transpose(x), y, _domain="com.microsoft")
97+
return op.FusedMatMul(
98+
op.Transpose(x, _outputs=["transposed"]),
99+
y,
100+
_domain="com.microsoft",
101+
_outputs=["fused"],
102+
)
96103

97104

98105
class TransposeMatMul2(_TransposeMatMulBase):
99-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
106+
"""Replaces ``Transpose + MatMul`` with FusedMatMul."""
100107

101108
_pos: ClassVar = 2
102109

103110
def pattern(self, op, x, y):
104-
return op.MatMul(x, op.Transpose(y))
111+
return op.MatMul(x, op.Transpose(y, _outputs=["transposed"]))
105112

106113

107114
class TransposeFusedMatMul2(TransposeMatMul2):
108-
"""Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
115+
"""Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""
109116

110117
def pattern(self, op, x, y):
111-
return op.FusedMatMul(x, op.Transpose(y), _domain="com.microsoft")
118+
return op.FusedMatMul(
119+
x,
120+
op.Transpose(y, _outputs=["transposed"]),
121+
_domain="com.microsoft",
122+
_outputs=["fused"],
123+
)
112124

113125

114126
class MatMulTranspose(orp.RewriteRuleClassBase):
115-
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
127+
"""Replaces ``MatMul + Transpose`` with FusedMatMul."""
116128

117129
def pattern(self, op, x, y):
118-
return op.Transpose(op.MatMul(x, y))
130+
return op.Transpose(op.MatMul(x, y), _outputs=["transposed"])
119131

120-
def check(self, context, x, y) -> orp.MatchResult:
132+
def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
121133
check_result = orp.MatchResult()
122-
matmul = list(x.uses())[0][0] # noqa: RUF015
123-
transpose = list(matmul.outputs[0].uses())[0][0] # noqa: RUF015
124-
perm = transpose.attributes["perm"].value
134+
transpose: ir.Node = transposed.producer() # type: ignore[assignment]
135+
perm = transpose.attributes["perm"].as_ints()
125136
expected_perm = list(range(len(perm)))
126137
expected_perm[-2], expected_perm[-1] = expected_perm[-1], expected_perm[-2]
127138
if perm != expected_perm:
128139
return check_result.fail("Permutation values for Transpose are not correct.")
129140
return check_result
130141

131-
def rewrite(self, op, x, y):
132-
node = list(x.uses())[0][0] # noqa: RUF015
142+
def rewrite(self, op, x, y, fused: ir.Value | None = None, **_):
133143
kwargs = {}
134-
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
135-
att = node.attributes.get(name)
136-
if att:
137-
kwargs[name] = att.value
144+
if fused:
145+
node: ir.Node = fused.producer() # type: ignore[assignment]
146+
for name in ["alpha", "transA", "transB", "transBatchA", "transBatchB"]:
147+
att = node.attributes.get(name)
148+
if att:
149+
kwargs[name] = att.value
138150
for name in ["transA", "transB"]:
139151
kwargs[name] = 1 - kwargs.get(name, 0)
140152
return op.FusedMatMul(y, x, **kwargs, _domain="com.microsoft")
141153

142154

143155
class FusedMatMulTranspose(MatMulTranspose):
144-
"""Replaces ``MatMul + Transpose`` by FusedMatMul."""
156+
"""Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""
145157

146158
def pattern(self, op, x, y):
147-
return op.Transpose(op.FusedMatMul(x, y, _domain="com.microsoft"))
159+
return op.Transpose(
160+
op.FusedMatMul(x, y, _domain="com.microsoft", _outputs=["fused"]),
161+
_outputs=["transposed"],
162+
)
148163

149164

150165
def fused_matmul_rule_sets() -> orp.RewriteRuleSet:

0 commit comments

Comments
 (0)