diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index 47eb66621008..f7d54a6216a7 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -1628,10 +1628,12 @@ def _std(self, node: fx.Node) -> relax.Var: def _sum(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) - keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False - if len(args) == 1: - return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim)) - return self.block_builder.emit(relax.op.sum(args[0], args[1])) + x = args[0] + dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None) + if isinstance(dim, (list, tuple)) and len(dim) == 0: + dim = None + keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False) + return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim)) def _var(self, node: fx.Node) -> relax.Var: args = self.retrieve_args(node) diff --git a/tests/python/relax/test_frontend_from_exported_program.py b/tests/python/relax/test_frontend_from_exported_program.py index 01e16e7564ac..4a84b50cc9d9 100644 --- a/tests/python/relax/test_frontend_from_exported_program.py +++ b/tests/python/relax/test_frontend_from_exported_program.py @@ -4945,6 +4945,14 @@ class Sum(Module): def forward(self, x): return torch.sum(x, (2, 1)) + class SumKeepDim(Module): + def forward(self, x): + return torch.sum(x, (2, 1), keepdim=True) + + class SumWithoutDim(Module): + def forward(self, x): + return torch.sum(x) + @tvm.script.ir_module class expected1: @R.function @@ -4958,8 +4966,36 @@ def main( R.output(gv) return gv + @tvm.script.ir_module + class expected2: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum( + inp_0, axis=[2, 1], keepdims=True + ) + gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,) + R.output(gv) + return gv + + @tvm.script.ir_module + class expected3: + @R.function + def main( + inp_0: R.Tensor((1, 2, 3, 4), dtype="float32") + ) -> R.Tuple(R.Tensor((), dtype="float32")): + with R.dataflow(): + lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, keepdims=False) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,) + R.output(gv) + return gv + example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),) verify_model(Sum(), example_args, {}, expected1) + verify_model(SumKeepDim(), example_args, {}, expected2) + verify_model(SumWithoutDim(), example_args, {}, expected3) def test_argmax_argmin(): @@ -7840,7 +7876,7 @@ def forward(self, x): @tvm.script.ir_module class Expected1: @R.function - def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")): + def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")): with R.dataflow(): lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32") lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1) @@ -7863,11 +7899,11 @@ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype=" lv12: R.Tensor((4,), dtype="bool") = R.not_equal( R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64") ) - lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False) - lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32") - lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False) - lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14) - gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,) + lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None, keepdims=False) + lv14: R.Tensor((), dtype="float32") = R.astype(lv13, dtype="float32") + lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False) + lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14) + gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,) R.output(gv) return gv