Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MKLDNN] Enable more convolution + activation fusion #14819

Merged
merged 12 commits into from
May 17, 2019
4 changes: 2 additions & 2 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -301,7 +301,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`'
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
4 changes: 2 additions & 2 deletions example/ssd/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
label = mx.sym.Variable(name='label')
sym = mx.sym.Group([sym, label])

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -163,6 +163,6 @@ def calib_layer(name): return not (name.endswith('_data') or
label_names=(label_name,), logger=logger)
sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
save_symbol(sym_name, qsym, logger)
save_params(param_name, qarg_params, aux_params, logger)
3 changes: 2 additions & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ def load(self, filename, ctx=None, allow_missing=False,
"restore_prefix is '%s' but Parameters name '%s' does not start " \
"with '%s'"%(restore_prefix, name, restore_prefix)
lprefix = len(restore_prefix)
ndarray_load = ndarray.load(filename)
loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \
for k, v in ndarray.load(filename).items()]
for k, v in ndarray_load.items()] if isinstance(ndarray_load, dict) else ndarray_load
arg_dict = {restore_prefix+k: v for k, v in loaded}
if not allow_missing:
for name in self.keys():
Expand Down
49 changes: 49 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def get_rtol(rtol=None):
# be needed for different device and dtype
return 1e-5 if rtol is None else rtol

def get_etol(etol=None):
"""Get default numerical threshold for regression test."""
# _TODO: get from env variable, different threshold might
# be needed for different device and dtype
return 0 if etol is None else etol

def random_arrays(*shapes):
"""Generate some random numpy arrays."""
Expand Down Expand Up @@ -494,6 +499,50 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=
names=names)
raise AssertionError(msg)

def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False):
"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.

Parameters
----------
a : np.ndarray
b : np.ndarray
threshold : None or float
The checking threshold. Default threshold will be used if set to ``None``.
etol : None or float
The error rate threshold. If etol is float, return true if error_rate < etol even if
any error is found.
"""
rtol = get_rtol(rtol)
atol = get_atol(atol)
etol = get_etol(etol)
if etol:
equals = np.isclose(a, b, rtol=rtol, atol=atol)
err = 1 - np.count_nonzero(equals) / equals.size
if err > etol:
#if True:
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f, etol=%f."
" Error_rate=%f. Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, etol, err, str(index), a[index], b[index]),
names=names)
raise AssertionError(msg)

if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. "
" Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, str(index), a[index], b[index]),
names=names)
raise AssertionError(msg)


def almost_equal_ignore_nan(a, b, rtol=None, atol=None):
"""Test that two NumPy arrays are almost equal (ignoring NaN in either array).
Expand Down
4 changes: 4 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
return SupportMKLDNNAct(param);
}

bool SupportQuantizedMKLDNNAct(const ActivationParam &param) {
return param.act_type == activation::kReLU;
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
}

mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ struct SoftmaxOutputParam;
struct TransposeParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
bool SupportMKLDNNConv(const ConvolutionParam &params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
Expand Down
28 changes: 16 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ namespace op {

struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
bool with_bn;
bool with_relu;
bool with_act;
bool with_sum;
bool with_postsum_relu;
bool with_postsum_act;
bool quantized;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
Expand All @@ -49,12 +49,12 @@ struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
DMLC_DECLARE_PARAMETER(MKLDNNConvParam) {
DMLC_DECLARE_FIELD(with_bn).set_default(false)
.describe("Add post batchnorm.");
DMLC_DECLARE_FIELD(with_relu).set_default(false)
.describe("Add post relu");
DMLC_DECLARE_FIELD(with_act).set_default(false)
.describe("Add post activation");
DMLC_DECLARE_FIELD(with_sum).set_default(false)
.describe("Add post sum");
DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false)
.describe("Add post relu after sum");
DMLC_DECLARE_FIELD(with_postsum_act).set_default(false)
.describe("Add post activation after sum");
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("enable quantization");
DMLC_DECLARE_FIELD(min_calib_range)
Expand All @@ -70,18 +70,22 @@ struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
}
};

struct MKLDNNPostActParam {
mkldnn::algorithm alg = mkldnn::algorithm::algorithm_undef;
float scale = 1.f;
float alpha = 0.f;
float beta = 1.f;
};

struct MKLDNNConvFullParam {
ConvolutionParam conv_param;
MKLDNNConvParam mkldnn_param;
float sum_scale;
float sum_scale = 1.f;
std::vector<float> requantize_scales;
MKLDNNPostActParam act_param;
MKLDNNPostActParam postsum_act_param;
};

static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) {
return ((!mkldnn_param.with_sum) && mkldnn_param.with_relu) ||
mkldnn_param.with_postsum_relu;
}

mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam &param,
const bool is_train,
const NDArray &data,
Expand Down
16 changes: 6 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,16 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
}
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
if (param.mkldnn_param.with_relu) {
float scale = 1.0f; // for fp32, scale is 1.
float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu.
float beta = 1.0f; // ignored for mkldnn_eltwise_relu.
ops.append_eltwise(scale, eltwise_relu, alpha, beta);
if (param.mkldnn_param.with_act) {
const auto &act_param = param.act_param;
ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta);
}
if (param.mkldnn_param.with_sum) {
ops.append_sum(param.sum_scale);
}
if (param.mkldnn_param.with_postsum_relu) {
float scale = 1.0f; // for fp32, scale is 1.
float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu.
float beta = 1.0f; // ignored for mkldnn_eltwise_relu.
ops.append_eltwise(scale, eltwise_relu, alpha, beta);
if (param.mkldnn_param.with_postsum_act) {
const auto &act_param = param.postsum_act_param;
ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta);
}
attr.set_post_ops(ops);

Expand Down
1 change: 1 addition & 0 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph
}
LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
<< ". Excluding nodes " << excluded_node_names << "and retrying";
subgraph_selector->Reset();
}
++count;
}
Expand Down
24 changes: 22 additions & 2 deletions src/operator/subgraph/mkldnn/mkldnn_conv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_
#if MXNET_USE_MKLDNN == 1

#include <string>
#include <utility>
#include <vector>
#include <string>
#include "../../nn/convolution-inl.h"
#include "../../nn/activation-inl.h"
#include "../../nn/batch_norm-inl.h"
#include "../../nn/convolution-inl.h"
#include "../../nn/mkldnn/mkldnn_convolution-inl.h"

namespace mxnet {
Expand All @@ -36,6 +37,25 @@ struct MKLDNNConvFusionParam {
std::shared_ptr<BatchNormParam> bn_param;
};

static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) {
bool result = false;
const auto& mkldnn_param = param.full_conv_param.mkldnn_param;
auto IsOutputUInt8Helper = [](const mkldnn::algorithm& act_alg) {
return (act_alg == mkldnn::algorithm::eltwise_relu ||
act_alg == mkldnn::algorithm::eltwise_logistic ||
act_alg == mkldnn::algorithm::eltwise_soft_relu ||
act_alg == mkldnn::algorithm::eltwise_bounded_relu);
};
if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) {
CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef);
result = IsOutputUInt8Helper(param.full_conv_param.act_param.alg);
} else if (mkldnn_param.with_postsum_act) {
CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef);
result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param.alg);
}
return result;
}

enum MKLDNNConvOpOutputs { kOut, kMin, kMax };

} // namespace op
Expand Down
71 changes: 68 additions & 3 deletions src/operator/subgraph/mkldnn/mkldnn_conv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
#include "../../nn/mkldnn/mkldnn_ops-inl.h"
#include "../../quantization/quantization_utils.h"
#include "mkldnn_conv-inl.h"
#include "../../nn/mkldnn/mkldnn_act-inl.h"
#include "../../tensor/matrix_op-inl.h"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -294,7 +296,6 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
cached_data_max_ = data_max;
cached_sum_min_ = sum_min;
cached_sum_max_ = sum_max;
full_conv_param.sum_scale = 1.0;
cached_weight_ = inputs[in_weight].Reorder2Default();
weight_ver_ = inputs[in_weight].version();
if (!conv_param.no_bias) {
Expand Down Expand Up @@ -348,7 +349,7 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
sum_in_scale = quantized_sum_range / MaxAbs(cached_sum_min_, cached_sum_max_);
}
if (post_requantize_) {
quantized_out_range = IsOutputUInt8(mkldnn_param) ? kUint8Range : kInt8Range;
quantized_out_range = IsOutputUInt8(param_) ? kUint8Range : kInt8Range;
out_range = MaxAbs(cached_output_min_, cached_output_max_);
output_scale = quantized_out_range / out_range;
full_conv_param.requantize_scales.resize(weight_channelwise_scale ? channel : 1);
Expand All @@ -373,6 +374,19 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
if (mkldnn_param.with_sum) {
full_conv_param.sum_scale = output_scale / sum_in_scale;
}
if (mkldnn_param.with_act &&
full_conv_param.act_param.alg == mkldnn::algorithm::eltwise_bounded_relu) {
if (mkldnn_param.with_sum) {
LOG(ERROR) << "mkldnn doesn't support conv + relu + sum fusion yet.";
ZhennanQin marked this conversation as resolved.
Show resolved Hide resolved
full_conv_param.act_param.alpha *= output_scale;
} else {
// For conv+relu6 without sum, we don't need post_ops as output_scale can do the cut off.
mkldnn_param.with_act = false;
}
}
if (mkldnn_param.with_postsum_act) {
CHECK(full_conv_param.postsum_act_param.alg == mkldnn::algorithm::eltwise_relu);
}
}
fwd_.reset(new MKLDNNConvForward(
full_conv_param, ctx.is_train, data, cached_weight_,
Expand All @@ -385,6 +399,25 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx,
initialized_ = true;
}

if (mkldnn_param.with_sum) {
const auto output_mem = output.GetMKLDNNData();
const auto out_mem_desc = output_mem->get_primitive_desc().desc();
const auto dst_format = fwd_->fwd_pd.dst_primitive_desc().desc().data.format;
if (out_mem_desc.data.format != dst_format) {
auto tmp_out_mem = output.GetMKLDNNDataReorder(fwd_->fwd_pd.dst_primitive_desc());
mkldnn::memory::desc data_md(
mkldnn::memory::dims(out_mem_desc.data.dims,
out_mem_desc.data.dims + out_mem_desc.data.ndims),
static_cast<mkldnn::memory::data_type>(out_mem_desc.data.data_type),
static_cast<memory::format>(dst_format));
mkldnn::memory::primitive_desc pd(data_md, CpuEngine::Get()->get_engine());
mkldnn_mem_ptr new_out_mem(new mkldnn::memory(pd, output_mem->get_data_handle()));
MKLDNNStream::Get()->RegisterMem(new_out_mem);
mxnet::MKLDNNCopy(*tmp_out_mem, new_out_mem.get());
output = NDArray(new_out_mem);
}
}

if (mkldnn_param.quantized) {
auto data_mem = data.GetMKLDNNDataReorder(fwd_->fwd_pd.src_primitive_desc());
mkldnn::memory *mem = output.CreateMKLDNNData(fwd_->fwd_pd.dst_primitive_desc());
Expand Down Expand Up @@ -437,6 +470,23 @@ static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) {

static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) {
MKLDNNConvFusionParam param_;

// For back-compatible, rename
// with_relu -> with_act
// with_postsum_relu -> with_postsum_act

auto old = attrs->dict.find("with_relu");
if (old != attrs->dict.end()) {
attrs->dict["with_act"] = old->second;
attrs->dict.erase(old);
}

old = attrs->dict.find("with_postsum_relu");
if (old != attrs->dict.end()) {
attrs->dict["with_postsum_act"] = old->second;
attrs->dict.erase(old);
}

try {
param_.full_conv_param.mkldnn_param.Init(attrs->dict);
} catch (const dmlc::ParamError &e) {
Expand All @@ -452,6 +502,7 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) {
}
CHECK_EQ(attrs->subgraphs.size(), 1);
auto subgraph_sym = attrs->subgraphs[0];
bool with_act = false;
DFSVisit(subgraph_sym->outputs, [&](const nnvm::NodePtr &node) {
if (node->is_variable()) return;
auto &node_name = node->op()->name;
Expand All @@ -463,6 +514,20 @@ static void SgMKLDNNConvParamParser(nnvm::NodeAttrs *attrs) {
} else if (node_name == "Convolution") {
param_.full_conv_param.conv_param =
nnvm::get<ConvolutionParam>(node->attrs.parsed);
} else if (node_name == "Activation" || node_name == "clip") {
auto &post_act_param =
(param_.full_conv_param.mkldnn_param.with_act && !with_act)
? param_.full_conv_param.act_param
: param_.full_conv_param.postsum_act_param;
with_act = true;
if (node_name == "Activation") {
const auto act_param = nnvm::get<ActivationParam>(node->attrs.parsed);
post_act_param.alg = GetMKLDNNActAlgo(act_param);
} else {
const auto clip_param = nnvm::get<ClipParam>(node->attrs.parsed);
post_act_param.alg = mkldnn::algorithm::eltwise_bounded_relu;
post_act_param.alpha = clip_param.a_max;
}
}
});
attrs->parsed = std::move(param_);
Expand Down Expand Up @@ -605,7 +670,7 @@ static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs,
}
if (param.full_conv_param.mkldnn_param.min_calib_range.has_value() &&
param.full_conv_param.mkldnn_param.max_calib_range.has_value()) {
if (IsOutputUInt8(param.full_conv_param.mkldnn_param)) {
if (IsOutputUInt8(param)) {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kUint8);
} else {
TYPE_ASSIGN_CHECK(*out_types, 0, mshadow::kInt8);
Expand Down
Loading