diff --git a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py index d5cad2381b49..a0f00e1f4b9d 100644 --- a/python/tvm/relax/frontend/torch/base_fx_graph_translator.py +++ b/python/tvm/relax/frontend/torch/base_fx_graph_translator.py @@ -883,6 +883,14 @@ def _expand(self, node: fx.Node) -> relax.Var: broadcast_shape.append(i) return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape)) + def _expand_as(self, node: fx.Node) -> relax.Var: + args = self.retrieve_args(node) + # args[0] is the 'self' tensor + # args[1] is the 'other' tensor + data = args[0] + other_shape = self.shape_of(args[1]) # the shape of 'other' + return self.block_builder.emit(relax.op.broadcast_to(data, other_shape)) + def _flip(self, node: fx.Node) -> relax.Var: x = self.env[node.args[0]] dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None) diff --git a/python/tvm/relax/frontend/torch/exported_program_translator.py b/python/tvm/relax/frontend/torch/exported_program_translator.py index 4ff31ea1d772..2103365c6c60 100644 --- a/python/tvm/relax/frontend/torch/exported_program_translator.py +++ b/python/tvm/relax/frontend/torch/exported_program_translator.py @@ -298,6 +298,7 @@ def create_convert_map( "copy_.default": self._copy_, "cumsum.default": self._cumsum, "expand.default": self._expand, + "expand_as.default": self._expand_as, "permute.default": self._permute, "repeat.default": self._repeat, "select.int": self._select, diff --git a/python/tvm/relax/frontend/torch/fx_translator.py b/python/tvm/relax/frontend/torch/fx_translator.py index 29d959818f21..abda5088db4d 100644 --- a/python/tvm/relax/frontend/torch/fx_translator.py +++ b/python/tvm/relax/frontend/torch/fx_translator.py @@ -749,6 +749,7 @@ def create_convert_map( "contiguous": lambda node: self.env[node.args[0]], "cumsum": self._cumsum, "expand": self._expand, + "expand_as.default": self._expand_as, "flatten": self._flatten, "flip": self._flip, "gather": self._gather, diff --git a/tests/python/relax/test_from_exported_to_cuda.py b/tests/python/relax/test_from_exported_to_cuda.py index bd4bdcf61770..e8b5da0dc2ab 100644 --- a/tests/python/relax/test_from_exported_to_cuda.py +++ b/tests/python/relax/test_from_exported_to_cuda.py @@ -56,6 +56,53 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5) +@tvm.testing.parametrize_targets("cuda") +def test_tensor_expand_as(target, dev): + class ExpandAs0(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((1, 1, 1, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs1(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 1, 4, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs2(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 1, 1, 10)) + + def forward(self, x): + return self.template.expand_as(x) + + class ExpandAs3(torch.nn.Module): + def __init__(self): + super().__init__() + self.template = torch.ones((2, 3, 1, 1)) + + def forward(self, x): + return self.template.expand_as(x) + + raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32) + + torch_module0 = ExpandAs0().eval() + torch_module1 = ExpandAs1().eval() + torch_module2 = ExpandAs2().eval() + torch_module3 = ExpandAs3().eval() + + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev) + assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev) + + @tvm.testing.parametrize_targets("cuda") def test_copy_(target, dev): class CopyTester(nn.Module):