44
55from typing import ClassVar , Optional , Sequence
66
7- from onnxscript .rewriter import _ir_utils
87import onnxscript .rewriter .pattern as orp
98from onnxscript import ir
9+ from onnxscript .rewriter import _ir_utils
1010
1111
1212def _get_node (value : ir .Value , name : str ) -> ir .Node :
@@ -22,15 +22,6 @@ def _get_kwargs(node: ir.Node) -> dict[str, float | int]:
2222 return kwargs
2323
2424
25- def _get_int_or_default (node : ir .Node , name : str , default : int = 0 ) -> int :
26- """Get the int value from the node attribute dictionary or return default."""
27- if name in node .attributes :
28- value = node .attributes [name ].as_int ()
29- else :
30- value = default
31- return value
32-
33-
3425def _get_ints_or_default (
3526 node : ir .Node , name : str , default : Optional [Sequence [int ]] = None
3627) -> Sequence [int ]:
@@ -103,14 +94,18 @@ def check(
10394 expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
10495 if perm != expected_perm :
10596 return check_result .fail ("Permutation values for Transpose are not correct." )
106- elif not (self ._pos == 1 and _ir_utils .has_rank (x , 2 )) and (self ._pos == 2 and _ir_utils .has_rank (y , 2 )):
97+ elif (self ._pos == 1 and not _ir_utils .has_rank (x , 2 )) or (
98+ self ._pos == 2 and not _ir_utils .has_rank (y , 2 )
99+ ):
107100 # If perm is not defined, the default transpose behavior is to swap
108101 # all dimensions, which is correct for MatMul with rank = 2.
109- return check_result .fail ("Permutation values for Transpose are not correct." )
102+ return check_result .fail (
103+ "If perm is not defined, rank must be 2 for TransposeMatMul rule."
104+ )
110105 if fused :
111106 fused_node = _get_node (fused , "FusedMatMul" )
112107 trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
113- if _get_int_or_default ( fused_node , trans_batch_property ):
108+ if fused_node . attributes . get_int ( trans_batch_property , 0 ):
114109 return check_result .fail (
115110 "FusedMatMul with transposed batch cannot be used with op.Transpose in this rule."
116111 )
@@ -204,7 +199,7 @@ def check(
204199 check_result = orp .MatchResult ()
205200 fused_node = _get_node (fused , "FusedMatMul" )
206201 trans_batch_property = "transBatchA" if self ._pos == 1 else "transBatchB"
207- trans_batch = _get_int_or_default ( fused_node , trans_batch_property )
202+ trans_batch = fused_node . attributes . get_int ( trans_batch_property , 0 )
208203 transposed_node = _get_node (transposed , "Transpose" )
209204 perm = transposed_node .attributes ["perm" ].as_ints ()
210205 if not perm :
@@ -312,16 +307,21 @@ def check(self, context, x, y, transposed: ir.Value, **_) -> orp.MatchResult:
312307 check_result = orp .MatchResult ()
313308 transpose_node = _get_node (transposed , "Transpose" )
314309 perm = _get_ints_or_default (transpose_node , "perm" )
315- if perm :
316- # Check that last two dimensions are swapped
317- expected_perm = list (range (len (perm )))
318- expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
319- if perm != expected_perm :
320- return check_result .fail ("Permutation values for Transpose are not correct." )
321- elif not (self ._pos == 1 and _ir_utils .has_rank (x , 2 )) and (self ._pos == 2 and _ir_utils .has_rank (y , 2 )):
322- # If perm is not defined, the default transpose behavior is to swap
323- # all dimensions, which is correct for MatMul with rank = 2.
324- return check_result .fail ("Permutation values for Transpose are not correct." )
310+ # transA/transB only work on the last two dimensions of the input,
311+ # so we can only apply this rule if the inputs are rank 2.
312+ if _ir_utils .has_rank (x , 2 ) and _ir_utils .has_rank (y , 2 ):
313+ if perm :
314+ # Check that last two dimensions are swapped
315+ expected_perm = list (range (len (perm )))
316+ expected_perm [- 2 ], expected_perm [- 1 ] = expected_perm [- 1 ], expected_perm [- 2 ]
317+ if perm != expected_perm :
318+ return check_result .fail (
319+ "Permutation values for Transpose are not correct."
320+ )
321+ # If perm is not defined, the default transpose behavior is to swap
322+ # all dimensions, which is correct for MatMul with rank = 2.
323+ else :
324+ return check_result .fail ("Rank must be 2 for MatMulTranspose rule." )
325325 return check_result
326326
327327 def rewrite (self , op , x , y , fused : ir .Value | None = None , ** _ ):
0 commit comments