Skip to content

Commit

Permalink
quantize add inplace opt
Browse files Browse the repository at this point in the history
  • Loading branch information
anirudh2290 committed Feb 20, 2019
1 parent 2305e22 commit cc6bf9e
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
14 changes: 13 additions & 1 deletion src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
int out_type;
dmlc::optional<float> min_calib_range;
dmlc::optional<float> max_calib_range;
int in_type;
DMLC_DECLARE_PARAMETER(QuantizeV2Param) {
DMLC_DECLARE_FIELD(out_type)
.add_enum("auto", kAuto)
Expand All @@ -58,6 +59,8 @@ struct QuantizeV2Param : public dmlc::Parameter<QuantizeV2Param> {
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32. If present, it will be used to "
"quantize the fp32 data into int8 or uint8.");
DMLC_DECLARE_FIELD(in_type)
.describe("in type");
}
};

Expand Down Expand Up @@ -201,7 +204,16 @@ static inline bool QuantizeV2Type(const nnvm::NodeAttrs &attrs, std::vector<int>
CHECK_EQ(in_attrs->size(), 1U);
CHECK_EQ(out_attrs->size(), 3U);
const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
if ((*in_attrs)[0] == mshadow::kUint8) {
TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kUint8);
param.in_type = mshadow::kUint8;
} else if ((*in_attrs)[0] == mshadow::kInt8) {
TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kInt8);
param.in_type = mshadow::kInt8;
} else {
param.in_type = mshadow::kFloat32;
TYPE_ASSIGN_CHECK(*in_attrs, 0, mshadow::kFloat32);
}
auto out_type = GetOutputType(param);
if (out_type == mshadow::kUint8) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8);
Expand Down
13 changes: 13 additions & 0 deletions src/operator/quantization/quantize_v2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,18 @@ static bool QuantizeV2StorageType(const nnvm::NodeAttrs& attrs, const int dev_ma
return true;
}

std::vector<std::pair<int, int>> QuantizeInPlaceOption(
const NodeAttrs &attrs) {
auto const &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
if ((param.out_type == mshadow::kUint8 && param.in_type == mshadow::kUint8)
|| (param.out_type == mshadow::kInt8 && param.in_type == mshadow::kInt8)) {
return std::vector<std::pair<int, int>>{{0, 0}};
} else {
return std::vector<std::pair<int, int>>();
}
}


NNVM_REGISTER_OP(_contrib_quantize_v2)
.describe(R"code(Quantize a input tensor from float to `out_type`,
with user-specified `min_calib_range` and `max_calib_range` or the input range collected at runtime.
Expand Down Expand Up @@ -87,6 +99,7 @@ If min_calib_range isn't presented, the output type will be int8.
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizeV2Compute)
#endif
.set_attr<nnvm::FInplaceOption>("FInplaceOption", QuantizeInPlaceOption)
.set_attr<FCompute>("FCompute<cpu>", QuantizeV2Compute<cpu>)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& attrs) {
const QuantizeV2Param &param = nnvm::get<QuantizeV2Param>(attrs.parsed);
Expand Down

0 comments on commit cc6bf9e

Please sign in to comment.