diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 61478219908c..d534166481e5 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -972,19 +972,21 @@ def convolution(self, inputs, input_types): msg = "Data type %s could not be parsed in conv op" % (type(weight)) raise AssertionError(msg) - # Transposed convolutions have IOHW layout. - if use_transpose: - weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0] - - channels = weight_shape[0] groups = int(inputs[8]) + if use_transpose: + channels = weight_shape[1] * groups + in_channels = weight_shape[0] + else: + channels = weight_shape[0] + in_channels = weight_shape[1] + # Check if this is depth wise convolution # We need to reshape weight so that Relay could recognize this is depth wise # weight_shape[1] is always in_channels // groups # For depthwise, in_channels == groups, so weight_shape[1] == 1 # If groups > 1 but weight_shape[1] != 1, this is group convolution - if groups > 1 and weight_shape[1] == 1: + if groups > 1 and in_channels == 1: channel_multiplier = channels // groups new_weight_shape = (groups, channel_multiplier) + tuple(weight_shape[2:]) weight = _op.transform.reshape(weight, new_weight_shape) diff --git a/tests/python/frontend/pytorch/test_forward.py b/tests/python/frontend/pytorch/test_forward.py index 97ef08f7b8a9..c240a19c9730 100644 --- a/tests/python/frontend/pytorch/test_forward.py +++ b/tests/python/frontend/pytorch/test_forward.py @@ -1067,6 +1067,32 @@ def test_forward_conv_transpose( verify_model(conv1d_transpose, conv1d_input_data) +def test_forward_conv2d_transpose_group(): + # https://github.com/apache/tvm/issues/10223 + + class ModulatedConvTranspose2D(torch.nn.Module): + def forward(self, x, w, s): + B, C, H, W = x.shape + I, O, KH, KW = w.shape + + # weight is different for each input in batch (this is why we want grouped conv transpose) + w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1) + w = w.reshape(B * I, O, KH, KW) + x = x.reshape(1, B * C, H, W) + x = torch.nn.functional.conv_transpose2d( + x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B + ) + return x.reshape(B, O, H * 2, W * 2) + + b, c, h, w, k = 4, 512, 8, 16, 3 + inputs = torch.rand(b, c, h, w) + weights = torch.rand(c, c // 2, k, k) + styles = torch.rand(b) + + # cuda not supported for group > 1 conv2d_transpose + verify_trace_model(ModulatedConvTranspose2D().eval(), [inputs, weights, styles], ["llvm"]) + + def test_forward_deform_conv(): torch.set_grad_enabled(False) @@ -4115,7 +4141,7 @@ def forward(self, x): x = torch.rand([4, 4, 16, 32]).float() script_module = torch.jit.trace(List_tuple(), x, strict=False).eval() - mod, params = relay.frontend.from_pytorch(script_module, [("x", x.shape)]) + relay.frontend.from_pytorch(script_module, [("x", x.shape)]) if __name__ == "__main__":