Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[MKLDNN] Enable more convolution + activation fusion (#14819)
Browse files Browse the repository at this point in the history
* conv+act fusion

* fix mobilenetv2

* fix lint

* Add comment

* Fix build

* trigger ci

* Fix CutGraphInputs

* Run CI

* run ci

* trigger
  • Loading branch information
ZhennanQin authored and pengzhao-intel committed May 17, 2019
1 parent 8d6ac4a commit 0d77947
Show file tree
Hide file tree
Showing 19 changed files with 545 additions and 254 deletions.
4 changes: 2 additions & 2 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
prefix, epoch = download_model(model_name=args.model, logger=logger)
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -315,7 +315,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`'
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
4 changes: 2 additions & 2 deletions example/ssd/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
label = mx.sym.Variable(name='label')
sym = mx.sym.Group([sym, label])

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_QUANTIZE')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -163,6 +163,6 @@ def calib_layer(name): return not (name.endswith('_data') or
label_names=(label_name,), logger=logger)
sym_name = '%s-symbol.json' % ('./model/cqssd_vgg16_reduced_300')
param_name = '%s-%04d.params' % ('./model/cqssd_vgg16_reduced_300', epoch)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_QUANTIZE')
save_symbol(sym_name, qsym, logger)
save_params(param_name, qarg_params, aux_params, logger)
3 changes: 2 additions & 1 deletion python/mxnet/gluon/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -900,8 +900,9 @@ def load(self, filename, ctx=None, allow_missing=False,
"restore_prefix is '%s' but Parameters name '%s' does not start " \
"with '%s'"%(restore_prefix, name, restore_prefix)
lprefix = len(restore_prefix)
ndarray_load = ndarray.load(filename)
loaded = [(k[4:] if k.startswith('arg:') or k.startswith('aux:') else k, v) \
for k, v in ndarray.load(filename).items()]
for k, v in ndarray_load.items()] if isinstance(ndarray_load, dict) else ndarray_load
arg_dict = {restore_prefix+k: v for k, v in loaded}
if not allow_missing:
for name in self.keys():
Expand Down
49 changes: 49 additions & 0 deletions python/mxnet/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,11 @@ def get_rtol(rtol=None):
# be needed for different device and dtype
return 1e-5 if rtol is None else rtol

def get_etol(etol=None):
"""Get default numerical threshold for regression test."""
# _TODO: get from env variable, different threshold might
# be needed for different device and dtype
return 0 if etol is None else etol

def random_arrays(*shapes):
"""Generate some random numpy arrays."""
Expand Down Expand Up @@ -494,6 +499,50 @@ def assert_almost_equal(a, b, rtol=None, atol=None, names=('a', 'b'), equal_nan=
names=names)
raise AssertionError(msg)

def assert_almost_equal_with_err(a, b, rtol=None, atol=None, etol=None, names=('a', 'b'), equal_nan=False):
"""Test that two numpy arrays are almost equal within given error rate. Raise exception message if not.
Parameters
----------
a : np.ndarray
b : np.ndarray
threshold : None or float
The checking threshold. Default threshold will be used if set to ``None``.
etol : None or float
The error rate threshold. If etol is float, return true if error_rate < etol even if
any error is found.
"""
rtol = get_rtol(rtol)
atol = get_atol(atol)
etol = get_etol(etol)
if etol:
equals = np.isclose(a, b, rtol=rtol, atol=atol)
err = 1 - np.count_nonzero(equals) / equals.size
if err > etol:
#if True:
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f, etol=%f."
" Error_rate=%f. Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, etol, err, str(index), a[index], b[index]),
names=names)
raise AssertionError(msg)

if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
else:
if almost_equal(a, b, rtol, atol, equal_nan=equal_nan):
return
index, rel = find_max_violation(a, b, rtol, atol)
np.set_printoptions(threshold=4, suppress=True)
msg = npt.build_err_msg([a, b],
err_msg="Error %f exceeds tolerance rtol=%f, atol=%f. "
" Location of maximum error:%s, a=%f, b=%f"
% (rel, rtol, atol, str(index), a[index], b[index]),
names=names)
raise AssertionError(msg)


def almost_equal_ignore_nan(a, b, rtol=None, atol=None):
"""Test that two NumPy arrays are almost equal (ignoring NaN in either array).
Expand Down
6 changes: 6 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_act.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input) {
return SupportMKLDNNAct(param);
}

bool SupportQuantizedMKLDNNAct(const ActivationParam &param) {
// TODO(zhennan): Add more activation type when mkldnn supports.
// Remove this when it's identity to SupportMKLDNNAct.
return param.act_type == activation::kReLU;
}

mkldnn::algorithm GetMKLDNNActAlgo(const ActivationParam& param) {
switch (param.act_type) {
case activation::kReLU:
Expand Down
3 changes: 2 additions & 1 deletion src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,8 @@ struct SoftmaxOutputParam;
struct TransposeParam;
bool SupportMKLDNNAct(const ActivationParam& param);
bool SupportMKLDNNAct(const ActivationParam& param, const NDArray &input);
bool SupportMKLDNNConv(const ConvolutionParam& params, const NDArray &input);
bool SupportQuantizedMKLDNNAct(const ActivationParam &param);
bool SupportMKLDNNConv(const ConvolutionParam &params, const NDArray &input);
bool SupportMKLDNNDeconv(const DeconvolutionParam& params, const NDArray &input);
bool SupportMKLDNNSoftmax(const SoftmaxParam& param, const NDArray &input, const NDArray &output);
bool SupportMKLDNNSoftmaxOutput(const SoftmaxOutputParam &param);
Expand Down
28 changes: 16 additions & 12 deletions src/operator/nn/mkldnn/mkldnn_convolution-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ namespace op {

struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
bool with_bn;
bool with_relu;
bool with_act;
bool with_sum;
bool with_postsum_relu;
bool with_postsum_act;
bool quantized;

dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
Expand All @@ -49,12 +49,12 @@ struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
DMLC_DECLARE_PARAMETER(MKLDNNConvParam) {
DMLC_DECLARE_FIELD(with_bn).set_default(false)
.describe("Add post batchnorm.");
DMLC_DECLARE_FIELD(with_relu).set_default(false)
.describe("Add post relu");
DMLC_DECLARE_FIELD(with_act).set_default(false)
.describe("Add post activation");
DMLC_DECLARE_FIELD(with_sum).set_default(false)
.describe("Add post sum");
DMLC_DECLARE_FIELD(with_postsum_relu).set_default(false)
.describe("Add post relu after sum");
DMLC_DECLARE_FIELD(with_postsum_act).set_default(false)
.describe("Add post activation after sum");
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("enable quantization");
DMLC_DECLARE_FIELD(min_calib_range)
Expand All @@ -70,18 +70,22 @@ struct MKLDNNConvParam : public dmlc::Parameter<MKLDNNConvParam> {
}
};

struct MKLDNNPostActParam {
mkldnn::algorithm alg = mkldnn::algorithm::algorithm_undef;
float scale = 1.f;
float alpha = 0.f;
float beta = 1.f;
};

struct MKLDNNConvFullParam {
ConvolutionParam conv_param;
MKLDNNConvParam mkldnn_param;
float sum_scale;
float sum_scale = 1.f;
std::vector<float> requantize_scales;
MKLDNNPostActParam act_param;
MKLDNNPostActParam postsum_act_param;
};

static inline bool IsOutputUInt8(const MKLDNNConvParam &mkldnn_param) {
return ((!mkldnn_param.with_sum) && mkldnn_param.with_relu) ||
mkldnn_param.with_postsum_relu;
}

mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullParam &param,
const bool is_train,
const NDArray &data,
Expand Down
16 changes: 6 additions & 10 deletions src/operator/nn/mkldnn/mkldnn_convolution.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,16 @@ mkldnn::convolution_forward::primitive_desc GetConvFwdImpl(const MKLDNNConvFullP
}
mkldnn::primitive_attr attr;
mkldnn::post_ops ops;
if (param.mkldnn_param.with_relu) {
float scale = 1.0f; // for fp32, scale is 1.
float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu.
float beta = 1.0f; // ignored for mkldnn_eltwise_relu.
ops.append_eltwise(scale, eltwise_relu, alpha, beta);
if (param.mkldnn_param.with_act) {
const auto &act_param = param.act_param;
ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta);
}
if (param.mkldnn_param.with_sum) {
ops.append_sum(param.sum_scale);
}
if (param.mkldnn_param.with_postsum_relu) {
float scale = 1.0f; // for fp32, scale is 1.
float alpha = 0.0f; // negative slope for mkldnn_eltwise_relu.
float beta = 1.0f; // ignored for mkldnn_eltwise_relu.
ops.append_eltwise(scale, eltwise_relu, alpha, beta);
if (param.mkldnn_param.with_postsum_act) {
const auto &act_param = param.postsum_act_param;
ops.append_eltwise(act_param.scale, act_param.alg, act_param.alpha, act_param.beta);
}
attr.set_post_ops(ops);

Expand Down
17 changes: 10 additions & 7 deletions src/operator/subgraph/build_subgraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,7 @@ void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph
}
LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
<< ". Excluding nodes " << excluded_node_names << "and retrying";
subgraph_selector->Reset();
}
++count;
}
Expand Down Expand Up @@ -509,27 +510,29 @@ void FindOutputEntries(nnvm::Graph* g,
void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
std::vector<nnvm::NodeEntry> *orig_entries,
const bool skip_var = false) {
orig_entries->reserve(input_entries.size());
orig_entries->resize(input_entries.size());
// map for creating unique var nodes for deduplicating entries from the same node
std::unordered_map<std::string, nnvm::NodePtr> new_node_map;
std::unordered_map<std::string, int> name_count_map;
for (size_t i = 0; i < input_entries.size(); ++i) {
nnvm::NodeEntry *e = input_entries[i];
// If the node is a variable itself, we may want to skip the node.
if (e->node->is_variable() && skip_var) {
continue;
}

orig_entries->at(i) = *e;
nnvm::Symbol sym;
sym.outputs.push_back(*e);
const auto output_names = sym.ListOutputNames();
CHECK_EQ(output_names.size(), 1U);
const std::string& var_name = output_names[0];
auto it = new_node_map.find(var_name);
if (it == new_node_map.end()) {
orig_entries->push_back(*e);
new_node_map[var_name] = nnvm::CreateVariableNode(var_name);
auto it = name_count_map.find(var_name);
if (name_count_map.end() == it) {
name_count_map.emplace(var_name, 0);
} else {
++(it->second);
}
nnvm::NodePtr n = new_node_map[var_name];
nnvm::NodePtr n = nnvm::CreateVariableNode(var_name + std::to_string(name_count_map[var_name]));
*e = nnvm::NodeEntry{n, 0, 0};
}
}
Expand Down
24 changes: 22 additions & 2 deletions src/operator/subgraph/mkldnn/mkldnn_conv-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_INL_H_
#if MXNET_USE_MKLDNN == 1

#include <string>
#include <utility>
#include <vector>
#include <string>
#include "../../nn/convolution-inl.h"
#include "../../nn/activation-inl.h"
#include "../../nn/batch_norm-inl.h"
#include "../../nn/convolution-inl.h"
#include "../../nn/mkldnn/mkldnn_convolution-inl.h"

namespace mxnet {
Expand All @@ -36,6 +37,25 @@ struct MKLDNNConvFusionParam {
std::shared_ptr<BatchNormParam> bn_param;
};

static inline bool IsOutputUInt8(const MKLDNNConvFusionParam& param) {
bool result = false;
const auto& mkldnn_param = param.full_conv_param.mkldnn_param;
auto IsOutputUInt8Helper = [](const mkldnn::algorithm& act_alg) {
return (act_alg == mkldnn::algorithm::eltwise_relu ||
act_alg == mkldnn::algorithm::eltwise_logistic ||
act_alg == mkldnn::algorithm::eltwise_soft_relu ||
act_alg == mkldnn::algorithm::eltwise_bounded_relu);
};
if ((!mkldnn_param.with_sum) && mkldnn_param.with_act) {
CHECK(param.full_conv_param.act_param.alg != mkldnn::algorithm::algorithm_undef);
result = IsOutputUInt8Helper(param.full_conv_param.act_param.alg);
} else if (mkldnn_param.with_postsum_act) {
CHECK(param.full_conv_param.postsum_act_param.alg != mkldnn::algorithm::algorithm_undef);
result = IsOutputUInt8Helper(param.full_conv_param.postsum_act_param.alg);
}
return result;
}

enum MKLDNNConvOpOutputs { kOut, kMin, kMax };

} // namespace op
Expand Down
Loading

0 comments on commit 0d77947

Please sign in to comment.