55from typing import ClassVar
66
77import onnxscript .rewriter .pattern as orp
8+ from onnxscript import ir
9+ from onnxscript .rewriter import _ir_utils
10+
11+
12+ def _get_node (value : ir .Value , name : str ) -> ir .Node :
13+ """Get the node from the output value."""
14+ node = value .producer ()
15+ assert node is not None , f"{ name } node should not be None"
16+ return node
17+
18+
19+ def _get_kwargs (node : ir .Node ) -> dict [str , float | int ]:
20+ """Get the kwargs from the node."""
21+ kwargs = {key : val .value for key , val in node .attributes .items ()}
22+ return kwargs
823
924
1025class FusedMatMulDiv1 (orp .RewriteRuleClassBase ):
11- """Replaces ``MatMul + Div`` by FusedMatMul ."""
26+ """Replaces ``MatMul + Div`` with MatMul ."""
1227
1328 def pattern (self , op , x , y , cst ):
1429 return op .Div (op .MatMul (x , y ), cst )
@@ -29,122 +44,286 @@ def rewrite(self, op, x, y, cst):
2944
3045
3146class FusedMatMulDiv2 (orp .RewriteRuleClassBase ):
32- """Replaces ``FusedMatMul + Div`` by FusedMatMul."""
47+ """Replaces ``FusedMatMul + Div`` with FusedMatMul."""
3348
3449 def pattern (self , op , x , y , cst ):
35- return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" ), cst )
50+ return op .Div (op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = [ "fused" ] ), cst )
3651
37- def check (self , context , x , y , cst ) -> orp .MatchResult :
52+ def check (self , context , x , y , cst , ** _ ) -> orp .MatchResult :
3853 check_result = orp .MatchResult ()
3954 if cst .const_value is None :
4055 return check_result .fail ("Divisor is not a constant value." )
4156 if cst .const_value .numpy ().size > 1 :
4257 return check_result .fail ("Divisor is not a scalar value." )
4358 return check_result
4459
45- def rewrite (self , op , x , y , cst ):
60+ def rewrite (self , op , x , y , cst , fused : ir . Value ):
4661 value = cst .const_value .numpy ()
4762 c = float (value [0 ] if value .shape == (1 ,) else value )
48- node = list (x .uses ())[0 ][0 ] # noqa: RUF015
49-
50- kwargs = {}
51- alpha = node .attributes .get ("alpha" , None )
52- kwargs ["alpha" ] = alpha .value / c if alpha else 1.0 / c
53- for name in ["transA" , "transB" , "transBatchA" , "transBatchB" ]:
54- att = node .attributes .get (name )
55- if att :
56- kwargs [name ] = att .value
63+ fused_node = _get_node (fused , "FusedMatMul" )
64+ kwargs = _get_kwargs (fused_node )
65+ kwargs ["alpha" ] = kwargs .get ("alpha" , 1.0 ) / c
5766 return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
5867
5968
6069class _TransposeMatMulBase (orp .RewriteRuleClassBase ):
6170 _pos : ClassVar = 1
6271
63- def check (self , context , x , y ) -> orp .MatchResult :
72+ def check (
73+ self , context , x , y , transposed : ir .Value , fused : ir .Value | None = None , ** _
74+ ) -> orp .MatchResult :
6475 check_result = orp .MatchResult ()
65- perm = list ((x if self ._pos == 1 else y ).uses ())[0 ][0 ].attributes ["perm" ].value # noqa: RUF015
66- expected_perm = list (range (len (perm )))
67- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
68- if perm != expected_perm :
69- return check_result .fail ("Permutation values for Transpose are not correct." )
76+ transposed_node = _get_node (transposed , "Transpose" )
77+ perm = transposed_node .attributes .get_ints ("perm" )
78+ if perm :
79+ # Check that last two dimensions are swapped
80+ expected_perm = list (range (len (perm )))
81+ expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
82+ if perm != expected_perm :
83+ return check_result .fail ("Permutation values for Transpose are not correct." )
84+ elif (self ._pos == 1 and not _ir_utils .has_rank (x , 2 )) or (
85+ self ._pos == 2 and not _ir_utils .has_rank (y , 2 )
86+ ):
87+ # If perm is not defined, the default transpose behavior is to swap
88+ # all dimensions, which is correct for MatMul with rank = 2.
89+ return check_result .fail (
90+ "If perm is not defined, rank must be 2 for TransposeMatMul rule."
91+ )
92+ if fused :
93+ fused_node = _get_node (fused , "FusedMatMul" )
94+ trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
95+ if fused_node .attributes .get_int (trans_batch_property , 0 ):
96+ return check_result .fail (
97+ "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
98+ )
7099 return check_result
71100
72- def rewrite (self , op , x , y ):
73- node = list ((x if self ._pos == 2 else y ).uses ())[0 ][0 ] # noqa: RUF015
101+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
74102 kwargs = {}
75- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
76- att = node .attributes .get (name )
77- if att :
78- kwargs [name ] = att .value
79- name = "transA" if self ._pos == 1 else "transB"
80- kwargs [name ] = 1 - kwargs .get (name , 0 )
103+ if fused :
104+ fused_node = _get_node (fused , "FusedMatMul" )
105+ kwargs = _get_kwargs (fused_node )
106+ trans_name = "transA" if self ._pos == 1 else "transB"
107+ kwargs [trans_name ] = 1 - kwargs .get (trans_name , 0 )
81108 return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
82109
83110
84111class TransposeMatMul1 (_TransposeMatMulBase ):
85- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
112+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
86113
87114 def pattern (self , op , x , y ):
88- return op .MatMul (op .Transpose (x ), y )
115+ return op .MatMul (op .Transpose (x , _outputs = [ "transposed" ] ), y )
89116
90117
91118class TransposeFusedMatMul1 (TransposeMatMul1 ):
92- """Replaces ``Transpose + (Fused)MatMul `` by FusedMatMul."""
119+ """Replaces ``Transpose + FusedMatMul `` with FusedMatMul."""
93120
94121 def pattern (self , op , x , y ):
95- return op .FusedMatMul (op .Transpose (x ), y , _domain = "com.microsoft" )
122+ return op .FusedMatMul (
123+ op .Transpose (x , _outputs = ["transposed" ]),
124+ y ,
125+ _domain = "com.microsoft" ,
126+ _outputs = ["fused" ],
127+ )
96128
97129
98130class TransposeMatMul2 (_TransposeMatMulBase ):
99- """Replaces ``Transpose + (Fused) MatMul`` by FusedMatMul."""
131+ """Replaces ``Transpose + MatMul`` with FusedMatMul."""
100132
101133 _pos : ClassVar = 2
102134
103135 def pattern (self , op , x , y ):
104- return op .MatMul (x , op .Transpose (y ))
136+ return op .MatMul (x , op .Transpose (y , _outputs = [ "transposed" ] ))
105137
106138
107139class TransposeFusedMatMul2 (TransposeMatMul2 ):
108- """Replaces ``Transpose + (Fused)MatMul `` by FusedMatMul."""
140+ """Replaces ``Transpose + FusedMatMul `` with FusedMatMul."""
109141
110142 def pattern (self , op , x , y ):
111- return op .FusedMatMul (x , op .Transpose (y ), _domain = "com.microsoft" )
143+ return op .FusedMatMul (
144+ x ,
145+ op .Transpose (y , _outputs = ["transposed" ]),
146+ _domain = "com.microsoft" ,
147+ _outputs = ["fused" ],
148+ )
149+
150+
151+ class _TransposeFusedMatMulBaseWithBatch (orp .RewriteRuleClassBase ):
152+ """Replaces ``Transpose + FusedMatMul`` with FusedMatMul, either
153+ when transBatchA or transBatchB in FusedMatMul is 1, or
154+ can be inverted based on the permutation dims of the Transpose, in
155+ contrast to the original FusedMatMul rule which assumes that
156+ transBatchA and transBatchB are always 0 before and after rewriting.
157+
158+ transBatchA = 1, transA = 0 applies a batch transpose by moving the first dimension to the second-to-last position
159+ i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-2, 0, N-1].
160+ transBatchA = 0, transA = 1 flips the last two dimensions
161+ i.e., equivalent to a Transpose with "perm" [0, 1, ... N-3, N-1, N-2].
162+ transBatchA = 1, transA = 1 applies a batch transpose, then flips the last two dimensions
163+ i.e., equivalent to a Transpose with "perm" [1, 2, ..., N-1, 0].
164+
165+ The flipping logic is based on the following cases:
166+ Case 1: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
167+ or transBatchA is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
168+ - Then transBatchA and transA can be flipped in FusedMatMul when rewriting.
169+ Case 2: transBatchA is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
170+ or transBatchA is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
171+ - Then transBatchA can be flipped in FusedMatMul when rewriting.
172+ Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
173+ - Then transA can be flipped in FusedMatMul when rewriting.
174+ The same logic applies for transBatchB and transB, when _pos is set to 2.
175+ The _flip_transpose_batch and _flip_transpose flags are used to control
176+ which case is applied by the rules of inheriting classes that change these class vars.
177+ """
178+
179+ _pos : ClassVar = 1
180+ _flip_transpose_batch : ClassVar = False
181+ _flip_transpose : ClassVar = False
182+
183+ def check (
184+ self , context , x , y , transposed : ir .Value , fused : ir .Value , ** _
185+ ) -> orp .MatchResult :
186+ check_result = orp .MatchResult ()
187+ fused_node = _get_node (fused , "FusedMatMul" )
188+ trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
189+ trans_batch = fused_node .attributes .get_int (trans_batch_property , 0 )
190+ transposed_node = _get_node (transposed , "Transpose" )
191+ perm = transposed_node .attributes ["perm" ].as_ints ()
192+ if not perm :
193+ return check_result .fail ("Permutation values for Transpose are not correct." )
194+
195+ list_perm = list (range (len (perm )))
196+ if self ._flip_transpose_batch and self ._flip_transpose :
197+ # Case 1: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-1, 0]
198+ # or transBatchA/B is 1, Transpose "perm" is [N-1, 0, 1, ..., N-2]
199+ # - Then transBatchA/B and transA/B can be flipped in FusedMatMul when rewriting.
200+ if trans_batch == 0 :
201+ expected_perm = [* list_perm [1 :], list_perm [0 ]]
202+ else :
203+ expected_perm = [list_perm [- 1 ], * list_perm [0 :- 1 ]]
204+ if expected_perm == perm :
205+ return check_result
206+ elif self ._flip_transpose_batch :
207+ # Case 2: transBatchA/B is 0, Transpose "perm" is [1, 2, ..., N-2, 0, N-1]
208+ # or transBatchA/B is 1, Transpose "perm" is [N-2, 0, 1, ..., N-3, N-1]
209+ # - Then transBatchA/B can be flipped in FusedMatMul when rewriting.
210+ if trans_batch == 0 :
211+ expected_perm = [* list_perm [1 :- 1 ], list_perm [0 ], list_perm [- 1 ]]
212+ else :
213+ expected_perm = [list_perm [- 2 ], * list_perm [0 :- 2 ], list_perm [- 1 ]]
214+ if expected_perm == perm :
215+ return check_result
216+ elif self ._flip_transpose :
217+ # Case 3: transBatchA is 1, Transpose "perm" is [N-1, 1, ..., N-2, 0]
218+ # - Then transA can be flipped in FusedMatMul when rewriting.
219+ expected_perm = [list_perm [- 1 ], * list_perm [1 :- 1 ], list_perm [0 ]]
220+ if expected_perm == perm and trans_batch == 1 :
221+ return check_result
222+
223+ return check_result .fail ("Permutation values for Transpose are not correct." )
224+
225+ def rewrite (self , op , x , y , fused : ir .Value , ** _ ):
226+ kwargs = {}
227+ fused_node = _get_node (fused , "FusedMatMul" )
228+ kwargs = _get_kwargs (fused_node )
229+ name = "A" if self ._pos == 1 else "B"
230+ if self ._flip_transpose_batch :
231+ trans_batch_property = f"transBatch{ name } "
232+ kwargs [trans_batch_property ] = 1 - kwargs .get (trans_batch_property , 0 )
233+ if self ._flip_transpose :
234+ trans_property = f"trans{ name } "
235+ kwargs [trans_property ] = 1 - kwargs .get (trans_property , 0 )
236+ return op .FusedMatMul (x , y , ** kwargs , _domain = "com.microsoft" )
237+
238+ def pattern (self , op , x , y ):
239+ if self ._pos == 1 :
240+ return op .FusedMatMul (
241+ op .Transpose (x , _outputs = ["transposed" ]),
242+ y ,
243+ _domain = "com.microsoft" ,
244+ _outputs = ["fused" ],
245+ )
246+ else :
247+ return op .FusedMatMul (
248+ x ,
249+ op .Transpose (y , _outputs = ["transposed" ]),
250+ _domain = "com.microsoft" ,
251+ _outputs = ["fused" ],
252+ )
253+
254+
255+ class TransposeFusedMatMulWithFlippedBatchAndTranspose1 (_TransposeFusedMatMulBaseWithBatch ):
256+ _flip_transpose = True
257+ _flip_transpose_batch = True
258+
259+
260+ class TransposeFusedMatMulWithFlippedBatchAndTranspose2 (_TransposeFusedMatMulBaseWithBatch ):
261+ _pos = 2
262+ _flip_transpose = True
263+ _flip_transpose_batch = True
264+
265+
266+ class TransposeFusedMatMulWithFlippedBatch1 (_TransposeFusedMatMulBaseWithBatch ):
267+ _flip_transpose_batch = True
268+
269+
270+ class TransposeFusedMatMulWithFlippedBatch2 (_TransposeFusedMatMulBaseWithBatch ):
271+ _pos = 2
272+ _flip_transpose_batch = True
273+
274+
275+ class TransposeFusedMatMulWithBatchAndTranspose1 (_TransposeFusedMatMulBaseWithBatch ):
276+ _flip_transpose = True
277+
278+
279+ class TransposeFusedMatMulWithBatchAndTranspose2 (_TransposeFusedMatMulBaseWithBatch ):
280+ _pos = 2
281+ _flip_transpose = True
112282
113283
114284class MatMulTranspose (orp .RewriteRuleClassBase ):
115- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
285+ """Replaces ``MatMul + Transpose`` with FusedMatMul."""
116286
117287 def pattern (self , op , x , y ):
118- return op .Transpose (op .MatMul (x , y ))
288+ return op .Transpose (op .MatMul (x , y ), _outputs = [ "transposed" ] )
119289
120- def check (self , context , x , y ) -> orp .MatchResult :
290+ def check (self , context , x , y , transposed : ir . Value , ** _ ) -> orp .MatchResult :
121291 check_result = orp .MatchResult ()
122- matmul = list (x .uses ())[0 ][0 ] # noqa: RUF015
123- transpose = list (matmul .outputs [0 ].uses ())[0 ][0 ] # noqa: RUF015
124- perm = transpose .attributes ["perm" ].value
125- expected_perm = list (range (len (perm )))
126- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
127- if perm != expected_perm :
128- return check_result .fail ("Permutation values for Transpose are not correct." )
292+ transpose_node = _get_node (transposed , "Transpose" )
293+ perm = transpose_node .attributes .get_ints ("perm" )
294+ # transA/transB only work on the last two dimensions of the input,
295+ # so we can only apply this rule if the inputs are rank 2.
296+ if _ir_utils .has_rank (x , 2 ) and _ir_utils .has_rank (y , 2 ):
297+ if perm :
298+ # Check that the two dimensions are swapped
299+ if perm != [1 , 0 ]:
300+ return check_result .fail (
301+ "Permutation values for Transpose are not correct."
302+ )
303+ # If perm is not defined, the default transpose behavior is to swap
304+ # all dimensions, which is correct for MatMul with rank = 2.
305+ else :
306+ return check_result .fail ("Rank must be 2 for MatMulTranspose rule." )
129307 return check_result
130308
131- def rewrite (self , op , x , y ):
132- node = list (x .uses ())[0 ][0 ] # noqa: RUF015
309+ def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
133310 kwargs = {}
134- for name in ["alpha" , "transA" , "transB" , "transBatchA" , "transBatchB" ]:
135- att = node .attributes .get (name )
136- if att :
137- kwargs [name ] = att .value
311+ if fused :
312+ fused_node = _get_node (fused , "FusedMatMul" )
313+ kwargs = _get_kwargs (fused_node )
138314 for name in ["transA" , "transB" ]:
139315 kwargs [name ] = 1 - kwargs .get (name , 0 )
140316 return op .FusedMatMul (y , x , ** kwargs , _domain = "com.microsoft" )
141317
142318
143319class FusedMatMulTranspose (MatMulTranspose ):
144- """Replaces ``MatMul + Transpose`` by FusedMatMul."""
320+ """Replaces ``FusedMatMul + Transpose`` with FusedMatMul."""
145321
146322 def pattern (self , op , x , y ):
147- return op .Transpose (op .FusedMatMul (x , y , _domain = "com.microsoft" ))
323+ return op .Transpose (
324+ op .FusedMatMul (x , y , _domain = "com.microsoft" , _outputs = ["fused" ]),
325+ _outputs = ["transposed" ],
326+ )
148327
149328
150329def fused_matmul_rule_sets () -> orp .RewriteRuleSet :
@@ -165,5 +344,11 @@ def fused_matmul_rule_sets() -> orp.RewriteRuleSet:
165344 TransposeFusedMatMul1 .rule (),
166345 TransposeMatMul2 .rule (),
167346 TransposeFusedMatMul2 .rule (),
347+ TransposeFusedMatMulWithFlippedBatch1 .rule (),
348+ TransposeFusedMatMulWithFlippedBatch2 .rule (),
349+ TransposeFusedMatMulWithFlippedBatchAndTranspose1 .rule (),
350+ TransposeFusedMatMulWithFlippedBatchAndTranspose2 .rule (),
351+ TransposeFusedMatMulWithBatchAndTranspose1 .rule (),
352+ TransposeFusedMatMulWithBatchAndTranspose2 .rule (),
168353 ]
169354 )
0 commit comments