@@ -350,14 +350,28 @@ def is_slice_view(self, node: torch.fx.Node) -> bool:
350350 def is_cat_along_outermost_dim (
351351 self , graph_module : torch .fx .GraphModule , cat_node : torch .fx .Node
352352 ) -> bool :
353+ assert len (cat_node .args ) > 0
354+ cat_tensors = cat_node .args [0 ]
355+ if not isinstance (cat_tensors , Sequence ) or not all (
356+ isinstance (t , torch .fx .Node ) for t in cat_tensors
357+ ):
358+ raise ValueError ("cat_tensors must be a sequence of torch.fx.Node objects." )
359+
360+ if len (cat_node .args ) > 1 :
361+ cat_dim = cat_node .args [1 ]
362+ else :
363+ cat_dim = cat_node .kwargs .get ("dim" , None )
364+ if not isinstance (cat_dim , int ):
365+ raise ValueError ("cat_dim must be an integer." )
366+
353367 # If the cat op has default dim, then the concat dim is 0
354- if len (cat_node . args ) == 1 or cat_node . args [ 1 ] == 0 :
368+ if len (cat_tensors ) == 1 or cat_dim == 0 :
355369 return True
356- # Get the concatenation dimension and concatenated tensors
357- (cat_tensors , cat_dim ) = cast (
358- tuple [Sequence [torch .fx .Node ], int ], cat_node .args
359- )
370+
371+ # Make sure all dimes before cat_dim are 1.
360372 for tensor in cat_tensors :
373+ if not isinstance (tensor , torch .fx .Node ):
374+ continue
361375 shape = get_shape (graph_module , tensor )
362376 if shape is None or not all (dim == 1 for dim in shape [0 :cat_dim ]):
363377 return False
0 commit comments