Skip to content

Commit

Permalink
add int8 bn mkldnn implementation and test (apache#15664)
Browse files Browse the repository at this point in the history
* add int8 bn mkldnn implementation and test

* fix lint

* fix ci

* enable int8 bn test only in mkldnn backend

* disable int8 bn forward test with gpu backend

* update int8 bn with reference to comments

* fix lint

* disable int8 bn gluon forward test with gpu backend

* disable uint8 bn forward test with mkldnn backend

* restore support mkldnn bn condition

* rm duplicate code
  • Loading branch information
ElaineBao authored and Ubuntu committed Aug 20, 2019
1 parent 3e07e2f commit dfef29f
Show file tree
Hide file tree
Showing 7 changed files with 446 additions and 28 deletions.
1 change: 1 addition & 0 deletions cpp-package/scripts/OpWrapperGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class Arg:
'int (non-negative)': 'uint32_t',\
'long (non-negative)': 'uint64_t',\
'int or None':'dmlc::optional<int>',\
'float or None':'dmlc::optional<float>',\
'long':'int64_t',\
'double':'double',\
'double or None':'dmlc::optional<double>',\
Expand Down
52 changes: 41 additions & 11 deletions src/operator/nn/batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,19 @@ enum BatchNormOpOutputs {kOut, kMean, kVar}; // req, out_data
enum BatchNormOpResource {kTempSpace};
enum BatchNormOpAuxiliary {kMovingMean, kMovingVar}; // aux_states

/*! \brief Default channel axis if none specified int he params */
/*! \brief Default channel axis if none specified in the params */
constexpr int DEFAULT_AXIS = 1;
} // namespace batchnorm

/*! \brief Parameters for BatchNorm operator */
namespace quantized_batchnorm {
enum QuantizedBatchNormOpInputs {kData, kGamma, kBeta, kInMovingMean,
kInMovingVar, kDataMin, kDataMax};
enum QuantizedBatchNormOutputs {kOut, kOutMin, kOutMax};
enum QuantizedBatchNormOpAuxiliary {kMovingMean, kMovingVar};
} // quantized_batchnorm

/*! \brief Parameters for BatchNoram operator */
struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
double eps;
float momentum;
Expand All @@ -66,6 +74,10 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
bool output_mean_var;
int axis;
bool cudnn_off;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset

DMLC_DECLARE_PARAMETER(BatchNormParam) {
DMLC_DECLARE_FIELD(eps).set_default(1e-3f)
.describe("Epsilon to prevent div 0. "
Expand All @@ -81,19 +93,37 @@ struct BatchNormParam : public dmlc::Parameter<BatchNormParam> {
DMLC_DECLARE_FIELD(output_mean_var).set_default(false)
.describe("Output the mean and inverse std ");
DMLC_DECLARE_FIELD(axis).set_default(mxnet::op::batchnorm::DEFAULT_AXIS)
.describe("Specify which shape axis the channel is specified");
.describe("Specify which shape axis the channel is specified");
DMLC_DECLARE_FIELD(cudnn_off).set_default(false)
.describe("Do not select CUDNN operator, if available");
.describe("Do not select CUDNN operator, if available");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized batch norm op to calculate primitive scale."
"Note: this calib_range is to calib bn output.");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized batch norm op to calculate primitive scale."
"Note: this calib_range is to calib bn output.");
}

bool operator==(const BatchNormParam& other) const {
return this->eps == other.eps &&
this->momentum == other.momentum &&
this->fix_gamma == other.fix_gamma &&
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var &&
this->axis == other.axis &&
this->cudnn_off == other.cudnn_off;
bool operator==(const BatchNormParam &other) const {
bool flag = this->eps == other.eps && this->momentum == other.momentum &&
this->fix_gamma == other.fix_gamma &&
this->use_global_stats == other.use_global_stats &&
this->output_mean_var == other.output_mean_var && this->axis == other.axis &&
this->cudnn_off == other.cudnn_off &&
this->min_calib_range.has_value() == other.min_calib_range.has_value() &&
this->max_calib_range.has_value() == other.max_calib_range.has_value();
if (this->min_calib_range.has_value() && other.min_calib_range.has_value() &&
this->max_calib_range.has_value() && other.max_calib_range.has_value()) {
flag = flag && this->min_calib_range.value() == other.min_calib_range.value() &&
this->max_calib_range.value() == other.max_calib_range.value();
}
return flag;
}
};

Expand Down
29 changes: 16 additions & 13 deletions src/operator/nn/mkldnn/mkldnn_batch_norm-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,8 @@ class MKLDNNBNForward {
return *var_m;
}

