Skip to content

Commit 116a60d

Browse files
author
ibsidorenko
committed
[QNN][Relay][Topi] Add qnn.dense with weight layout
This commit adds new Relay operation "qnn.dense_pack" that supports different weights layout (nn.dense and qnn.dense do not support this attribute). This new operation is full analog of "nn.contrib_dense_pack" operation but in QNN space.
1 parent fc98e9c commit 116a60d

File tree

13 files changed

+680
-80
lines changed

13 files changed

+680
-80
lines changed

python/tvm/relay/qnn/op/_qnn.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,16 @@ def alter_op_layout_qnn_conv2d(attrs, inputs, tinfos, out_type):
9393

9494
# qnn.dense
9595
register_strategy("qnn.dense", strategy.qnn_dense_strategy)
96-
register_pattern("qnn.dense", OpPattern.OUT_ELEMWISE_FUSABLE)
96+
97+
98+
@register_alter_op_layout("qnn.dense")
99+
def alter_op_layout_qnn_dense(attrs, inputs, tinfos, out_type):
100+
"""Alternate the layout of qnn.dense"""
101+
return topi.nn.qnn_dense_alter_layout(attrs, inputs, tinfos, out_type)
102+
103+
104+
# qnn.dense_pack
105+
register_strategy("qnn.dense_pack", strategy.qnn_dense_pack_strategy)
97106

98107
# qnn.batch_matmul
99108
register_strategy("qnn.batch_matmul", strategy.qnn_batch_matmul_strategy)

python/tvm/relay/qnn/op/legalizations.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,62 @@ def helper_change_dtypes_to_int8(attrs, inputs, types, relay_op):
340340
)
341341

342342

