Skip to content

[Bug] Conv2DTranspose with groups not working correctly #10223

@JCBrouwer

Description

@JCBrouwer

I'm trying to convert a PyTorch model which makes use of torch.nn.functional.conv_transpose2d and am running into issues with my converter to the corresponding tvm.relay.op.nn.conv2d_transpose operation.

I've done a little monkey patching on the PyTorchOpConverter as the operations that torch.nn.functional.conv2d/conv_transpose2d trace to (aten::conv2d and aten::conv_transpose2d) aren't covered by default. I've added functions to convert each one to the PyTorchOpConverter so that I have access to self.infer_shape(weight) in the functions as follows:

Converter implementation
class MyPyTorchOpConverter(PyTorchOpConverter):
    def __init__(self, prelude, default_dtype):
        super().__init__(prelude, default_dtype)
        self.update_convert_map(
            {"aten::conv2d": self.convert_conv2d, "aten::conv_transpose2d": self.convert_conv_transpose2d}
        )

    def convert_conv2d(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        dilation = inputs[5]
        groups = inputs[6]

        channels, input_channels, kh, kw = self.infer_shape(weight)  # OIHW

        if groups > 1 and input_channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        res = relay.op.nn.conv2d(
            data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, channels=channels
        )
        if bias is not None:
            res = relay.op.nn.bias_add(res, bias)

        return res

    def convert_conv_transpose2d(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        output_padding = inputs[5]
        groups = inputs[6]
        dilation = inputs[7]

        input_channels, channels, kh, kw = list(self.infer_shape(weight))  # IOHW

        if groups > 1 and channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        res = relay.op.nn.conv2d_transpose(
            data,
            weight,
            strides=strides,
            padding=padding,
            output_padding=output_padding,
            dilation=dilation,
            groups=groups,
        )
        if bias is not None:
            res = relay.op.nn.bias_add(res, bias)

        return res

tvm.relay.frontend.pytorch.PyTorchOpConverter = MyPyTorchOpConverter

The implementations of the convertors are adapted from tvm.relay.frontend.pytorch.PyTorchOpConverter.convolution(inputs, input_types) but updated to support the call signature of torch.nn.functional.conv2d/conv_transpose2d.

The problem I'm seeing is that it seems like tvm.relay.op.nn.conv2d_transpose() doesn't respect the groups argument. When I print the input and outputs of the first 4 conv(_transpose) ops in my network, the PyTorch shapes are the following:

PyTorch shapes
conv2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 4, 4)

conv2d
input (1, 1536, 4, 4)
weight (9, 512, 1, 1)
groups 3
out (1, 9, 4, 4)

conv_transpose2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 9, 9)

conv2d
data (1, 1536, 11, 11)
weight (1536, 1, 4, 4)
groups 1536
out (1, 1536, 8, 8)

While the TVM shapes are:

TVM shapes
conv2d
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 1536, 4, 4)

conv2d
input [1, 1536, 4, 4]
weight [9, 512, 1, 1]
groups 3
out (1, 9, 4, 4)

conv2d_transpose
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 512, 9, 9)

conv2d
input [1, 512, 11, 11]
weight [512, 1, 4, 4]
groups 1536
TVMError

Notice that the output shape of tvm.relay.op.nn.conv2d_transpose() does not have the correct number of channels (output is as if groups = 1). This leads to the error in the next conv2d operation:

