Skip to content

Commit ecc6faf

Browse files
committed
cleanup _argmax_argmin()
1 parent eff6b68 commit ecc6faf

File tree

1 file changed

+13
-26
lines changed

1 file changed

+13
-26
lines changed

python/tvm/relax/frontend/torch/fx_translator.py

Lines changed: 13 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -900,6 +900,19 @@ def _sum(self, node: fx.Node) -> relax.Var:
900900
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
901901
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
902902

903+
########## Search ##########
904+
905+
def _argmax_argmin(self, op: Callable) -> Callable:
906+
from torch import fx
907+
908+
def convert(node: fx.Node):
909+
x = self.env[node.args[0]]
910+
dim = node.args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
911+
keepdim = node.args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
912+
return self.block_builder.emit(op(x, dim, keepdim))
913+
914+
return convert
915+
903916
########## Creation ##########
904917

905918
def _arange(self, node: fx.Node) -> relax.Var:
@@ -1220,32 +1233,6 @@ def _inplace_masked_fill(self, node: fx.Node) -> relax.Var:
12201233
self.env[node.args[0]] = output
12211234
return output
12221235

1223-
########## Search ##########
1224-
1225-
def _argmax_argmin(self, op: Callable) -> Callable:
1226-
from torch import fx
1227-
1228-
def convert(node: fx.Node):
1229-
x = self.env[node.args[0]]
1230-
dim = None
1231-
keepdims = False
1232-
1233-
if len(node.args) > 1:
1234-
dim = node.args[1]
1235-
if len(node.args) > 2:
1236-
keepdims = node.args[2]
1237-
1238-
if "dim" in node.kwargs:
1239-
dim = node.kwargs["dim"]
1240-
if "keepdim" in node.kwargs:
1241-
keepdims = node.kwargs["keepdim"]
1242-
if "keepdims" in node.kwargs:
1243-
keepdims = node.kwargs["keepdims"]
1244-
1245-
return self.block_builder.emit(op(x, dim, keepdims))
1246-
1247-
return convert
1248-
12491236
########## Neural Network ##########
12501237

12511238
def _softmax(self, node: fx.Node) -> relax.Var:

0 commit comments

Comments
 (0)