Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[mkldnn-1.0] int8 conv quantize dequantize requantize #16283

Merged
merged 4 commits into from
Oct 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/ndarray.h
Original file line number Diff line number Diff line change
Expand Up @@ -787,7 +787,7 @@ class NDArray {
/*!
* \ Fix mkldnn memory descriptor mismatch from NDArray.
*/
void UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format);
void UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc);
#endif

/*!
Expand Down
2 changes: 1 addition & 1 deletion src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace exec {
using namespace mxnet::common;

static const std::string GetDefaultSubgraphBackend() {
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
return std::string("MKLDNN");
#else
return std::string();
Expand Down
12 changes: 5 additions & 7 deletions src/ndarray/ndarray.cc
Original file line number Diff line number Diff line change
Expand Up @@ -724,16 +724,14 @@ mkldnn::memory *NDArray::CreateMKLDNNData(const mkldnn::memory::desc &desc) {
return ptr_->mkl_mem_->GetRaw();
}

void NDArray::UpdateMKLDNNMemDesc(mkldnn::memory::format_tag format) {
const mkldnn::memory *mem = GetMKLDNNData();
auto mem_desc = mem->get_desc();
void NDArray::UpdateMKLDNNMemDesc(const mkldnn::memory::desc &desc) {
auto new_desc = desc;
auto this_dtype = get_mkldnn_type(dtype());
mkldnn::memory::desc data_md(
mkldnn::memory::dims(mem_desc.data.dims, mem_desc.data.dims + mem_desc.data.ndims),
this_dtype, format);
ptr_->mkl_mem_.reset(new MKLDNNMemory(data_md, ptr_->shandle.dptr));
new_desc.data.data_type = static_cast<mkldnn_data_type_t>(this_dtype);
ptr_->mkl_mem_.reset(new MKLDNNMemory(new_desc, ptr_->shandle.dptr));
MKLDNNStream::Get()->RegisterMem(ptr_->mkl_mem_->GetMem());
}

#endif

void NDArray::SetTBlob() const {
Expand Down
15 changes: 15 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,21 @@ static inline mkldnn::memory::data_type get_mkldnn_type(int dtype) {
}
}

template<typename T>
static inline mkldnn::memory::data_type get_mkldnn_type() {
return static_cast<mkldnn::memory::data_type>(data_type_enum<T>::type);
}

static inline mkldnn_data_type_t get_mkldnn_type_t(int dtype) {
return static_cast<mkldnn_data_type_t>(get_mkldnn_type(dtype));
}

template<typename T>
static inline mkldnn_data_type_t get_mkldnn_type_t() {
return static_cast<mkldnn_data_type_t>(data_type_enum<T>::type);
}


static inline int get_mxnet_type(mkldnn_data_type_t dtype) {
auto mkldnn_dtype = static_cast<mkldnn::memory::data_type>(dtype);
switch (mkldnn_dtype) {
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,10 @@ class MKLDNNConvForward {

typedef ParamOpSign<ConvolutionParam> MKLDNNConvSignature;

MKLDNNConvForward &GetConvFwd(const MKLDNNConvFullParam &param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
const NDArray &output);

void MKLDNNConvolutionForwardFullFeature(const MKLDNNConvFullParam &param,
const OpContext &ctx,
MKLDNNConvForward *fwd,
Expand Down
8 changes: 4 additions & 4 deletions src/operator/quantization/dequantize.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \brief
*/
#include "./dequantize-inl.h"
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "./mkldnn/mkldnn_dequantize-inl.h"
#endif

