@@ -900,6 +900,19 @@ def _sum(self, node: fx.Node) -> relax.Var:
900900 return self .block_builder .emit (relax .op .sum (args [0 ], keepdims = keepdim ))
901901 return self .block_builder .emit (relax .op .sum (args [0 ], args [1 ]))
902902
903+ ########## Search ##########
904+
905+ def _argmax_argmin (self , op : Callable ) -> Callable :
906+ from torch import fx
907+
908+ def convert (node : fx .Node ):
909+ x = self .env [node .args [0 ]]
910+ dim = node .args [1 ] if len (node .args ) > 1 else node .kwargs .get ("dim" , None )
911+ keepdim = node .args [2 ] if len (node .args ) > 2 else node .kwargs .get ("keepdim" , False )
912+ return self .block_builder .emit (op (x , dim , keepdim ))
913+
914+ return convert
915+
903916 ########## Creation ##########
904917
905918 def _arange (self , node : fx .Node ) -> relax .Var :
@@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
12201233 self .env [node .args [0 ]] = output
12211234 return output
12221235
1223- ########## Search ##########
1224-
1225- def _argmax_argmin (self , op : Callable ) -> Callable :
1226- from torch import fx
1227-
1228- def convert (node : fx .Node ):
1229- x = self .env [node .args [0 ]]
1230- dim = None
1231- keepdims = False
1232-
1233- if len (node .args ) > 1 :
1234- dim = node .args [1 ]
1235- if len (node .args ) > 2 :
1236- keepdims = node .args [2 ]
1237-
1238- if "dim" in node .kwargs :
1239- dim = node .kwargs ["dim" ]
1240- if "keepdim" in node .kwargs :
1241- keepdims = node .kwargs ["keepdim" ]
1242- if "keepdims" in node .kwargs :
1243- keepdims = node .kwargs ["keepdims" ]
1244-
1245- return self .block_builder .emit (op (x , dim , keepdims ))
1246-
1247- return convert
1248-
12491236 ########## Neural Network ##########
12501237
12511238 def _softmax (self , node : fx .Node ) -> relax .Var :
0 commit comments