Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-v1.0] add quantized bn #16458

Merged
merged 1 commit into from
Oct 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 28 additions & 22 deletions src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \author Yixin Bao
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h"
#include "../quantization_utils.h"

Expand All @@ -44,21 +44,22 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const

// reorder if data type = uint8
if (in_data[quantized_batchnorm::kData].dtype() == mshadow::kUint8) {
auto u8_pd = data_mem->get_primitive_desc();
auto u8_md = u8_pd.desc();
mkldnn::memory::desc s8_md(
mkldnn::memory::dims(u8_md.data.dims, u8_md.data.dims + u8_md.data.ndims),
mkldnn::memory::data_type::s8, static_cast<mkldnn::memory::format>(u8_md.data.format));
auto s8_pd = mkldnn::memory::primitive_desc(s8_md, CpuEngine::Get()->get_engine());
auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_pd);
auto u8_md = data_mem->get_desc();
auto s8_md = u8_md;
s8_md.data.data_type = static_cast<mkldnn_data_type_t>(mkldnn::memory::data_type::s8);
auto data_reorder_mem = TmpMemMgr::Get()->Alloc(s8_md);

std::vector<float> reorder_scale;
reorder_scale = {static_cast<float>(kInt8Range) / kUint8Range};
primitive_attr reorder_attr;
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
mkldnn::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, reorder_scale);
const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_pd, s8_pd, reorder_attr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *data_mem, *data_reorder_mem));
mkldnn::engine cpu_engine = CpuEngine::Get()->get_engine();
const auto reorder_pd =
mkldnn::reorder::primitive_desc(cpu_engine, u8_md, cpu_engine, s8_md, reorder_attr);
mkldnn_args_map_t reorder_args;
reorder_args[MKLDNN_ARG_SRC] = *data_mem;
reorder_args[MKLDNN_ARG_DST] = *data_reorder_mem;
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), reorder_args);
data_mem = data_reorder_mem;
}
const size_t channelAxis = static_cast<size_t>(
Expand All @@ -79,10 +80,11 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const
}
const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr));

unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift;
mkldnn::normalization_flags flags =
mkldnn::normalization_flags::use_global_stats | mkldnn::normalization_flags::use_scale_shift;
auto &fwd = GetBNForward<float>(param, ctx, data_mem, flags);
const mkldnn::memory &weight_mem = fwd.GetWeight();
CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2);
CHECK_EQ(weight_mem.get_desc().get_size(), channel_count * sizeof(float) * 2);
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());

float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr<float>();
Expand All @@ -94,9 +96,8 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const
float *moving_var_ptr = moving_var.data().dptr<float>();

// rescale gamma and beta, to make mean=0 and var=1
auto rescaled_mean_mem =
TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_primitive_desc());
auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_primitive_desc());
auto rescaled_mean_mem = TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_desc());
auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_desc());
float *rescaled_mean_ptr = reinterpret_cast<float *>(rescaled_mean_mem->get_data_handle());
float *rescaled_var_ptr = reinterpret_cast<float *>(rescaled_var_mem->get_data_handle());

Expand All @@ -111,11 +112,16 @@ static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const
rescaled_var_ptr[channel] = 1.0f;
}

auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut],
fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data);
fwd.SetDataHandle(data_mem, rescaled_mean_mem, rescaled_var_mem, out_mem.second);
const NDArray &out = outputs[batchnorm::kOut];
auto out_mem = const_cast<NDArray &>(out).CreateMKLDNNData(fwd.GetPd().dst_desc());
mkldnn_args_map_t net_args;
net_args[MKLDNN_ARG_SRC] = *data_mem;
net_args[MKLDNN_ARG_SCALE_SHIFT] = weight_mem;
net_args[MKLDNN_ARG_DST] = *out_mem;
net_args[MKLDNN_ARG_MEAN] = *rescaled_mean_mem;
net_args[MKLDNN_ARG_VARIANCE] = *rescaled_var_mem;

MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
MKLDNNStream::Get()->Submit();
}

Expand All @@ -141,4 +147,4 @@ NNVM_REGISTER_OP(_contrib_quantized_batch_norm)
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100
4 changes: 2 additions & 2 deletions src/operator/quantization/quantized_batch_norm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
*/
#include <mxnet/op_attr_types.h>
#include "../nn/batch_norm-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../nn/mkldnn/mkldnn_batch_norm-inl.h"
#endif

Expand Down Expand Up @@ -67,7 +67,7 @@ bool QuantizedBatchNormType(const nnvm::NodeAttrs& attrs, std::vector<int>* in_t
CHECK_EQ(in_type->size(), 7U);
CHECK_EQ(out_type->size(), 3U);

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
CHECK(in_type->at(0) == mshadow::kInt8 || in_type->at(0) == mshadow::kUint8)
<< "QuantizedBatchNorm with MKLDNN backend only supports int8/uint8 input, while "
<< in_type->at(0) << " is given.";
Expand Down
11 changes: 5 additions & 6 deletions tests/python/quantization/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,12 +617,11 @@ def check_quantized_bn(data_shape, qdtype):
# qdtype = uint8
if qdtype == 'uint8':
data_low = 0.0
data_high = 127.0
data_high = 255.0
else:
data_low = -127.0
data_high = 127.0
# output type = int8
quantized_range = 127.0

# run fp32 bn
data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32')
bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', use_global_stats=True, fix_gamma=False)
Expand Down Expand Up @@ -653,12 +652,12 @@ def check_quantized_bn(data_shape, qdtype):

calib_data = NDArrayIter(data=data, batch_size=data_shape[0])
calib_data = DummyIter(calib_data)
# quantize bn with quantized_type = int8: MKLDNN BN only support int8 output
qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32,
arg_params=arg_params,
aux_params=bn_fp32_exe.aux_dict,
ctx=mx.current_context(),
quantized_dtype='int8',
quantized_dtype=qdtype,
quantize_mode='full',
calib_mode='naive',
calib_data=calib_data,
num_calib_examples=20)
Expand All @@ -670,7 +669,7 @@ def check_quantized_bn(data_shape, qdtype):
mod.forward(batch, is_train=False)
output_int8_to_fp32 = mod.get_outputs()[0]

assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=4)
assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=8)

for qdtype in ['int8', 'uint8']:
check_quantized_bn((32, 512, 4, 4), qdtype)
Expand Down