Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
MKLDNN based Quantized FullyConnected Operator and its fusion (#14128)
Browse files Browse the repository at this point in the history
* add MKL-DNN quantized innerproduct

* initial qfc with mkldnn

* Add MKL-DNN quantized_fully_connected

* refactor params order for fullyconnected

* update quantized_fully_connected unittest, force data to uint8 type temporary

* change mkl based quantized fully_connected to FCompute

* add check data type for mkldnn quantized_fc

* add fuse requantize and dequantize for mkldnn quantized fullyconnected

* add env setting for enable/disable fuse requantize/dequantize for quantize fullyconnected

* fix requantize scaling error

* add fallback when input data is int8

* fix mkl quantized fullyconnected index error

* update quantized fc test cases

* add subgraph node for mkldnn fullyconnected

* fix compiling and lint error

* clean and refactor code

* enable quantized_fc for imagenet

* cleanup code

* Fix StorageType error for non-mkldnn path

* fix pylint

* reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check

* rename variables and refactor codes according to comments

* add subgraph qfc tests and fix shape error

* remove fuse_requantize and change fuse_dequantize to enable_float_output.

* change to use mxnet::Tuple and update tests

* update description in file header

* update input0 type check for quantized FullyConnected

* fix conflit of mkl/test_subgraph.py

* retrigger CI

* retrigger CI due to hang
  • Loading branch information
ciyongch authored and TaoLv committed Mar 8, 2019
1 parent 30b1cbc commit 8668db7
Show file tree
Hide file tree
Showing 12 changed files with 1,679 additions and 224 deletions.
9 changes: 4 additions & 5 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_FC')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -207,19 +208,18 @@ def save_params(fname, arg_params, aux_params, logger=None):
if args.model == 'imagenet1k-resnet-152':
rgb_mean = '0,0,0'
rgb_std = '1,1,1'
excluded_sym_names += ['flatten0', 'fc1']
excluded_sym_names += ['flatten0']
if exclude_first_conv:
excluded_sym_names += ['conv0']
elif args.model == 'imagenet1k-inception-bn':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '1,1,1'
excluded_sym_names += ['flatten', 'fc1']
excluded_sym_names += ['flatten']
if exclude_first_conv:
excluded_sym_names += ['conv_1']
elif args.model in ['resnet50_v1', 'resnet101_v1']:
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['resnetv10_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['resnetv10_conv0_fwd']
elif args.model == 'squeezenet1.0':
Expand All @@ -232,14 +232,12 @@ def save_params(fname, arg_params, aux_params, logger=None):
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['mobilenet0_flatten0_flatten0',
'mobilenet0_dense0_fwd',
'mobilenet0_pool0_fwd']
if exclude_first_conv:
excluded_sym_names += ['mobilenet0_conv0_fwd']
elif args.model == 'inceptionv3':
rgb_mean = '123.68,116.779,103.939'
rgb_std = '58.393, 57.12, 57.375'
excluded_sym_names += ['inception30_dense0_fwd']
if exclude_first_conv:
excluded_sym_names += ['inception30_conv0_fwd']
elif args.model == 'custom':
Expand Down Expand Up @@ -305,6 +303,7 @@ def save_params(fname, arg_params, aux_params, logger=None):
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
13 changes: 13 additions & 0 deletions python/mxnet/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,12 @@ def __call__(self, desc, arr):
elif desc.endswith('max'):
self._init_one(desc, arr)
self._verbose_print(desc, 'max', arr)
elif desc.endswith('weight_quantize'):
self._init_quantized_weight(desc, arr)
self._verbose_print(desc, 'weight_quantize', arr)
elif desc.endswith('bias_quantize'):
self._init_quantized_bias(desc, arr)
self._verbose_print(desc, 'bias_quantize', arr)
else:
self._init_default(desc, arr)

Expand Down Expand Up @@ -235,6 +241,9 @@ def _init_one(self, _, arr):
def _init_bias(self, _, arr):
arr[:] = 0.0

def _init_quantized_bias(self, _, arr):
arr[:] = 0

def _init_gamma(self, _, arr):
arr[:] = 1.0

Expand All @@ -245,6 +254,10 @@ def _init_weight(self, name, arr):
"""Abstract method to Initialize weight."""
raise NotImplementedError("Must override it")

def _init_quantized_weight(self, _, arr):
_arr = random.randint(-127, 127, dtype='int32').asnumpy()
arr[:] = np.int8(_arr)

def _init_default(self, name, _):
raise ValueError(
'Unknown initialization pattern for %s. ' \
Expand Down
133 changes: 133 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_fully_connected-inl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file mkldnn_fully_connected-inl.h
* \brief Common functions used by MKLDNN (Quantized) FullyConnected operator
* \author Ciyong Chen
*/

#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_

#if MXNET_USE_MKLDNN == 1

#include <vector>
#include <string>
#include "../fully_connected-inl.h"
#include "./mkldnn_base-inl.h"

namespace mxnet {
namespace op {

struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> {
bool quantized;
bool enable_float_output;
bool with_relu;
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset

DMLC_DECLARE_PARAMETER(MKLDNNFCParam) {
DMLC_DECLARE_FIELD(quantized).set_default(false)
.describe("Whether it's a quantized FullyConnected operator");
DMLC_DECLARE_FIELD(enable_float_output).set_default(false)
.describe("Whether to enable float32 output");
DMLC_DECLARE_FIELD(with_relu).set_default(false)
.describe("Whether there's a post relu after FullyConnected operator");
DMLC_DECLARE_FIELD(min_calib_range)
.set_default(dmlc::optional<float>())
.describe("The minimum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized fullyconnected op to calculate primitive scale");
DMLC_DECLARE_FIELD(max_calib_range)
.set_default(dmlc::optional<float>())
.describe("The maximum scalar value in the form of float32 obtained "
"through calibration. If present, it will be used to by "
"quantized fullyconnected op to calculate primitive scale");
}
};

struct MKLDNNFCFullParam {
FullyConnectedParam default_param;
MKLDNNFCParam mkldnn_param;
std::vector<float> output_scales = {0.0};
std::vector<float> requantize_scales = {0.0};
};

mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl(
const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight, const NDArray *bias,
const mkldnn::memory::desc &out_md);

class MKLDNNFullyConnectedForward {
public:
mkldnn::inner_product_forward::primitive_desc fwd_pd;

MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias,
const mkldnn::memory::desc &out_md)
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {}


void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight,
const mkldnn::memory *bias, const mkldnn::memory &output);

const mkldnn::inner_product_forward &GetFwd() const {
return *fwd_;
}

private:
std::shared_ptr<mkldnn::inner_product_forward> fwd_;
std::shared_ptr<mkldnn::memory> data_;
std::shared_ptr<mkldnn::memory> weight_;
std::shared_ptr<mkldnn::memory> bias_;
std::shared_ptr<mkldnn::memory> out_;
};

typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature;

MKLDNNFullyConnectedForward &GetFCFwd(
const FullyConnectedParam &param, const bool is_train,
const NDArray &data, const NDArray &weight,
const NDArray *bias, const mkldnn::memory::desc &out_md);

void MKLDNNFCFlattenData(const FullyConnectedParam &param,
const NDArray &out_data,
NDArray *in_data,
mkldnn::memory::desc *out_md);

void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam &param,
const OpContext &ctx,
MKLDNNFullyConnectedForward *fwd,
const std::vector<NDArray> &in_data,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &out_data);

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_MKLDNN == 1
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_
Loading

0 comments on commit 8668db7

Please sign in to comment.