Skip to content

Commit 7068e63

Browse files
committed
Add FP32 requantize flow for llvm target
1 parent 224f8da commit 7068e63

File tree

5 files changed

+362
-12
lines changed

5 files changed

+362
-12
lines changed

python/tvm/relay/qnn/op/qnn.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,44 @@ def requantize(
9292
)
9393

9494

95+
def upward(data):
96+
r"""Upward operator.
97+
98+
UPWARD is the standard rounding except at midpoints where the value
99+
is rounded to positive infinity (for example, -1.5 rounds to -1).
100+
Parameters
101+
----------
102+
data : tvm.relay.Expr
103+
The input data to the operator.
104+
105+
Returns
106+
-------
107+
result : tvm.relay.Expr
108+
The computed result.
109+
"""
110+
111+
return _make.upward(data)
112+
113+
114+
def tonearest(data):
115+
r"""Tonearest operator.
116+
117+
TONEAREST is the standard rounding where the value is rounded away
118+
from zero at midpoints (for example, -1.5 rounds to -2).
119+
Parameters
120+
----------
121+
data : tvm.relay.Expr
122+
The input data to the operator.
123+
124+
Returns
125+
-------
126+
result : tvm.relay.Expr
127+
The computed result.
128+
"""
129+
130+
return _make.tonearest(data)
131+
132+
95133
def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
96134
r"""Quantize op
97135
This operator takes float32 as input and produces quantized int8 or unit8 as output.

python/tvm/topi/x86/utils.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,26 @@
1616
# under the License.
1717
"""Common x86 related utilities"""
1818
import tvm
19+
import tvm._ffi
1920

2021

22+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse41")
23+
def target_has_sse41(target):
24+
return (
25+
target_has_sse42(target)
26+
or target_has_avx(target)
27+
or target_has_avx2(target)
28+
or target_has_avx512(target)
29+
or target_has_vnni(target)
30+
or target
31+
in {
32+
"btver2",
33+
"penryn",
34+
}
35+
)
36+
37+
38+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_sse42")
2139
def target_has_sse42(target):
2240
return (
2341
target_has_avx(target)
@@ -42,6 +60,7 @@ def target_has_sse42(target):
4260
)
4361

4462

63+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx")
4564
def target_has_avx(target):
4665
return (
4766
target_has_avx2(target)
@@ -51,6 +70,7 @@ def target_has_avx(target):
5170
)
5271

5372

73+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx2")
5474
def target_has_avx2(target):
5575
return (
5676
target_has_avx512(target)
@@ -70,6 +90,7 @@ def target_has_avx2(target):
7090
)
7191

7292

93+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_avx512")
7394
def target_has_avx512(target):
7495
return target in {
7596
"skylake-avx512",
@@ -82,26 +103,28 @@ def target_has_avx512(target):
82103
"cascadelake",
83104
"icelake-client",
84105
"rocketlake",
85-
"icelake",
106+
"icelake-server",
86107
"tigerlake",
87108
"cooperlake",
88109
"sapphirerapids",
89110
}
90111

91112

113+
@tvm._ffi.register_func("tvm.topi.x86.utils.target_has_vnni")
92114
def target_has_vnni(target):
93115
return target in {
94116
"cascadelake",
95117
"icelake-client",
96118
"rocketlake",
97-
"icelake",
119+
"icelake-server",
98120
"tigerlake",
99121
"cooperlake",
100122
"sapphirerapids",
101123
"alderlake",
102124
}
103125

104126

127+
@tvm._ffi.register_func("tvm.topi.x86.utils.get_simd_32bit_lanes")
105128
def get_simd_32bit_lanes():
106129
mcpu = tvm.target.Target.current().mcpu
107130
fp32_vec_len = 4

src/relay/qnn/op/requantize.cc

Lines changed: 232 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include <tvm/relay/op_attr_types.h>
2727
#include <tvm/relay/qnn/attrs.h>
2828

29+
#include "../../op/op_common.h"
2930
#include "../../transforms/infer_layout_utils.h"
3031
#include "../../transforms/pattern_utils.h"
3132
#include "../utils.h"
@@ -111,6 +112,106 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
111112
return InferCorrectLayoutOutput(input_layouts, output_layouts, Attrs(param));
112113
}
113114

115+
bool has_current_target_sse41_support() {
116+
auto target = Target::Current(false);
117+
Optional<String> mcpu = target->GetAttr<String>("mcpu");
118+
auto target_has_sse41_fn_ptr = tvm::runtime::Registry::Get("tvm.topi.x86.utils.target_has_sse41");
119+
ICHECK(target_has_sse41_fn_ptr) << "Function tvm.topi.x86.utils.target_has_sse41 not found";
120+
return mcpu && (*target_has_sse41_fn_ptr)(mcpu.value());
121+
}
122+
123+
/*
124+
* \brief TONEAREST is the standard rounding where the value is rounded away
125+
* from zero at midpoints (for example, -1.5 rounds to -2).
126+
* \param input_tensor The input tensor to rounding op.
127+
* \return The sequence of existing Relay ops.
128+
*/
129+
Expr Tonearest(const Expr& input_tensor) {
130+
if (has_current_target_sse41_support()) return Round(input_tensor);
131+
132+
auto half = MakeConstantScalar(DataType::Float(32), 0.5f);
133+
auto zero = MakeConstantScalar(DataType::Float(32), 0.f);
134+
auto pos_one = MakeConstantScalar(DataType::Float(32), +1.f);
135+
auto neg_one = MakeConstantScalar(DataType::Float(32), -1.f);
136+
auto multiplier = Where(Less(input_tensor, zero), neg_one, pos_one);
137+
auto half_multiplied = Multiply(half, multiplier);
138+
auto input_tensor_biased = Add(input_tensor, half_multiplied);
139+
auto input_tensor_biased_multiplied = Multiply(input_tensor_biased, multiplier);
140+
auto input_tensor_biased_multiplied_int32 =
141+
Cast(input_tensor_biased_multiplied, DataType::Int(32));
142+
auto input_tensor_biased_multiplied_float32 =
143+
Cast(input_tensor_biased_multiplied_int32, DataType::Float(32));
144+
auto input_tensor_rounded = Multiply(input_tensor_biased_multiplied_float32, multiplier);
145+
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
146+
}
147+
148+
/*
149+
* \brief UPWARD is the standard rounding except at midpoints where the value
150+
* is rounded to positive infinity (for example, -1.5 rounds to -1).
151+
* \param input_tensor The input tensor to rounding op.
152+
* \return The sequence of existing Relay ops.
153+
*/
154+
Expr Upward(const Expr& input_tensor) {
155+
auto half = MakeConstantScalar(DataType::Float(32), 0.5f);
156+
auto input_tensor_biased = Add(input_tensor, half);
157+
if (has_current_target_sse41_support()) return Floor(input_tensor_biased);
158+
159+
auto zero = MakeConstantScalar(DataType::Float(32), 0.f);
160+
auto one = MakeConstantScalar(DataType::Float(32), +1.f);
161+
auto input_tensor_biased_int_32 = Cast(input_tensor_biased, DataType::Int(32));
162+
auto input_tensor_biased_float32 = Cast(input_tensor_biased_int_32, DataType::Float(32));
163+
auto is_subtraction_not_necessary =
164+
LogicalOr(Equal(input_tensor_biased, input_tensor_biased_float32),
165+
GreaterEqual(input_tensor_biased, zero));
166+
auto input_tensor_rounded = Where(is_subtraction_not_necessary, input_tensor_biased_float32,
167+
Subtract(input_tensor_biased_float32, one));
168+
return Where(IsFinite(input_tensor), input_tensor_rounded, input_tensor);
169+
}
170+
171+
// Positional relay function to create tonearest operator
172+
// used by frontend FFI.
173+
Expr MakeTonearest(const Attrs& attrs, const Array<Expr>& new_args,
174+
const Array<tvm::relay::Type>& types) {
175+
ICHECK_EQ(new_args.size(), 1);
176+
auto& data = new_args[0];
177+
return Tonearest(data);
178+
}
179+
180+
RELAY_REGISTER_OP("tonearest")
181+
.set_num_inputs(1)
182+
.add_argument("data", "Tensor", "The input tensor.")
183+
.add_type_rel("Identity", IdentityRel)
184+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
185+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
186+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeTonearest);
187+
188+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.tonearest").set_body_typed([](Expr data) {
189+
static const Op& op = Op::Get("tonearest");
190+
return Call(op, {data}, Attrs(), {});
191+
});
192+
193+
// Positional relay function to create upward operator
194+
// used by frontend FFI.
195+
Expr MakeUpward(const Attrs& attrs, const Array<Expr>& new_args,
196+
const Array<tvm::relay::Type>& types) {
197+
ICHECK_EQ(new_args.size(), 1);
198+
auto& data = new_args[0];
199+
return Upward(data);
200+
}
201+
202+
RELAY_REGISTER_OP("upward")
203+
.set_num_inputs(1)
204+
.add_argument("data", "Tensor", "The input tensor.")
205+
.add_type_rel("Identity", IdentityRel)
206+
.set_attr<TOpPattern>("TOpPattern", kElemWise)
207+
.set_attr<TOpIsStateful>("TOpIsStateful", false)
208+
.set_attr<FTVMLegalize>("FTVMQnnCanonicalize", MakeUpward);
209+
210+
TVM_REGISTER_GLOBAL("relay.qnn.op._make.upward").set_body_typed([](Expr data) {
211+
static const Op& op = Op::Get("upward");
212+
return Call(op, {data}, Attrs(), {});
213+
});
214+
114215
// Lowering of qnn.requantize op
115216

116217
/*
@@ -119,7 +220,7 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
119220
* \param param The requantize op attrs.
120221
* \param input_shape The input tensor shape of the requantize op.
121222
* \return The sequence of existing Relay ops.
122-
* \note Requantization using only integer computation. Here, the computation is
223+
* \note RequantizationI32 using only integer computation. Here, the computation is
123224
* converted to a fixed point computation by computing output multiplier
124225
* and shift. This is useful, if the target device does not support/have
125226
* very expensive floating point computations.
@@ -131,10 +232,10 @@ InferCorrectLayoutOutput RequantizeInferCorrectLayout(const Attrs& attrs,
131232
* 4) Add the output zero point.
132233
* 5) Cast to the out_dtype.
133234
*/
134-
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
135-
const Expr& input_zero_point, const Expr& output_scale,
136-
const Expr& output_zero_point, const RequantizeAttrs* param,
137-
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
235+
Expr RequantizeLowerI32(const Expr& input_tensor, const Expr& input_scale,
236+
const Expr& input_zero_point, const Expr& output_scale,
237+
const Expr& output_zero_point, const RequantizeAttrs* param,
238+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
138239
auto tensor = Cast(input_tensor, DataType::Int(32));
139240
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
140241
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
@@ -208,6 +309,132 @@ Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
208309
return Cast(clipped_t, out_dtype);
209310
}
210311

312+
// Lowering of qnn.requantize op
313+
314+
/*
315+
* \brief Lower requantize to a sequence of ops.
316+
* \param input_tensor The input tensor to requantize op.
317+
* \param param The requantize op attrs.
318+
* \param input_shape The input tensor shape of the requantize op.
319+
* \return The sequence of existing Relay ops.
320+
* \note RequantizationFP32 using floating computation. All multiplication/sub/sum
321+
* occurs in floating point data type and only at the end is converted to
322+
* int32 data type and clamped for output data type.
323+
*
324+
* The whole computation this can be broken down into following steps
325+
* 1) Subtract the input zero point.
326+
* 2) Perform multiplication.
327+
* 3) Add the output zero point.
328+
* 4) Cast to the out_dtype.
329+
*/
330+
Expr RequantizeLowerFP32(const Expr& input_tensor, const Expr& input_scale,
331+
const Expr& input_zero_point, const Expr& output_scale,
332+
const Expr& output_zero_point, const RequantizeAttrs* param,
333+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
334+
auto tensor = Cast(input_tensor, DataType::Int(32));
335+
auto zero_scalar = MakeConstantScalar(DataType::Int(32), 0);
336+
if (!IsEqualScalar(input_zero_point, zero_scalar)) {
337+
// Broadcast input zero point if needed.
338+
int rank = static_cast<int>(input_shape.size());
339+
int axis = (param->axis < 0) ? ((rank > 0) ? rank + param->axis : 0) : param->axis;
340+
Expr input_zero_broadcast = ExpandBiasToMatchAxis(Reshape(input_zero_point,
341+
{
342+
-1,
343+
}),
344+
rank, {axis});
345+
tensor = Subtract(Cast(tensor, DataType::Float(32)),
346+
Cast(input_zero_broadcast, DataType::Float(32)));
347+
} else {
348+
tensor = Cast(tensor, DataType::Float(32));
349+
}
350+
351+
// 2) If the input and output scales are same, we can skip the multiplication. Check
352+
// if the input scale is per-tensor or per-channel. If it is per-tensor, there is single scale for
353+
// the whole tensor. For per-channel (aka per-axis), there is a vector of scales for the input
354+
// tensor. Depending on the quantization type, the fixed point multiplication routing is called.
355+
auto scaled_int32_t = tensor;
356+
float output_scale_float = GetScalarFromConstant<float>(output_scale);
357+
if (IsConstScalar(input_scale)) {
358+
// This is per-tensor quantization. Single scale.
359+
float input_scale_float = GetScalarFromConstant<float>(input_scale);
360+
double double_multiplier =
361+
static_cast<double>(input_scale_float) / static_cast<double>(output_scale_float);
362+
// Skip if input and output scales are same.
363+
if (!IsEqualScalar(input_scale, output_scale)) {
364+
float multiplier = double_multiplier;
365+
auto m_scalar = MakeConstantScalar(DataType::Float(32), multiplier);
366+
scaled_int32_t = Multiply(m_scalar, scaled_int32_t);
367+
}
368+
369+
} else {
370+
// This is per-channel (per=axis) quantization.
371+
std::vector<float> double_multipliers;
372+
auto input_axis_scales = GetFloatVectorFromConstant(input_scale);
373+
float output_scale_float = GetScalarFromConstant<float>(output_scale);
374+
for (auto input_axis_scale : input_axis_scales) {
375+
double multiplier =
376+
static_cast<double>(input_axis_scale) / static_cast<double>(output_scale_float);
377+
double_multipliers.push_back(multiplier);
378+
}
379+
int axis = param->axis;
380+
axis = (axis == -1) ? input_shape.size() - 1 : axis;
381+
382+
auto fixed_pt_multiplier_expr = MakeConstantTensor(
383+
DataType::Float(32), {(int64_t)double_multipliers.size()}, double_multipliers);
384+
size_t n_dim = input_shape.size();
385+
auto exp_fixed_pt_multiplier_expr =
386+
ExpandBiasToMatchAxis(fixed_pt_multiplier_expr, n_dim, {axis});
387+
388+
scaled_int32_t = Multiply(scaled_int32_t, exp_fixed_pt_multiplier_expr);
389+
}
390+
391+
// 3) Add the output zero point.
392+
auto shifted_int32_t = scaled_int32_t;
393+
if (!IsEqualScalar(output_zero_point, zero_scalar)) {
394+
shifted_int32_t = Add(shifted_int32_t, Cast(output_zero_point, DataType::Float(32)));
395+
}
396+
397+
if (param->rounding == "UPWARD") {
398+
shifted_int32_t = Upward(shifted_int32_t);
399+
} else /*if (param->rounding == "TONEAREST")*/ {
400+
shifted_int32_t = Tonearest(shifted_int32_t);
401+
}
402+
403+
shifted_int32_t = Cast(shifted_int32_t, DataType::Int(32));
404+
// 4) Clip to the out_dtype min/max. Skip clipping if out_dtype is Int32. The fixed point
405+
// multiplication keeps the value in int32 range.
406+
if (out_dtype == DataType::Int(32)) {
407+
return shifted_int32_t;
408+
}
409+
410+
auto q_min = GetQmin(out_dtype);
411+
auto q_max = GetQmax(out_dtype);
412+
auto clipped_t = Clip(shifted_int32_t, q_min, q_max);
413+
return Cast(clipped_t, out_dtype);
414+
}
415+
416+
// Lowering of qnn.requantize op
417+
/*
418+
* \brief Lower requantize to a sequence of ops.
419+
* \param input_tensor The input tensor to requantize op.
420+
* \param param The requantize op attrs.
421+
* \param input_shape The input tensor shape of the requantize op.
422+
* \return The sequence of existing Relay ops.
423+
*/
424+
Expr RequantizeLower(const Expr& input_tensor, const Expr& input_scale,
425+
const Expr& input_zero_point, const Expr& output_scale,
426+
const Expr& output_zero_point, const RequantizeAttrs* param,
427+
const Array<IndexExpr>& input_shape, const DataType& out_dtype) {
428+
auto target = Target::Current(true);
429+
if (target->kind->name == "llvm") {
430+
return RequantizeLowerFP32(input_tensor, input_scale, input_zero_point, output_scale,
431+
output_zero_point, param, input_shape, out_dtype);
432+
} else {
433+
return RequantizeLowerI32(input_tensor, input_scale, input_zero_point, output_scale,
434+
output_zero_point, param, input_shape, out_dtype);
435+
}
436+
}
437+
211438
/*
212439
* \brief Forward rewrite the requantize op.
213440
* \param ref_call The original call that will be lowered.

0 commit comments

Comments
 (0)