Skip to content

Commit 6159b8e

Browse files
[Topi][Op][PyTorch][Vitas] Fix inconsistent kernel layout conventions for conv2d_transpose (#9336)
* fix a lot of initial tests * make pytorch tests pass * lint * add test * fix bug with layout transform * change layouts for conv2d_transpose too * fix vitis tests * fix qnn conv2d transpose tests * fix fake quantization pass * add todo * lint * undo just formatting changes * remove formatting only change * remove f2qi for later pr * more frontend tests fixes * fix a lot of initial tests * make pytorch tests pass * lint * add test * fix bug with layout transform * change layouts for conv2d_transpose too * fix vitis tests * fix qnn conv2d transpose tests * fix fake quantization pass * add todo * lint * undo just formatting changes * remove formatting only change * remove f2qi for later pr * more frontend tests fixes * jostle * fix keras * fix another frontend test * fix things * jostle ci
1 parent 4bebfd8 commit 6159b8e

File tree

17 files changed

+223
-155
lines changed

17 files changed

+223
-155
lines changed

python/tvm/relay/frontend/caffe.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@
2121
import numpy as np
2222
import tvm
2323
from tvm.ir import IRModule
24+
25+
from ... import nd as _nd
2426
from .. import analysis
2527
from .. import expr as _expr
2628
from .. import function as _function
2729
from .. import op as _op
28-
from ... import nd as _nd
2930
from .common import ExprTable
3031
from .common import infer_shape as _infer_shape
3132

@@ -514,14 +515,16 @@ def convert_deconv(self, op):
514515
weight_shape = [-1, conv_params.num_output, kh, kw]
515516
weight_value = np.asarray(weight.data, np.float32)
516517
weight_value = np.reshape(weight_value, weight_shape)
518+
519+
# weight shape is in relay's IOHW format rn, we need it to be OIHW
520+
weight_value = np.transpose(weight_value, [1, 0, 2, 3])
517521
else:
518522
raise Exception("No weight value of layer {} in caffemodel".format(op.name))
519523

520524
weight_expr = self.exp_tab.new_const(weight_value, dtype="float32")
521525
in_expr = self.exp_tab.get_expr(inputs[0])
522526
out = _op.nn.conv2d_transpose(data=in_expr, weight=weight_expr, **params)
523527
if bias:
524-
525528
bias_value = np.asarray(bias.data, np.float32)
526529
bias_expr = self.exp_tab.new_const(bias_value, dtype="float32")
527530
out = _op.nn.bias_add(out, bias_expr)

python/tvm/relay/frontend/keras.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -355,11 +355,14 @@ def _convert_convolution(inexpr, keras_layer, etab):
355355
else:
356356
kernel_layout = "HWIO"
357357
else:
358-
kernel_layout = "OIHW"
358+
if is_deconv:
359+
kernel_layout = "IOHW"
360+
else:
361+
kernel_layout = "OIHW"
359362

360363
if is_deconv:
361364
kernel_h, kernel_w, n_filters, in_channels = weight.shape
362-
if kernel_layout == "OIHW":
365+
if kernel_layout == "IOHW":
363366
weight = weight.transpose([3, 2, 0, 1])
364367
elif is_depthconv:
365368
kernel_h, kernel_w, in_channels, depth_mult = weight.shape

python/tvm/relay/frontend/mxnet.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,40 +18,50 @@
1818
"""MXNet symbol frontend."""
1919
import json
2020
import math
21+
2122
import numpy as np
2223
import tvm
23-
from tvm.ir import IRModule
24-
2524
from tvm import relay
25+
from tvm.ir import IRModule
2626
from tvm.topi.utils import get_const_tuple
27+
28+
from ... import nd as _nd
2729
from .. import analysis
2830
from .. import expr as _expr
2931
from .. import function as _function
3032
from .. import op as _op
3133
from .. import scope_builder as _scope_builder
32-
from ... import nd as _nd
33-
3434
from .common import StrAttrsDict
35-
from .common import infer_type as _infer_type
35+
from .common import get_name as _get_name
3636
from .common import infer_shape as _infer_shape
37+
from .common import infer_type as _infer_type
3738
from .common import infer_value as _infer_value
38-
from .common import get_name as _get_name
39-
from .nnvm_common import _rename, _binop_scalar, _rbinop_scalar, _reduce
40-
from .nnvm_common import _arg_reduce, _init_op, _softmax_op, _cast
41-
from .nnvm_common import _clip, _transpose, _upsampling
42-
from .nnvm_common import _elemwise_sum, _reshape
43-
from .nnvm_common import _warn_not_used
4439
from .mxnet_qnn_op_utils import (
45-
quantize_mxnet_min_max,
46-
quantize_conv_weights_bias_channel_mkldnn_from_var,
47-
quantize_conv_bias_mkldnn_from_var,
48-
get_conv_mkldnn_requantized_scale_outDtype,
4940
dequantize_mxnet_min_max,
41+
get_conv_mkldnn_requantized_scale_outDtype,
5042
get_mkldnn_int8_scale,
51-
get_mkldnn_uint8_scale,
5243
get_mkldnn_requantize_scale_outDtype,
44+
get_mkldnn_uint8_scale,
45+
quantize_conv_bias_mkldnn_from_var,
46+
quantize_conv_weights_bias_channel_mkldnn_from_var,
47+
quantize_mxnet_min_max,
48+
)
49+
from .nnvm_common import (
50+
_arg_reduce,
51+
_binop_scalar,
52+
_cast,
53+
_clip,
54+
_elemwise_sum,
55+
_init_op,
56+
_rbinop_scalar,
57+
_reduce,
58+
_rename,
59+
_reshape,
60+
_softmax_op,
61+
_transpose,
62+
_upsampling,
63+
_warn_not_used,
5364
)
54-
5565

5666
__all__ = ["from_mxnet"]
5767

@@ -329,7 +339,7 @@ def _mx_conv2d_transpose(inputs, attrs):
329339
if "kernel_layout" in attrs.attrs:
330340
kernel_layout = attrs.get_str("kernel_layout")
331341
else:
332-
kernel_layout = "HWIO" if data_layout == "NHWC" else "OIHW"
342+
kernel_layout = "HWIO" if data_layout == "NHWC" else "IOHW"
333343

334344
new_attrs = {}
335345
new_attrs["channels"] = attrs.get_int("num_filter")

python/tvm/relay/frontend/pytorch.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919
# pylint: disable=import-outside-toplevel, simplifiable-if-expression, cell-var-from-loop, unnecessary-lambda
2020
# pylint: disable=missing-function-docstring
2121
"""PT: PyTorch frontend."""
22-
import itertools
2322
import functools
23+
import itertools
2424
import logging
2525
import math
2626
import sys
@@ -40,11 +40,11 @@
4040
from ..prelude import Prelude, StaticTensorArrayOps
4141
from ..ty import Any, TensorType, TupleType
4242
from . import qnn_torch
43-
from .common import AttrCvt, get_relay_op, unbind, lstm_cell, gru_cell
44-
from .common import infer_value as _infer_value
43+
from .common import AttrCvt, get_relay_op, gru_cell
4544
from .common import infer_shape as _infer_shape
45+
from .common import infer_value as _infer_value
4646
from .common import infer_value_simulated as _infer_value_simulated
47-
from .common import try_infer_value
47+
from .common import lstm_cell, try_infer_value, unbind
4848
from .pytorch_utils import is_version_greater_than
4949

5050
__all__ = ["from_pytorch"]
@@ -1010,6 +1010,9 @@ def convolution(self, inputs, input_types):
10101010
elif len(kernel_size) == 2:
10111011
data_layout = "NCHW"
10121012
kernel_layout = "OIHW"
1013+
if use_transpose:
1014+
# Transposed convolutions have IOHW layout.
1015+
kernel_layout = "IOHW"
10131016
else:
10141017
data_layout = "NCW"
10151018
kernel_layout = "OIW"

python/tvm/relay/frontend/qnn_torch.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import logging
2020

2121
import numpy as np
22-
2322
import tvm
2423
from tvm import relay
2524
from tvm.relay import expr as _expr
@@ -1043,11 +1042,8 @@ def _impl(inputs, _):
10431042

10441043
weight_shape = list(infer_shape(weight))
10451044

1046-
# Swap I and O dims to match shape relay expects for OIHW
1047-
weight_shape[0], weight_shape[1] = weight_shape[1], weight_shape[0]
1048-
10491045
kernel_size = (weight_shape[2], weight_shape[3])
1050-
out_channels = weight_shape[0]
1046+
out_channels = weight_shape[1]
10511047

10521048
conv_out = relay.qnn.op.conv2d_transpose(
10531049
inputs[0],
@@ -1064,7 +1060,7 @@ def _impl(inputs, _):
10641060
channels=out_channels,
10651061
output_padding=output_padding,
10661062
out_dtype="int32",
1067-
kernel_layout="OIHW",
1063+
kernel_layout="IOHW",
10681064
)
10691065

10701066
return _do_bias_and_requantize(

python/tvm/relay/frontend/tensorflow_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -461,8 +461,11 @@ def _impl(inputs, attr, params, mod):
461461
raise tvm.error.OpAttributeInvalid(msg.format(attr["padding"]))
462462

463463
if "kernel_layout" not in attr:
464-
if opname in ["conv", "conv_transpose"]:
464+
if opname == "conv":
465465
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "OIHW"
466+
elif opname == "conv_transpose":
467+
# conv_transpose in TVM has weights be IOHW for NCHW
468+
attr["kernel_layout"] = "HWIO" if attr["data_format"] == "NHWC" else "IOHW"
466469
else:
467470
attr["kernel_layout"] = "HWOI" if attr["data_format"] == "NHWC" else "OIHW"
468471

python/tvm/relay/frontend/tflite.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,24 +16,25 @@
1616
# under the License.
1717
# pylint: disable=invalid-name, unused-argument, too-many-lines, import-outside-toplevel
1818
"""Tensorflow lite frontend."""
19-
import math
2019
import itertools
20+
import math
21+
2122
import numpy as np
2223
import tvm
24+
from tvm import relay
2325
from tvm.ir import IRModule
2426

25-
from tvm import relay
27+
from ... import nd as _nd
2628
from .. import analysis
2729
from .. import expr as _expr
2830
from .. import function as _function
2931
from .. import op as _op
3032
from .. import qnn as _qnn
31-
from ... import nd as _nd
3233
from .common import ExprTable
33-
from .common import infer_shape as _infer_shape, to_int_list
34+
from .common import infer_shape as _infer_shape
35+
from .common import to_int_list
3436
from .tflite_flexbuffer import FlexBufferDecoder
3537

36-
3738
__all__ = ["from_tflite"]
3839

3940

@@ -53,9 +54,9 @@ class OperatorConverter(object):
5354
def __init__(self, model, subgraph, exp_tab):
5455

5556
try:
57+
from tflite.ActivationFunctionType import ActivationFunctionType
5658
from tflite.BuiltinOperator import BuiltinOperator
5759
from tflite.BuiltinOptions import BuiltinOptions
58-
from tflite.ActivationFunctionType import ActivationFunctionType
5960
except ImportError:
6061
raise ImportError("The tflite package must be installed")
6162

@@ -1061,8 +1062,8 @@ def convert_log_softmax(self, op):
10611062
def convert_concatenation(self, op):
10621063
"""Convert TFLite concatenation"""
10631064
try:
1064-
from tflite.ConcatenationOptions import ConcatenationOptions
10651065
from tflite.BuiltinOptions import BuiltinOptions
1066+
from tflite.ConcatenationOptions import ConcatenationOptions
10661067
except ImportError:
10671068
raise ImportError("The tflite package must be installed")
10681069

@@ -1242,10 +1243,10 @@ def _convert_elemwise(self, relay_op, op, ignore_qnn_params=False):
12421243
"""Generic method to Convert TFLite elemwise"""
12431244
try:
12441245
from tflite.AddOptions import AddOptions
1245-
from tflite.SubOptions import SubOptions
1246-
from tflite.MulOptions import MulOptions
1247-
from tflite.DivOptions import DivOptions
12481246
from tflite.BuiltinOptions import BuiltinOptions
1247+
from tflite.DivOptions import DivOptions
1248+
from tflite.MulOptions import MulOptions
1249+
from tflite.SubOptions import SubOptions
12491250
except ImportError:
12501251
raise ImportError("The tflite package must be installed")
12511252

@@ -1804,9 +1805,9 @@ def convert_reduce_any(self, op):
18041805
def _convert_arg_min_max(self, relay_op, op):
18051806
"""Generic method converting TFLite arg_min_max"""
18061807
try:
1807-
from tflite.BuiltinOptions import BuiltinOptions
1808-
from tflite.ArgMinOptions import ArgMinOptions
18091808
from tflite.ArgMaxOptions import ArgMaxOptions
1809+
from tflite.ArgMinOptions import ArgMinOptions
1810+
from tflite.BuiltinOptions import BuiltinOptions
18101811
except ImportError:
18111812
raise ImportError("The tflite package must be installed")
18121813

@@ -1853,8 +1854,8 @@ def convert_arg_max(self, op):
18531854
def convert_fully_connected(self, op):
18541855
"""Convert TFLite fully connected"""
18551856
try:
1856-
from tflite.FullyConnectedOptions import FullyConnectedOptions
18571857
from tflite.BuiltinOptions import BuiltinOptions
1858+
from tflite.FullyConnectedOptions import FullyConnectedOptions
18581859
from tflite.TensorType import TensorType
18591860
except ImportError:
18601861
raise ImportError("The tflite package must be installed")
@@ -2024,10 +2025,10 @@ def convert_conv(self, op, conv_type):
20242025
"""convolution implementation."""
20252026
try:
20262027
from tflite.BuiltinOptions import BuiltinOptions
2027-
from tflite.TensorType import TensorType
20282028
from tflite.Conv2DOptions import Conv2DOptions
20292029
from tflite.DepthwiseConv2DOptions import DepthwiseConv2DOptions
20302030
from tflite.Padding import Padding
2031+
from tflite.TensorType import TensorType
20312032
except ImportError:
20322033
raise ImportError("The tflite package must be installed")
20332034

@@ -2434,8 +2435,8 @@ def convert_pool2d(self, op, pool_type):
24342435
"""pool2d implementation."""
24352436
try:
24362437
from tflite.BuiltinOptions import BuiltinOptions
2437-
from tflite.Pool2DOptions import Pool2DOptions
24382438
from tflite.Padding import Padding
2439+
from tflite.Pool2DOptions import Pool2DOptions
24392440
except ImportError:
24402441
raise ImportError("The tflite package must be installed")
24412442

@@ -2850,9 +2851,9 @@ def convert_transpose_conv(self, op):
28502851
"""Convert TFLite TRANSPOSE_CONV"""
28512852
try:
28522853
from tflite.BuiltinOptions import BuiltinOptions
2854+
from tflite.Padding import Padding
28532855
from tflite.TensorType import TensorType
28542856
from tflite.TransposeConvOptions import TransposeConvOptions
2855-
from tflite.Padding import Padding
28562857
except ImportError:
28572858
raise ImportError("The tflite package must be installed")
28582859

@@ -2946,7 +2947,7 @@ def convert_transpose_conv(self, op):
29462947
channels=int(out_channels),
29472948
kernel_size=(int(kernel_h), int(kernel_w)),
29482949
data_layout="NHWC",
2949-
kernel_layout="OIHW",
2950+
kernel_layout="IOHW",
29502951
out_dtype="int32",
29512952
)
29522953
else:
@@ -2958,7 +2959,7 @@ def convert_transpose_conv(self, op):
29582959
channels=int(out_channels),
29592960
kernel_size=(int(kernel_h), int(kernel_w)),
29602961
data_layout="NHWC",
2961-
kernel_layout="OIHW",
2962+
kernel_layout="IOHW",
29622963
out_dtype=output_tensor_type_str,
29632964
)
29642965

