Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

MKLDNN based Quantized FullyConnected Operator and its fusion #14128

Merged
merged 30 commits into from
Mar 8, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
f034b86
add MKL-DNN quantized innerproduct
ciyongch Jan 10, 2019
4ee837e
initial qfc with mkldnn
ciyongch Jan 14, 2019
006d2d8
Add MKL-DNN quantized_fully_connected
ciyongch Jan 15, 2019
88e3a89
refactor params order for fullyconnected
ciyongch Jan 16, 2019
e188f63
update quantized_fully_connected unittest, force data to uint8 type t…
ciyongch Jan 17, 2019
af4132d
change mkl based quantized fully_connected to FCompute
ciyongch Jan 17, 2019
989901c
add check data type for mkldnn quantized_fc
ciyongch Jan 18, 2019
9b3d96c
add fuse requantize and dequantize for mkldnn quantized fullyconnected
ciyongch Jan 22, 2019
314a667
add env setting for enable/disable fuse requantize/dequantize for qua…
ciyongch Jan 23, 2019
6df747e
fix requantize scaling error
ciyongch Jan 24, 2019
6d4883a
add fallback when input data is int8
ciyongch Jan 24, 2019
bd6f313
fix mkl quantized fullyconnected index error
ciyongch Jan 28, 2019
cb0bcfa
update quantized fc test cases
ciyongch Jan 28, 2019
ce44bd6
add subgraph node for mkldnn fullyconnected
ciyongch Feb 1, 2019
3532bd5
fix compiling and lint error
ciyongch Feb 1, 2019
678d555
clean and refactor code
ciyongch Feb 2, 2019
95dfffe
enable quantized_fc for imagenet
ciyongch Feb 11, 2019
8a9e2f9
cleanup code
ciyongch Feb 12, 2019
d891e0b
Fix StorageType error for non-mkldnn path
ciyongch Feb 12, 2019
4da6f5a
fix pylint
ciyongch Feb 12, 2019
40039bd
reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check
ciyongch Feb 13, 2019
68be291
rename variables and refactor codes according to comments
ciyongch Feb 23, 2019
ca6a427
add subgraph qfc tests and fix shape error
ciyongch Feb 23, 2019
517f55d
remove fuse_requantize and change fuse_dequantize to enable_float_out…
ciyongch Mar 2, 2019
16f0b07
change to use mxnet::Tuple and update tests
ciyongch Mar 2, 2019
bb8a294
update description in file header
ciyongch Mar 4, 2019
9ec3cf9
update input0 type check for quantized FullyConnected
ciyongch Mar 4, 2019
b8edfb5
fix conflit of mkl/test_subgraph.py
ciyongch Mar 7, 2019
145f454
retrigger CI
ciyongch Mar 7, 2019
35a711a
retrigger CI due to hang
ciyongch Mar 8, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_FC')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -207,19 +208,18 @@ def save_params(fname, arg_params, aux_params, logger=None):
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
rgb_std = '1,1,1'
excluded_sym_names += ['flatten0', 'fc1']
excluded_sym_names += ['flatten0']
if exclude_first_conv:
excluded_sym_names += ['conv0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '1,1,1'
excluded_sym_names += ['flatten', 'fc1']
excluded_sym_names += ['flatten']
if exclude_first_conv:
excluded_sym_names += ['conv_1']
elif args.model in ['resnet50_v1', 'resnet101_v1']:
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['resnetv10_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model == 'squeezenet1.0':
Expand All @@ -232,14 +232,12 @@ def save_params(fname, arg_params, aux_params, logger=None):
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['mobilenet0_flatten0_flatten0',
'mobilenet0_dense0_fwd',
'mobilenet0_pool0_fwd']
if exclude_first_conv:
excluded_sym_names += ['mobilenet0_conv0_fwd']
elif args.model == 'inceptionv3':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['inception30_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['inception30_conv0_fwd']
elif args.model == 'custom':
Expand Down Expand Up @@ -305,6 +303,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
13 changes: 13 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def __call__(self, desc, arr):
elif desc.endswith('max'):
self._init_one(desc, arr)
self._verbose_print(desc, 'max', arr)
elif desc.endswith('weight_quantize'):
self._init_quantized_weight(desc, arr)
self._verbose_print(desc, 'weight_quantize', arr)
elif desc.endswith('bias_quantize'):
self._init_quantized_bias(desc, arr)
self._verbose_print(desc, 'bias_quantize', arr)
else:
self._init_default(desc, arr)

Expand Down Expand Up @@ -235,6 +241,9 @@ def _init_one(self, _, arr):
def _init_bias(self, _, arr):
arr[:] = 0.0

def _init_quantized_bias(self, _, arr):
arr[:] = 0

def _init_gamma(self, _, arr):
arr[:] = 1.0

Expand All @@ -245,6 +254,10 @@ def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")

def _init_quantized_weight(self, _, arr):
_arr = random.randint(-127, 127, dtype='int32').asnumpy()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ummm, seems need extend randint to support dtype='int8'.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, int8 dtype is limitied to current randint.

arr[:] = np.int8(_arr)

def _init_default(self, name, _):
raise ValueError(
'Unknown initialization pattern for %s. ' \
Expand Down
133 changes: 133 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_fully_connected-inl.h
* \brief Common functions used by MKLDNN (Quantized) FullyConnected operator
* \author Ciyong Chen
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_

#if MXNET_USE_MKLDNN == 1

#include <vector>
#include <string>
#include "../fully_connected-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
bool quantized;
bool enable_float_output;
bool with_relu;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset

DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("Whether it's a quantized FullyConnected operator");
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_relu).set_default(false)
.describe("Whether there's a post relu after FullyConnected operator");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized fullyconnected op to calculate primitive scale");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized fullyconnected op to calculate primitive scale");
}
};

struct MKLDNNFCFullParam {
FullyConnectedParam default_param;
MKLDNNFCParam mkldnn_param;
std::vector<float> output_scales = {0.0};
std::vector<float> requantize_scales = {0.0};
};

mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
const mkldnn::memory::desc &out_md);

class MKLDNNFullyConnectedForward {
public:
mkldnn::inner_product_forward::primitive_desc fwd_pd;

MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &out_md)
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}


void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output);

const mkldnn::inner_product_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> weight_;
std::shared_ptr<mkldnn::memory> bias_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

MKLDNNFullyConnectedForward &GetFCFwd(
const FullyConnectedParam &param, const bool is_train,
TaoLv marked this conversation as resolved.
Show resolved Hide resolved
const NDArray &data, const NDArray &weight,
const NDArray *bias, const mkldnn::memory::desc &out_md);

void MKLDNNFCFlattenData(const FullyConnectedParam &param,
const NDArray &out_data,
NDArray *in_data,
mkldnn::memory::desc *out_md);

void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &param,
const OpContext &ctx,
MKLDNNFullyConnectedForward *fwd,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
Loading