Skip to content

Commit

Permalink
refactor codes and add an option to skip/check weight's version to re…
Browse files Browse the repository at this point in the history
…duce overhead (apache#17707) (apache#18039)
  • Loading branch information
ciyongch authored Apr 17, 2020
1 parent 8cfc64a commit 2e22b5e
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 59 deletions.
2 changes: 1 addition & 1 deletion src/operator/subgraph/mkldnn/mkldnn_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ static std::vector<float> GetWeightScales(const NDArray &weight, const NDArray *
}

static void ConvertWeightBias2MKLDNN(NDArray *weight, NDArray *bias, bool has_bias,
const mkldnn::memory::desc weight_md,
const mkldnn::memory::desc &weight_md,
const mkldnn::memory::desc *bias_md,
const int num_group, float data_scale,
const std::vector<float> &weight_scales,
Expand Down
124 changes: 66 additions & 58 deletions src/operator/subgraph/mkldnn/mkldnn_fc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ class SgMKLDNNFCOp {

private:
bool initialized_{false};
bool channel_wise_runtime_{false};
bool reorder_data_{false};
nnvm::Symbol subgraph_sym_;
MKLDNNFCFullParam full_param_;
mkldnn_args_map_t args_;
std::shared_ptr<MKLDNNFullyConnectedForward> fwd_;
std::shared_ptr<mkldnn::memory> cached_data_mem_;
std::shared_ptr<mkldnn::memory> cached_out_mem_;
NDArray cached_weight_;
NDArray cached_bias_;
Expand All @@ -82,28 +85,10 @@ class SgMKLDNNFCOp {
float cached_max_output_;
float data_scale_{0.0f};
std::vector<float> weight_scales_;
size_t total_num_inputs_;
size_t total_num_outputs_;
};

static inline void MKLDNNFCFlattenData(const FullyConnectedParam &param,
NDArray *in_data) {
const mxnet::TShape ishape = in_data->shape();

// If the input data is a view of an MKLDNN array, we should create a new
// NDArray with reordered data.
if (in_data->IsMKLDNNData() && in_data->IsView())
*in_data = in_data->Reorder2Default();

auto data_ndim = ishape.ndim();
if (data_ndim != 2) {
if (!param.flatten) {
*in_data = in_data->MKLDNNDataReshape(
Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
} else {
*in_data = in_data->MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
}
}
}

void SgMKLDNNFCOp::Forward(const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
Expand All @@ -112,9 +97,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
auto &default_param = full_param_.default_param;
bool has_bias = !default_param.no_bias;
size_t base_num_inputs = has_bias ? 3 : 2;
size_t total_num_inputs = base_num_inputs;
size_t base_num_outputs = 1;
size_t total_num_outputs = base_num_outputs;

float min_data = 0.0f;
float max_data = 0.0f;
Expand All @@ -123,17 +106,29 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
float min_bias = 0.0f;
float max_bias = 0.0f;

bool channel_wise = false;
if (mkldnn_param.channel_wise_quantize.has_value() &&
mkldnn_param.channel_wise_quantize) {
channel_wise = true;
if (!initialized_) {
if (mkldnn_param.channel_wise_quantize.has_value() &&
mkldnn_param.channel_wise_quantize) {
channel_wise_runtime_ = true;
}

total_num_inputs_ = base_num_inputs;
total_num_outputs_ = base_num_outputs;
if (mkldnn_param.quantized) {
total_num_inputs_ = channel_wise_runtime_ ? (base_num_inputs + 2) : (base_num_inputs * 3);
total_num_outputs_ =
mkldnn_param.enable_float_output ? base_num_outputs : (base_num_outputs * 3);
}
}
CHECK_EQ(in_data.size(), total_num_inputs_);
CHECK_EQ(out_data.size(), total_num_outputs_);

NDArray data = in_data[fullc::kData];
const NDArray &weight = in_data[fullc::kWeight];
const NDArray &output = out_data[fullc::kOut];

if (mkldnn_param.quantized) {
if (channel_wise) {
total_num_inputs = base_num_inputs + 2;
} else {
total_num_inputs = base_num_inputs * 3;
if (!channel_wise_runtime_) {
min_weight = in_data[base_num_inputs + quantized_fullc::kWeightMin].data().dptr<float>()[0];
max_weight = in_data[base_num_inputs + quantized_fullc::kWeightMax].data().dptr<float>()[0];
if (has_bias) {
Expand All @@ -143,20 +138,11 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
}
min_data = in_data[base_num_inputs + quantized_fullc::kDataMin].data().dptr<float>()[0];
max_data = in_data[base_num_inputs + quantized_fullc::kDataMax].data().dptr<float>()[0];
if (!mkldnn_param.enable_float_output) {
total_num_outputs = base_num_outputs * 3;
}
}
CHECK_EQ(in_data.size(), total_num_inputs);
CHECK_EQ(out_data.size(), total_num_outputs);

NDArray data = in_data[fullc::kData];
NDArray weight = in_data[fullc::kWeight];
NDArray output = out_data[fullc::kOut];
MKLDNNFCFlattenData(default_param, &data);

if (initialized_ && mkldnn_param.quantized) {
if (channel_wise) {
if (initialized_ && mkldnn_param.quantized &&
dmlc::GetEnv("MXNET_MKLDNN_QFC_DYNAMIC_PARAMS", 0)) {
if (channel_wise_runtime_) {
if (cached_min_data_ != min_data || cached_max_data_ != max_data ||
weight_ver_ != weight.version() ||
(has_bias && (bias_ver_ != in_data[fullc::kBias].version()))) {
Expand All @@ -173,6 +159,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,

if (!initialized_) {
const auto nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
const auto engine = CpuEngine::Get()->get_engine();
cached_min_data_ = min_data;
cached_max_data_ = max_data;
cached_min_weight_ = min_weight;
Expand All @@ -187,9 +174,22 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
} else {
cached_bias_ = NDArray();
}
const mxnet::TShape ishape = data.shape();
const auto data_ndim = ishape.ndim();
if (data.IsMKLDNNData()) {
reorder_data_ = true;
data = data.Reorder2Default();
}
if (data_ndim != 2) {
if (!default_param.flatten) {
data = data.MKLDNNDataReshape(
Shape2(ishape.ProdShape(0, data_ndim - 1), ishape[data_ndim - 1]));
} else {
data = data.MKLDNNDataReshape(Shape2(ishape[0], ishape.ProdShape(1, data_ndim)));
}
}

// create cached out_md
const mxnet::TShape ishape = data.shape();
const mxnet::TShape oshape = output.shape();
mkldnn::memory::dims out_dims(2);
if (oshape.ndim() == 2) {
Expand All @@ -206,7 +206,7 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
}
mkldnn::memory::desc out_md = mkldnn::memory::desc(out_dims, get_mkldnn_type(output.dtype()),
static_cast<mkldnn::memory::format_tag>(GetDefaultFormat(2)));
cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, CpuEngine::Get()->get_engine());
cached_out_mem_ = std::make_shared<mkldnn::memory>(out_md, engine);

bool support_channelwise_scale = false;
if (mkldnn_param.quantized) {
Expand All @@ -229,15 +229,15 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
// True True True
// True False Error
// False True/False False
if (channel_wise && !support_channelwise_scale) {
if (channel_wise_runtime_ && !support_channelwise_scale) {
LOG(FATAL)
<< "Currently, channel-wise quantization requires fuse requantize or dequantize."
<< " Please make sure the `min_calib_range` and `max_calib_range` are set when only"
<< " fuse requantize (outputs of FullyConnected are collected during calibration phase),"
<< " or the env var of `MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT` and "
<< " `MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL` are not set to true (default is false)";
}
support_channelwise_scale = support_channelwise_scale && channel_wise;
support_channelwise_scale = support_channelwise_scale && channel_wise_runtime_;

if (support_channelwise_scale) {
MSHADOW_REAL_TYPE_SWITCH(cached_weight_.dtype(), DType, {
Expand Down Expand Up @@ -329,30 +329,38 @@ void SgMKLDNNFCOp::Forward(const OpContext &ctx,
has_bias ? &bias_md : nullptr,
1, data_scale_, weight_scales_, false);
} else {
cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
auto cached_weight_mem = cached_weight_.GetMKLDNNData();
auto def_weight_mem = weight.GetMKLDNNData();
std::unordered_map<int, mkldnn::memory> args(
{{MKLDNN_ARG_FROM, *def_weight_mem},
{MKLDNN_ARG_TO, *cached_weight_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
const auto def_weight_mem = weight.GetMKLDNNData();
if (def_weight_mem->get_desc() != fwd_->fwd_pd.weights_desc()) {
cached_weight_ = NDArray(fwd_->fwd_pd.weights_desc());
auto cached_weight_mem = cached_weight_.GetMKLDNNData();
std::unordered_map<int, mkldnn::memory> args(
{{MKLDNN_ARG_FROM, *def_weight_mem},
{MKLDNN_ARG_TO, *cached_weight_mem}});
MKLDNNStream::Get()->RegisterPrimArgs(
mkldnn::reorder(*def_weight_mem, *cached_weight_mem), args);
}
}

args_[MKLDNN_ARG_SRC] = *data.GetMKLDNNData();
const auto data_mem = data.GetMKLDNNData();
cached_data_mem_ = std::make_shared<mkldnn::memory>(data_mem->get_desc(), engine);

args_[MKLDNN_ARG_SRC] = *cached_data_mem_;
args_[MKLDNN_ARG_WEIGHTS] = *cached_weight_.GetMKLDNNData();
if (has_bias)
args_[MKLDNN_ARG_BIAS] = *cached_bias_.GetMKLDNNData();
args_[MKLDNN_ARG_DST] = *cached_out_mem_;
initialized_ = true;
}

auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_desc());
if (reorder_data_) {
data = data.Reorder2Default();
}
MSHADOW_TYPE_SWITCH(data.dtype(), DType, {
cached_data_mem_->set_data_handle(reinterpret_cast<void *>(data.data().dptr<DType>()));
});
MSHADOW_TYPE_SWITCH(output.dtype(), DType, {
cached_out_mem_->set_data_handle(reinterpret_cast<void *>(output.data().dptr<DType>()));
});
args_[MKLDNN_ARG_SRC] = *data_mem;
args_[MKLDNN_ARG_DST] = *cached_out_mem_;
MKLDNNStream::Get()->RegisterPrimArgs(fwd_->GetFwd(), args_);
MKLDNNStream::Get()->Submit();

Expand Down

0 comments on commit 2e22b5e

Please sign in to comment.