@@ -3717,8 +3718,8 @@ def from_tflite(model, shape_dict=None, dtype_dict=None, op_converter=OperatorCo
37173718
The parameter dict to be used by relay
37183719
"""
37193720
try:
3720-
import tflite.SubGraph
37213721
import tflite.BuiltinOperator
3722+
import tflite.SubGraph
37223723
except ImportError:
37233724
raise ImportError("The tflite package must be installed")
37243725

python/tvm/relay/op/nn/nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -522,7 +522,7 @@ def conv2d_transpose(
522522
channels=None,
523523
kernel_size=None,
524524
data_layout="NCHW",
525-
kernel_layout="OIHW",
525+
kernel_layout="IOHW",
526526
out_layout="",
527527
output_padding=(0, 0),
528528
out_dtype="",

python/tvm/relay/qnn/op/layout_conversions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def convert_qnn_conv2d_transpose(attrs, inputs, tinfos, desired_layouts):
119119

120120
# Handle default kernel layouts
121121
if desired_data_layout == "NCHW":
122-
new_attrs["kernel_layout"] = "OIHW"
122+
new_attrs["kernel_layout"] = "IOHW"
123123
return relay.qnn.op.conv2d_transpose(*inputs, **new_attrs)
124124
if desired_data_layout == "NHWC":
125125
new_attrs["kernel_layout"] = "HWIO"

0 commit comments

Comments
 (0)