Expand All @@ -37,7 +37,7 @@ bool DequantizeStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
*dispatch_mode = DispatchMode::kFCompute;
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
if (dev_mask == mshadow::cpu::kDevMask) {
*dispatch_mode = DispatchMode::kFComputeEx;
}
Expand All @@ -55,7 +55,7 @@ static OpStatePtr CreateDequantizeState(const nnvm::NodeAttrs &attrs, Context ct
if (ctx.dev_type == kGPU) {
state = OpStatePtr::Create<DequantizeOperator<gpu>>(attrs);
} else {
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
state = OpStatePtr::Create<SgMKLDNNDequantizeOperator>(attrs);
#else
state = OpStatePtr::Create<DequantizeOperator<cpu>>(attrs);
Expand Down Expand Up @@ -95,7 +95,7 @@ by keep zero centered for the quantized value:
// will be reverted after the improvement of CachedOP is done.
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<FCreateOpState>("FCreateOpState", CreateDequantizeState)
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", SgMKLDNNDequantizeForward)
#endif
Expand Down
45 changes: 19 additions & 26 deletions src/operator/quantization/mkldnn/mkldnn_dequantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_DEQUANTIZE_INL_H_
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <algorithm>
#include <string>
#include <vector>
Expand All @@ -48,8 +48,8 @@ class SgMKLDNNDequantizeOperator {
DequantizeParam param_;
float cached_data_min_{0.f};
float cached_data_max_{0.f};
std::shared_ptr<mkldnn::memory> i_mem_;
std::shared_ptr<mkldnn::memory> o_mem_;
mkldnn::memory::desc o_desc_;
mkldnn_args_map_t args_;
std::shared_ptr<mkldnn::reorder> fwd_pd_;
};

Expand Down Expand Up @@ -79,37 +79,30 @@ void SgMKLDNNDequantizeOperator::Forward(const OpContext &ctx, const std::vector
LOG(FATAL) << "mkldnn dequantize op only supports int8 and uint8 as output type";
}
float scale = real_range / quantized_range;
primitive_attr attr;
mkldnn::primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
auto i_mpd = i_mem->get_primitive_desc();
auto i_desc = i_mpd.desc();
auto i_desc = i_mem->get_desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
}
mkldnn::memory::format o_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
if (o_fmt == mkldnn::memory::format::nhwc) {
// For 4d tensor, nchw is the default format
o_fmt = mkldnn::memory::format::nchw;
if (i_ndim == 4) {
mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nchw;
mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type<float>(), o_fmt);
} else {
o_desc_ = i_desc;
o_desc_.data.data_type = get_mkldnn_type_t<float>();
}
auto o_desc =
mkldnn::memory::desc(i_dims, (mkldnn::memory::data_type)data_type_enum<float>::type, o_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
i_mem_ = std::make_shared<mkldnn::memory>(i_mpd, nullptr);
o_mem_ = std::make_shared<mkldnn::memory>(o_mpd, nullptr);
fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *i_mem_, *o_mem_);
auto reorder_pd =
mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd);
initialized_ = true;
}
auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]);
i_mem_->set_data_handle(i_mem->get_data_handle());
o_mem_->set_data_handle(o_mem.second->get_data_handle());
MKLDNNStream::Get()->RegisterPrim(*fwd_pd_);
auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]);
args_[MKLDNN_ARG_FROM] = *i_mem;
args_[MKLDNN_ARG_TO] = *o_mem.second;
MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down
44 changes: 18 additions & 26 deletions src/operator/quantization/mkldnn/mkldnn_quantize-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_INL_H_
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <string>
#include <algorithm>
#include <vector>
Expand All @@ -35,11 +35,11 @@
namespace mxnet {
namespace op {

template<typename SrcType, typename DstType>
template <typename SrcType, typename DstType>
static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
const std::vector<NDArray>& outputs,
const QuantizeParam& param,
const std::vector<OpReqType> &req) {
const std::vector<OpReqType>& req) {
using namespace mshadow;
using namespace mxnet_op;
using red::limits::MaxValue;
Expand All @@ -60,38 +60,30 @@ static void MKLDNNQuantizeComputeKer(const std::vector<NDArray>& inputs,
LOG(FATAL) << "mkldnn quantize op only supports int8 and uint8 as output type";
}
float scale = quantized_range / real_range;
primitive_attr attr;
mkldnn::primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();

NDArray in_buffer = inputs[0];
if (inputs[0].IsView() && inputs[0].IsMKLDNNData())
in_buffer = inputs[0].Reorder2Default();
if (inputs[0].IsView() && inputs[0].IsMKLDNNData()) in_buffer = inputs[0].Reorder2Default();

auto i_mem = in_buffer.GetMKLDNNData();
auto i_mpd = i_mem->get_primitive_desc();
auto i_desc = i_mpd.desc();
mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
if (i_fmt == mkldnn::memory::format::nchw ||
i_fmt == mkldnn::memory::format::nChw8c ||
i_fmt == mkldnn_nChw16c) {
i_fmt = mkldnn::memory::format::nhwc;
}
auto i_desc = i_mem->get_desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
mkldnn::memory::desc o_desc;
if (i_ndim == 4) {
mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc;
mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
o_desc = mkldnn::memory::desc(o_dims, get_mkldnn_type<DstType>(), o_fmt);
} else {
o_desc = i_desc;
o_desc.data.data_type = get_mkldnn_type_t<DstType>();
}
auto o_desc = mkldnn::memory::desc(i_dims,
(mkldnn::memory::data_type)data_type_enum<DstType>::type,
i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_mpd, req[0]);
MKLDNNStream::Get()->RegisterPrim(mkldnn::reorder(reorder_pd, *i_mem, *o_mem.second));
auto reorder_pd = mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc, attr);
auto o_mem = CreateMKLDNNMem(outputs[0], o_desc, req[0]);
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::reorder(reorder_pd), {{MKLDNN_ARG_FROM, *i_mem}, {MKLDNN_ARG_TO, *o_mem.second}});
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down
44 changes: 19 additions & 25 deletions src/operator/quantization/mkldnn/mkldnn_quantize_v2-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

