@@ -730,6 +730,51 @@ def convert(node: fx.Node):
730730
731731 ########## Manipulation ##########
732732
733+ def _cat (self , node : fx .Node ) -> relax .Var :
734+ args = self .retrieve_args (node )
735+ axis = args [1 ] if len (node .args ) > 1 else node .kwargs .get ("dim" , 0 )
736+ return self .block_builder .emit (relax .op .concat (args [0 ], axis = axis ))
737+
738+ def _cumsum (self , node : fx .Node ) -> relax .Var :
739+ x = self .env [node .args [0 ]]
740+
741+ dim = node .args [1 ] if len (node .args ) > 1 else node .kwargs .get ("dim" , None )
742+ if "dtype" in node .kwargs :
743+ dtype = self ._convert_data_type (str (node .kwargs ["dtype" ]), self .env )
744+ else :
745+ dtype = None
746+ if "out" in node .kwargs :
747+ raise ValueError ("specifying out for cumsum is not supported yet" )
748+
749+ return self .block_builder .emit (relax .op .cumsum (x , dim , dtype ))
750+
751+ def _expand (self , node : fx .Node ) -> relax .Var :
752+ args = self .retrieve_args (node )
753+ sizes = args [1 :] if len (args ) > 2 else args [1 ]
754+ broadcast_shape , in_shape = [], self .shape_of (args [0 ])
755+ for idx , i in enumerate (sizes ):
756+ if isinstance (i , int ) and i == - 1 :
757+ broadcast_shape .append (in_shape [idx ])
758+ else :
759+ broadcast_shape .append (i )
760+ return self .block_builder .emit (relax .op .broadcast_to (args [0 ], broadcast_shape ))
761+
762+ def _permute (self , node : fx .Node ) -> relax .Var :
763+ import torch # type: ignore
764+
765+ args = self .retrieve_args (node )
766+ x = args [0 ]
767+ dims = args [1 ] if isinstance (args [1 ], (torch .Size , tuple , list )) else args [1 :]
768+ return self .block_builder .emit (relax .op .permute_dims (x , dims ))
769+
770+ def _repeat (self , node : fx .Node ) -> relax .Var :
771+ import torch # type: ignore
772+
773+ args = self .retrieve_args (node )
774+ x = args [0 ]
775+ dims = args [1 ] if isinstance (args [1 ], (torch .Size , tuple , list )) else args [1 :]
776+ return self .block_builder .emit (relax .op .tile (x , dims ))
777+
733778 def _reshape (self , node : fx .Node ) -> relax .Var :
734779 import torch # type: ignore
735780
@@ -738,6 +783,122 @@ def _reshape(self, node: fx.Node) -> relax.Var:
738783 dims = args [1 ] if isinstance (args [1 ], (torch .Size , tuple , list )) else args [1 :]
739784 return self .block_builder .emit (relax .op .reshape (x , dims ))
740785
786+ def _split (self , node : fx .Node ) -> relax .Var :
787+ x = self .env [node .args [0 ]]
788+ split_size = node .args [1 ]
789+ dim = node .args [2 ] if len (node .args ) > 2 else node .kwargs .get ("dim" , 0 )
790+ if isinstance (split_size , (list , tuple )):
791+ n_section = []
792+ for s in split_size [:- 1 ]:
793+ cum_sum = 0 if not n_section else n_section [- 1 ]
794+ n_section .append (s + cum_sum )
795+ else :
796+ n_section = (self .shape_of (x )[dim ].value + split_size - 1 ) // split_size
797+ return self .block_builder .emit (relax .op .split (x , n_section , dim ))
798+
799+ def _squeeze (self , node : fx .Node ) -> relax .Var :
800+ x = self .env [node .args [0 ]]
801+ dim = node .args [1 ] if len (node .args ) > 1 else node .kwargs .get ("dim" , None )
802+ return self .block_builder .emit (relax .op .squeeze (x , dim ))
803+
804+ def _tile (self , node : fx .Node ) -> relax .Var :
805+ import torch # type: ignore
806+
807+ args = self .retrieve_args (node )
808+ x = args [0 ]
809+ dims = args [1 ] if isinstance (args [1 ], (torch .Size , tuple , list )) else args [1 :]
810+ return self .block_builder .emit (relax .op .tile (x , dims ))
811+
812+ def _transpose (self , node : fx .Node ) -> relax .Var :
813+ args = self .retrieve_args (node )
814+ full_idx = list (range (len (self .shape_of (args [0 ]))))
815+ full_idx [args [1 ]], full_idx [args [2 ]] = full_idx [args [2 ]], full_idx [args [1 ]]
816+ return self .block_builder .emit (relax .op .permute_dims (args [0 ], full_idx ))
817+
818+ ########## Creation ##########
819+
820+ def _to_copy (self , node : fx .Node ) -> relax .Var :
821+ import torch # type: ignore
822+
823+ x = self .env [node .args [0 ]]
824+ if len (node .args ) == 2 :
825+ if isinstance (node .args [1 ], torch .dtype ):
826+ dtype = self ._convert_data_type (node .args [1 ], self .env )
827+ return self .block_builder .emit (relax .op .astype (x , dtype ))
828+ elif "dtype" in node .kwargs :
829+ dtype = self ._convert_data_type (node .kwargs ["dtype" ], self .env )
830+ return self .block_builder .emit (relax .op .astype (x , dtype ))
831+ return x
832+
833+ def _arange (self , node : fx .Node ) -> relax .Var :
834+ import torch # type: ignore
835+
836+ start_end_step = [None , None , None ]
837+ if "start" in node .kwargs :
838+ start_end_step [0 ] = node .kwargs ["start" ]
839+ if "end" in node .kwargs :
840+ start_end_step [1 ] = node .kwargs ["end" ]
841+ if "step" in node .kwargs :
842+ start_end_step [2 ] = node .kwargs ["step" ]
843+
844+ if len (node .args ) == 1 :
845+ assert start_end_step [1 ] is None
846+ start_end_step [1 ] = node .args [0 ]
847+ elif len (node .args ) == 2 :
848+ assert start_end_step [0 ] is None
849+ assert start_end_step [1 ] is None
850+ start_end_step [0 ] = node .args [0 ]
851+ start_end_step [1 ] = node .args [1 ]
852+ elif len (node .args ) == 3 :
853+ assert start_end_step [0 ] is None
854+ assert start_end_step [1 ] is None
855+ assert start_end_step [2 ] is None
856+ start_end_step [0 ] = node .args [0 ]
857+ start_end_step [1 ] = node .args [1 ]
858+ start_end_step [2 ] = node .args [2 ]
859+
860+ if start_end_step [0 ] is None :
861+ start_end_step [0 ] = 0
862+ if start_end_step [2 ] is None :
863+ start_end_step [2 ] = 1
864+
865+ if "dtype" in node .kwargs :
866+ dtype = self ._convert_data_type (str (node .kwargs ["dtype" ]), self .env )
867+ elif any ([isinstance (x , float ) for x in start_end_step ]):
868+ dtype = self ._convert_data_type (torch .get_default_dtype ())
869+ else :
870+ dtype = "int64"
871+ start_end_step = [
872+ self .env [x ] if isinstance (x , torch .fx .Node ) else x for x in start_end_step
873+ ]
874+ return self .block_builder .emit (relax .op .arange (* start_end_step , dtype = dtype ))
875+
876+ def _empty (self , node : fx .Node ) -> relax .Var :
877+ dtype = self ._convert_data_type (str (node .kwargs ["dtype" ]), self .env )
878+ return self .block_builder .emit (relax .op .zeros (node .args [0 ], dtype ))
879+
880+ def _fill (self , node : fx .Node ) -> relax .Var :
881+ args = self .retrieve_args (node )
882+ x = args [0 ]
883+ dtype = x .struct_info .dtype
884+ value = args [1 ] if isinstance (args [1 ], relax .Expr ) else relax .const (args [1 ], dtype )
885+ return self .block_builder .emit (relax .op .full (x .struct_info .shape , value , dtype ))
886+
887+ def _new_ones (self , node : fx .Node ) -> relax .Var :
888+ args = self .retrieve_args (node )
889+ self_var = args [0 ]
890+ size = args [1 ] if isinstance (args [1 ], (list , tuple )) else args [1 :]
891+ if not isinstance (size , (list , tuple )):
892+ size = (size ,)
893+ size = relax .ShapeExpr (size )
894+ return self .block_builder .emit (
895+ relax .op .full (
896+ size ,
897+ relax .const (1 , self_var .struct_info .dtype ),
898+ self_var .struct_info .dtype ,
899+ )
900+ )
901+
741902 ########## Others ##########
742903
743904 def _getitem (self , node : fx .Node ) -> relax .Var :
0 commit comments