Error traceback
Traceback (most recent call last):
  File "/home/hans/code/stylegan3/func.py", line 217, in <module>
    Gtvm, tvm_params = relay.frontend.pytorch.from_pytorch(
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 4010, in from_pytorch
    outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 3385, in convert_operators
    relay_out = relay_op(
  File "/home/hans/code/stylegan3/func.py", line 101, in convert_conv2d
    print("out", self.infer_shape(res))
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 204, in infer_shape
    typ = self.infer_type(inputs, mod=mod)
  File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 162, in infer_type
    new_mod = transform.InferType()(new_mod)
  File "/home/hans/code/tvm/python/tvm/ir/transform.py", line 161, in __call__
    return _ffi_transform_api.RunPass(self, mod)
  File "/home/hans/code/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  7: TVMFuncCall
  6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  5: tvm::transform::Pass::operator()(tvm::IRModule) const
  4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  0: tvm::relay::TypeSolver::Solve() [clone .cold]
  9: TVMFuncCall
  8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  7: tvm::transform::Pass::operator()(tvm::IRModule) const
  6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
  4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
  2: tvm::relay::TypeSolver::Solve()
  1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
  0: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
  File "/home/hans/code/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
  Check failed: (false) is false: [14:35:14] /home/hans/code/tvm/src/relay/op/tensor/transform.cc:787: 
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------

  Check failed: oshape_sum == data_shape_sum (24576 vs. 8192) : Input tensor shape(1536,1,4,4) and reshaped shape(512,1,4,4) are not compatible!

As a workaround, I've rewritten my conv_transpose2d converter to manually split the data and weights into groups, perform each transposed conv, and then concatenate them back. This converter does seem to give the correct output shape although I haven't yet tested the outputs for correctness, I might have just gotten lucky with the shapes.

Workaround converter implementation (manual grouping)
    def convert_conv_transpose2d_workaround(self, inputs, input_types):
        data = inputs[0]
        weight = inputs[1]
        bias = inputs[2]
        strides = inputs[3]
        padding = inputs[4]
        output_padding = inputs[5]
        groups = inputs[6]
        dilation = inputs[7]

        input_channels, channels, kh, kw = list(self.infer_shape(weight))

        if groups > 1 and channels == 1:
            channel_multiplier = channels // groups
            new_weight_shape = (groups, channel_multiplier, kh, kw)
            weight = relay.op.transform.reshape(weight, new_weight_shape)

        datas = relay.op.split(data, groups, axis=1)
        weights = relay.op.split(weight, groups, axis=0)

        rs = []
        for d, w in zip(datas, weights):
            r = relay.op.nn.conv2d_transpose(
                d, w, strides=strides, padding=padding, output_padding=output_padding, dilation=dilation, groups=1
            )
            if bias is not None:
                r = relay.op.nn.bias_add(r, bias)
            rs.append(r)
        res = relay.op.concatenate(rs, axis=1)

        return res

Expected behavior

The groups argument of tvm.relay.op.nn.conv2d_transpose should work correctly like tvm.relay.op.nn.conv2d does.

Actual behavior

The transposed convolution seems to only be applied to a single group?

Environment

Ubuntu 20.04
PyTorch 1.12.0.dev20220210
TVM 0.9.dev525+g8aeb72265 (compiled from main a couple hours ago)
CUDA 11.4

Steps to reproduce

from copy import deepcopy

import torch
import tvm.relay
from torch.nn.functional import conv_transpose2d

_original_get_constant = deepcopy(tvm.relay.frontend.pytorch._get_constant)


def _my_get_constant(node):
    """Monkey patch in support for prim::Constant lists, I guess torch.jit.optimize_for_inference introduces these?"""
    if node.output().type().kind() == "ListType":
        print("WARNING: Encountered ListType in _get_constant, doing weird eval stuff to get the list value:", end=" ")
        lst = eval(node.__repr__().split("value=")[1].replace("]()", ""))
        print(lst)
        return lst
    else:
        return _original_get_constant(node)


tvm.relay.frontend.pytorch._get_constant = _my_get_constant


def convert_conv_transpose2d(inputs, input_types):
    data = inputs[0]
    weight = inputs[1]
    bias = inputs[2]
    strides = inputs[3]
    padding = inputs[4]
    output_padding = inputs[5]
    groups = inputs[6]
    dilation = inputs[7]

    res = tvm.relay.op.nn.conv2d_transpose(
        data,
        weight,
        strides=strides,
        padding=padding,
        output_padding=output_padding,
        dilation=dilation,
        groups=groups,
    )
    if bias is not None:
        res = tvm.relay.op.nn.bias_add(res, bias)

    return res


class ModulatedConvTranspose2D(torch.nn.Module):
    def forward(self, x, w, s):
        B, C, H, W = x.shape
        I, O, KH, KW = w.shape

        # weight is different for each input in batch (this is why we want grouped conv transpose)
        w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
        w = w.reshape(B * I, O, KH, KW)

        x = x.reshape(1, B * C, H, W)

        x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B)

        # Check failed: oshape_sum == data_shape_sum (524288 vs. 131072) : Input tensor shape(4,256,16,32) and reshaped shape(1,256,16,32) are not compatible!
        x = x.reshape(B, O, H * 2, W * 2)

        return x


with torch.inference_mode():
    b, c, h, w, k = 4, 512, 8, 16, 3
    inputs = torch.rand(b, c, h, w)
    weights = torch.rand(c, c // 2, k, k)
    styles = torch.rand(b)

    torch_mod = torch.jit.optimize_for_inference(
        torch.jit.trace(ModulatedConvTranspose2D().eval(), (inputs, weights, styles))
    )

    outputs_torch = torch_mod(inputs, weights, styles)
    print("Torch output shape", outputs_torch.shape)  # torch.Size([4, 256, 16, 32])

    tvm_mod, params = tvm.relay.frontend.pytorch.from_pytorch(
        torch_mod,
        [("inputs", inputs.shape), ("weights", weights.shape), ("styles", styles.shape)],
        {"aten::conv_transpose2d": convert_conv_transpose2d},
    )

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions