Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -1628,10 +1628,12 @@ def _std(self, node: fx.Node) -> relax.Var:

def _sum(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
keepdim = node.kwargs["keepdim"] if "keepdim" in node.kwargs else False
if len(args) == 1:
return self.block_builder.emit(relax.op.sum(args[0], keepdims=keepdim))
return self.block_builder.emit(relax.op.sum(args[0], args[1]))
x = args[0]
dim = args[1] if len(node.args) > 1 else node.kwargs.get("dim", None)
if isinstance(dim, (list, tuple)) and len(dim) == 0:
dim = None
keepdim = args[2] if len(node.args) > 2 else node.kwargs.get("keepdim", False)
return self.block_builder.emit(relax.op.sum(x, dim, keepdims=keepdim))

def _var(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
Expand Down
48 changes: 42 additions & 6 deletions tests/python/relax/test_frontend_from_exported_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -4945,6 +4945,14 @@ class Sum(Module):
def forward(self, x):
return torch.sum(x, (2, 1))

class SumKeepDim(Module):
def forward(self, x):
return torch.sum(x, (2, 1), keepdim=True)

class SumWithoutDim(Module):
def forward(self, x):
return torch.sum(x)

@tvm.script.ir_module
class expected1:
@R.function
Expand All @@ -4958,8 +4966,36 @@ def main(
R.output(gv)
return gv

@tvm.script.ir_module
class expected2:
@R.function
def main(
inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
) -> R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")):
with R.dataflow():
lv: R.Tensor((1, 1, 1, 4), dtype="float32") = R.sum(
inp_0, axis=[2, 1], keepdims=True
)
gv: R.Tuple(R.Tensor((1, 1, 1, 4), dtype="float32")) = (lv,)
R.output(gv)
return gv

@tvm.script.ir_module
class expected3:
@R.function
def main(
inp_0: R.Tensor((1, 2, 3, 4), dtype="float32")
) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.sum(inp_0, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv

example_args = (torch.randn(1, 2, 3, 4, dtype=torch.float32),)
verify_model(Sum(), example_args, {}, expected1)
verify_model(SumKeepDim(), example_args, {}, expected2)
verify_model(SumWithoutDim(), example_args, {}, expected3)


def test_argmax_argmin():
Expand Down Expand Up @@ -7840,7 +7876,7 @@ def forward(self, x):
@tvm.script.ir_module
class Expected1:
@R.function
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="float32")):
def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((4, 3), dtype="float32") = R.astype(x, dtype="float32")
lv1: R.Tensor((4, 3), dtype="float32") = R.nn.log_softmax(lv, axis=1)
Expand All @@ -7863,11 +7899,11 @@ def main(x: R.Tensor((4, 3), dtype="float32")) -> R.Tuple(R.Tensor((4,), dtype="
lv12: R.Tensor((4,), dtype="bool") = R.not_equal(
R.const([0, 1, 2, 1], dtype="int64"), R.const(-100, "int64")
)
lv13: R.Tensor((4,), dtype="bool") = R.sum(lv12, axis=[], keepdims=False)
lv14: R.Tensor((4,), dtype="float32") = R.astype(lv13, dtype="float32")
lv15: R.Tensor((4,), dtype="float32") = R.sum(lv11, axis=[], keepdims=False)
lv16: R.Tensor((4,), dtype="float32") = R.divide(lv15, lv14)
gv: R.Tuple(R.Tensor((4,), dtype="float32")) = (lv16,)
lv13: R.Tensor((), dtype="bool") = R.sum(lv12, axis=None, keepdims=False)
lv14: R.Tensor((), dtype="float32") = R.astype(lv13, dtype="float32")
lv15: R.Tensor((), dtype="float32") = R.sum(lv11, axis=None, keepdims=False)
lv16: R.Tensor((), dtype="float32") = R.divide(lv15, lv14)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv16,)
R.output(gv)
return gv

Expand Down