Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhance gpu quantization #14094

Merged
merged 12 commits into from
Mar 6, 2019
4 changes: 4 additions & 0 deletions src/operator/quantization/quantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ void QuantizeCompute(const nnvm::NodeAttrs& attrs,

const QuantizeParam& param = nnvm::get<QuantizeParam>(attrs.parsed);
if (param.out_type == mshadow::kUint8) {
if (std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
}
Kernel<quantize_unsigned, xpu>::Launch(s, outputs[0].Size(),
outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(), outputs[2].dptr<float>(),
inputs[0].dptr<float>(), inputs[1].dptr<float>(), inputs[2].dptr<float>(),
Expand Down
8 changes: 8 additions & 0 deletions src/operator/quantization/quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,10 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
auto out_type = GetOutputType(param);
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
if (out_type == mshadow::kUint8) {
if (std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
}
Kernel<quantize_v2_unsigned, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), param.min_calib_range.value(),
Expand Down Expand Up @@ -170,6 +174,10 @@ void QuantizeV2Compute(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
broadcast::Reduce<red::maximum, 2, SrcDType, mshadow::op::identity>(
s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape));
if (out_type == mshadow::kUint8) {
if (std::is_same<xpu, gpu>::value) {
LOG(FATAL) << "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
}
Kernel<quantize_v2_unsigned, xpu>::Launch(
s, outputs[0].Size(), outputs[0].dptr<uint8_t>(), outputs[1].dptr<float>(),
outputs[2].dptr<float>(), inputs[0].dptr<SrcDType>(), in_min_t.dptr<float>(),
Expand Down
10 changes: 10 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def get_fp32_sym_with_multiple_outputs(length=1):
@with_seed()
def test_quantize_model():
def check_quantize_model(qdtype):
if is_test_for_native_cpu():
print('skipped testing quantize_model for native cpu since it is not supported yet')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing quantize_model for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantize_model for gpu uint8 since it is not supported yet')
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add else clause.


def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
Expand Down