@@ -47,9 +47,10 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
4747
4848 const auto input_dtype = data->dtype ;
4949 ICHECK (input_dtype == DataType::Int (8 ) || input_dtype == DataType::UInt (8 ) ||
50- input_dtype == DataType::Int (16 ) || input_dtype == DataType::Int (32 ))
51- << " Input type should be one of the quantized types [unit8, int8, int16, int32] but was "
52- << input_dtype;
50+ input_dtype == DataType::Int (16 ) || input_dtype == DataType::UInt (16 ) ||
51+ input_dtype == DataType::Int (32 ))
52+ << " Input type should be one of the quantized types [int8, unit8, int16, uint16, int32] but "
53+ << " was " << input_dtype;
5354
5455 const auto * dequantize_attrs = attrs.as <DequantizeAttrs>();
5556 int axis = dequantize_attrs->axis ;
@@ -77,18 +78,24 @@ bool DequantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
7778 // Check and assign types for scale and zero points.
7879 AssignType (types[1 ], DataType::Float (32 ), axis_shape, reporter); // scale
7980 AssignType (types[2 ], DataType::Int (32 ), axis_shape, reporter); // zero point
81+
8082 const Array<tvm::PrimExpr> oshape = data->shape ;
81- // assign output type, output will always be float 32.
82- reporter->Assign (types[3 ], TensorType (oshape, DataType::Float (32 )));
83+ const DataType out_dtype = dequantize_attrs->out_dtype ;
84+ ICHECK (out_dtype == DataType::Float (16 ) || out_dtype == DataType::Float (32 ))
85+ << " Output type should be one of [float16, float32] but was " << out_dtype;
86+ // assign output type.
87+ reporter->Assign (types[3 ], TensorType (oshape, out_dtype));
8388 return true ;
8489}
8590
86- Expr MakeDequantize (Expr data, Expr input_scale, Expr input_zero_point, int axis) {
91+ Expr MakeDequantize (Expr data, Expr input_scale, Expr input_zero_point, int axis,
92+ DataType out_dtype) {
8793 // real_value = scale * (quantized_value - zero_point)
8894 // A more detailed explanation can be found here -
8995 // https://github.com/google/gemmlowp/blob/master/doc/quantization.md
9096 auto attrs = make_object<DequantizeAttrs>();
9197 attrs->axis = axis;
98+ attrs->out_dtype = out_dtype;
9299 static const Op& op = Op::Get (" qnn.dequantize" );
93100 return Call (op, {data, input_scale, input_zero_point}, Attrs (attrs), {});
94101}
@@ -125,7 +132,14 @@ Expr DequantizeLower(const Expr& input_tensor, const Expr& input_scale,
125132
126133 auto shift = Subtract (Cast (input_tensor, DataType::Int (32 )), expanded_input_zero_point);
127134 auto scaled_output = Multiply (Cast (shift, DataType::Float (32 )), expanded_input_scale);
128- return scaled_output;
135+
136+ const DataType out_dtype = attrs->out_dtype ;
137+ if (out_dtype.is_float () && out_dtype.bits () == 32 ) return scaled_output;
138+
139+ double min_val = tvm::min_value (out_dtype).as <FloatImmNode>()->value ;
140+ double max_val = tvm::max_value (out_dtype).as <FloatImmNode>()->value ;
141+ auto clamped_output = Clip (scaled_output, min_val, max_val);
142+ return Cast (clamped_output, out_dtype);
129143}
130144
131145Expr DequantizeQnnCanonicalize (const Attrs& attrs, const Array<Expr>& new_args,
0 commit comments