Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[mkldnn-1.0] mkldnn int8 elemwise_add (#16454)
Browse files Browse the repository at this point in the history
* add mkldnn int8 elemwise_add

* add workaround to fix format any issue

* code clean
  • Loading branch information
rongzha1 authored and pengzhao-intel committed Oct 12, 2019
1 parent 48bfcf9 commit 34239b6
Showing 1 changed file with 29 additions and 27 deletions.
56 changes: 29 additions & 27 deletions src/operator/quantization/mkldnn/mkldnn_quantized_elemwise_add.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \brief
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../quantized_elemwise_add-inl.h"
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
#include "../../nn/mkldnn/mkldnn_base-inl.h"
Expand Down Expand Up @@ -73,17 +73,17 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons

// output default set as int32
float output_data_range = kInt32Range;
auto output_data_type = mkldnn::memory::s32;
auto output_data_type = mkldnn::memory::data_type::s32;
// dataA && dataB are uint8
if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kInt8) {
output_data_range = kInt8Range;
output_data_type = mkldnn::memory::s8;
output_data_type = mkldnn::memory::data_type::s8;
} else if (out_data[quantized_elemwise_add_enum::kOut].dtype() == mshadow::kUint8) {
output_data_range = kUint8Range;
output_data_type = mkldnn::memory::u8;
output_data_type = mkldnn::memory::data_type::u8;
} else {
output_data_range = kInt32Range;
output_data_type = mkldnn::memory::s32;
output_data_type = mkldnn::memory::data_type::s32;
}

float output_min = 0;
Expand All @@ -100,12 +100,13 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons
// 2: scale 0 for dataA, scale 1 for data B
const int scales_num = 2;
std::vector<float> scales(scales_num, 1);
auto engine = CpuEngine::Get()->get_engine();
if (in_data[quantized_elemwise_add_enum::kDataA].dtype()
!= in_data[quantized_elemwise_add_enum::kDataB].dtype()) {
auto s8_pd = (is_dataA_int8 == true)
? dataA_mem->get_primitive_desc()
: dataB_mem->get_primitive_desc();
rescaled_mem = TmpMemMgr::Get()->Alloc(s8_pd);
auto s8_desc = (is_dataA_int8 == true)
? dataA_mem->get_desc()
: dataB_mem->get_desc();
rescaled_mem = TmpMemMgr::Get()->Alloc(s8_desc);
float u8_reorder_scale = 0;
if (params.max_calib_range.has_value() && params.min_calib_range.has_value()) {
if (is_dataA_int8 == true) {
Expand All @@ -130,14 +131,16 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons
}
}
std::vector<float> reorder_scale = {u8_reorder_scale};
primitive_attr reorder_attr;
reorder_attr.set_int_output_round_mode(round_mode::round_nearest);
mkldnn::primitive_attr reorder_attr;
reorder_attr.set_output_scales(0, reorder_scale);
auto u8_mem = (is_dataA_int8 == true) ? dataB_mem : dataA_mem;
const auto reorder_pd = mkldnn::reorder::primitive_desc(u8_mem->get_primitive_desc(),
s8_pd,
const auto reorder_pd = mkldnn::reorder::primitive_desc(engine,
u8_mem->get_desc(),
engine,
s8_desc,
reorder_attr);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *u8_mem, *rescaled_mem));
mkldnn_args_map_t args({{MKLDNN_ARG_FROM, *u8_mem }, {MKLDNN_ARG_TO, *rescaled_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(mkldnn::reorder(reorder_pd), args);

if (is_dataA_int8 == true) {
dataB_mem = rescaled_mem;
Expand All @@ -155,27 +158,26 @@ static void MKLDNNQuantizedElemwiseAddForward(const nnvm::NodeAttrs& attrs, cons
}
}

std::vector<mkldnn::primitive::at> in_prims;
std::vector<mkldnn::memory::primitive_desc> in_pds;
in_prims.push_back(*dataA_mem);
in_prims.push_back(*dataB_mem);
in_pds.push_back(dataA_mem->get_primitive_desc());
in_pds.push_back(dataB_mem->get_primitive_desc());
std::vector<mkldnn::memory::desc> in_desc;
in_desc.push_back(dataA_mem->get_desc());
in_desc.push_back(dataB_mem->get_desc());
size_t i_ndim = in_data[quantized_elemwise_add_enum::kDataA].shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_data[quantized_elemwise_add_enum::kDataA].shape()[i]);
}
mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(
in_pds[quantized_elemwise_add_enum::kDataA].desc().data.format);
auto output_desc = mkldnn::memory::desc(i_dims, output_data_type, i_fmt);
mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_pds);
auto output_desc = dataA_mem->get_desc();
output_desc.data.data_type = static_cast<mkldnn_data_type_t>(output_data_type);
mkldnn::sum::primitive_desc pdesc(output_desc, scales, in_desc, engine);
auto mem = CreateMKLDNNMem(out_data[quantized_elemwise_add_enum::kOut],
pdesc.dst_primitive_desc(),
pdesc.dst_desc(),
req[0],
&in_data[0]);
mkldnn_args_map_t args({{MKLDNN_ARG_MULTIPLE_SRC, *dataA_mem},
{MKLDNN_ARG_MULTIPLE_SRC + 1, *dataB_mem},
{MKLDNN_ARG_DST, *mem.second}});
MKLDNNStream *stream = MKLDNNStream::Get();
stream->RegisterPrim(mkldnn::sum(pdesc, in_prims, *mem.second));
stream->RegisterPrimArgs(mkldnn::sum(pdesc), args);
CommitOutput(out_data[quantized_elemwise_add_enum::kOut], mem);
stream->Submit();

Expand Down Expand Up @@ -203,4 +205,4 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_add)
} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_USE_MKLDNN == 100

0 comments on commit 34239b6

Please sign in to comment.