Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -972,19 +972,21 @@ def convolution(self, inputs, input_types):
msg = "Data type %s could not be parsed in conv op" % (type(weight))
raise AssertionError(msg)

# Transposed convolutions have IOHW layout.
if use_transpose:
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]

channels = weight_shape[0]
groups = int(inputs[8])

if use_transpose:
channels = weight_shape[1] * groups
in_channels = weight_shape[0]
else:
channels = weight_shape[0]
in_channels = weight_shape[1]

# Check if this is depth wise convolution
# We need to reshape weight so that Relay could recognize this is depth wise
# weight_shape[1] is always in_channels // groups
# For depthwise, in_channels == groups, so weight_shape[1] == 1
# If groups > 1 but weight_shape[1] != 1, this is group convolution
if groups > 1 and weight_shape[1] == 1:
if groups > 1 and in_channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:])
weight = _op.transform.reshape(weight, new_weight_shape)
Expand Down
28 changes: 27 additions & 1 deletion tests/python/frontend/pytorch/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -1067,6 +1067,32 @@ def test_forward_conv_transpose(
verify_model(conv1d_transpose, conv1d_input_data)


def test_forward_conv2d_transpose_group():
# https://github.com/apache/tvm/issues/10223

class ModulatedConvTranspose2D(torch.nn.Module):
def forward(self, x, w, s):
B, C, H, W = x.shape
I, O, KH, KW = w.shape

# weight is different for each input in batch (this is why we want grouped conv transpose)
w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
w = w.reshape(B * I, O, KH, KW)
x = x.reshape(1, B * C, H, W)
x = torch.nn.functional.conv_transpose2d(
x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B
)
return x.reshape(B, O, H * 2, W * 2)

b, c, h, w, k = 4, 512, 8, 16, 3
inputs = torch.rand(b, c, h, w)
weights = torch.rand(c, c // 2, k, k)
styles = torch.rand(b)

# cuda not supported for group > 1 conv2d_transpose
verify_trace_model(ModulatedConvTranspose2D().eval(), [inputs, weights, styles], ["llvm"])


def test_forward_deform_conv():
torch.set_grad_enabled(False)

Expand Down Expand Up @@ -4115,7 +4141,7 @@ def forward(self, x):

x = torch.rand([4, 4, 16, 32]).float()
script_module = torch.jit.trace(List_tuple(), x, strict=False).eval()
mod, params = relay.frontend.from_pytorch(script_module, [("x", x.shape)])
relay.frontend.from_pytorch(script_module, [("x", x.shape)])


if __name__ == "__main__":
Expand Down