Skip to content

Commit 2da3798

Browse files
authored
[Relay][Frontend][Torch] add aten:broadcast_to (#16319)
Recently, I worked with the Stable Video Diffusion model, which contains the `aten::broadcast_to` op, but TVM does not support it. Add support for it here.
1 parent 506eff2 commit 2da3798

File tree

2 files changed

+41
-0
lines changed

2 files changed

+41
-0
lines changed

python/tvm/relay/frontend/pytorch.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2328,6 +2328,21 @@ def broadcast_tensors(self, inputs, input_types):
23282328
res_shape = list(torch.broadcast_tensors(*map(torch.empty, infer_shape_value))[0].shape)
23292329
return [_op.broadcast_to(tensor, res_shape) for tensor in tensor_list]
23302330

2331+
def broadcast_to(self, inputs, input_types):
2332+
tensor = inputs[0]
2333+
new_shape = inputs[1]
2334+
import torch
2335+
2336+
if not isinstance(new_shape, (list, tuple, torch.Size)):
2337+
msg = f"Data type {type(new_shape)} could not be parsed in broadcast_to op"
2338+
raise AssertionError(msg)
2339+
2340+
for i, dim in enumerate(new_shape):
2341+
if not isinstance(dim, int):
2342+
new_shape[i] = int(_infer_value(dim, {}).numpy())
2343+
2344+
return _op.broadcast_to(tensor, new_shape)
2345+
23312346
def Bool(self, inputs, input_types):
23322347
assert len(inputs) == 1
23332348
return inputs[0]
@@ -4190,6 +4205,7 @@ def create_convert_map(self):
41904205
"aten::upsample_nearest3d": self.make_upsample3d("nearest_neighbor"),
41914206
"aten::expand_as": self.expand_as,
41924207
"aten::broadcast_tensors": self.broadcast_tensors,
4208+
"aten::broadcast_to": self.broadcast_to,
41934209
"aten::lt": self.make_elemwise("less"),
41944210
"aten::gt": self.make_elemwise("greater"),
41954211
"aten::le": self.make_elemwise("less_equal"),

tests/python/frontend/pytorch/test_forward.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2162,6 +2162,31 @@ def forward(self, x, y, z):
21622162
verify_model(BroadCastTensors2().float().eval(), input_data=[x, y, z])
21632163

21642164

2165+
@tvm.testing.uses_gpu
2166+
def test_forward_broadcast_to():
2167+
"""test_forward_broadcast_to"""
2168+
torch.set_grad_enabled(False)
2169+
2170+
class BroadCastTo1(Module):
2171+
def forward(self, x):
2172+
return torch.broadcast_to(x, (3, 3))
2173+
2174+
x = torch.tensor([1, 2, 3])
2175+
verify_model(BroadCastTo1().float().eval(), input_data=[x])
2176+
2177+
class BroadCastTo2(Module):
2178+
def __init__(self):
2179+
super().__init__()
2180+
self.y = torch.tensor(1)
2181+
self.z = torch.tensor(2)
2182+
2183+
def forward(self, x):
2184+
return torch.broadcast_to(x, (self.y + self.z, 3))
2185+
2186+
x = torch.tensor([1, 2, 3])
2187+
verify_model(BroadCastTo2().float().eval(), input_data=[x])
2188+
2189+
21652190
@tvm.testing.uses_gpu
21662191
def test_forward_pow():
21672192
"""test_forward_pow"""

0 commit comments

Comments
 (0)