#ifndef MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
#define MXNET_OPERATOR_QUANTIZATION_MKLDNN_MKLDNN_QUANTIZE_V2_INL_H_
#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include <algorithm>
#include <string>
#include <vector>
Expand All @@ -47,8 +47,8 @@ class SgMKLDNNQuantizeOperator {
QuantizeV2Param param_;
float cached_data_min_{0.f};
float cached_data_max_{0.f};
std::shared_ptr<mkldnn::memory> i_mem_;
std::shared_ptr<mkldnn::memory> o_mem_;
mkldnn::memory::desc o_desc_;
mkldnn_args_map_t args_;
std::shared_ptr<mkldnn::reorder> fwd_pd_;
};

Expand Down Expand Up @@ -127,36 +127,30 @@ void SgMKLDNNQuantizeOperator::Forward(const OpContext &ctx, const std::vector<N
cached_data_max_ = data_max;
float real_range = MaxAbs(data_min, data_max);
float scale = quantized_range / real_range;
primitive_attr attr;
mkldnn::primitive_attr attr;
const int mask = 0;
std::vector<float> scales = {scale};
attr.set_output_scales(mask, scales);
attr.set_int_output_round_mode(round_nearest);
mkldnn::engine cpu_engine = mxnet::CpuEngine::Get()->get_engine();
auto i_mpd = i_mem->get_primitive_desc();
auto i_desc = i_mpd.desc();
mkldnn::memory::format i_fmt = static_cast<mkldnn::memory::format>(i_desc.data.format);
if (i_fmt == mkldnn::memory::format::nchw || i_fmt == mkldnn::memory::format::nChw8c ||
i_fmt == mkldnn_nChw16c) {
i_fmt = mkldnn::memory::format::nhwc;
}
auto i_desc = i_mem->get_desc();
size_t i_ndim = in_buffer.shape().ndim();
mkldnn::memory::dims i_dims = mkldnn::memory::dims(i_ndim);
for (size_t i = 0; i < i_ndim; i++) {
i_dims[i] = static_cast<int>(in_buffer.shape()[i]);
if (i_ndim == 4) {
mkldnn::memory::format_tag o_fmt = mkldnn::memory::format_tag::nhwc;
mkldnn::memory::dims o_dims(i_desc.data.dims, i_desc.data.dims + i_desc.data.ndims);
o_desc_ = mkldnn::memory::desc(o_dims, get_mkldnn_type(out_type), o_fmt);
} else {
o_desc_ = i_desc;
o_desc_.data.data_type = get_mkldnn_type_t(out_type);
}
auto o_desc = mkldnn::memory::desc(i_dims, get_mkldnn_type(out_type), i_fmt);
auto o_mpd = memory::primitive_desc(o_desc, cpu_engine);
auto reorder_pd = reorder::primitive_desc(i_mpd, o_mpd, attr);
i_mem_ = std::make_shared<mkldnn::memory>(i_mpd, nullptr);
o_mem_ = std::make_shared<mkldnn::memory>(o_mpd, nullptr);
fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd, *i_mem_, *o_mem_);
auto reorder_pd =
mkldnn::reorder::primitive_desc(cpu_engine, i_desc, cpu_engine, o_desc_, attr);
fwd_pd_ = std::make_shared<mkldnn::reorder>(reorder_pd);
initalized_ = true;
}
auto o_mem = CreateMKLDNNMem(outputs[0], o_mem_->get_primitive_desc(), req[0]);
i_mem_->set_data_handle(i_mem->get_data_handle());
o_mem_->set_data_handle(o_mem.second->get_data_handle());
MKLDNNStream::Get()->RegisterPrim(*fwd_pd_);
auto o_mem = CreateMKLDNNMem(outputs[0], o_desc_, req[0]);
args_[MKLDNN_ARG_FROM] = *i_mem;
args_[MKLDNN_ARG_TO] = *o_mem.second;
MKLDNNStream::Get()->RegisterPrimArgs(*fwd_pd_, args_);
CommitOutput(outputs[0], o_mem);
MKLDNNStream::Get()->Submit();
}
Expand Down
36 changes: 20 additions & 16 deletions src/operator/quantization/mkldnn/mkldnn_quantized_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
* \author Wenting Jiang, Xinyu Chen
*/

