Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

add uint8 bn mkldnn implementation #16003

Merged
merged 7 commits into from
Aug 26, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,8 @@ def save_params(fname, arg_params, aux_params, logger=None):
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model.find('resnet') != -1 and args.model.find('v2') != -1:
excluded_sym_names += ['resnetv20_flatten0_flatten0']
# resnetv20_stage1_batchnorm0_fwd is excluded for the sake of accuracy
excluded_sym_names += ['resnetv20_flatten0_flatten0', 'resnetv20_stage1_batchnorm0_fwd']
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why exclude the first one?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is for the sake of accuracy, if do not exclude this layer, top-1 accuracy will drop to 52.3. Reason of this accuracy drop is under investigation.

if exclude_first_conv:
excluded_sym_names += ['resnetv20_conv0_fwd']
elif args.model.find('vgg') != -1:
Expand Down
11 changes: 5 additions & 6 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,14 +132,13 @@ class MKLDNNBNForward {
return *var_m;
}

void SetDataHandle(const NDArray &data, const mkldnn::memory *mean,
void SetDataHandle(const mkldnn::memory *data, const mkldnn::memory *mean,
const mkldnn::memory *var, const mkldnn::memory *out) {
auto _data = data.GetMKLDNNData();
if (data_m) {
data_m->set_data_handle(_data->get_data_handle());
data_m->set_data_handle(data->get_data_handle());
} else {
data_m.reset(new mkldnn::memory(_data->get_primitive_desc(),
_data->get_data_handle()));
data_m.reset(new mkldnn::memory(data->get_primitive_desc(),
data->get_data_handle()));
}
if (out_m) {
out_m->set_data_handle(out->get_data_handle());
Expand Down Expand Up @@ -175,7 +174,7 @@ class MKLDNNBNForward {

void SetDataHandle(const NDArray &data, const NDArray &mean,
const NDArray &var, const mkldnn::memory &out) {
SetDataHandle(data, mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
SetDataHandle(data.GetMKLDNNData(), mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
}

const mkldnn::batch_normalization_forward &GetFwd() const {
Expand Down
23 changes: 22 additions & 1 deletion src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,27 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const
TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const NDArray &data = in_data[quantized_batchnorm::kData];
auto data_mem = data.GetMKLDNNData();

// reorder if data type = uint8
if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) {
auto u8_pd = data_mem->get_primitive_desc();
auto u8_md = u8_pd.desc();
mkldnn::memory::desc s8_md(
mkldnn::memory::dims(u8_md.data.dims, u8_md.data.dims + u8_md.data.ndims),
mkldnn::memory::data_type::s8, static_cast<mkldnn::memory::format>(u8_md.data.format));
auto s8_pd = mkldnn::memory::primitive_desc(s8_md, CpuEngine::Get()->get_engine());
auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_pd);

std::vector<float> reorder_scale;
reorder_scale = {static_cast<float>(kInt8Range) / kUint8Range};
primitive_attr reorder_attr;
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
reorder_attr.set_output_scales(0, reorder_scale);
const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_pd, s8_pd, reorder_attr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *data_mem, *data_reorder_mem));
data_mem = data_reorder_mem;
}
const size_t channelAxis = static_cast<size_t>(
param.axis < 0 ? static_cast<int>(data.shape().ndim()) + param.axis : param.axis);
const int channel_count = data.shape()[channelAxis];
Expand Down Expand Up @@ -92,7 +113,7 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const

auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut],
fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data);
fwd.SetDataHandle(data, rescaled_mean_mem, rescaled_var_mem, out_mem.second);
fwd.SetDataHandle(data_mem, rescaled_mean_mem, rescaled_var_mem, out_mem.second);

MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
MKLDNNStream::Get()->Submit();
Expand Down
6 changes: 6 additions & 0 deletions src/operator/quantization/quantized_batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,13 @@ bool QuantizedBatchNormType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_t
CHECK_EQ(in_type->size(), 7U);
CHECK_EQ(out_type->size(), 3U);

#if MXNET_USE_MKLDNN == 1
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
<< "QuantizedBatchNorm with MKLDNN backend only supports int8/uint8 input, while "
<< in_type->at(0) << " is given.";
#else
TYPE_ASSIGN_CHECK(*in_type, 0, mshadow::kInt8);
#endif
for (size_t i = 1; i < 7; ++i) {
TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32);
}
Expand Down
49 changes: 23 additions & 26 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,19 +607,21 @@ def get_mean_var(data):
return mean, var

