diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 624ad047e061..7e1a9288c99b 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -287,6 +287,7 @@ Graph QuantizeGraph(Graph &&src) { static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); + static const auto& flist_inputs = nnvm::Op::GetAttr("FListInputNames"); const auto offline_params = src.GetAttr>("offline_params"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); const auto quantize_granularity = src.GetAttr("quantize_granularity"); @@ -346,7 +347,13 @@ Graph QuantizeGraph(Graph &&src) { std::string name = GetOutputName(e.node.get(), e.index); suffix = "_" + name; } else if (!offline_params.count(new_name)) { - new_name = node->attrs.name + "_" + e.node->attrs.name; + std::string input_name; + if (flist_inputs.count(node->op())) { + input_name = flist_inputs[node->op()](node->attrs)[i]; + new_name = node->attrs.name + "_" + input_name; + } else { + new_name = node->attrs.name + "_" + e.node->attrs.name; + } } ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2", @@ -504,20 +511,33 @@ Graph QuantizeGraph(Graph &&src) { static const auto& need_calib_output_map = Op::GetAttr("FNeedCalibrateOutput"); - std::stack calib_variables; + std::unordered_set calib_variables; std::vector calib_nodes; DFSVisit(ret.outputs, [&](const ObjectPtr& node) { if (node->op() && !calib_variables.empty()) { - if (reverse_mirror_map.count(node)) { - const std::string& var_name = calib_variables.top(); - const auto& fp32_in_node = reverse_mirror_map[node]; - for (const auto &input_node : fp32_in_node->inputs) { - if (var_name == input_node.node->attrs.name) { - calib_nodes.push_back(fp32_in_node->attrs.name + "_" + var_name); - calib_variables.pop(); - break; + // find nodes where input is variable node + // and add proper input_name to calib_nodes + for (int i = 0; i < node->inputs.size(); i++) { + const auto &input_node = node->inputs[i]; + if (calib_variables.find(input_node.node) != std::end(calib_variables)) { + auto fp32_node = std::find_if(std::begin(quantized_node_map), + std::end(quantized_node_map), + [&](const std::pair &pair) { + return pair.second == node; + }); + if (fp32_node != std::end(quantized_node_map)) { + const auto& fp32_in_node = fp32_node->first; + std::string node_input_name; + if (flist_inputs.count(fp32_in_node->op())) { + std::string op_input_name = flist_inputs[fp32_in_node->op()](fp32_in_node->attrs)[i]; + node_input_name = fp32_in_node->attrs.name + "_" + op_input_name; + } else { + node_input_name = fp32_in_node->attrs.name + "_" + input_node.node->attrs.name; } + calib_nodes.push_back(node_input_name); + calib_variables.erase(input_node.node); } + } } } if (need_calib_input_map.count(node->op())) { @@ -530,10 +550,13 @@ Graph QuantizeGraph(Graph &&src) { } else { const auto& e = node->inputs[idx]; if (e.node->is_variable()) { - // monitor callback join operator name and variable name as observable node, - // utilize fact that we're using DFS and put variable name on stack to - // find operator node name for this variable node - calib_variables.emplace(e.node->attrs.name); + // monitor callback join operator name and variable name as observable node name, + // instead of using variable output we can use op node input + // + // data_output/fc_input + // e.g. data (var.) ----------------------> FC (op) + // remember current node and compare with inputs of next nodes + calib_variables.insert(node); } else { if (reverse_mirror_map.count(e.node)) { const auto& fp32_in_node = reverse_mirror_map.at(e.node); diff --git a/src/operator/quantization/quantized_elemwise_mul.cc b/src/operator/quantization/quantized_elemwise_mul.cc index 6d112af34418..41727d9f0ed0 100644 --- a/src/operator/quantization/quantized_elemwise_mul.cc +++ b/src/operator/quantization/quantized_elemwise_mul.cc @@ -32,14 +32,6 @@ namespace op { DMLC_REGISTER_PARAMETER(QuantizeElemwiseMulParam); -static std::vector QuantizedElemwiseMulOutputNames(const NodeAttrs &attrs) { - const QuantizeElemwiseMulParam& params = nnvm::get(attrs.parsed); - if (params.enable_float_output) - return std::vector{"output"}; - else - return std::vector{"output", "min_output", "max_output"}; -} - inline bool QuantizedElemwiseMulOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { @@ -229,7 +221,6 @@ NNVM_REGISTER_OP(_contrib_quantized_elemwise_mul) [](const NodeAttrs& attrs) { return std::vector{"lhs", "rhs", "lhs_min", "lhs_max", "rhs_min", "rhs_max"}; }) -.set_attr("FListOutputNames", QuantizedElemwiseMulOutputNames) .set_attr("FInferShape", QuantizedElemwiseMulOpShape) .set_attr("FInferType", QuantizedElemwiseMulOpType) .set_attr("FInferStorageType", QuantizedElemwiseMulOpStorageType) diff --git a/src/operator/subgraph/mkldnn/mkldnn_elemwisemul_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_elemwisemul_post_quantize_property.h index ad816755a8dd..e551cfd40117 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_elemwisemul_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_elemwisemul_post_quantize_property.h @@ -28,6 +28,7 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_ELEMWISEMUL_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_ONEDNN == 1 +#include #include #include #include "../../tensor/elemwise_binary_op-inl.h" @@ -40,7 +41,7 @@ namespace op { #define QUANTIZED_ElemwiseMul_NAME "_contrib_quantized_elemwise_mul" -class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { +class ElemwiseMulPostQuantizeSelector : public SubgraphSelectorV2 { public: /*! \brief pattern match status */ enum SelectStatus { @@ -54,7 +55,7 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { bool disable_all; bool disable_float_output; SelectStatus status; - std::vector matched_list; + std::vector matched_list; public: explicit ElemwiseMulPostQuantizeSelector(const bool dis_all, @@ -62,8 +63,9 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { : disable_all(dis_all), disable_float_output(dis_float_output) {} - bool Select(const nnvm::Node &n) override { - if ((!disable_all) && n.op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) { + bool Select(const BiDirectedNode &n) override { + const auto rawnode = n.node; + if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_ElemwiseMul_NAME)) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -72,12 +74,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { return false; } - bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override { return false; } - bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) + bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override { + const auto raw_node = n.node; + const auto raw_new_node = new_node.node; + if (status == kFail || status == kSuccess || raw_new_node->is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. @@ -95,8 +99,8 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { switch (status) { case kStart: - if (new_node.op() == Op::Get("_contrib_requantize")) { - auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (raw_new_node->op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(raw_new_node->attrs.parsed); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.push_back(&new_node); @@ -105,7 +109,20 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { } } case kRequantize: - if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) { + CHECK(raw_node->op() == Op::Get("_contrib_requantize")); + if (n.outputs.size() > 1) { + // check if requantize have other outputs than dequantize + // if it has we can't fuse dequantize into elemwise_mul + for (auto kv : n.outputs) { + const auto& node = kv.first; + if (node->op() != Op::Get("_contrib_dequantize")) { + status = kSuccess; + return false; + } + } + } + matched_list.push_back(&new_node); status = kSuccess; return true; @@ -116,14 +133,14 @@ class ElemwiseMulPostQuantizeSelector : public SubgraphSelector { } } - std::vector Filter( - const std::vector &candidates) override { + std::vector Filter( + const std::vector& candidates) override { if ((status != kSuccess) || (matched_list.size() <= 1)) { - return std::vector(0); + return std::vector(0); } else { - std::vector ret; + std::vector ret; for (auto i : matched_list) { - auto non_const_i = const_cast(i); + auto non_const_i = const_cast(i); if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { ret.push_back(non_const_i); @@ -194,7 +211,7 @@ class ElemwiseMulPostQuantizeProperty : public SubgraphProperty { return em_node; } - SubgraphSelectorPtr CreateSubgraphSelector() const override { + SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { auto selector = std::make_shared(disable_fuse_all, disable_float_output); diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h index 3404fdb4478a..72e6dc553679 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h @@ -28,6 +28,7 @@ #define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_ONEDNN == 1 +#include #include #include #include "../../nn/fully_connected-inl.h" @@ -40,7 +41,7 @@ namespace op { #define QUANTIZED_FC_NAME "_sg_mkldnn_fully_connected" -class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { +class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelectorV2 { public: /*! \brief pattern match status */ enum SelectStatus { @@ -54,7 +55,7 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { bool disable_all; bool disable_float_output; SelectStatus status; - std::vector matched_list; + std::vector matched_list; public: explicit SgMKLDNNFCPostQuantizeSelector(const bool dis_all, @@ -62,8 +63,9 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { : disable_all(dis_all), disable_float_output(dis_float_output) {} - bool Select(const nnvm::Node &n) override { - if ((!disable_all) && n.op() == Op::Get(QUANTIZED_FC_NAME)) { + bool Select(const BiDirectedNode &n) override { + const auto rawnode = n.node; + if ((!disable_all) && rawnode->op() == Op::Get(QUANTIZED_FC_NAME)) { status = disable_all ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); @@ -72,12 +74,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { return false; } - bool SelectInput(const nnvm::Node &n, const nnvm::Node &new_node) override { + bool SelectInput(const BiDirectedNode &n, const BiDirectedNode &new_node) override { return false; } - bool SelectOutput(const nnvm::Node &n, const nnvm::Node &new_node) override { - if (status == kFail || status == kSuccess || new_node.is_variable()) + bool SelectOutput(const BiDirectedNode &n, const BiDirectedNode &new_node) override { + const auto raw_node = n.node; + const auto raw_new_node = new_node.node; + if (status == kFail || status == kSuccess || raw_new_node->is_variable()) return false; // If n isn't the last matched node, then we encoutered a internal // branch, we should pop out the node behind n and stop fusion. @@ -95,8 +99,8 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { switch (status) { case kStart: - if (new_node.op() == Op::Get("_contrib_requantize")) { - auto const ¶m = nnvm::get(new_node.attrs.parsed); + if (raw_new_node->op() == Op::Get("_contrib_requantize")) { + auto const ¶m = nnvm::get(raw_new_node->attrs.parsed); if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { matched_list.push_back(&new_node); @@ -105,7 +109,19 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { } } case kRequantize: - if ((!disable_float_output) && (new_node.op() == Op::Get("_contrib_dequantize"))) { + if ((!disable_float_output) && (raw_new_node->op() == Op::Get("_contrib_dequantize"))) { + CHECK(raw_node->op() == Op::Get("_contrib_requantize")); + if (n.outputs.size() > 1) { + // check if requantize have other outputs than dequantize + // if it has we can't fuse dequantize into FC + for (auto kv : n.outputs) { + const auto& node = kv.first; + if (node->op() != Op::Get("_contrib_dequantize")) { + status = kSuccess; + return false; + } + } + } matched_list.push_back(&new_node); status = kSuccess; return true; @@ -116,14 +132,14 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { } } - std::vector Filter( - const std::vector &candidates) override { + std::vector Filter( + const std::vector& candidates) override { if ((status != kSuccess) || (matched_list.size() <= 1)) { - return std::vector(0); + return std::vector(0); } else { - std::vector ret; + std::vector ret; for (auto i : matched_list) { - auto non_const_i = const_cast(i); + auto non_const_i = const_cast(i); if (std::find(candidates.begin(), candidates.end(), non_const_i) != candidates.end()) { ret.push_back(non_const_i); @@ -194,7 +210,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { return fc_node; } - SubgraphSelectorPtr CreateSubgraphSelector() const override { + SubgraphSelectorV2Ptr CreateSubgraphSelectorV2() const override { auto selector = std::make_shared(disable_fuse_all, disable_float_output); diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h index 9a0c7770a22d..756b69df87d2 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -40,7 +40,7 @@ namespace op { class SgMKLDNNFCSelector : public SubgraphSelector { public: - /*! \brief pattern match status */ + /* pattern match status */ enum SelectStatus { kFail = 0, kStart, diff --git a/tests/python/mkl/subgraphs/test_fc_subgraph.py b/tests/python/mkl/subgraphs/test_fc_subgraph.py index 1bcd332e3b8c..6351bfe0bbeb 100644 --- a/tests/python/mkl/subgraphs/test_fc_subgraph.py +++ b/tests/python/mkl/subgraphs/test_fc_subgraph.py @@ -173,3 +173,27 @@ def infer_shape(self, x, *args): out_quantized = qnet(data_nd) assert_almost_equal_with_err(out.asnumpy(), out_quantized.asnumpy(), rtol=1e-2, atol=1e-2, etol=0.01) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_fc_int8_and_fp32_outputs(data_shape): + +# /---> Quantizable op +# Input ---> FC -| +# \---> Non quantizable op + + class MultiOutputFC(nn.HybridBlock): + def __init__(self, **kwargs): + super(MultiOutputFC, self).__init__(**kwargs) + self.dense0 = nn.Dense(64) + self.dense1 = nn.Dense(64) + + def hybrid_forward(self, F, x): + x = self.dense0(x) + y = self.dense1(x) # quantizable + z = F.softmax(x) # non quantizable + return y + z + + attrs = {'fc': {}} + net = MultiOutputFC() + check_fusion(net, data_shape, attrs, check_quantization=True)