Skip to content

Commit 15e185d

Browse files
authored
[Hexagon][QNN] Improve performance wo QNN canonicalization (#13734)
This commit improves performance of different models tuned with MetaScheduler for Hexagon target and without QNN canonicalization. Benchmarking of several models on Snapdragon 8gen1 and tuned with MS: shape | QNN canon enabled, ms | QNN canon disabled, ms | speedup | -----------------|-----------------------|------------------------|-------------| ResNet, int8 | 50 | 48 | +4.2% | Inception, int8 | 103 | 106 | -2.8% | SRGAN, int8 | 348 | 431 | -19.3% | --------------------------------------------------------------------------------| What was done: 1) Added 2 new passes: QnnLegalize and QnnCanonicalize. But this is just wrappers for Legalize("FTVMQnnLegalize") and Legalize("FTVMQnnCanonicalize"). 2) Added ability to disable inline for specific blocks in MetaSchedule AutoInline rule. For example, it can be done through the T.block_attr({"meta_schedule.inline_rule": "disable"}). 3) Implemented compute, alter op and legalization functions for qnn.conv2d operation (for Hexagon target).
1 parent 77b6f0e commit 15e185d

File tree

14 files changed

+499
-42
lines changed

14 files changed

+499
-42
lines changed

include/tvm/relay/transform.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,10 @@ TVM_DLL Function UnCPS(const Function& f);
710710
*/
711711
TVM_DLL Expr DeDup(const Expr& e);
712712

713+
namespace legalize {
714+
TVM_DLL Expr Legalize(const Expr& expr, const std::string& legalize_map_attr_name);
715+
} // namespace legalize
716+
713717
} // namespace relay
714718
} // namespace tvm
715719

include/tvm/tir/stmt.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1613,6 +1613,9 @@ constexpr const char* meta_schedule_auto_tensorize_init = "meta_schedule.auto_te
16131613
*/
16141614
constexpr const char* warp_execution = "warp_execution";
16151615