void SetDataHandle(const NDArray &data, const NDArray &mean,
const NDArray &var, const mkldnn::memory &out) {
void SetDataHandle(const NDArray &data, const mkldnn::memory *mean,
const mkldnn::memory *var, const mkldnn::memory *out) {
auto _data = data.GetMKLDNNData();
if (data_m) {
data_m->set_data_handle(_data->get_data_handle());
Expand All @@ -142,24 +142,22 @@ class MKLDNNBNForward {
_data->get_data_handle()));
}
if (out_m) {
out_m->set_data_handle(out.get_data_handle());
out_m->set_data_handle(out->get_data_handle());
} else {
out_m.reset(new mkldnn::memory(out.get_primitive_desc(),
out.get_data_handle()));
out_m.reset(new mkldnn::memory(out->get_primitive_desc(),
out->get_data_handle()));
}
auto mean_ptr = mean.data().dptr_;
if (mean_m) {
mean_m->set_data_handle(mean_ptr);
mean_m->set_data_handle(mean->get_data_handle());
} else {
mean_m.reset(new mkldnn::memory(pd.mean_primitive_desc(),
mean_ptr));
mean_m.reset(new mkldnn::memory(mean->get_primitive_desc(),
mean->get_data_handle()));
}
auto var_ptr = var.data().dptr_;
if (var_m) {
var_m->set_data_handle(var_ptr);
var_m->set_data_handle(var->get_data_handle());
} else {
var_m.reset(new mkldnn::memory(pd.variance_primitive_desc(),
var_ptr));
var_m.reset(new mkldnn::memory(var->get_primitive_desc(),
var->get_data_handle()));
}

if (fwd == nullptr) {
Expand All @@ -175,6 +173,11 @@ class MKLDNNBNForward {
}
}

void SetDataHandle(const NDArray &data, const NDArray &mean,
const NDArray &var, const mkldnn::memory &out) {
SetDataHandle(data, mean.GetMKLDNNData(), var.GetMKLDNNData(), &out);
}

const mkldnn::batch_normalization_forward &GetFwd() const {
return *fwd;
}
Expand Down
123 changes: 123 additions & 0 deletions src/operator/quantization/mkldnn/mkldnn_quantized_batch_norm.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file mkldnn_quantized_batch_norm.cc
* \brief
* \author Yixin Bao
*/

#if MXNET_USE_MKLDNN == 1
#include "../../nn/mkldnn/mkldnn_batch_norm-inl.h"
#include "../quantization_utils.h"

namespace mxnet {
namespace op {

static void MKLDNNQuantizedBatchNormForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
CHECK_EQ(in_data.size(), 7U);
CHECK_EQ(outputs.size(), 3U);

TmpMemMgr::Get()->Init(ctx.requested[batchnorm::kTempSpace]);
const BatchNormParam &param = nnvm::get<BatchNormParam>(attrs.parsed);
const NDArray &data = in_data[quantized_batchnorm::kData];
const size_t channelAxis = static_cast<size_t>(
param.axis < 0 ? static_cast<int>(data.shape().ndim()) + param.axis : param.axis);
const int channel_count = data.shape()[channelAxis];
const float min_data = in_data[quantized_batchnorm::kDataMin].data().dptr<float>()[0];
const float max_data = in_data[quantized_batchnorm::kDataMax].data().dptr<float>()[0];
const float max_abs_data = std::max(std::abs(min_data), std::abs(max_data));

float *min_output_ptr = outputs[quantized_batchnorm::kOutMin].data().dptr<float>();
float *max_output_ptr = outputs[quantized_batchnorm::kOutMax].data().dptr<float>();
if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) {
*max_output_ptr = param.max_calib_range.value();
*min_output_ptr = param.min_calib_range.value();
} else {
LOG(FATAL) << "min_calib_range or max_calib_range is not available. Quantized BN currently "
"don't support calib_mode=None";
}
const float max_abs_output = std::max(std::abs(*min_output_ptr), std::abs(*max_output_ptr));

unsigned flags = mkldnn::use_global_stats | mkldnn::use_scale_shift;
auto &fwd = GetBNForward<float>(param, ctx, data, flags);
const mkldnn::memory &weight_mem = fwd.GetWeight();
CHECK_EQ(weight_mem.get_primitive_desc().get_size(), channel_count * sizeof(float) * 2);
float *weight_buf = reinterpret_cast<float *>(weight_mem.get_data_handle());

float *gamma_ptr = in_data[quantized_batchnorm::kGamma].data().dptr<float>();
float *beta_ptr = in_data[quantized_batchnorm::kBeta].data().dptr<float>();

const NDArray &moving_mean = in_data[quantized_batchnorm::kInMovingMean];
const NDArray &moving_var = in_data[quantized_batchnorm::kInMovingVar];
float *moving_mean_ptr = moving_mean.data().dptr<float>();
float *moving_var_ptr = moving_var.data().dptr<float>();

// rescale gamma and beta, to make mean=0 and var=1
auto rescaled_mean_mem =
TmpMemMgr::Get()->Alloc(moving_mean.GetMKLDNNData()->get_primitive_desc());
auto rescaled_var_mem = TmpMemMgr::Get()->Alloc(moving_var.GetMKLDNNData()->get_primitive_desc());
float *rescaled_mean_ptr = reinterpret_cast<float *>(rescaled_mean_mem->get_data_handle());
float *rescaled_var_ptr = reinterpret_cast<float *>(rescaled_var_mem->get_data_handle());

#pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount())
for (int channel = 0; channel < channel_count; ++channel) {
float invstd = 1.0 / std::sqrt(moving_var_ptr[channel] + param.eps);
weight_buf[channel] = gamma_ptr[channel] * invstd * max_abs_data / max_abs_output;
weight_buf[channel_count + channel] =
(beta_ptr[channel] - moving_mean_ptr[channel] * gamma_ptr[channel] * invstd) * kInt8Range /
max_abs_output;
rescaled_mean_ptr[channel] = 0.0f;
rescaled_var_ptr[channel] = 1.0f;
}

auto out_mem = CreateMKLDNNMem(outputs[batchnorm::kOut],
fwd.GetPd().dst_primitive_desc(), req[batchnorm::kOut], &data);
fwd.SetDataHandle(data, rescaled_mean_mem, rescaled_var_mem, out_mem.second);

MKLDNNStream::Get()->RegisterPrim(fwd.GetFwd());
MKLDNNStream::Get()->Submit();
}

inline static bool QuantizedBatchNormStorageType(const nnvm::NodeAttrs &attrs, const int dev_mask,
DispatchMode *dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
bool dispatched = false;
if (!dispatched) {
dispatched = MKLDNNStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs);
}
return dispatched;
}

