This repository has been archived by the owner on Nov 17, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MKLDNN based Quantized FullyConnected Operator and its fusion #14128
Merged
+1,679
−224
Merged
Changes from all commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
f034b86
add MKL-DNN quantized innerproduct
ciyongch 4ee837e
initial qfc with mkldnn
ciyongch 006d2d8
Add MKL-DNN quantized_fully_connected
ciyongch 88e3a89
refactor params order for fullyconnected
ciyongch e188f63
update quantized_fully_connected unittest, force data to uint8 type t…
ciyongch af4132d
change mkl based quantized fully_connected to FCompute
ciyongch 989901c
add check data type for mkldnn quantized_fc
ciyongch 9b3d96c
add fuse requantize and dequantize for mkldnn quantized fullyconnected
ciyongch 314a667
add env setting for enable/disable fuse requantize/dequantize for qua…
ciyongch 6df747e
fix requantize scaling error
ciyongch 6d4883a
add fallback when input data is int8
ciyongch bd6f313
fix mkl quantized fullyconnected index error
ciyongch cb0bcfa
update quantized fc test cases
ciyongch ce44bd6
add subgraph node for mkldnn fullyconnected
ciyongch 3532bd5
fix compiling and lint error
ciyongch 678d555
clean and refactor code
ciyongch 95dfffe
enable quantized_fc for imagenet
ciyongch 8a9e2f9
cleanup code
ciyongch d891e0b
Fix StorageType error for non-mkldnn path
ciyongch 4da6f5a
fix pylint
ciyongch 40039bd
reverse BUILD_TAG for MKL IGEMM ut, remove IGEMM qfc check
ciyongch 68be291
rename variables and refactor codes according to comments
ciyongch ca6a427
add subgraph qfc tests and fix shape error
ciyongch 517f55d
remove fuse_requantize and change fuse_dequantize to enable_float_out…
ciyongch 16f0b07
change to use mxnet::Tuple and update tests
ciyongch bb8a294
update description in file header
ciyongch 9ec3cf9
update input0 type check for quantized FullyConnected
ciyongch b8edfb5
fix conflit of mkl/test_subgraph.py
ciyongch 145f454
retrigger CI
ciyongch 35a711a
retrigger CI due to hang
ciyongch File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
/* | ||
* Licensed to the Apache Software Foundation (ASF) under one | ||
* or more contributor license agreements. See the NOTICE file | ||
* distributed with this work for additional information | ||
* regarding copyright ownership. The ASF licenses this file | ||
* to you under the Apache License, Version 2.0 (the | ||
* "License"); you may not use this file except in compliance | ||
* with the License. You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, | ||
* software distributed under the License is distributed on an | ||
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | ||
* KIND, either express or implied. See the License for the | ||
* specific language governing permissions and limitations | ||
* under the License. | ||
*/ | ||
|
||
/*! | ||
* Copyright (c) 2019 by Contributors | ||
* \file mkldnn_fully_connected-inl.h | ||
* \brief Common functions used by MKLDNN (Quantized) FullyConnected operator | ||
* \author Ciyong Chen | ||
*/ | ||
|
||
#ifndef MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ | ||
#define MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ | ||
|
||
#if MXNET_USE_MKLDNN == 1 | ||
|
||
#include <vector> | ||
#include <string> | ||
#include "../fully_connected-inl.h" | ||
#include "./mkldnn_base-inl.h" | ||
|
||
namespace mxnet { | ||
namespace op { | ||
|
||
struct MKLDNNFCParam: public dmlc::Parameter<MKLDNNFCParam> { | ||
bool quantized; | ||
bool enable_float_output; | ||
bool with_relu; | ||
dmlc::optional<float> min_calib_range; // min float value calculated from calibration dataset | ||
dmlc::optional<float> max_calib_range; // max float value calculated from calibration dataset | ||
|
||
DMLC_DECLARE_PARAMETER(MKLDNNFCParam) { | ||
DMLC_DECLARE_FIELD(quantized).set_default(false) | ||
.describe("Whether it's a quantized FullyConnected operator"); | ||
DMLC_DECLARE_FIELD(enable_float_output).set_default(false) | ||
.describe("Whether to enable float32 output"); | ||
DMLC_DECLARE_FIELD(with_relu).set_default(false) | ||
.describe("Whether there's a post relu after FullyConnected operator"); | ||
DMLC_DECLARE_FIELD(min_calib_range) | ||
.set_default(dmlc::optional<float>()) | ||
.describe("The minimum scalar value in the form of float32 obtained " | ||
"through calibration. If present, it will be used to by " | ||
"quantized fullyconnected op to calculate primitive scale"); | ||
DMLC_DECLARE_FIELD(max_calib_range) | ||
.set_default(dmlc::optional<float>()) | ||
.describe("The maximum scalar value in the form of float32 obtained " | ||
"through calibration. If present, it will be used to by " | ||
"quantized fullyconnected op to calculate primitive scale"); | ||
} | ||
}; | ||
|
||
struct MKLDNNFCFullParam { | ||
FullyConnectedParam default_param; | ||
MKLDNNFCParam mkldnn_param; | ||
std::vector<float> output_scales = {0.0}; | ||
std::vector<float> requantize_scales = {0.0}; | ||
}; | ||
|
||
mkldnn::inner_product_forward::primitive_desc GetFCFwdImpl( | ||
const MKLDNNFCFullParam &full_param, const bool is_train, | ||
const NDArray &data, const NDArray &weight, const NDArray *bias, | ||
const mkldnn::memory::desc &out_md); | ||
|
||
class MKLDNNFullyConnectedForward { | ||
public: | ||
mkldnn::inner_product_forward::primitive_desc fwd_pd; | ||
|
||
MKLDNNFullyConnectedForward(const MKLDNNFCFullParam &full_param, const bool is_train, | ||
const NDArray &data, const NDArray &weight, | ||
const NDArray *bias, | ||
const mkldnn::memory::desc &out_md) | ||
: fwd_pd(GetFCFwdImpl(full_param, is_train, data, weight, bias, out_md)) {} | ||
|
||
|
||
void SetNewMem(const mkldnn::memory &data, const mkldnn::memory &weight, | ||
const mkldnn::memory *bias, const mkldnn::memory &output); | ||
|
||
const mkldnn::inner_product_forward &GetFwd() const { | ||
return *fwd_; | ||
} | ||
|
||
private: | ||
std::shared_ptr<mkldnn::inner_product_forward> fwd_; | ||
std::shared_ptr<mkldnn::memory> data_; | ||
std::shared_ptr<mkldnn::memory> weight_; | ||
std::shared_ptr<mkldnn::memory> bias_; | ||
std::shared_ptr<mkldnn::memory> out_; | ||
}; | ||
|
||
typedef ParamOpSign<FullyConnectedParam> MKLDNNFullyconSignature; | ||
|
||
MKLDNNFullyConnectedForward &GetFCFwd( | ||
const FullyConnectedParam ¶m, const bool is_train, | ||
TaoLv marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const NDArray &data, const NDArray &weight, | ||
const NDArray *bias, const mkldnn::memory::desc &out_md); | ||
|
||
void MKLDNNFCFlattenData(const FullyConnectedParam ¶m, | ||
const NDArray &out_data, | ||
NDArray *in_data, | ||
mkldnn::memory::desc *out_md); | ||
|
||
void MKLDNNFCForward(const nnvm::NodeAttrs &attrs, const OpContext &ctx, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data); | ||
|
||
void MKLDNNFCForwardFullFeature(const MKLDNNFCFullParam ¶m, | ||
const OpContext &ctx, | ||
MKLDNNFullyConnectedForward *fwd, | ||
const std::vector<NDArray> &in_data, | ||
const std::vector<OpReqType> &req, | ||
const std::vector<NDArray> &out_data); | ||
|
||
} // namespace op | ||
} // namespace mxnet | ||
|
||
#endif // MXNET_USE_MKLDNN == 1 | ||
#endif // MXNET_OPERATOR_NN_MKLDNN_MKLDNN_FULLY_CONNECTED_INL_H_ |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ummm, seems need extend
randint
to supportdtype='int8'
.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes,
int8
dtype is limitied to currentrandint
.