Skip to content
This repository was archived by the owner on Nov 25, 2022. It is now read-only.

Commit fc103f5

Browse files
masahixinetzone
authored andcommitted
[Hexagon] vrmpy tensorization for e2e compilation of int8 models (apache#12911)
* [Hexagon] Support vrmpy tensorization for conv2d and dense schedules * update * clean up * migrate tests to test_launcher.py * remove vrmpy test files * use generic int8 conv2d schedule * clean up * doc update * pylint fix * parametrize dtype in test * doc update * add missing paralleization for dense * more pylint * fixed for fp32 dense
1 parent 318305a commit fc103f5

File tree

10 files changed

+662
-6
lines changed

10 files changed

+662
-6
lines changed

python/tvm/relay/op/strategy/hexagon.py

Lines changed: 36 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def batch_matmul_strategy_hexagon(attrs, inputs, out_type, target):
3030
"""batch_matmul strategy for Hexagon"""
3131
strategy = _op.OpStrategy()
3232
strategy.add_implementation(
33-
wrap_compute_batch_matmul(topi.nn.batch_matmul),
33+
wrap_compute_batch_matmul(topi.nn.batch_matmul, need_out_dtype=True),
3434
wrap_topi_schedule(topi.hexagon.schedule_batch_matmul),
3535
name="batch_matmul.hexagon",
3636
)
@@ -187,3 +187,38 @@ def schedule_reduce_hexagon(attrs, outs, target):
187187
"""Schedule reduction ops for Hexagon"""
188188
with target:
189189
return topi.hexagon.schedule_reduce(outs)
190+
191+
192+
@conv2d_NCHWc_strategy.register("hexagon")
193+
def conv2d_NCHWc_strategy_hexagon(attrs, inputs, out_type, target):
194+
"""conv2d_NCHWc_ hexagon strategy"""
195+
strategy = _op.OpStrategy()
196+
strategy.add_implementation(
197+
wrap_compute_conv2d(
198+
topi.hexagon.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
199+
),
200+
wrap_topi_schedule(topi.hexagon.schedule_conv2d_NCHWc_int8),
201+
name="conv2d_NCHWc_int8.hexagon",
202+
)
203+
return strategy
204+
205+
206+
@dense_pack_strategy.register("hexagon")
207+
def dense_pack_strategy_hexagon(attrs, inputs, out_type, target):
208+
"""dense_pack hexagon strategy"""
209+
strategy = _op.OpStrategy()
210+
211+
if (
212+
inputs[0].dtype == "uint8"
213+
and inputs[1].dtype == "uint8"
214+
and out_type.dtype == "int32"
215+
and attrs["weight_layout"] == "NC32n4c"
216+
):
217+
strategy.add_implementation(
218+
wrap_compute_dense(topi.hexagon.dense.dense_u8u8i32_vrmpy_compute),
219+
wrap_topi_schedule(topi.hexagon.dense.dense_u8u8i32_vrmpy_schedule),
220+
name="dense_uint8.hexagon",
221+
plevel=12,
222+
)
223+
224+
return strategy

python/tvm/topi/generic/conv2d.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,16 @@ def schedule_conv_NCHWc_cpu_common_int8(
139139
More details - https://software.intel.com/en-us/articles/
140140
lower-numerical-precision-deep-learning-inference-and-training
141141
"""
142-
reg_n, unroll_kw = cfg["tile_ow"].size[-1], cfg["unroll_kw"].val
142+
if isinstance(cfg["tile_ow"], int):
143+
reg_n = cfg["tile_ow"]
144+
else:
145+
reg_n = cfg["tile_ow"].size[-1]
146+
147+
if isinstance(cfg["unroll_kw"], (int, bool)):
148+
unroll_kw = cfg["unroll_kw"]
149+
else:
150+
unroll_kw = cfg["unroll_kw"].val
151+
143152
_, _, _, _, ic_bn = get_const_tuple(data_vec.shape)
144153
_, _, _, _, oc_bn = get_const_tuple(conv_out.shape)
145154

python/tvm/topi/hexagon/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,5 @@
2929
from .resize2d import *
3030
from .tensor_intrin import *
3131
from .qnn import *
32+
from .dense_alter_op import *
33+
from .conv2d_alter_op import *

python/tvm/topi/hexagon/conv2d.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,15 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
17+
# pylint: disable=invalid-name
1818
"""Schedule for conv2d"""
1919

2020
import tvm
21+
from tvm import te
22+
from .. import nn
2123
from ..utils import traverse_inline
24+
from .tensor_intrin import dot_vrmpy
25+
from ..generic import conv2d as conv2d_generic
2226

2327

2428
def schedule_conv2d_nhwc(outs):
@@ -86,3 +90,46 @@ def _callback(op):
8690

8791
traverse_inline(s, outs[0].op, _callback)
8892
return s
93+
94+
95+
def conv2d_NCHWc_int8(
96+
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype="int32"
97+
):
98+
"""Compute definition for int8 conv2d in NCHWc layout"""
99+
n_elems = int(kernel.shape[-1])
100+
return nn.conv2d_NCHWc_int8(
101+
data, kernel, stride, padding, dilation, layout, out_layout, out_dtype, n_elems=n_elems
102+
)
103+
104+
105+
def schedule_conv2d_NCHWc_int8(outs):
106+
"""Schedule for int8 conv2d in NCHWc layout using vrmpy tensorization"""
107+
s = te.create_schedule([x.op for x in outs])
108+
109+
def _callback(op):
110+
if "conv2d_NCHWc_int8" in op.tag:
111+
conv_out = op.output(0)
112+
kernel_vec = conv_out.op.input_tensors[1]
113+
data_vec = conv_out.op.input_tensors[0]
114+
out_width = conv_out.shape[3]
115+
116+
reg_n = 1
117+
for n in range(31, 0, -1):
118+
if out_width % n == 0:
119+
reg_n = n
120+
break
121+
122+
cfg = {"tile_ow": reg_n, "unroll_kw": False}
123+
args = [s, cfg, data_vec, kernel_vec, conv_out, outs[0]]
124+
intrin = dot_vrmpy(data_vec.dtype, kernel_vec.dtype)
125+
126+
conv2d_generic.schedule_conv_NCHWc_cpu_common_int8(
127+
*args,
128+
int32_lanes=32,
129+
int8_elems=4,
130+
intrin=intrin,
131+
inline_fused=True,
132+
)
133+
134+
traverse_inline(s, outs[0].op, _callback)
135+
return s
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
18+
"""Conv2d alter op functions for Hexagon"""
19+
20+
from tvm import relay
21+
from ..utils import get_const_tuple
22+
from .. import nn
23+
from ..nn import conv2d_alter_layout
24+
from ..generic.conv2d import conv2d_alter_int8_common
25+
26+
27+
@conv2d_alter_layout.register("hexagon")
28+
def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
29+
"""Convert nn.conv2d into nn.contrib_conv2d_nchwc if vrmpy is applicable."""
30+
new_attrs = {k: attrs[k] for k in attrs.keys()}
31+
32+
data_layout = attrs["data_layout"]
33+
kernel_layout = attrs["kernel_layout"]
34+
data_tensor, kernel_tensor = tinfos
35+
out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)
36+
37+
if (
38+
"int8" in data_tensor.dtype
39+
and "int8" in kernel_tensor.dtype
40+
and out_channel % 32 == 0
41+
and in_channel % 4 == 0
42+
and data_layout == "NCHW"
43+
and kernel_layout == "OIHW"
44+
):
45+
out_channel, in_channel, _, _ = get_const_tuple(kernel_tensor.shape)
46+
47+
n_elems = 4
48+
oc_bn = 32
49+
ic_bn = min(in_channel, 32)
50+
51+
new_attrs = {k: attrs[k] for k in attrs.keys()}
52+
53+
new_attrs["channels"] = out_channel
54+
new_attrs["data_layout"] = "NCHW%dc" % ic_bn
55+
new_attrs["kernel_layout"] = "OIHW{:n}i{:n}o{:n}i".format(ic_bn // n_elems, oc_bn, n_elems)
56+
new_attrs["out_layout"] = "NCHW%dc" % oc_bn
57+
58+
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)
59+
60+
return None
61+
62+
63+
@nn.conv2d_legalize.register("hexagon")
64+
def _conv2d_legalize(attrs, inputs, arg_types):
65+
"""Legalize conv2d op for vrmpy tensorization.
66+
67+
If the inputs are signed or unsigned int8, the input and output channels are padded to be
68+
a multiple of 4 and 32 respectively.
69+
70+
If the input data types are (int8, int8), they are converted to (uint8, int8) and
71+
the vector-by-vector variant of vrmpy is applied.
72+
If the input data types are (uint8, uint8), the more efficient vector-by-scalar variant of vrmpy
73+
is applied.
74+
75+
Unlike the nn.dense case (see dense_alter_op.py), we do not convert (uint8, int8) to
76+
(uint8, uint8). That would introduce another convolution by a constant (128 or 1) filter,
77+
to compensate for the dtype legalization. In the nn.dense case, such compensation factor is
78+
just a sum over the K axis.
79+
"""
80+
data_layout = attrs["data_layout"]
81+
kernel_layout = attrs["kernel_layout"]
82+
83+
output_tensor = arg_types[2]
84+
85+
data, kernel = inputs
86+
87+
if data_layout != "NCHW" or kernel_layout != "OIHW":
88+
return None
89+
90+
data_tensor, kernel_tensor = arg_types[0], arg_types[1]
91+
92+
if "int8" in data_tensor.dtype and "int8" in data_tensor.dtype:
93+
output_tensor = arg_types[2]
94+
data, kernel = inputs
95+
desired_data_dtype = "uint8"
96+
in_channel_vector_length = 4
97+
out_channel_vector_length = 32
98+
99+
return conv2d_alter_int8_common(
100+
data,
101+
data_tensor,
102+
kernel,
103+
kernel_tensor,
104+
output_tensor,
105+
attrs,
106+
desired_data_dtype,
107+
in_channel_vector_length,
108+
out_channel_vector_length,
109+
)
110+
111+
return None

python/tvm/topi/hexagon/dense.py

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
17+
# pylint: disable=invalid-name
1818
"""Schedule for dense operator"""
1919

2020
import tvm
21+
from tvm.topi.utils import traverse_inline
22+
from tvm import te
23+
from .. import tag
24+
from .tensor_intrin import dot_vrmpy
2125

2226

2327
def schedule_dense(outs):
@@ -38,3 +42,70 @@ def schedule_dense(outs):
3842
s = tvm.te.create_schedule([x.op for x in outs])
3943
tvm.te.schedule.AutoInlineInjective(s)
4044
return s
45+
46+
47+
def dense_u8u8i32_vrmpy_compute(X, packed_w, bias, out_dtype):
48+
"""Compute for uint8 x uint8 -> int32 dense using vrmpy"""
49+
assert X.dtype == "uint8" and packed_w.dtype == "uint8" and out_dtype == "int32"
50+
m, k = X.shape
51+
n_o, _, n_i, _ = packed_w.shape
52+
assert n_i == 32
53+
ak = te.reduce_axis((0, k), name="k")
54+
55+
C = te.compute(
56+
(m, n_o * n_i),
57+
lambda i, j: te.sum(
58+
X[i, ak].astype("int32")
59+
* packed_w[tvm.tir.indexdiv(j, 32), tvm.tir.indexdiv(ak, 4), j % 32, ak % 4].astype(
60+
"int32"
61+
),
62+
axis=ak,
63+
),
64+
tag="dense_u8u8i32_vrmpy",
65+
name="compute",
66+
)
67+
68+
if bias is not None:
69+
C = te.compute(C.shape, lambda i, j: C[i, j] + bias[j], tag=tag.BROADCAST)
70+
71+
return C
72+
73+
74+
def dense_u8u8i32_vrmpy_schedule(outs):
75+
"""Schedule for vrmpy dense"""
76+
s = te.create_schedule([x.op for x in outs])
77+
# O: The output of the fused op
78+
O = outs[0]
79+
80+
def _schedule_dense(s, C, O):
81+
(a_k,) = C.op.reduce_axis
82+
a_y = C.op.axis[-2]
83+
a_yo, a_yi = s[C].split(a_y, factor=32)
84+
a_xo, a_xi = s[C].split(C.op.axis[-1], factor=32)
85+
a_ko, a_ki = s[C].split(a_k, factor=4)
86+
87+
s[C].reorder(a_yo, a_xo, a_yi, a_ko, a_xi, a_ki)
88+
89+
pc = dot_vrmpy("uint8", "uint8")
90+
s[C].tensorize(a_xi, pc)
91+
s[C].parallel(s[C].fuse(a_yo, a_xo))
92+
93+
if C != O:
94+
a_y = O.op.axis[-2]
95+
a_yo, a_yi = s[O].split(a_y, factor=32)
96+
a_xo, a_xi = s[O].split(O.op.axis[-1], factor=32)
97+
98+
s[O].reorder(a_yo, a_xo, a_yi, a_xi)
99+
s[O].vectorize(a_xi)
100+
s[C].compute_at(s[O], a_yi)
101+
s[O].parallel(s[O].fuse(a_yo, a_xo))
102+
103+
def _callback(op):
104+
if "u8u8i32_vrmpy" in op.tag:
105+
# C: The output of GEMM
106+
C = op.output(0)
107+
_schedule_dense(s, C, O)
108+
109+
traverse_inline(s, outs[0].op, _callback)
110+
111+
return s

0 commit comments

Comments
 (0)