NNVM_REGISTER_OP(_contrib_quantized_batch_norm)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedBatchNormStorageType)
.set_attr<FComputeEx>("FComputeEx<cpu>", MKLDNNQuantizedBatchNormForward)
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.set_attr<bool>("TIsMKLDNN", true);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
42 changes: 39 additions & 3 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,19 @@ using nnvm::NodePtr;
using nnvm::NodeEntry;
using nnvm::Graph;

inline size_t GetNumOutputs(NodePtr node) {
// Get NumOutputs, check if current node has NumVisibleOutputs function, if yes, return
// num_visible_outputs
size_t num_outputs = node->num_outputs();
static const auto& num_visible_outputs_attr =
nnvm::Op::GetAttr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs");
auto num_visible_output_func = num_visible_outputs_attr.get(node->op(), nullptr);
if (num_visible_output_func != nullptr) {
num_outputs = num_visible_output_func(node->attrs);
}
return num_outputs;
}

NodePtr CreateNode(std::string op_name, std::string node_name) {
NodePtr node = Node::Create();
node->attrs.name = node_name;
Expand Down Expand Up @@ -223,7 +236,7 @@ Graph QuantizeGraph(Graph &&src) {
// calculate min/max index from mirror node) based on assumption that
// there is only 1min and 1max output from mirror node (which is
// currently true)
size_t num_outputs = mirror_node->num_outputs() - 2;
size_t num_outputs = GetNumOutputs(mirror_node) - 2;
min_index = num_outputs + 2 * e.index;
max_index = num_outputs + 2 * e.index + 1;
} else {
Expand Down Expand Up @@ -276,7 +289,7 @@ Graph QuantizeGraph(Graph &&src) {
// calculate min/max index from mirror node) based on assumption that
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = mirror_node->num_outputs() - 2;
size_t num_outputs = GetNumOutputs(mirror_node) - 2;
uint32_t min_index = num_outputs + 2 * e.index;
uint32_t max_index = num_outputs + 2 * e.index + 1;
NodePtr dequantize_node = CreateNode("_contrib_dequantize",
Expand Down Expand Up @@ -309,7 +322,7 @@ Graph QuantizeGraph(Graph &&src) {
// calculate min/max index from mirror node) based on assumption that
// there is only 1 min and 1 max output from mirror node (which is
// currently true)
size_t num_outputs = e.node->num_outputs();
size_t num_outputs = GetNumOutputs(e.node);
uint32_t min_index = num_outputs + 2 * e.index;
uint32_t max_index = num_outputs + 2 * e.index + 1;

Expand Down Expand Up @@ -403,6 +416,29 @@ Graph SetCalibTableToQuantizedGraph(Graph&& g) {
<< "` has negative input, consider use `auto` or `int8` as out_type";
}
}
} else if (node->op() == Op::Get("_contrib_quantized_batch_norm")) {
auto quantized_op_idx = node->inputs[0].index;
const std::string prefix = "quantized_";
std::string out_data_name = node->attrs.name.substr(prefix.size());
if (node->op()) {
auto list_output_names_func = flist_outputs.get(node->op(), nullptr);
// We want to get the pre-calculated min_range and max_range from the calibration table for
// out_data. Here we create the output data name same as its constructed in
// GraphExecutor::ExecuteMonCallback.
if (list_output_names_func != nullptr) {
std::vector<std::string> names = list_output_names_func(node->attrs);
out_data_name += "_" + names[quantized_op_idx];
} else {
out_data_name += "_" + std::to_string(quantized_op_idx);
}
}

const auto calib_table_iter = calib_table.find(out_data_name);
if (calib_table_iter != calib_table.end()) {
node->attrs.dict["min_calib_range"] = std::to_string(calib_table_iter->second.first);
node->attrs.dict["max_calib_range"] = std::to_string(calib_table_iter->second.second);
node->op()->attr_parser(&(node->attrs));
}
}
});
return g;
Expand Down
Loading

0 comments on commit dfef29f

Please sign in to comment.