diff --git a/src/operator/quantization/quantize_v2-inl.h b/src/operator/quantization/quantize_v2-inl.h index 5ae10a7e4fa8..7f1c9d3f9e0e 100644 --- a/src/operator/quantization/quantize_v2-inl.h +++ b/src/operator/quantization/quantize_v2-inl.h @@ -42,6 +42,7 @@ struct QuantizeV2Param : public dmlc::Parameter { int out_type; dmlc::optional min_calib_range; dmlc::optional max_calib_range; + int in_type; DMLC_DECLARE_PARAMETER(QuantizeV2Param) { DMLC_DECLARE_FIELD(out_type) .add_enum("auto", kAuto) @@ -58,6 +59,8 @@ struct QuantizeV2Param : public dmlc::Parameter { .set_default(dmlc::optional()) .describe("The maximum scalar value in the form of float32. If present, it will be used to " "quantize the fp32 data into int8 or uint8."); + DMLC_DECLARE_FIELD(in_type) + .describe("in type"); } }; @@ -201,7 +204,16 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector CHECK_EQ(in_attrs->size(), 1U); CHECK_EQ(out_attrs->size(), 3U); const QuantizeV2Param ¶m = nnvm::get(attrs.parsed); - TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + if ((*in_attrs)[0] == mshadow::kUint8) { + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kUint8); + param.in_type = mshadow::kUint8; + } else if ((*in_attrs)[0] == mshadow::kInt8) { + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt8); + param.in_type = mshadow::kInt8; + } else { + param.in_type = mshadow::kFloat32; + TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32); + } auto out_type = GetOutputType(param); if (out_type == mshadow::kUint8) { TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 21410933d35e..f1e8182b67c7 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -47,6 +47,18 @@ static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, const int dev_ma return true; } +std::vector> QuantizeInPlaceOption( + const NodeAttrs &attrs) { + auto const ¶m = nnvm::get(attrs.parsed); + if ((param.out_type == mshadow::kUint8 && param.in_type == mshadow::kUint8) + || (param.out_type == mshadow::kInt8 && param.in_type == mshadow::kInt8)) { + return std::vector>{{0, 0}}; + } else { + return std::vector>(); + } +} + + NNVM_REGISTER_OP(_contrib_quantize_v2) .describe(R"code(Quantize a input tensor from float to `out_type`, with user-specified `min_calib_range` and `max_calib_range` or the input range collected at runtime. @@ -87,6 +99,7 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeV2Compute) #endif +.set_attr("FInplaceOption", QuantizeInPlaceOption) .set_attr("FCompute", QuantizeV2Compute) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { const QuantizeV2Param ¶m = nnvm::get(attrs.parsed);