Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[BUGFIX] Fix quantized_op + requantize + dequantize fuse (#20323)
Browse files Browse the repository at this point in the history
* Fix quantized_op + requantize + quantize fuse pass
Fix for elemwise_mul and FC
Fix passing calibration layers required
  • Loading branch information
bgawrych committed Jun 28, 2021
1 parent dc69b04 commit 835e250
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 56 deletions.
51 changes: 37 additions & 14 deletions src/operator/quantization/quantize_graph_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ Graph QuantizeGraph(Graph &&src) {
static const auto& need_requantize_map = Op::GetAttr<mxnet::FNeedRequantize>("FNeedRequantize");
static const auto& avoid_quantize_input_map =
Op::GetAttr<mxnet::FAvoidQuantizeInput>("FAvoidQuantizeInput");
static const auto& flist_inputs = nnvm::Op::GetAttr<nnvm::FListOutputNames>("FListInputNames");
const auto offline_params = src.GetAttr<std::unordered_set<std::string>>("offline_params");
const auto quantized_dtype = src.GetAttr<std::string>("quantized_dtype");
const auto quantize_granularity = src.GetAttr<std::string>("quantize_granularity");
Expand Down Expand Up @@ -346,7 +347,13 @@ Graph QuantizeGraph(Graph &&src) {
std::string name = GetOutputName(e.node.get(), e.index);
suffix = "_" + name;
} else if (!offline_params.count(new_name)) {
new_name = node->attrs.name + "_" + e.node->attrs.name;
std::string input_name;
if (flist_inputs.count(node->op())) {
input_name = flist_inputs[node->op()](node->attrs)[i];
new_name = node->attrs.name + "_" + input_name;
} else {
new_name = node->attrs.name + "_" + e.node->attrs.name;
}
}

ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2",
Expand Down Expand Up @@ -504,20 +511,33 @@ Graph QuantizeGraph(Graph &&src) {
static const auto& need_calib_output_map =
Op::GetAttr<mxnet::FNeedCalibrateOutput>("FNeedCalibrateOutput");

std::stack<std::string> calib_variables;
std::unordered_set<nnvm::ObjectPtr> calib_variables;
std::vector<std::string> calib_nodes;
DFSVisit(ret.outputs, [&](const ObjectPtr& node) {
if (node->op() && !calib_variables.empty()) {
if (reverse_mirror_map.count(node)) {
const std::string& var_name = calib_variables.top();
const auto& fp32_in_node = reverse_mirror_map[node];
for (const auto &input_node : fp32_in_node->inputs) {
if (var_name == input_node.node->attrs.name) {
calib_nodes.push_back(fp32_in_node->attrs.name + "_" + var_name);
calib_variables.pop();
break;
// find nodes where input is variable node
// and add proper input_name to calib_nodes
for (int i = 0; i < node->inputs.size(); i++) {
const auto &input_node = node->inputs[i];
if (calib_variables.find(input_node.node) != std::end(calib_variables)) {
auto fp32_node = std::find_if(std::begin(quantized_node_map),
std::end(quantized_node_map),
[&](const std::pair<ObjectPtr, ObjectPtr> &pair) {
return pair.second == node;
});
if (fp32_node != std::end(quantized_node_map)) {
const auto& fp32_in_node = fp32_node->first;
std::string node_input_name;
if (flist_inputs.count(fp32_in_node->op())) {
std::string op_input_name = flist_inputs[fp32_in_node->op()](fp32_in_node->attrs)[i];
node_input_name = fp32_in_node->attrs.name + "_" + op_input_name;
} else {
node_input_name = fp32_in_node->attrs.name + "_" + input_node.node->attrs.name;
}
calib_nodes.push_back(node_input_name);
calib_variables.erase(input_node.node);
}
}
}
}
if (need_calib_input_map.count(node->op())) {
Expand All @@ -530,10 +550,13 @@ Graph QuantizeGraph(Graph &&src) {
} else {
const auto& e = node->inputs[idx];
if (e.node->is_variable()) {
// monitor callback join operator name and variable name as observable node,
// utilize fact that we're using DFS and put variable name on stack to
// find operator node name for this variable node
calib_variables.emplace(e.node->attrs.name);
// monitor callback join operator name and variable name as observable node name,
// instead of using variable output we can use op node input
//
// data_output/fc_input
// e.g. data (var.) ----------------------> FC (op)
// remember current node and compare with inputs of next nodes
calib_variables.insert(node);
} else {
if (reverse_mirror_map.count(e.node)) {
const auto& fp32_in_node = reverse_mirror_map.at(e.node);
Expand Down
9 changes: 0 additions & 9 deletions src/operator/quantization/quantized_elemwise_mul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,6 @@ namespace op {

DMLC_REGISTER_PARAMETER(QuantizeElemwiseMulParam);

static std::vector<std::string> QuantizedElemwiseMulOutputNames(const NodeAttrs &attrs) {
const QuantizeElemwiseMulParam& params = nnvm::get<QuantizeElemwiseMulParam>(attrs.parsed);
if (params.enable_float_output)
return std::vector<std::string>{"output"};
else
return std::vector<std::string>{"output", "min_output", "max_output"};
}

inline bool QuantizedElemwiseMulOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
Expand Down Expand Up @@ -229,7 +221,6 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_mul)
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"lhs", "rhs", "lhs_min", "lhs_max", "rhs_min", "rhs_max"};
})
.set_attr<nnvm::FListOutputNames>("FListOutputNames", QuantizedElemwiseMulOutputNames)
.set_attr<mxnet::FInferShape>("FInferShape", QuantizedElemwiseMulOpShape)
.set_attr<nnvm::FInferType>("FInferType", QuantizedElemwiseMulOpType)
.set_attr<FInferStorageType>("FInferStorageType", QuantizedElemwiseMulOpStorageType)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1

#include <memory>
#include <string>
#include <vector>
#include "../../tensor/elemwise_binary_op-inl.h"
Expand All @@ -40,7 +41,7 @@ namespace op {

#define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul"

class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 {
public:
/*! \brief pattern match status */
enum SelectStatus {
Expand All @@ -54,16 +55,17 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
bool disable_all;
bool disable_float_output;
SelectStatus status;
std::vector<const nnvm::Node *> matched_list;
std::vector<const BiDirectedNode *> matched_list;

public:
explicit ElemwiseMulPostQuantizeSelector(const bool dis_all,
const bool dis_float_output)
: disable_all(dis_all),
disable_float_output(dis_float_output) {}

bool Select(const nnvm::Node &n) override {
if ((!disable_all) && n.op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
bool Select(const BiDirectedNode &n) override {
const auto rawnode = n.node;
if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand All @@ -72,12 +74,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
return false;
}

bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
return false;
}

bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
if (status == kFail || status == kSuccess || new_node.is_variable())
bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
const auto raw_node = n.node;
const auto raw_new_node = new_node.node;
if (status == kFail || status == kSuccess || raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
Expand All @@ -95,8 +99,8 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {

switch (status) {
case kStart:
if (new_node.op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.push_back(&new_node);
Expand All @@ -105,7 +109,20 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
}
}
case kRequantize:
if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) {
CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
if (n.outputs.size() > 1) {
// check if requantize have other outputs than dequantize
// if it has we can't fuse dequantize into elemwise_mul
for (auto kv : n.outputs) {
const auto& node = kv.first;
if (node->op() != Op::Get("_contrib_dequantize")) {
status = kSuccess;
return false;
}
}
}

matched_list.push_back(&new_node);
status = kSuccess;
return true;
Expand All @@ -116,14 +133,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector {
}
}

std::vector<nnvm::Node *> Filter(
const std::vector<nnvm::Node *> &candidates) override {
std::vector<BiDirectedNode *> Filter(
const std::vector<BiDirectedNode *>& candidates) override {
if ((status != kSuccess) || (matched_list.size() <= 1)) {
return std::vector<nnvm::Node *>(0);
return std::vector<BiDirectedNode *>(0);
} else {
std::vector<nnvm::Node *> ret;
std::vector<BiDirectedNode *> ret;
for (auto i : matched_list) {
auto non_const_i = const_cast<nnvm::Node *>(i);
auto non_const_i = const_cast<BiDirectedNode *>(i);
if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
ret.push_back(non_const_i);
Expand Down Expand Up @@ -194,7 +211,7 @@ class ElemwiseMulPostQuantizeProperty : public SubgraphProperty {
return em_node;
}

SubgraphSelectorPtr CreateSubgraphSelector() const override {
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector =
std::make_shared<ElemwiseMulPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
Expand Down
48 changes: 32 additions & 16 deletions src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_
#if MXNET_USE_ONEDNN == 1

#include <memory>
#include <string>
#include <vector>
#include "../../nn/fully_connected-inl.h"
Expand All @@ -40,7 +41,7 @@ namespace op {

#define QUANTIZED_FC_NAME "_sg_mkldnn_fully_connected"

class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelectorV2 {
public:
/*! \brief pattern match status */
enum SelectStatus {
Expand All @@ -54,16 +55,17 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
bool disable_all;
bool disable_float_output;
SelectStatus status;
std::vector<const nnvm::Node *> matched_list;
std::vector<const BiDirectedNode *> matched_list;

public:
explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all,
const bool dis_float_output)
: disable_all(dis_all),
disable_float_output(dis_float_output) {}

bool Select(const nnvm::Node &n) override {
if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) {
bool Select(const BiDirectedNode &n) override {
const auto rawnode = n.node;
if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) {
status = disable_all ? kSuccess : kStart;
matched_list.clear();
matched_list.push_back(&n);
Expand All @@ -72,12 +74,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
return false;
}

bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override {
bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
return false;
}

bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override {
if (status == kFail || status == kSuccess || new_node.is_variable())
bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override {
const auto raw_node = n.node;
const auto raw_new_node = new_node.node;
if (status == kFail || status == kSuccess || raw_new_node->is_variable())
return false;
// If n isn't the last matched node, then we encoutered a internal
// branch, we should pop out the node behind n and stop fusion.
Expand All @@ -95,8 +99,8 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {

switch (status) {
case kStart:
if (new_node.op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(new_node.attrs.parsed);
if (raw_new_node->op() == Op::Get("_contrib_requantize")) {
auto const &param = nnvm::get<RequantizeParam>(raw_new_node->attrs.parsed);
if (param.min_calib_range.has_value() &&
param.max_calib_range.has_value()) {
matched_list.push_back(&new_node);
Expand All @@ -105,7 +109,19 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
}
}
case kRequantize:
if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) {
if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) {
CHECK(raw_node->op() == Op::Get("_contrib_requantize"));
if (n.outputs.size() > 1) {
// check if requantize have other outputs than dequantize
// if it has we can't fuse dequantize into FC
for (auto kv : n.outputs) {
const auto& node = kv.first;
if (node->op() != Op::Get("_contrib_dequantize")) {
status = kSuccess;
return false;
}
}
}
matched_list.push_back(&new_node);
status = kSuccess;
return true;
Expand All @@ -116,14 +132,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector {
}
}

std::vector<nnvm::Node *> Filter(
const std::vector<nnvm::Node *> &candidates) override {
std::vector<BiDirectedNode *> Filter(
const std::vector<BiDirectedNode *>& candidates) override {
if ((status != kSuccess) || (matched_list.size() <= 1)) {
return std::vector<nnvm::Node *>(0);
return std::vector<BiDirectedNode *>(0);
} else {
std::vector<nnvm::Node *> ret;
std::vector<BiDirectedNode *> ret;
for (auto i : matched_list) {
auto non_const_i = const_cast<nnvm::Node *>(i);
auto non_const_i = const_cast<BiDirectedNode *>(i);
if (std::find(candidates.begin(), candidates.end(), non_const_i) !=
candidates.end()) {
ret.push_back(non_const_i);
Expand Down Expand Up @@ -194,7 +210,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty {
return fc_node;
}

SubgraphSelectorPtr CreateSubgraphSelector() const override {
SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override {
auto selector =
std::make_shared<SgMKLDNNFCPostQuantizeSelector>(disable_fuse_all,
disable_float_output);
Expand Down
2 changes: 1 addition & 1 deletion src/operator/subgraph/mkldnn/mkldnn_fc_property.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace op {

class SgMKLDNNFCSelector : public SubgraphSelector {
public:
/*! \brief pattern match status */
/* pattern match status */
enum SelectStatus {
kFail = 0,
kStart,
Expand Down
24 changes: 24 additions & 0 deletions tests/python/mkl/subgraphs/test_fc_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,27 @@ def infer_shape(self, x, *args):
out_quantized = qnet(data_nd)
assert_almost_equal_with_err(out.asnumpy(), out_quantized.asnumpy(),
rtol=1e-2, atol=1e-2, etol=0.01)


@pytest.mark.parametrize('data_shape', DATA_SHAPE)
def test_fc_int8_and_fp32_outputs(data_shape):

# /---> Quantizable op
# Input ---> FC -|
# \---> Non quantizable op

class MultiOutputFC(nn.HybridBlock):
def __init__(self, **kwargs):
super(MultiOutputFC, self).__init__(**kwargs)
self.dense0 = nn.Dense(64)
self.dense1 = nn.Dense(64)

def hybrid_forward(self, F, x):
x = self.dense0(x)
y = self.dense1(x) # quantizable
z = F.softmax(x) # non quantizable
return y + z

attrs = {'fc': {}}
net = MultiOutputFC()
check_fusion(net, data_shape, attrs, check_quantization=True)

0 comments on commit 835e250

Please sign in to comment.