55from typing import ClassVar
66
77import onnxscript .rewriter .pattern as orp
8+ from onnxscript import ir
89
910
1011class FusedMatMulDiv1 (orp .RewriteRuleClassBase ):
11- """Replaces ``MatMul + Div`` by FusedMatMul."""
12+ """Replaces ``MatMul + Div`` with FusedMatMul."""
1213
1314 def pattern (self , op , x , y , cst ):
1415 return op .Div (op .MatMul (x , y ), cst )
@@ -29,27 +30,26 @@ def rewrite(self, op, x, y, cst):
2930
3031
3132class FusedMatMulDiv2 (orp .RewriteRuleClassBase ):
32- """Replaces ``FusedMatMul + Div`` by FusedMatMul."""
33+ """Replaces ``FusedMatMul + Div`` with FusedMatMul."""
3334
3435 def pattern (self , op , x , y , cst ):
35- return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" ), cst )
36+ return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = [ "fused" ] ), cst )
3637
37- def check (self , context , x , y , cst ) -> orp .MatchResult :
38+ def check (self , context , x , y , cst , fused : ir . Value ) -> orp .MatchResult :
3839 check_result = orp .MatchResult ()
3940 if cst .const_value is None :
4041 return check_result .fail ("Divisor is not a constant value." )
4142 if cst .const_value .numpy ().size > 1 :
4243 return check_result .fail ("Divisor is not a scalar value." )
4344 return check_result
4445
45- def rewrite (self , op , x , y , cst ):
46+ def rewrite (self , op , x , y , cst , fused : ir . Value ):
4647 value = cst .const_value .numpy ()
4748 c = float (value [0 ] if value .shape == (1 ,) else value )
48- node = list ( x . uses ())[ 0 ][ 0 ] # noqa: RUF015
49+ node : ir . Node = fused . producer () # type: ignore[assignment]
4950
5051 kwargs = {}
51- alpha = node .attributes .get ("alpha" , None )
52- kwargs ["alpha" ] = alpha .value / c if alpha else 1.0 / c
52+ kwargs ["alpha" ] = node .attributes ["alpha" ].as_float () / c
5353 for name in ["transA" , "transB" , "transBatchA" , "transBatchB" ]:
5454 att = node .attributes .get (name )
5555 if att :
@@ -60,91 +60,106 @@ def rewrite(self, op, x, y, cst):
6060class _TransposeMatMulBase (orp .RewriteRuleClassBase ):
6161 _pos : ClassVar = 1
6262
63- def check (self , context , x , y ) -> orp .MatchResult :
63+ def check (self , context , x , y , transposed : ir . Value , ** _ ) -> orp .MatchResult :
6464 check_result = orp .MatchResult ()
65- perm = list ((x if self ._pos == 1 else y ).uses ())[0 ][0 ].attributes ["perm" ].value # noqa: RUF015
65+ node : ir .Node = transposed .producer () # type: ignore[assignment]
66+ perm = node .attributes ["perm" ].as_ints ()
6667 expected_perm = list (range (len (perm )))
6768 expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
6869 if perm != expected_perm :
6970 return check_result .fail ("Permutation values for Transpose are not correct." )
7071 return check_result
7172
72- def rewrite (self , op , x , y ):
73- node = list ((x if self ._pos == 2 else y ).uses ())[0 ][0 ] # noqa: RUF015
73+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
7474 kwargs = {}
75- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
76- att = node .attributes .get (name )
77- if att :
78- kwargs [name ] = att .value
75+ if fused :
76+ node : ir .Node = fused .producer () # type: ignore[assignment]
77+ for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
78+ att = node .attributes .get (name )
79+ if att :
80+ kwargs [name ] = att .value
7981 name = "transA" if self ._pos == 1 else "transB"
8082 kwargs [name ] = 1 - kwargs .get (name , 0 )
8183 return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
8284
8385
8486class TransposeMatMul1 (_TransposeMatMulBase ):
85- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
87+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
8688
8789 def pattern (self , op , x , y ):
88- return op .MatMul (op .Transpose (x ), y )
90+ return op .MatMul (op .Transpose (x , _outputs = [ "transposed" ] ), y )
8991
9092
9193class TransposeFusedMatMul1 (TransposeMatMul1 ):
92- """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
94+ """Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""
9395
9496 def pattern (self , op , x , y ):
95- return op .FusedMatMul (op .Transpose (x ), y , _domain = "com.microsoft" )
97+ return op .FusedMatMul (
98+ op .Transpose (x , _outputs = ["transposed" ]),
99+ y ,
100+ _domain = "com.microsoft" ,
101+ _outputs = ["fused" ],
102+ )
96103
97104
98105class TransposeMatMul2 (_TransposeMatMulBase ):
99- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
106+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
100107
101108 _pos : ClassVar = 2
102109
103110 def pattern (self , op , x , y ):
104- return op .MatMul (x , op .Transpose (y ))
111+ return op .MatMul (x , op .Transpose (y , _outputs = [ "transposed" ] ))
105112
106113
107114class TransposeFusedMatMul2 (TransposeMatMul2 ):
108- """Replaces ``Transpose + (Fused)MatMul`` by FusedMatMul."""
115+ """Replaces ``Transpose + (Fused)MatMul`` with FusedMatMul."""
109116
110117 def pattern (self , op , x , y ):
111- return op .FusedMatMul (x , op .Transpose (y ), _domain = "com.microsoft" )
118+ return op .FusedMatMul (
119+ x ,
120+ op .Transpose (y , _outputs = ["transposed" ]),
121+ _domain = "com.microsoft" ,
122+ _outputs = ["fused" ],
123+ )
112124
113125
114126class MatMulTranspose (orp .RewriteRuleClassBase ):
115- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
127+ """Replaces ``MatMul + Transpose`` with FusedMatMul."""
116128
117129 def pattern (self , op , x , y ):
118- return op .Transpose (op .MatMul (x , y ))
130+ return op .Transpose (op .MatMul (x , y ), _outputs = [ "transposed" ] )
119131
120- def check (self , context , x , y ) -> orp .MatchResult :
132+ def check (self , context , x , y , transposed : ir . Value , ** _ ) -> orp .MatchResult :
121133 check_result = orp .MatchResult ()
122- matmul = list (x .uses ())[0 ][0 ] # noqa: RUF015
123- transpose = list (matmul .outputs [0 ].uses ())[0 ][0 ] # noqa: RUF015
124- perm = transpose .attributes ["perm" ].value
134+ transpose : ir .Node = transposed .producer () # type: ignore[assignment]
135+ perm = transpose .attributes ["perm" ].as_ints ()
125136 expected_perm = list (range (len (perm )))
126137 expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
127138 if perm != expected_perm :
128139 return check_result .fail ("Permutation values for Transpose are not correct." )
129140 return check_result
130141
131- def rewrite (self , op , x , y ):
132- node = list (x .uses ())[0 ][0 ] # noqa: RUF015
142+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
133143 kwargs = {}
134- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
135- att = node .attributes .get (name )
136- if att :
137- kwargs [name ] = att .value
144+ if fused :
145+ node : ir .Node = fused .producer () # type: ignore[assignment]
146+ for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
147+ att = node .attributes .get (name )
148+ if att :
149+ kwargs [name ] = att .value
138150 for name in ["transA" , "transB" ]:
139151 kwargs [name ] = 1 - kwargs .get (name , 0 )
140152 return op .FusedMatMul (y , x , ** kwargs , _domain = "com.microsoft" )
141153
142154
143155class FusedMatMulTranspose (MatMulTranspose ):
144- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
156+ """Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""
145157
146158 def pattern (self , op , x , y ):
147- return op .Transpose (op .FusedMatMul (x , y , _domain = "com.microsoft" ))
159+ return op .Transpose (
160+ op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = ["fused" ]),
161+ _outputs = ["transposed" ],
162+ )
148163
149164
150165def fused_matmul_rule_sets () -> orp .RewriteRuleSet :
0 commit comments