@@ -913,6 +913,32 @@ def convert(node: fx.Node):
913913
914914 return convert
915915
916+ ########## DataType ##########
917+
918+ def _float (self , node : fx .Node ) -> relax .Var :
919+ return self .block_builder .emit (relax .op .astype (self .env [node .args [0 ]], "float32" ))
920+
921+ def _half (self , node : fx .Node ) -> relax .Var :
922+ return self .block_builder .emit (relax .op .astype (self .env [node .args [0 ]], "float16" ))
923+
924+ def _to (self , node : fx .Node ) -> relax .Var :
925+ import torch
926+
927+ x = self .env [node .args [0 ]]
928+ if len (node .args ) == 2 :
929+ if isinstance (node .args [1 ], torch .dtype ):
930+ dtype = TorchFXImporter ._convert_data_type (node .args [1 ], self .env )
931+ return self .block_builder .emit (relax .op .astype (x , dtype ))
932+ elif "dtype" in node .kwargs :
933+ dtype = TorchFXImporter ._convert_data_type (node .kwargs ["dtype" ], self .env )
934+ return self .block_builder .emit (relax .op .astype (x , dtype ))
935+ return x
936+
937+ def _type (self , node : fx .Node ) -> relax .Var :
938+ x = self .env [node .args [0 ]]
939+ dtype = TorchFXImporter ._convert_data_type (node .args [1 ], self .env )
940+ return self .block_builder .emit (relax .op .astype (x , dtype ))
941+
916942 ########## Creation ##########
917943
918944 def _arange (self , node : fx .Node ) -> relax .Var :
@@ -1051,32 +1077,6 @@ def _full(self, node: fx.Node) -> relax.Var:
10511077 )
10521078 )
10531079
1054- ########## DataType ##########
1055-
1056- def _float (self , node : fx .Node ) -> relax .Var :
1057- return self .block_builder .emit (relax .op .astype (self .env [node .args [0 ]], "float32" ))
1058-
1059- def _half (self , node : fx .Node ) -> relax .Var :
1060- return self .block_builder .emit (relax .op .astype (self .env [node .args [0 ]], "float16" ))
1061-
1062- def _type (self , node : fx .Node ) -> relax .Var :
1063- x = self .env [node .args [0 ]]
1064- dtype = TorchFXImporter ._convert_data_type (node .args [1 ], self .env )
1065- return self .block_builder .emit (relax .op .astype (x , dtype ))
1066-
1067- def _to (self , node : fx .Node ) -> relax .Var :
1068- import torch
1069-
1070- x = self .env [node .args [0 ]]
1071- if len (node .args ) == 2 :
1072- if isinstance (node .args [1 ], torch .dtype ):
1073- dtype = TorchFXImporter ._convert_data_type (node .args [1 ], self .env )
1074- return self .block_builder .emit (relax .op .astype (x , dtype ))
1075- elif "dtype" in node .kwargs :
1076- dtype = TorchFXImporter ._convert_data_type (node .kwargs ["dtype" ], self .env )
1077- return self .block_builder .emit (relax .op .astype (x , dtype ))
1078- return x
1079-
10801080 ########## Manipulation ##########
10811081
10821082 def _cat (self , node : fx .Node ) -> relax .Var :
0 commit comments