#if MXNET_USE_MKLDNN == 1
#if MXNET_USE_MKLDNN == 100
#include "../../nn/mkldnn/mkldnn_base-inl.h"
#include "../../nn/mkldnn/mkldnn_convolution-inl.h"
#include "../../nn/convolution-inl.h"
Expand All @@ -43,32 +43,36 @@ static void MKLDNNQuantizedConvForward(const nnvm::NodeAttrs& attrs,
TmpMemMgr::Get()->Init(ctx.requested[conv::kTempSpace]);
NDArray weight = in_data[conv::kWeight];
ConvolutionParam param = nnvm::get<ConvolutionParam>(attrs.parsed);
auto &fwd = GetConvFwd(
param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight],
param.no_bias ? nullptr : &in_data[conv::kBias],
out_data[conv::kOut]);
auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.fwd_pd.src_primitive_desc());
MKLDNNConvFullParam full_param;
full_param.conv_param = param;
full_param.mkldnn_param.Init(std::unordered_map<std::string, std::string>());
auto &fwd = GetConvFwd(full_param, ctx.is_train, in_data[conv::kData], in_data[conv::kWeight],
param.no_bias ? nullptr : &in_data[conv::kBias], out_data[conv::kOut]);
auto data_mem = in_data[conv::kData].GetMKLDNNDataReorder(fwd.GetPd().src_desc());
const mkldnn::memory *weight_mem;
// For inference, we want to reorder the weight array so we don't need to
// reorder data every time.
if (weight.IsDefaultData()) {
// We also need to modify the layout on the original weight array.
// Don't switch below sequence because naive engine will executes
// pushAsync synchronously.
weight.MKLDNNDataReorderAsync(fwd.fwd_pd.weights_primitive_desc());
weight_mem = GetWeights(weight, fwd.fwd_pd.weights_primitive_desc(), param.num_group);
weight.MKLDNNDataReorderAsync(fwd.GetPd().weights_desc());
weight_mem = GetWeights(weight, fwd.GetPd().weights_desc(), param.num_group);
} else {
weight_mem = weight.GetMKLDNNData();
CHECK(weight_mem->get_primitive_desc() == fwd.fwd_pd.weights_primitive_desc());
}
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.fwd_pd.dst_primitive_desc(),
auto out_mem = CreateMKLDNNMem(out_data[conv::kOut], fwd.GetPd().dst_desc(),
req[conv::kOut]);
const mkldnn::memory *bias_mem = nullptr;
if (!param.no_bias)
bias_mem = in_data[conv::kBias].GetMKLDNNDataReorder(fwd.fwd_pd.bias_primitive_desc());
fwd.SetNewMem(*data_mem, *weight_mem, bias_mem, *out_mem.second);
MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());

mkldnn_args_map_t net_args;
if (!param.no_bias) {
const mkldnn::memory *bias_mem =
in_data[conv::kBias].GetMKLDNNDataReorder(fwd.GetPd().bias_desc());
net_args.insert({MKLDNN_ARG_BIAS, *bias_mem});
}
net_args.insert({MKLDNN_ARG_SRC, *data_mem});
net_args.insert({MKLDNN_ARG_WEIGHTS, *weight_mem});
net_args.insert({MKLDNN_ARG_DST, *out_mem.second});
MKLDNNStream::Get()->RegisterPrimArgs(fwd.GetFwd(), net_args);
CommitOutput(out_data[conv::kOut], out_mem);
MKLDNNStream::Get()->Submit();
Stream<cpu> *s = ctx.get_stream<cpu>();
Expand Down
Loading