Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions include/tvm/relay/qnn/attrs.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,11 @@ struct SimulatedQuantizeAttrs : public tvm::AttrsNode<SimulatedQuantizeAttrs> {

/*! \brief Attribute for dequantize operator */
struct DequantizeAttrs : public tvm::AttrsNode<DequantizeAttrs> {
DataType out_dtype;
int axis;

TVM_DECLARE_ATTRS(DequantizeAttrs, "relay.attrs.DequantizeAttrs") {
TVM_ATTR_FIELD(out_dtype).describe("Output data type, can be one of [float16, float32].");
TVM_ATTR_FIELD(axis)
.describe(
"The channel axis for channel wise dequantization. Default value is -1,"
Expand Down
21 changes: 12 additions & 9 deletions python/tvm/relay/qnn/op/qnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,8 @@ def requantize(

def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):
r"""Quantize op
This operator takes float32 as input and produces quantized int8 or unit8 as output.
The input tensor can be of any shape. The output shape is the same as input shape.
This operator takes float32 input and produces quantized output. The input
tensor can be of any shape. The output shape is the same as input shape.

Q_output = clamp((round(input_tensor/output_scale) + output_zero_point),
out_dtype::min,
Expand All @@ -206,8 +206,9 @@ def quantize(data, output_scale, output_zero_point, axis=-1, out_dtype="int8"):

axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.

out_dtype : str, optional
The data type of the input tensor. Can be [int8, uint8, int32]
The data type of the output tensor. Can be [int8, unit8, int16, uint16, int32].

Returns
-------
Expand Down Expand Up @@ -256,16 +257,15 @@ def simulated_quantize(data, output_scale, output_zero_point, axis=-1, out_dtype
return _make.simulated_quantize(data, out_dtype, output_scale, output_zero_point, axis)


def dequantize(data, input_scale, input_zero_point, axis=-1):
def dequantize(data, input_scale, input_zero_point, axis=-1, out_dtype="float32"):
r"""Dequantize op
This operator takes quantized int8 and unit8 as input and produces
dequantized float32 as output. The output shape is the same as input shape. The input
tensor can be of any shape.
This operator takes quantized input and produces dequantized float output.
The output shape is the same as input shape. The input tensor can be of any shape.

Parameters
----------
data : tvm.relay.Expr
The input tensor to be dequantized. Can be of type [int8, uint8, int32].
The input tensor to be dequantized. Can be of type [int8, unit8, int16, uint16, int32].

input_scale : tvm.relay.Expr
The input scale.
Expand All @@ -276,13 +276,16 @@ def dequantize(data, input_scale, input_zero_point, axis=-1):
axis : int
The channel axis for quantization. Default value is -1 which corresponds to the last axis.

out_dtype : str, optional
The data type of the output tensor. Can be [float16, float32].

Returns
-------
result : tvm.relay.Expr
The computed result.
"""

return _make.dequantize(data, input_scale, input_zero_point, axis)
return _make.dequantize(data, input_scale, input_zero_point, axis, out_dtype)


def simulated_dequantize(data, input_scale, input_zero_point, axis=-1, in_dtype="int8"):
Expand Down
28 changes: 21 additions & 7 deletions src/relay/qnn/op/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,

const auto input_dtype = data->dtype;
ICHECK(input_dtype == DataType::Int(8) || input_dtype == DataType::UInt(8) ||
input_dtype == DataType::Int(16) || input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
<< input_dtype;
input_dtype == DataType::Int(16) || input_dtype == DataType::UInt(16) ||
input_dtype == DataType::Int(32))
<< "Input type should be one of the quantized types [int8, unit8, int16, uint16, int32] but "
<< "was " << input_dtype;

const auto* dequantize_attrs = attrs.as<DequantizeAttrs>();
int axis = dequantize_attrs->axis;
Expand Down Expand Up @@ -77,18 +78,24 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
// Check and assign types for scale and zero points.
AssignType(types[1], DataType::Float(32), axis_shape, reporter); // scale
AssignType(types[2], DataType::Int(32), axis_shape, reporter); // zero point

const Array<tvm::PrimExpr> oshape = data->shape;
// assign output type, output will always be float 32.
reporter->Assign(types[3], TensorType(oshape, DataType::Float(32)));
const DataType out_dtype = dequantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Float(16) || out_dtype == DataType::Float(32))
<< "Output type should be one of [float16, float32] but was " << out_dtype;
// assign output type.
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
}

Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis) {
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
DataType out_dtype) {
// real_value = scale * (quantized_value - zero_point)
// A more detailed explanation can be found here -
// https://github.com/google/gemmlowp/blob/master/doc/quantization.md
auto attrs = make_object<DequantizeAttrs>();
attrs->axis = axis;
attrs->out_dtype = out_dtype;
static const Op& op = Op::Get("qnn.dequantize");
return Call(op, {data, input_scale, input_zero_point}, Attrs(attrs), {});
}
Expand Down Expand Up @@ -125,7 +132,14 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,

auto shift = Subtract(Cast(input_tensor, DataType::Int(32)), expanded_input_zero_point);
auto scaled_output = Multiply(Cast(shift, DataType::Float(32)), expanded_input_scale);
return scaled_output;

const DataType out_dtype = attrs->out_dtype;
if (out_dtype.is_float() && out_dtype.bits() == 32) return scaled_output;

double min_val = tvm::min_value(out_dtype).as<FloatImmNode>()->value;
double max_val = tvm::max_value(out_dtype).as<FloatImmNode>()->value;
auto clamped_output = Clip(scaled_output, min_val, max_val);
return Cast(clamped_output, out_dtype);
}

Expr DequantizeQnnCanonicalize(const Attrs& attrs, const Array<Expr>& new_args,
Expand Down
5 changes: 3 additions & 2 deletions src/relay/qnn/op/quantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,9 @@ bool QuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const Array<tvm::PrimExpr> oshape = data->shape;
const DataType out_dtype = quantize_attrs->out_dtype;
ICHECK(out_dtype == DataType::Int(8) || out_dtype == DataType::UInt(8) ||
out_dtype == DataType::Int(16) || out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int16, int32] but was " << out_dtype;
out_dtype == DataType::Int(16) || out_dtype == DataType::UInt(16) ||
out_dtype == DataType::Int(32))
<< "Output type should be one of [int8, unit8, int16, uint16, int32] but was " << out_dtype;
// assign output type
reporter->Assign(types[3], TensorType(oshape, out_dtype));
return true;
Expand Down
3 changes: 2 additions & 1 deletion src/relay/qnn/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ static inline Expr Dequantize(const Expr& data, const Expr& input_scale,

return DequantizeLower(data, input_scale, input_zero_point, types, attrs.operator->());
}
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis);
Expr MakeDequantize(Expr data, Expr input_scale, Expr input_zero_point, int axis,
DataType out_dtype = DataType::Float(32));

Expr QuantizeLower(const Expr& input_tensor, const Expr& output_scale,
const Expr& output_zero_point, const Array<tvm::relay::Type>& types,
Expand Down
35 changes: 32 additions & 3 deletions tests/python/relay/test_op_qnn_dequantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,19 @@
from tvm.relay.testing import run_infer_type


def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, axis):
def dequantize_test_driver(
in_dtype, quant_args, in_data, verify_output_data, axis, out_dtype="float32"
):
shape = in_data.shape
input_data = relay.var("input_data", shape=shape, dtype=in_dtype)
input_zero_point = relay.const(quant_args["in_zero_point"], "int32")
input_scale = relay.const(quant_args["in_scale"], "float32")
quantized_output = relay.qnn.op.dequantize(
input_data, input_scale=input_scale, input_zero_point=input_zero_point, axis=axis
input_data,
input_scale=input_scale,
input_zero_point=input_zero_point,
axis=axis,
out_dtype=out_dtype,
)
mod = relay.Function(relay.analysis.free_vars(quantized_output), quantized_output)
mod = tvm.IRModule.from_expr(mod)
Expand All @@ -41,7 +47,7 @@ def dequantize_test_driver(in_dtype, quant_args, in_data, verify_output_data, ax
rt_mod.run()
res = rt_mod.get_output(0).numpy()
np.testing.assert_equal(res, verify_output_data)
assert res.dtype == np.float32
assert res.dtype == out_dtype


def test_uint8_to_float32():
Expand Down Expand Up @@ -74,6 +80,28 @@ def test_int8_to_float32():
)


def test_int8_to_float16():
data = (
np.array([-128, -127, -126, -125, -124, 123, 124, 125, 126, 127])
.astype("int8")
.reshape((2, 5))
)
output = (
np.array([-63.5, -63, -62.5, -62, -61.5, 62, 62.5, 63, 63.5, 64])
.astype("float16")
.reshape((2, 5))
)
quant_args = {"in_zero_point": -1, "in_scale": 0.5}
dequantize_test_driver(
in_dtype="int8",
quant_args=quant_args,
in_data=data,
verify_output_data=output,
axis=-1,
out_dtype="float16",
)


def test_scalar_int8_to_float32():
data = np.array(-128).astype("int8")
output = np.array(-63.5).astype("float32")
Expand Down Expand Up @@ -171,6 +199,7 @@ def test_dynamic_dequantize():
if __name__ == "__main__":
test_uint8_to_float32()
test_int8_to_float32()
test_int8_to_float16()
test_scalar_int8_to_float32()
test_int32_to_float32()
test_channelwise_axis_1()
Expand Down
23 changes: 23 additions & 0 deletions tests/python/relay/test_op_qnn_quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,28 @@ def test_float32_to_int8():
)


def test_float32_to_uint16():
data = (
np.array([-6553, -6552.8, -6552.6, -6552.4, -6552.2, 6553.2, 6553.4, 6553.6, 6553.8, 6554])
.astype("float32")
.reshape((2, 5))
)
output = (
np.array([0, 1, 2, 3, 4, 65531, 65532, 65533, 65534, 65535])
.astype("uint16")
.reshape((2, 5))
)
quant_args = {"out_zero_point": np.int32(32765), "out_scale": np.float32(0.2)}
quantize_test_driver(
in_dtype="float32",
quant_args=quant_args,
axis=-1,
out_dtype="uint16",
in_data=data,
verify_output_data=output,
)


def test_scalar_float32_to_int8():
data = np.array(-63.5).astype("float32")
output = np.array(-128).astype("int8")
Expand Down Expand Up @@ -177,6 +199,7 @@ def test_dynamic_quantize():
if __name__ == "__main__":
test_float32_to_uint8()
test_float32_to_int8()
test_float32_to_uint16()
test_scalar_float32_to_int8()
test_channelwise_axis_0()
test_channelwise_axis_1()
Expand Down