343+
def helper_change_dtypes_to_uint8(attrs, inputs, types, relay_op):
344+
"""Helper function to change dtypes to uint8 x uint8.
345+
Legalizes QNN dense op for Hexagon DSP. It supports fast u8 x u8 vrmpy instruction.
346+
347+
Converting from int8 to uint8 can be done in following manner:
348+
349+
Original equation
350+
scale * (QA - zp_a)
351+
scale * (QA + 128 - 128 - zp_a)
352+
scale * ( (QA + 128) - (zp_a + 128))
353+
354+
Replacing QA + 128 with QA' and (zp_a + 128) with zp_a'
355+
We get our new quantized uint8 tensor - scale * (QA' - zp_a')
356+
357+
Parameters
358+
----------
359+
attrs : tvm.ir.Attrs
360+
Attributes of current convolution
361+
inputs : list of tvm.relay.Expr
362+
The args of the Relay expr to be legalized
363+
types : list of types
364+
List of input and output types
365+
366+
Returns
367+
-------
368+
result : tvm.relay.Expr
369+
The legalized expr
370+
"""
371+
# Collect the dtypes.
372+
data_dtype = types[0].dtype
373+
kernel_dtype = types[1].dtype
374+
375+
# Do nothing since it is already uint8.
376+
if data_dtype == "uint8" and kernel_dtype == "uint8":
377+
return None
378+
379+
# Collect the input exprs.
380+
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale = inputs
381+
382+
# Shift input if necessary.
383+
if data_dtype == "int8":
384+
# Compute (QA + 128) and (zp_a + 128)
385+
data, input_zero_point = _shift(data, input_zero_point, "uint8")
386+
387+
# Shift kernel if necessary.
388+
if kernel_dtype == "int8":
389+
# Compute (QA + 128) and (zp_a + 128)
390+
kernel, kernel_zero_point = _shift(kernel, kernel_zero_point, "uint8")
391+
392+
# Call qnn.conv2d/qnn.dense with modified inputs and zero points.
393+
new_attrs = dict(attrs)
394+
return relay_op(
395+
data, kernel, input_zero_point, kernel_zero_point, input_scale, kernel_scale, **new_attrs
396+
)
397+
398+
343399
# Helper function to change dtypes to be same. ARM dotprod instructions prefer this setting.
344400
def helper_change_dtypes_to_be_same(attrs, inputs, types, relay_op):
345401
"""Sometimes MxNet + MLDNN can lead to uint8 x int8 datatypes for the conv inputs. However,
@@ -555,3 +611,54 @@ def _qnn_conv2d_legalize_hexagon(attrs, inputs, types):
555611
return out
556612

557613
return None
614+
615+
616+
@qnn_dense_legalize.register("hexagon")
617+
def _qnn_dense_legalize_hexagon(attrs, inputs, types):
618+
"""Legalize qnn.dense op for vrmpy tensorization.
619+
620+
N dimension of weights should be aligned on vector length. If not, then N dimension is padded to
621+
be a multiple of 32.
622+
"""
623+
assert len(types) == 7
624+
assert len(inputs) == 6
625+
626+
data_tensor, kernel_tensor = types[0], types[1]
627+
if "int8" not in data_tensor.dtype or "int8" not in kernel_tensor.dtype:
628+
return None
629+
630+
N, _ = kernel_tensor.shape
631+
632+
if N % OUT_CHANNEL_VECTOR_LENGTH != 0:
633+
N_padded = helper_align_up(N, OUT_CHANNEL_VECTOR_LENGTH)
634+
diff = N_padded - N
635+
636+
# Padd weights by 'diff'
637+
padded_kernel = relay.nn.pad(inputs[1], pad_width=((0, diff), (0, 0)))
638+
639+
# If units is explicitly specified, it is used to compute the output shape.
640+
# We need to update units after padding to prevent a type error.
641+
new_attrs = dict(attrs)
642+
if attrs["units"] is not None:
643+
new_attrs["units"] = N + diff
644+
645+
new_inputs = (inputs[0], padded_kernel, *inputs[2:])
646+
647+
# TODO: enable legalization u8i8i32 -> u8u8i32 for qnn.dense. Code:
648+
# """
649+
# new_types = (types[0], relay.TensorType([N + diff, C], types[1].dtype), *types[2:])
650+
# out = helper_change_dtypes_to_uint8(new_attrs, new_inputs, new_types, relay.qnn.op.dense)
651+
# if out is None:
652+
# out = relay.qnn.op.dense(*new_inputs, **new_attrs)
653+
# """
654+
out = relay.qnn.op.dense(*new_inputs, **new_attrs)
655+
656+
output_tensor = types[6]
657+
out = relay.strided_slice(out, begin=[0, 0], end=list(output_tensor.shape))
658+
return out
659+
660+
# TODO: enable legalization u8i8i32 -> u8u8i32 for qnn.dense. Code:
661+
# """
662+
# return helper_change_dtypes_to_uint8(attrs, inputs, types, relay.qnn.op.dense)
663+
# """
664+
return None

python/tvm/relay/qnn/op/qnn.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,70 @@ def dense(
718718
)
719719

720720

721+
def dense_pack(
722+
data,
723+
weight,
724+
input_zero_point,
725+
kernel_zero_point,
726+
input_scale,
727+
kernel_scale,
728+
kernel_layout="NC",
729+
units=None,
730+
out_dtype="int32",
731+
):
732+
"""Qnn Dense_pack operator.
733+
Applies a quantized linear transformation
734+
735+
.. math::
736+
737+
`Y = X * W`
738+
739+
If doing Per-channel quantization, qnn expects the kernel_zero_scale
740+
and optionally the kernel_zero_point will be 1-D vectors instead of scalars.
741+
742+
Parameters
743+
----------
744+
data : tvm.relay.Expr
745+
The quantized input data to the operator.
746+
weight : tvm.relay.Expr
747+
The quantized weight expressions.
748+
input_zero_point: tvm.relay.Expr
749+
The input zero point.
750+
kernel_zero_point: tvm.relay.Expr
751+
The kernel zero point.
752+
input_scale: tvm.relay.Expr
753+
The scale for the input tensor.
754+
kernel_scale: tvm.relay.Expr
755+
The scale for the weight tensor. The scale for the weight tensor is
756+
stored for access to this during relay. This information is not
757+
needed in the pass pipeline after qnn.conv2d is lowered to the
758+
sequence of steps as in nn.conv2d. See also input_scale in Requantize.
759+
kernel_layout: str
760+
The layout of weight, such as "NC" or "NC32n4c".
761+
units : int, optional
762+
Number of hidden units of the dense transformation.
763+
out_dtype : str, optional
764+
Specifies the output data type for mixed precision dense can be int32 or int16.
765+
766+
Returns
767+
-------
768+
result : tvm.relay.Expr
769+
The computed result.
770+
"""
771+
772+
return _make.dense_pack(
773+
data,
774+
weight,
775+
input_zero_point,
776+
kernel_zero_point,
777+
input_scale,
778+
kernel_scale,
779+
kernel_layout,
780+
units,
781+
out_dtype,
782+
)
783+
784+
721785
def mul(
722786
lhs,
723787
rhs,

python/tvm/relay/qnn/strategy/generic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,12 @@ def qnn_dense_strategy(attrs, inputs, out_type, target):
267267
)
268268

