Skip to content

Commit 84ee90c

Browse files
[ONNX] Fix onnx convtranspose error (#9938)
* fix mix up of channels with conv2d-transpose * add grouped convtranspose tests * turn off groups for non-llvm test
1 parent 1b1cfb3 commit 84ee90c

File tree

2 files changed

+18
-9
lines changed

2 files changed

+18
-9
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,9 @@
3838
from .. import ty as _ty
3939
from .. import vision as _vision
4040
from .common import (
41-
autopad,
4241
AttrCvt,
4342
Renamer,
43+
autopad,
4444
ensure_scalar_shape,
4545
fold_constant,
4646
get_name,
@@ -557,13 +557,13 @@ class ConvTranspose(OnnxOpConverter):
557557
def _impl_v1(cls, inputs, attr, params):
558558
# get number of channels
559559
out_type = infer_type(inputs[1])
560-
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
561-
channels = out_shapes[0][1]
562-
attr["channels"] = channels
560+
kernel_shape = [get_const_tuple(out_type.checked_type.shape)]
561+
out_channels = kernel_shape[0][1] * attr.get("group", 1)
562+
attr["channels"] = out_channels
563563
groups = attr.get("group", 1)
564564

565565
if "kernel_shape" not in attr:
566-
attr["kernel_shape"] = out_shapes[0][2:]
566+
attr["kernel_shape"] = kernel_shape[0][2:]
567567

568568
attr["groups"] = groups
569569
# infer pads for auto_pad
@@ -612,13 +612,13 @@ def _impl_v1(cls, inputs, attr, params):
612612
def _impl_v11(cls, inputs, attr, params):
613613
# get number of channels
614614
out_type = infer_type(inputs[1])
615-
out_shapes = [get_const_tuple(out_type.checked_type.shape)]
616-
channels = out_shapes[0][1]
617-
attr["channels"] = channels
615+
kernel_shape = [get_const_tuple(out_type.checked_type.shape)]
616+
out_channels = kernel_shape[0][1] * attr.get("group", 1)
617+
attr["channels"] = out_channels
618618
groups = attr.get("group", 1)
619619

620620
if "kernel_shape" not in attr:
621-
attr["kernel_shape"] = out_shapes[0][2:]
621+
attr["kernel_shape"] = kernel_shape[0][2:]
622622

623623
attr["groups"] = groups
624624
# infer pads for auto_pad

tests/python/frontend/onnx/test_forward.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,7 @@ def verify_with_ort_with_inputs(
189189
opt_level=opt_level,
190190
convert_config=convert_config,
191191
)
192+
192193
if not isinstance(tvm_out, list):
193194
tvm_out = [tvm_out]
194195
if not isinstance(ort_out, list):
@@ -2892,6 +2893,14 @@ def verify_convtranspose(x_shape, w_shape, y_shape, p, group=1):
28922893
# Test undefined groups.
28932894
verify_convtranspose((1, 1, 3, 3), (1, 2, 3, 3), (1, 2, 7, 3), [1, 2, 1, 2], group=None)
28942895

2896+
if "llvm" in target:
2897+
# GPU does not support groups != 1 for convtranspose, so only test llvm
2898+
# Test depthwise-convolution
2899+
verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 10, 7, 3), [1, 2, 1, 2], group=10)
2900+
2901+
# Test grouped-convolution
2902+
verify_convtranspose((1, 10, 3, 3), (10, 1, 3, 3), (1, 5, 7, 3), [1, 2, 1, 2], group=5)
2903+
28952904
def repeat(N, D):
28962905
return tuple([N for _ in range(D)])
28972906

0 commit comments

Comments
 (0)