def check_quantized_bn(data_shape, qdtype):
if qdtype == 'uint8':
print('skipped testing quantize_bn for uint8 since it is not supported yet')
return
elif is_test_for_native_cpu():
if is_test_for_native_cpu():
print('skipped testing quantize_bn for native cpu since it is not supported yet')
return
elif is_test_for_gpu():
print('skipped testing quantize_bn for gpu since it is not supported yet')
return

# qdtype = int8
data_low = -127.0
data_high = 127.0
# qdtype = uint8
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
else:
data_low = -127.0
data_high = 127.0
# output type = int8
quantized_range = 127.0
# run fp32 bn
data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
Expand All @@ -639,9 +641,6 @@ def check_quantized_bn(data_shape, qdtype):
bn_fp32_exe.arg_dict[arg_names[2]][:] = beta
bn_fp32_exe.aux_dict[aux_names[0]][:] = moving_mean
bn_fp32_exe.aux_dict[aux_names[1]][:] = moving_var
min_data = mx.nd.min(data)
max_data = mx.nd.max(data)
data_range = mx.nd.maximum(mx.nd.abs(min_data), mx.nd.abs(max_data))

output= bn_fp32_exe.forward()[0]

Expand All @@ -654,11 +653,12 @@ def check_quantized_bn(data_shape, qdtype):

calib_data = NDArrayIter(data=data, batch_size=data_shape[0])
calib_data = DummyIter(calib_data)
# quantize bn with quantized_type = int8: MKLDNN BN only support int8 output
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32,
arg_params=arg_params,
aux_params=bn_fp32_exe.aux_dict,
ctx=mx.current_context(),
quantized_dtype=qdtype,
quantized_dtype='int8',
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
Expand All @@ -668,13 +668,14 @@ def check_quantized_bn(data_shape, qdtype):
mod.set_params(qarg_params, qaux_params)
batch = mx.io.DataBatch([data], [])
mod.forward(batch, is_train=False)
output_int8_to_fp32= mod.get_outputs()[0]
output_int8_to_fp32 = mod.get_outputs()[0]

assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=3)
assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=4)

check_quantized_bn((32, 512, 4, 4), 'int8')
check_quantized_bn((32, 1024, 8, 8), 'int8')
check_quantized_bn((32, 3, 224, 224), 'int8')
for qdtype in ['int8', 'uint8']:
check_quantized_bn((32, 512, 4, 4), qdtype)
check_quantized_bn((32, 1024, 8, 8), qdtype)
check_quantized_bn((32, 3, 224, 224), qdtype)

@with_seed()
def test_quantize_params():
Expand Down Expand Up @@ -918,16 +919,12 @@ def check_qsym_forward(qsym, qarg_params, qaux_params, data_shape, label_shape=N
lshape_list.append(None)

for s, dshape, lshape, name in zip(sym_list, dshape_list, lshape_list, name_list):
if qdtype == 'int8' and is_test_for_mkldnn() and name in ['sym1', 'sym2', 'sym3']:
print('skipped testing test_quantize_model_with_forward for mkldnn cpu int8 since it is not supported yet')
continue
elif qdtype == 'uint8' and is_test_for_mkldnn() and name in ['sym1']:
print('skipping test_quantize_model_with_forward for mkldnn cpu uint8 since it is not supported yet')
continue
elif qdtype == 'int8' and is_test_for_gpu() and name in ['sym1']:
print('skipped testing test_quantize_model_with_forward for gpu int8 since it is not supported yet')
continue

if qdtype == 'int8' and name in ['sym1','sym2','sym3']:
print('mkldnn_quantized_conv op only supports uint8 as input type, skip test with int8.')
continue
if qdtype == 'uint8' and name in ['sym1']:
print('mkldnn_quantized_bn doesn\'t support calib_mode=None')
continue
if lshape is None:
mod = Module(symbol=s, label_names=None)
mod.bind(for_training=False,
Expand Down