Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhance gpu quantization #14094

Merged
merged 12 commits into from
Mar 6, 2019
3 changes: 3 additions & 0 deletions python/mxnet/contrib/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,9 @@ def quantize_model(sym, arg_params, aux_params,
if quantized_dtype not in ('int8', 'uint8', 'auto'):
raise ValueError('unknown quantized_dtype %s received,'
' expected `int8`, `uint8` or `auto`' % quantized_dtype)
if quantized_dtype == 'uint8' and ctx != cpu():
raise ValueError('currently, uint8 quantization is only supported by CPU,'
' please switch to the context of CPU or int8 data type for GPU')
Copy link
Member

@anirudh2290 anirudh2290 Feb 9, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to add this error to backend like in the case for MKLDNN with int8 so that we dont have to add error handling to other frontends when we support quantization.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

currently, only python frontend support quantization and in fact calibration progress will not use backend specific quantized operator. So I think it's good to add error message in this place currently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In QuantizeCompute (quantize-inl.h) you can check if std::is_same<xpu,gpu>::value and check for param.out_type and throw exception.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think this modification can work since infer type error mxnet.base.MXNetError: [02:07:55] /home/ubuntu/experimentals/1.4_release/src/operator/quantization/../tensor/matrix_op-inl.h:250: Check failed: src.type_flag_ == ret.type_flag_ (3 vs. 5) will occur before QuantizeCompute and we cannot get the ctx information during infer stage. So I think it's good to interrupt this action during the calibration stage.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isnt that called from the forward pass of quantized_conv ? The quantize forward pass should execute before this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add check src_type in quantized_conv.cu, please take a review again.

qsym = _quantize_symbol(sym, excluded_symbols=excluded_sym_names,
offline_params=list(arg_params.keys()),
quantized_dtype=quantized_dtype)
Expand Down
3 changes: 3 additions & 0 deletions src/operator/quantization/quantized_conv.cu
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ class QuantizedCuDNNConvOp {
if (param_.pad.ndim() == 0U) param_.pad = mshadow::Shape2(0, 0);
N = 0, H = 2, W = 3, C = 1;
src_type_ = mshadow::DataType<SrcType>::kCudnnFlag;
CHECK_EQ(src_type_, 5U)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the 5U here?

<< "currently, uint8 quantization is only supported by CPU, "
"please switch to the context of CPU or int8 data type for GPU.";
dst_type_ = mshadow::DataType<DstType>::kCudnnFlag;
cmp_type_ = mshadow::DataType<CmpType>::kCudnnFlag;
algo_ = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM;
Expand Down
10 changes: 10 additions & 0 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,16 @@ def get_fp32_sym_with_multiple_outputs(length=1):
@with_seed()
def test_quantize_model():
def check_quantize_model(qdtype):
if is_test_for_native_cpu():
print('skipped testing quantize_model for native cpu since it is not supported yet')
return
elif qdtype == 'int8' and is_test_for_mkldnn():
print('skipped testing quantize_model for mkldnn cpu int8 since it is not supported yet')
return
elif qdtype == 'uint8' and is_test_for_gpu():
print('skipped testing quantize_model for gpu uint8 since it is not supported yet')
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add else clause.


def check_params(params, qparams, qsym=None):
if qsym is None:
assert len(params) == len(qparams)
Expand Down