Skip to content

Commit b958076

Browse files
committed
Add FP requantize flow. Set it by default for llvm x86 targets
1 parent 74a2fa8 commit b958076

File tree

11 files changed

+1122
-340
lines changed

11 files changed

+1122
-340
lines changed

include/tvm/relay/qnn/attrs.h

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ namespace qnn {
3636
struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
3737
int axis;
3838
std::string rounding;
39+
std::string compute_dtype;
3940
DataType out_dtype;
4041

4142
TVM_DECLARE_ATTRS(RequantizeAttrs, "relay.attrs.RequantizeAttrs") {
@@ -44,7 +45,7 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
4445
"The output channel axis for channel wise quantization. Default value is -1,"
4546
"which corresponds to the last axis.")
4647
.set_default(-1);
47-
TVM_ATTR_FIELD(rounding).set_default("UPWARD").describe(
48+
TVM_ATTR_FIELD(rounding).set_default("None").describe(
4849
"Defines the rounding direction when the value is midway between"
4950
"two representable values. There are two supported modes - UPWARD"
5051
"or TONEAREST. Both modes behave exactly same except at the"
@@ -54,6 +55,11 @@ struct RequantizeAttrs : public tvm::AttrsNode<RequantizeAttrs> {
5455
"value is rounded away from zero at midpoints (for example, -1.5"
5556
"rounds to -2). More context can be found at following gblic manual"
5657
"https://www.gnu.org/software/libc/manual/html_node/Rounding.html.");
58+
TVM_ATTR_FIELD(compute_dtype)
59+
.set_default("None")
60+
.describe(
61+
"Calculation flow type, specifies the algorithm of requantize calculating. Supported "
62+
"options: \"int64\", \"float32\", \"float64\"");
5763
TVM_ATTR_FIELD(out_dtype)
5864
.set_default(NullValue<DataType>())
5965
.describe("Output data type, set to explicit type under mixed precision setting");
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
# pylint: disable=unused-argument
18+
"""Internal module for qnn requantization."""
19+
import tvm._ffi
20+
21+
tvm._ffi._init_api("relay._requantize", __name__)

python/tvm/relay/qnn/op/qnn.py

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,19 +14,109 @@
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
1616
# under the License.
17-
# pylint: disable=invalid-name
17+
# pylint: disable=invalid-name,unused-argument, not-context-manager
1818
"""QNN dialect operators."""
1919

2020
from __future__ import absolute_import as _abs
2121

22+
import tvm
23+
import tvm.ir
2224
from tvm import relay
25+
from tvm.runtime import Object
2326
from tvm.relay.expr import Tuple, TupleWrapper
2427
from tvm.relay.op.nn.utils import get_pad_tuple2d
2528
from tvm.topi.nn.qnn import SQNN_DTYPE_TO_CODE
26-
29+
from tvm.target import Target
30+
from tvm.topi.x86.utils import target_is_x86
2731
from ... import op as reg
2832
from ...op import OpPattern
2933
from . import _make
34+
from . import _requantize
35+
36+
37+
@tvm._ffi.register_object("relay.qnn.op.RequantizeConfig")
38+
class RequantizeConfig(Object):
39+
"""Configure the requantization behavior by setting config variables.
40+
41+
Note
42+
----
43+
This object is backed by node system in C++, with arguments that can be
44+
exchanged between python and C++.
45+
46+
Do not construct directly, use requantize_config instead.
47+
48+
The fields that are backed by the C++ node are immutable once an instance
49+
is constructed. Use _node_defaults getters to get results for the fields.
50+
"""
51+
52+
@staticmethod
53+
def _get_node_default_rounding():
54+
return "UPWARD"
55+
56+
@staticmethod
57+
def _get_node_default_compute_dtype():
58+
target = Target.current(True)
59+
if target and str(target.kind) == "llvm" and target_is_x86(target.mcpu):
60+
return "float32"
61+
62+
return "int64"
63+
64+
_node_defaults = {
65+
"rounding": _get_node_default_rounding.__func__,
66+
"compute_dtype": _get_node_default_compute_dtype.__func__,
67+
}
68+
69+
# pylint: disable=no-member
70+
def __init__(self, handle):
71+
"""Initialize the function with handle
72+
73+
Parameters
74+
----------
75+
handle : SymbolHandle
76+
the handle to the underlying C++ Symbol
77+
"""
78+
super(RequantizeConfig, self).__init__(handle)
79+
self.handle = handle
80+
81+
def __enter__(self):
82+
# pylint: disable=protected-access
83+
_requantize._EnterRequantizeConfigScope(self)
84+
return self
85+
86+
def __exit__(self, ptype, value, trace):
87+
_requantize._ExitRequantizeConfigScope()
88+
89+
def __setattr__(self, name, value):
90+
if name in RequantizeConfig._node_defaults:
91+
raise AttributeError("'%s' object cannot set attribute '%s'" % (str(type(self)), name))
92+
return super(RequantizeConfig, self).__setattr__(name, value)
93+
94+
95+
def current_requantize_config():
96+
"""Get the current requantization configuration."""
97+
return _requantize._GetCurrentRequantizeConfig()
98+
99+
100+
def requantize_config(**kwargs):
101+
"""Configure the requantization behavior by setting config variables.
102+
103+
Parameters
104+
---------
105+
rounding: "UPWARD" or "TONEAREST"
106+
Rounding direction for fixed point multiplications.
107+
compute_dtype:
108+
Calculation flow type, specifies the algorithm of requantize calculating.
109+
Supported options: \"int64\", \"float32\", \"float64\"
110+
111+
Returns
112+
-------
113+
config: RequantizeConfig
114+
The requantization configuration
115+
"""
116+
node_args = {
117+
k: v() if k not in kwargs else kwargs[k] for k, v in RequantizeConfig._node_defaults.items()
118+
}
119+
return tvm.ir.make_node("relay.qnn.op.RequantizeConfig", **node_args)
30120

31121

32122
def requantize(
@@ -36,7 +126,8 @@ def requantize(
36126
output_scale,
37127
output_zero_point,
38128
axis=-1,
39-
rounding="UPWARD",
129+
rounding="None",
130+
compute_dtype="None",
40131
out_dtype="int8",
41132
):
42133
r"""Requantized operator.
@@ -70,7 +161,9 @@ def requantize(
70161
rounding : string, optional
71162
Defines the rounding direction when the value is midway between two
72163
representable values.
73-
164+
compute_dtype:
165+
Calculation flow type, specifies the algorithm of requantize calculating.
166+
Supported options: \"int64\", \"float32\", \"float64\"
74167
out_dtype : str, optional
75168
Specifies the output data type.
76169
@@ -88,6 +181,7 @@ def requantize(
88181
output_zero_point,
89182
axis,
90183
rounding,
184+
compute_dtype,
91185
out_dtype,
92186
)
93187

python/tvm/topi/x86/utils.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,87 @@
1616
# under the License.
1717
"""Common x86 related utilities"""
1818
import tvm
19+
import tvm._ffi
1920

2021

22+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_is_x86")
23+
def target_is_x86(target):
24+
return (
25+
target_has_sse41(target)
26+
or target_has_sse42(target)
27+
or target_has_avx(target)
28+
or target_has_avx2(target)
29+
or target_has_avx512(target)
30+
or target_has_vnni(target)
31+
or target
32+
in {
33+
"k6-2",
34+
"x86-64",
35+
"atom",
36+
"pentium",
37+
"nocona",
38+
"k8",
39+
"i686",
40+
"winchip-c6",
41+
"prescott",
42+
"k6",
43+
"athlon-mp",
44+
"winchip2",
45+
"yonah",
46+
"athlon-tbird",
47+
"alderlake",
48+
"pentium4",
49+
"athlon64-sse3",
50+
"pentium2",
51+
"bonnell",
52+
"k8-sse3",
53+
"barcelona",
54+
"athlon",
55+
"c3",
56+
"generic",
57+
"lakemont",
58+
"pentium-mmx",
59+
"core2",
60+
"pentium3",
61+
"btver1",
62+
"athlon-fx",
63+
"k6-3",
64+
"geode",
65+
"pentium-m",
66+
"i586",
67+
"pentium3m",
68+
"pentiumpro",
69+
"i486",
70+
"opteron",
71+
"athlon64",
72+
"pentium4m",
73+
"i386",
74+
"opteron-sse3",
75+
"amdfam10",
76+
"athlon-4",
77+
"athlon-xp",
78+
"c3-2",
79+
}
80+
)
81+
82+
83+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41")
84+
def target_has_sse41(target):
85+
return (
86+
target_has_sse42(target)
87+
or target_has_avx(target)
88+
or target_has_avx2(target)
89+
or target_has_avx512(target)
90+
or target_has_vnni(target)
91+
or target
92+
in {
93+
"btver2",
94+
"penryn",
95+
}
96+
)
97+
98+
99+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42")
21100
def target_has_sse42(target):
22101
return (
23102
target_has_avx(target)
@@ -42,6 +121,7 @@ def target_has_sse42(target):
42121
)
43122

44123

124+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx")
45125
def target_has_avx(target):
46126
return (
47127
target_has_avx2(target)
@@ -51,6 +131,7 @@ def target_has_avx(target):
51131
)
52132

53133

134+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2")
54135
def target_has_avx2(target):
55136
return (
56137
target_has_avx512(target)
@@ -70,6 +151,7 @@ def target_has_avx2(target):
70151
)
71152

72153

154+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512")
73155
def target_has_avx512(target):
74156
return target in {
75157
"skylake-avx512",
@@ -89,6 +171,7 @@ def target_has_avx512(target):
89171
}
90172

91173

174+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni")
92175
def target_has_vnni(target):
93176
return target in {
94177
"cascadelake",
@@ -102,6 +185,7 @@ def target_has_vnni(target):
102185
}
103186

104187

188+
@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
105189
def get_simd_32bit_lanes():
106190
mcpu = tvm.target.Target.current().mcpu
107191
fp32_vec_len = 4

0 commit comments

Comments
 (0)