Skip to content

Commit d9d6a88

Browse files
authored
[QNN] Support Dequantize to "float16" and Quantize to "uint16" (#15235)
1 parent 3a33771 commit d9d6a88

File tree

7 files changed

+95
-22
lines changed

7 files changed

+95
-22
lines changed

include/tvm/relay/qnn/attrs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,11 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {
9595

9696
/*! \brief Attribute for dequantize operator */
9797
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
98+
DataType out_dtype;
9899
int axis;
99100

100101
TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
102+
TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [float16, float32].");
101103
TVM_ATTR_FIELD(axis)
102104
.describe(
103105
"The channel axis for channel wise dequantization. Default value is -1,"

python/tvm/relay/qnn/op/qnn.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -186,8 +186,8 @@ def requantize(
186186

187187
def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
188188
r"""Quantize op
189-
This operator takes float32 as input and produces quantized int8 or unit8 as output.
190-
The input tensor can be of any shape. The output shape is the same as input shape.
189+
This operator takes float32 input and produces quantized output. The input
190+
tensor can be of any shape. The output shape is the same as input shape.
191191
192192
Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
193193
out_dtype::min,
@@ -206,8 +206,9 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
206206
207207
axis : int
208208
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
209+
209210
out_dtype : str, optional
210-
The data type of the input tensor. Can be [int8, uint8, int32]
211+
The data type of the output tensor. Can be [int8, unit8, int16, uint16, int32].
211212
212213
Returns
213214
-------
@@ -256,16 +257,15 @@ def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype
256257
return _make.simulated_quantize(data, out_dtype, output_scale, output_zero_point, axis)
257258

258259

259-
def dequantize(data, input_scale, input_zero_point, axis=-1):
260+
def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"):
260261
r"""Dequantize op
261-
This operator takes quantized int8 and unit8 as input and produces
262-
dequantized float32 as output. The output shape is the same as input shape. The input
263-
tensor can be of any shape.
262+
This operator takes quantized input and produces dequantized float output.
263+
The output shape is the same as input shape. The input tensor can be of any shape.
264264
265265
Parameters
266266
----------
267267
data : tvm.relay.Expr
268-
The input tensor to be dequantized. Can be of type [int8, uint8, int32].
268+
The input tensor to be dequantized. Can be of type [int8, unit8, int16, uint16, int32].
269269
270270
input_scale : tvm.relay.Expr
271271
The input scale.
@@ -276,13 +276,16 @@ def dequantize(data, input_scale, input_zero_point, axis=-1):
276276
axis : int
277277
The channel axis for quantization. Default value is -1 which corresponds to the last axis.
278278
279+
out_dtype : str, optional
280+
The data type of the output tensor. Can be [float16, float32].
281+
279282
Returns
280283
-------
281284
result : tvm.relay.Expr
282285
The computed result.
283286
"""
284287

285-
return _make.dequantize(data, input_scale, input_zero_point, axis)
288+
return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype)
286289

287290

288291
def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype="int8"):

src/relay/qnn/op/dequantize.cc

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4747

4848
const auto input_dtype = data->dtype;
4949
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
50-
input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
51-
<< "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
52-
<< input_dtype;
50+
input_dtype == DataType::Int(16) || input_dtype == DataType::UInt(16) ||
51+
input_dtype == DataType::Int(32))
52+
<< "Input type should be one of the quantized types [int8, unit8, int16, uint16, int32] but "
53+
<< "was " << input_dtype;
5354

5455
const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
5556
int axis = dequantize_attrs->axis;
@@ -77,18 +78,24 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
7778
// Check and assign types for scale and zero points.
7879
AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale
7980
AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point
81+
8082
const Array<tvm::PrimExpr> oshape = data->shape;
81-
// assign output type, output will always be float 32.
82-
reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
83+
const DataType out_dtype = dequantize_attrs->out_dtype;
84+
ICHECK(out_dtype == DataType::Float(16) || out_dtype == DataType::Float(32))
85+
<< "Output type should be one of [float16, float32] but was " << out_dtype;
86+
// assign output type.
87+
reporter->Assign(types[3], TensorType(oshape, out_dtype));
8388
return true;
8489
}
8590

86-
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) {
91+
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
92+
DataType out_dtype) {
8793
// real_value = scale * (quantized_value - zero_point)
8894
// A more detailed explanation can be found here -
8995
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
9096
auto attrs = make_object<DequantizeAttrs>();
9197
attrs->axis = axis;
98+
attrs->out_dtype = out_dtype;
9299
static const Op& op = Op::Get("qnn.dequantize");
93100
return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {});
94101
}
@@ -125,7 +132,14 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
125132

126133
auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
127134
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
128-
return scaled_output;
135+
136+
const DataType out_dtype = attrs->out_dtype;
137+
if (out_dtype.is_float() && out_dtype.bits() == 32) return scaled_output;
138+
139+
double min_val = tvm::min_value(out_dtype).as<FloatImmNode>()->value;
140+
double max_val = tvm::max_value(out_dtype).as<FloatImmNode>()->value;
141+
auto clamped_output = Clip(scaled_output, min_val, max_val);
142+
return Cast(clamped_output, out_dtype);
129143
}
130144

131145
Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,

src/relay/qnn/op/quantize.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
9191
const Array<tvm::PrimExpr> oshape = data->shape;
9292
const DataType out_dtype = quantize_attrs->out_dtype;
9393
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
94-
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
95-
<< "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
94+
out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) ||
95+
out_dtype == DataType::Int(32))
96+
<< "Output type should be one of [int8, unit8, int16, uint16, int32] but was " << out_dtype;
9697
// assign output type
9798
reporter->Assign(types[3], TensorType(oshape, out_dtype));
9899
return true;

src/relay/qnn/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,8 @@ static inline Expr Dequantize(const Expr& data, const Expr& input_scale,
135135

136136
return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
137137
}
138-
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);
138+
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
139+
DataType out_dtype = DataType::Float(32));
139140

140141
Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
141142
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,

tests/python/relay/test_op_qnn_dequantize.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,19 @@
2323
from tvm.relay.testing import run_infer_type
2424

2525

26-
def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
26+
def dequantize_test_driver(
27+
in_dtype, quant_args, in_data, verify_output_data, axis, out_dtype="float32"
28+
):
2729
shape = in_data.shape
2830
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
2931
input_zero_point = relay.const(quant_args["in_zero_point"], "int32")
3032
input_scale = relay.const(quant_args["in_scale"], "float32")
3133
quantized_output = relay.qnn.op.dequantize(
32-
input_data, input_scale=input_scale, input_zero_point=input_zero_point, axis=axis
34+
input_data,
35+
input_scale=input_scale,
36+
input_zero_point=input_zero_point,
37+
axis=axis,
38+
out_dtype=out_dtype,
3339
)
3440
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
3541
mod = tvm.IRModule.from_expr(mod)
@@ -41,7 +47,7 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, ax
4147
rt_mod.run()
4248
res = rt_mod.get_output(0).numpy()
4349
np.testing.assert_equal(res, verify_output_data)
44-
assert res.dtype == np.float32
50+
assert res.dtype == out_dtype
4551

4652

4753
def test_uint8_to_float32():
@@ -74,6 +80,28 @@ def test_int8_to_float32():
7480
)
7581

7682

83+
def test_int8_to_float16():
84+
data = (
85+
np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127])
86+
.astype("int8")
87+
.reshape((2, 5))
88+
)
89+
output = (
90+
np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
91+
.astype("float16")
92+
.reshape((2, 5))
93+
)
94+
quant_args = {"in_zero_point": -1, "in_scale": 0.5}
95+
dequantize_test_driver(
96+
in_dtype="int8",
97+
quant_args=quant_args,
98+
in_data=data,
99+
verify_output_data=output,
100+
axis=-1,
101+
out_dtype="float16",
102+
)
103+
104+
77105
def test_scalar_int8_to_float32():
78106
data = np.array(-128).astype("int8")
79107
output = np.array(-63.5).astype("float32")
@@ -171,6 +199,7 @@ def test_dynamic_dequantize():
171199
if __name__ == "__main__":
172200
test_uint8_to_float32()
173201
test_int8_to_float32()
202+
test_int8_to_float16()
174203
test_scalar_int8_to_float32()
175204
test_int32_to_float32()
176205
test_channelwise_axis_1()

tests/python/relay/test_op_qnn_quantize.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,28 @@ def test_float32_to_int8():
8888
)
8989

9090

91+
def test_float32_to_uint16():
92+
data = (
93+
np.array([-6553, -6552.8, -6552.6, -6552.4, -6552.2, 6553.2, 6553.4, 6553.6, 6553.8, 6554])
94+
.astype("float32")
95+
.reshape((2, 5))
96+
)
97+
output = (
98+
np.array([0, 1, 2, 3, 4, 65531, 65532, 65533, 65534, 65535])
99+
.astype("uint16")
100+
.reshape((2, 5))
101+
)
102+
quant_args = {"out_zero_point": np.int32(32765), "out_scale": np.float32(0.2)}
103+
quantize_test_driver(
104+
in_dtype="float32",
105+
quant_args=quant_args,
106+
axis=-1,
107+
out_dtype="uint16",
108+
in_data=data,
109+
verify_output_data=output,
110+
)
111+
112+
91113
def test_scalar_float32_to_int8():
92114
data = np.array(-63.5).astype("float32")
93115
output = np.array(-128).astype("int8")
@@ -177,6 +199,7 @@ def test_dynamic_quantize():
177199
if __name__ == "__main__":
178200
test_float32_to_uint8()
179201
test_float32_to_int8()
202+
test_float32_to_uint16()
180203
test_scalar_float32_to_int8()
181204
test_channelwise_axis_0()
182205
test_channelwise_axis_1()

0 commit comments

Comments
 (0)