1616+
/*! \brief Mark that a block is disallowed in auto inline. */
1617+
constexpr const char* meta_schedule_inline_rule = "meta_schedule.inline_rule";
1618+
16161619
/*!
16171620
* \brief Check if attr_key is a pragma key extension
16181621
* \param attr_key The attr key to be compared

python/tvm/relay/qnn/op/_qnn.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .. import strategy
2323
from ...op.op import register_compute
2424
from ...op.op import register_injective_schedule
25-
from ...op.op import register_strategy, register_pattern, OpPattern
25+
from ...op.op import register_strategy, register_pattern, register_alter_op_layout, OpPattern
2626

2727

2828
@register_compute("qnn.simulated_quantize")
@@ -83,7 +83,13 @@ def simulated_dequantize_compute(attrs, inputs, output_type):
8383

8484
# qnn.conv2d
8585
register_strategy("qnn.conv2d", strategy.qnn_conv2d_strategy)
86-
register_pattern("qnn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)
86+
87+
88+
@register_alter_op_layout("qnn.conv2d")
89+
def alter_op_layout_qnn_conv2d(attrs, inputs, tinfos, out_type):
90+
"""Alternate the layout of qnn.conv2d"""
91+
return topi.nn.qnn_conv2d_alter_layout(attrs, inputs, tinfos, out_type)
92+
8793

8894
# qnn.dense
8995
register_strategy("qnn.dense", strategy.qnn_dense_strategy)

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,11 @@ def is_fast_int8_on_intel():
405405
return target_has_sse42(target.mcpu)
406406

407407

408+
# Helper function to align up given value.
409+
def helper_align_up(value, aligner):
410+
return ((value + aligner) // aligner) * aligner
411+
412+
408413
########################
409414
# ARM CPU legalizations.
410415
########################
@@ -483,3 +488,68 @@ def _qnn_dense_legalize_cuda(attrs, inputs, types):
483488
# CUDA prefers both datatypes to be the int8.
484489
return helper_change_dtypes_to_int8(attrs, inputs, types, relay.qnn.op.dense)
485490
return None
491+
492+
493+
########################
494+
# Hexagon legalizations.
495+
########################
496+
497+
IN_CHANNEL_VECTOR_LENGTH = 4
498+
OUT_CHANNEL_VECTOR_LENGTH = 32
499+
500+
501+
@qnn_conv2d_legalize.register("hexagon")
502+
def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
503+
"""Legalize qnn.conv2d op for vrmpy tensorization.
504+
505+
If the inputs are signed or unsigned int8 and data/kernel layouts are NCHW/OIHW, then the input
506+
and output channels are padded to be a multiple of 4 and 32 respectively.
507+
"""
508+
data_layout = attrs["data_layout"]
509+
kernel_layout = attrs["kernel_layout"]
510+
511+
if data_layout != "NCHW" or kernel_layout != "OIHW":
512+
return None
513+
514+
data_tensor, kernel_tensor = types[0], types[1]
515+
516+
if "int8" in data_tensor.dtype and "int8" in kernel_tensor.dtype:
517+
in_channel = data_tensor.shape[1].value
518+
out_channel = kernel_tensor.shape[0].value
519+
ic_modified = False
520+
oc_modified = False
521+
data, kernel, input_zp, output_zp, input_scale, output_scale = inputs
522+
523+
if in_channel % IN_CHANNEL_VECTOR_LENGTH != 0:
524+
new_in_channel = helper_align_up(in_channel, IN_CHANNEL_VECTOR_LENGTH)
525+
diff = new_in_channel - in_channel
526+
pad_width = ((0, 0), (0, diff), (0, 0), (0, 0))
527+
data = relay.nn.pad(data, pad_width=pad_width)
528+
kernel = relay.nn.pad(kernel, pad_width=pad_width)
529+
ic_modified = True
530+
531+
new_out_channel = out_channel
532+
if out_channel % OUT_CHANNEL_VECTOR_LENGTH != 0:
533+
new_out_channel = helper_align_up(out_channel, OUT_CHANNEL_VECTOR_LENGTH)
534+
diff = new_out_channel - out_channel
535+
kernel = relay.nn.pad(kernel, pad_width=((0, diff), (0, 0), (0, 0), (0, 0)))
536+
oc_modified = True
537+
538+
if ic_modified is True or oc_modified is True:
539+
new_attrs = dict(attrs)
540+
if oc_modified:
541+
new_attrs["channels"] = new_out_channel
542+
out = relay.qnn.op.conv2d(
543+
data, kernel, input_zp, output_zp, input_scale, output_scale, **new_attrs
544+
)
545+
output_tensor = types[6]
546+
original_out_shape = list(output_tensor.shape)
547+
out = relay.strided_slice(out, begin=[0, 0, 0, 0], end=original_out_shape)
548+
else:
549+
out = relay.qnn.op.conv2d(
550+
data, kernel, input_zp, output_zp, input_scale, output_scale, **new_attrs
551+
)
552+
553+
return out
554+
555+
return None

python/tvm/relay/qnn/strategy/hexagon.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,18 @@
1717
"""Definition of Hexagon operator strategy."""
1818
# pylint: disable=unused-argument,wildcard-import,unused-wildcard-import
1919

20+
import re
21+
2022
from tvm import topi
2123
from .generic import *
2224
from ... import op as _op
2325
from ...op.strategy.generic import is_depthwise_conv2d
2426

2527

28+
NCHWC_MATCHER = re.compile("^NCHW[0-9]+c$")
29+
OIHWIOI_MATCHER = re.compile("^OIHW[0-9]+i[0-9]+o[0-9]+i$")
30+
31+
2632
@qnn_quantize_strategy.register("hexagon")
2733
def qnn_quantize_strategy_hexagon(attrs, inputs, out_type, target):
2834
"""qnn.quantize strategy for Hexagon"""
@@ -135,6 +141,13 @@ def qnn_conv2d_strategy_hexagon(attrs, inputs, out_type, target):
135141
wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d),
136142
name="qnn_conv2d.hexagon",
137143
)
144+
elif NCHWC_MATCHER.match(data_layout) and OIHWIOI_MATCHER.match(kernel_layout):
145+
if data.dtype == "uint8" and kernel.dtype == "int8":
146+
strategy.add_implementation(
147+
wrap_topi_qnn_conv2d(topi.hexagon.qnn_conv2d_NCHWc_int8),
148+
wrap_topi_schedule(topi.hexagon.schedule_qnn_conv2d_NCHWc_int8),
149+
name="qnn_conv2d_NCHWc_int8.hexagon",
150+
)
138151
elif is_depthwise_conv2d(data.shape, data_layout, kernel.shape, kernel_layout, groups):
139152
if data_layout == "NCHW" and kernel_layout == "OIHW":
140153
strategy.add_implementation(

python/tvm/topi/hexagon/qnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,4 @@
2929
from .qdepthwise_conv2d_slice import qdepthwise_conv2d_compute, qdepthwise_conv2d_schedule
3030
from .adaptive_avg_pool1d import *
3131
from .global_avg_pool2d import *
32+
from .conv2d_alter_op import *
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""QNN Conv2d alter op functions for Hexagon"""
18+
19+
from tvm import relay
20+
from ...nn import qnn_conv2d_alter_layout
21+
from ...utils import get_const_tuple
22+
23+
24+
@qnn_conv2d_alter_layout.register("hexagon")
25+
def _alter_qnn_conv2d_layout(attrs, inputs, tinfos, _out_type):
26+
data_layout = attrs["data_layout"]
27+
kernel_layout = attrs["kernel_layout"]
28+
data_tensor, kernel_tensor, _, _, _, _ = tinfos
29+
30+
if (
31+
"int8" in data_tensor.dtype
32+
and "int8" in kernel_tensor.dtype
33+
and data_layout == "NCHW"
34+
and kernel_layout == "OIHW"
35+
):
36+
out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)
37+
38+
if out_channel % 32 != 0 or in_channel % 4 != 0:
39+
return None
40+
41+
n_elems = 4
42+
oc_bn = 32
43+
ic_bn = min(in_channel, 32)
44+
45+
new_attrs = dict(attrs)
46+
new_attrs["channels"] = out_channel
47+
new_attrs["data_layout"] = "NCHW%dc" % ic_bn
48+
new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems)
49+
new_attrs["out_layout"] = "NCHW%dc" % oc_bn
50+
51+
return relay.qnn.op.conv2d(*inputs, **new_attrs)
52+
53+
return None

0 commit comments

Comments
 (0)