Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions python/tvm/relax/frontend/torch/base_fx_graph_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -883,6 +883,14 @@ def _expand(self, node: fx.Node) -> relax.Var:
broadcast_shape.append(i)
return self.block_builder.emit(relax.op.broadcast_to(args[0], broadcast_shape))

def _expand_as(self, node: fx.Node) -> relax.Var:
args = self.retrieve_args(node)
# args[0] is the 'self' tensor
# args[1] is the 'other' tensor
data = args[0]
other_shape = self.shape_of(args[1]) # the shape of 'other'
return self.block_builder.emit(relax.op.broadcast_to(data, other_shape))

def _flip(self, node: fx.Node) -> relax.Var:
x = self.env[node.args[0]]
dims = node.args[1] if len(node.args) > 1 else node.kwargs.get("dims", None)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def create_convert_map(
"copy_.default": self._copy_,
"cumsum.default": self._cumsum,
"expand.default": self._expand,
"expand_as.default": self._expand_as,
"permute.default": self._permute,
"repeat.default": self._repeat,
"select.int": self._select,
Expand Down
1 change: 1 addition & 0 deletions python/tvm/relax/frontend/torch/fx_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -749,6 +749,7 @@ def create_convert_map(
"contiguous": lambda node: self.env[node.args[0]],
"cumsum": self._cumsum,
"expand": self._expand,
"expand_as.default": self._expand_as,
"flatten": self._flatten,
"flip": self._flip,
"gather": self._gather,
Expand Down
47 changes: 47 additions & 0 deletions tests/python/relax/test_from_exported_to_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,53 @@ def assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module, tar
np.testing.assert_allclose(actual=actual, desired=desired, rtol=1e-5, atol=1e-5)


@tvm.testing.parametrize_targets("cuda")
def test_tensor_expand_as(target, dev):
class ExpandAs0(torch.nn.Module):
def __init__(self):
super().__init__()
self.template = torch.ones((1, 1, 1, 1))

def forward(self, x):
return self.template.expand_as(x)

class ExpandAs1(torch.nn.Module):
def __init__(self):
super().__init__()
self.template = torch.ones((2, 1, 4, 1))

def forward(self, x):
return self.template.expand_as(x)

class ExpandAs2(torch.nn.Module):
def __init__(self):
super().__init__()
self.template = torch.ones((2, 1, 1, 10))

def forward(self, x):
return self.template.expand_as(x)

class ExpandAs3(torch.nn.Module):
def __init__(self):
super().__init__()
self.template = torch.ones((2, 3, 1, 1))

def forward(self, x):
return self.template.expand_as(x)

raw_data = np.random.randn(2, 3, 4, 10).astype(np.float32)

torch_module0 = ExpandAs0().eval()
torch_module1 = ExpandAs1().eval()
torch_module2 = ExpandAs2().eval()
torch_module3 = ExpandAs3().eval()

assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module0, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module1, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module2, target, dev)
assert_torch_output_vs_tvm_from_exported_to_cuda(raw_data, torch_module3, target, dev)


@tvm.testing.parametrize_targets("cuda")
def test_copy_(target, dev):
class CopyTester(nn.Module):
Expand Down