269269

270+
@override_native_generic_func("qnn_dense_pack_strategy")
271+
def qnn_dense_pack_strategy(attrs, inputs, out_type, target):
272+
"""qnn.dense_pack generic strategy"""
273+
raise RuntimeError("qnn.dense_pack is currently only supported with Hexagon. ")
274+
275+
270276
@override_native_generic_func("qnn_batch_matmul_strategy")
271277
def qnn_batch_matmul_strategy(attrs, inputs, out_type, target):
272278
"""qnn.batch_matmul generic strategy"""

python/tvm/relay/qnn/strategy/hexagon.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,24 @@ def qnn_dense_strategy_hexagon(attrs, inputs, out_type, target):
173173
return strategy
174174

175175

176+
@qnn_dense_pack_strategy.register("hexagon")
177+
def qnn_dense_pack_strategy_hexagon(attrs, inputs, out_type, target):
178+
"""qnn.dense_pack strategy for Hexagon"""
179+
strategy = _op.OpStrategy()
180+
if (
181+
"uint8" in inputs[0].dtype
182+
and "int8" in inputs[1].dtype
183+
and attrs["weight_layout"] == "NC32n4c"
184+
):
185+
# uint8 + uint8|int8 case
186+
strategy.add_implementation(
187+
wrap_topi_qnn_dense(topi.hexagon.qnn_dense_pack_vrmpy),
188+
wrap_topi_schedule(topi.hexagon.schedule_qnn_dense_pack_vrmpy),
189+
name="qnn_dense_pack_vrmpy.hexagon",
190+
)
191+
return strategy
192+
193+
176194
@qnn_batch_matmul_strategy.register("hexagon")
177195
def qnn_batch_matmul_strategy_hexagon(attrs, inputs, out_type, target):
178196
"""qnn.batch_matmul strategy for Hexagon"""

python/tvm/topi/hexagon/qnn/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from .adaptive_avg_pool1d import *
2121
from .avg_pool2d import qnn_avg_pool2d_compute, qnn_avg_pool2d_schedule
2222
from .conv2d_alter_op import *
23+
from .dense_alter_op import *
2324
from .dequantize import dequantize_compute, dequantize_schedule
2425
from .global_avg_pool2d import *
2526
from .nn import *
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""QNN Dense alter op functions for Hexagon"""
18+
19+
from tvm import relay
20+
from ..dense_alter_op import check_vrmpy_applicable
21+
from ...nn import qnn_dense_alter_layout
22+
23+
24+
@qnn_dense_alter_layout.register("hexagon")
25+
def _alter_qnn_dense_layout(_attrs, inputs, tinfos, out_type):
26+
data_tensor = tinfos[0]
27+
weight_tensor = tinfos[1]
28+
29+
if check_vrmpy_applicable(data_tensor, weight_tensor):
30+
weight_layout = "NC32n4c"
31+
return relay.qnn.op.dense_pack(*inputs, weight_layout, None, out_type.dtype)
32+
else:
33+
return None

0 commit comments

Comments
 (0)