-
Notifications
You must be signed in to change notification settings - Fork 3.7k
Description
I'm trying to convert a PyTorch model which makes use of torch.nn.functional.conv_transpose2d and am running into issues with my converter to the corresponding tvm.relay.op.nn.conv2d_transpose operation.
I've done a little monkey patching on the PyTorchOpConverter as the operations that torch.nn.functional.conv2d/conv_transpose2d trace to (aten::conv2d and aten::conv_transpose2d) aren't covered by default. I've added functions to convert each one to the PyTorchOpConverter so that I have access to self.infer_shape(weight) in the functions as follows:
Converter implementation
class MyPyTorchOpConverter(PyTorchOpConverter):
def __init__(self, prelude, default_dtype):
super().__init__(prelude, default_dtype)
self.update_convert_map(
{"aten::conv2d": self.convert_conv2d, "aten::conv_transpose2d": self.convert_conv_transpose2d}
)
def convert_conv2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
dilation = inputs[5]
groups = inputs[6]
channels, input_channels, kh, kw = self.infer_shape(weight) # OIHW
if groups > 1 and input_channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
res = relay.op.nn.conv2d(
data, weight, strides=strides, padding=padding, dilation=dilation, groups=groups, channels=channels
)
if bias is not None:
res = relay.op.nn.bias_add(res, bias)
return res
def convert_conv_transpose2d(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
input_channels, channels, kh, kw = list(self.infer_shape(weight)) # IOHW
if groups > 1 and channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
res = relay.op.nn.conv2d_transpose(
data,
weight,
strides=strides,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
res = relay.op.nn.bias_add(res, bias)
return res
tvm.relay.frontend.pytorch.PyTorchOpConverter = MyPyTorchOpConverterThe implementations of the convertors are adapted from tvm.relay.frontend.pytorch.PyTorchOpConverter.convolution(inputs, input_types) but updated to support the call signature of torch.nn.functional.conv2d/conv_transpose2d.
The problem I'm seeing is that it seems like tvm.relay.op.nn.conv2d_transpose() doesn't respect the groups argument. When I print the input and outputs of the first 4 conv(_transpose) ops in my network, the PyTorch shapes are the following:
PyTorch shapes
conv2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 4, 4)
conv2d
input (1, 1536, 4, 4)
weight (9, 512, 1, 1)
groups 3
out (1, 9, 4, 4)
conv_transpose2d
input (1, 1536, 4, 4)
weight (1536, 512, 3, 3)
groups 3
out (1, 1536, 9, 9)
conv2d
data (1, 1536, 11, 11)
weight (1536, 1, 4, 4)
groups 1536
out (1, 1536, 8, 8)
While the TVM shapes are:
TVM shapes
conv2d
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 1536, 4, 4)
conv2d
input [1, 1536, 4, 4]
weight [9, 512, 1, 1]
groups 3
out (1, 9, 4, 4)
conv2d_transpose
input [1, 1536, 4, 4]
weight [1536, 512, 3, 3]
groups 3
out (1, 512, 9, 9)
conv2d
input [1, 512, 11, 11]
weight [512, 1, 4, 4]
groups 1536
TVMError
Notice that the output shape of tvm.relay.op.nn.conv2d_transpose() does not have the correct number of channels (output is as if groups = 1). This leads to the error in the next conv2d operation:
Error traceback
Traceback (most recent call last):
File "/home/hans/code/stylegan3/func.py", line 217, in <module>
Gtvm, tvm_params = relay.frontend.pytorch.from_pytorch(
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 4010, in from_pytorch
outputs = converter.convert_operators(_get_operator_nodes(graph.nodes()), outputs, ret_name)
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 3385, in convert_operators
relay_out = relay_op(
File "/home/hans/code/stylegan3/func.py", line 101, in convert_conv2d
print("out", self.infer_shape(res))
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 204, in infer_shape
typ = self.infer_type(inputs, mod=mod)
File "/home/hans/code/tvm/python/tvm/relay/frontend/pytorch.py", line 162, in infer_type
new_mod = transform.InferType()(new_mod)
File "/home/hans/code/tvm/python/tvm/ir/transform.py", line 161, in __call__
return _ffi_transform_api.RunPass(self, mod)
File "/home/hans/code/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__
raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
7: TVMFuncCall
6: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
5: tvm::transform::Pass::operator()(tvm::IRModule) const
4: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
3: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
2: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
1: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
0: tvm::relay::TypeSolver::Solve() [clone .cold]
9: TVMFuncCall
8: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::transform::Pass, tvm::IRModule)>::AssignTypedLambda<tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}>(tvm::transform::{lambda(tvm::transform::Pass, tvm::IRModule)#7}, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
7: tvm::transform::Pass::operator()(tvm::IRModule) const
6: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
5: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const
4: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<tvm::IRModule (tvm::IRModule, tvm::transform::PassContext)>::AssignTypedLambda<tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1}>(tvm::relay::transform::InferType()::{lambda(tvm::IRModule, tvm::transform::PassContext const&)#1})::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
3: tvm::relay::TypeInferencer::Infer(tvm::GlobalVar, tvm::relay::Function)
2: tvm::relay::TypeSolver::Solve()
1: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::runtime::TypedPackedFunc<bool (tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>::AssignTypedLambda<bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)>(bool (*)(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&))::{lambda(tvm::runtime::TVMArgs const&, tvm::runtime::TVMRetValue*)#1}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)
0: tvm::relay::ReshapeRel(tvm::runtime::Array<tvm::Type, void> const&, int, tvm::Attrs const&, tvm::TypeReporter const&)
File "/home/hans/code/tvm/src/relay/analysis/type_solver.cc", line 624
TVMError:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: (false) is false: [14:35:14] /home/hans/code/tvm/src/relay/op/tensor/transform.cc:787:
---------------------------------------------------------------
An error occurred during the execution of TVM.
For more information, please see: https://tvm.apache.org/docs/errors.html
---------------------------------------------------------------
Check failed: oshape_sum == data_shape_sum (24576 vs. 8192) : Input tensor shape(1536,1,4,4) and reshaped shape(512,1,4,4) are not compatible!
As a workaround, I've rewritten my conv_transpose2d converter to manually split the data and weights into groups, perform each transposed conv, and then concatenate them back. This converter does seem to give the correct output shape although I haven't yet tested the outputs for correctness, I might have just gotten lucky with the shapes.
Workaround converter implementation (manual grouping)
def convert_conv_transpose2d_workaround(self, inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
input_channels, channels, kh, kw = list(self.infer_shape(weight))
if groups > 1 and channels == 1:
channel_multiplier = channels // groups
new_weight_shape = (groups, channel_multiplier, kh, kw)
weight = relay.op.transform.reshape(weight, new_weight_shape)
datas = relay.op.split(data, groups, axis=1)
weights = relay.op.split(weight, groups, axis=0)
rs = []
for d, w in zip(datas, weights):
r = relay.op.nn.conv2d_transpose(
d, w, strides=strides, padding=padding, output_padding=output_padding, dilation=dilation, groups=1
)
if bias is not None:
r = relay.op.nn.bias_add(r, bias)
rs.append(r)
res = relay.op.concatenate(rs, axis=1)
return resExpected behavior
The groups argument of tvm.relay.op.nn.conv2d_transpose should work correctly like tvm.relay.op.nn.conv2d does.
Actual behavior
The transposed convolution seems to only be applied to a single group?
Environment
Ubuntu 20.04
PyTorch 1.12.0.dev20220210
TVM 0.9.dev525+g8aeb72265 (compiled from main a couple hours ago)
CUDA 11.4
Steps to reproduce
from copy import deepcopy
import torch
import tvm.relay
from torch.nn.functional import conv_transpose2d
_original_get_constant = deepcopy(tvm.relay.frontend.pytorch._get_constant)
def _my_get_constant(node):
"""Monkey patch in support for prim::Constant lists, I guess torch.jit.optimize_for_inference introduces these?"""
if node.output().type().kind() == "ListType":
print("WARNING: Encountered ListType in _get_constant, doing weird eval stuff to get the list value:", end=" ")
lst = eval(node.__repr__().split("value=")[1].replace("]()", ""))
print(lst)
return lst
else:
return _original_get_constant(node)
tvm.relay.frontend.pytorch._get_constant = _my_get_constant
def convert_conv_transpose2d(inputs, input_types):
data = inputs[0]
weight = inputs[1]
bias = inputs[2]
strides = inputs[3]
padding = inputs[4]
output_padding = inputs[5]
groups = inputs[6]
dilation = inputs[7]
res = tvm.relay.op.nn.conv2d_transpose(
data,
weight,
strides=strides,
padding=padding,
output_padding=output_padding,
dilation=dilation,
groups=groups,
)
if bias is not None:
res = tvm.relay.op.nn.bias_add(res, bias)
return res
class ModulatedConvTranspose2D(torch.nn.Module):
def forward(self, x, w, s):
B, C, H, W = x.shape
I, O, KH, KW = w.shape
# weight is different for each input in batch (this is why we want grouped conv transpose)
w = w.unsqueeze(0) * s.reshape(B, 1, 1, 1, 1)
w = w.reshape(B * I, O, KH, KW)
x = x.reshape(1, B * C, H, W)
x = conv_transpose2d(x, w, stride=(2, 2), padding=(1, 1), output_padding=(1, 1), groups=B)
# Check failed: oshape_sum == data_shape_sum (524288 vs. 131072) : Input tensor shape(4,256,16,32) and reshaped shape(1,256,16,32) are not compatible!
x = x.reshape(B, O, H * 2, W * 2)
return x
with torch.inference_mode():
b, c, h, w, k = 4, 512, 8, 16, 3
inputs = torch.rand(b, c, h, w)
weights = torch.rand(c, c // 2, k, k)
styles = torch.rand(b)
torch_mod = torch.jit.optimize_for_inference(
torch.jit.trace(ModulatedConvTranspose2D().eval(), (inputs, weights, styles))
)
outputs_torch = torch_mod(inputs, weights, styles)
print("Torch output shape", outputs_torch.shape) # torch.Size([4, 256, 16, 32])
tvm_mod, params = tvm.relay.frontend.pytorch.from_pytorch(
torch_mod,
[("inputs", inputs.shape), ("weights", weights.shape), ("styles", styles.shape)],
{"aten::conv_transpose2d": convert_conv_transpose2d},
)