From e8a657a4c3d0cbb5af3763cb0c8f5128c4c38ca3 Mon Sep 17 00:00:00 2001 From: Xinyu Chen Date: Wed, 6 Mar 2019 15:20:17 +0800 Subject: [PATCH] Register fake grad to subgraph and quantized operators (#14275) * add fake grad * Skip inference only subgraph pass when gradient is needed. * add fake grad to quantizev2 * add TODO * modify prop_name to property_name * add test case --- src/executor/graph_executor.cc | 34 +++++++++--- src/operator/quantization/dequantize.cc | 3 ++ src/operator/quantization/quantize.cc | 3 ++ src/operator/quantization/quantize_v2.cc | 3 ++ src/operator/quantization/quantized_concat.cc | 3 ++ src/operator/quantization/quantized_conv.cc | 3 ++ .../quantization/quantized_flatten.cc | 3 ++ .../quantization/quantized_fully_connected.cc | 3 ++ .../quantization/quantized_pooling.cc | 3 ++ src/operator/quantization/requantize.cc | 3 ++ src/operator/subgraph/mkldnn/mkldnn_conv.cc | 3 ++ .../mkldnn_conv_post_quantize_property.cc | 6 ++- .../subgraph/mkldnn/mkldnn_conv_property.cc | 5 +- src/operator/subgraph/subgraph_property.h | 7 +++ tests/python/mkl/test_subgraph.py | 53 ++++++++++++++----- 15 files changed, 114 insertions(+), 21 deletions(-) diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index ca2cea093c5d..436eae37d785 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1506,8 +1506,26 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_arg_ctxes, - const std::vector& aux_state_ctxes) { + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types) { auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); + bool need_grad = false; + for (OpReqType req : grad_req_types) { + if (req != kNullOp) { + need_grad = true; + break; + } + } + if (subgraph_prop->HasAttr("inference_only") && + subgraph_prop->GetAttr("inference_only") == true) { + if (need_grad) { + auto full_name = subgraph_prop->HasAttr("prop_name") + ? subgraph_prop->GetAttr("prop_name") + : prop_name; + LOG(INFO) << "Skip subgraph " << full_name << " as it requires `grad_req=null`."; + return src; + } + } nnvm::Symbol ret = src.Copy(); nnvm::Graph g; g.outputs = ret.outputs; @@ -1539,7 +1557,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_arg_ctxes, - const std::vector& aux_state_ctxes) { + const std::vector& aux_state_ctxes, + const std::vector& grad_req_types) { const std::vector input_names = src.ListInputNames(Symbol::kAll); mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape()); nnvm::DTypeVector arg_dtypes(input_names.size(), -1); @@ -1559,7 +1578,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, } } return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes); + default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types); } // Given input ndarrays, partition the graph using the backend name equal to prop_name. @@ -1569,7 +1588,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, std::vector *in_args, const std::vector &aux_states, const Context& default_ctx, - const std::map& ctx_map) { + const std::map& ctx_map, + const std::vector& grad_req_types) { const std::vector input_names = src.ListInputNames(Symbol::kAll); const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); const std::vector aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates); @@ -1609,7 +1629,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, in_args_map[arg_names[i]] = in_args->at(i); } auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, - ctx_map, in_arg_ctxes, aux_state_ctxes); + ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types); // Reorder in_args into new_in_args according to partitioned symbol input sequence std::vector new_in_args(in_args->size()); // get new symbol in_arg names @@ -1644,7 +1664,7 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, if (!exec->subgraph_property().empty()) { symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map, arg_stype_map, default_ctx, group2ctx, in_arg_ctxes, - aux_state_ctxes); + aux_state_ctxes, grad_req_types); } exec->Init(symbol, default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, @@ -1667,7 +1687,7 @@ Executor *Executor::Bind(nnvm::Symbol symbol, std::vector tmp_in_args = in_args; if (!exec->subgraph_property().empty()) { symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, - default_ctx, group2ctx); + default_ctx, group2ctx, grad_req_type); } exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, arg_grad_store, grad_req_type, aux_states, diff --git a/src/operator/quantization/dequantize.cc b/src/operator/quantization/dequantize.cc index a4d57b9b4461..7c84673095f0 100644 --- a/src/operator/quantization/dequantize.cc +++ b/src/operator/quantization/dequantize.cc @@ -71,6 +71,9 @@ by keep zero centered for the quantized value: .set_attr("FInferShape", DequantizeShape) .set_attr("FInferType", DequantizeType) .set_attr("FInferStorageType", DequantizeStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNDequantizeCompute) diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index c28d8c860924..63467506b99b 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -82,6 +82,9 @@ where .set_attr("FInferShape", QuantizeShape) .set_attr("FInferType", QuantizeType) .set_attr("FInferStorageType", QuantizeStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeCompute) diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 300cdfe3b751..920100bc9f8b 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -83,6 +83,9 @@ If min_calib_range isn't presented, the output type will be int8. .set_attr("FInferShape", QuantizeV2Shape) .set_attr("FInferType", QuantizeV2Type) .set_attr("FInferStorageType", QuantizeV2StorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNQuantizeV2Compute) diff --git a/src/operator/quantization/quantized_concat.cc b/src/operator/quantization/quantized_concat.cc index f5c1e8e6ceae..e32bb5a18e1a 100644 --- a/src/operator/quantization/quantized_concat.cc +++ b/src/operator/quantization/quantized_concat.cc @@ -127,6 +127,9 @@ If any input holds int8, then the output will be int8. Otherwise output will be .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { return std::vector{"output", "min_output", "max_output"}; }) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FInferType", ConcatType) .set_attr("FInferShape", ConcatShape) .set_attr("key_var_num_args", "num_args") diff --git a/src/operator/quantization/quantized_conv.cc b/src/operator/quantization/quantized_conv.cc index 7841c3acb47c..1a801ee50744 100644 --- a/src/operator/quantization/quantized_conv.cc +++ b/src/operator/quantization/quantized_conv.cc @@ -160,6 +160,9 @@ and max thresholds representing the threholds for quantizing the float32 output .set_attr("FInferShape", QuantizedConvShape) .set_attr("FInferType", QuantizedConvType) .set_attr("FInferStorageType", QuantizedConvStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", [](const NodeAttrs& attrs) { return std::vector(1, ResourceRequest::kTempSpace); diff --git a/src/operator/quantization/quantized_flatten.cc b/src/operator/quantization/quantized_flatten.cc index f283d98cf10b..7e6d27b256d4 100644 --- a/src/operator/quantization/quantized_flatten.cc +++ b/src/operator/quantization/quantized_flatten.cc @@ -34,6 +34,9 @@ NNVM_REGISTER_OP(_contrib_quantized_flatten) .set_attr("FInferShape", QuantizedFlattenShape) .set_attr("FInferType", QuantizedFlattenType) .set_attr("FCompute", QuantizedFlattenCompute) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FListInputNames", [](const NodeAttrs& attrs) { return std::vector{"data", "min_data", "max_data"}; diff --git a/src/operator/quantization/quantized_fully_connected.cc b/src/operator/quantization/quantized_fully_connected.cc index f51b6fdd1798..3b18e6591afc 100644 --- a/src/operator/quantization/quantized_fully_connected.cc +++ b/src/operator/quantization/quantized_fully_connected.cc @@ -264,6 +264,9 @@ and max thresholds representing the threholds for quantizing the float32 output .set_attr("FInferShape", QuantizedFullyConnectedShape) .set_attr("FInferType", QuantizedFullyConnectedType) .set_attr("FInferStorageType", QuantizedFullyConnectedStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { return true; }) .set_attr("FComputeEx", QuantizedFullyConnectedForward) diff --git a/src/operator/quantization/quantized_pooling.cc b/src/operator/quantization/quantized_pooling.cc index cdc98eeac6f6..af604080a756 100644 --- a/src/operator/quantization/quantized_pooling.cc +++ b/src/operator/quantization/quantized_pooling.cc @@ -157,6 +157,9 @@ the float32 data into int8. .set_attr("FInferShape", QuantizedPoolingShape) .set_attr("FInferType", QuantizedPoolingType) .set_attr("FInferStorageType", QuantizedPoolingStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FNeedRequantize", [](const NodeAttrs& attrs) { const PoolingParam& param = nnvm::get(attrs.parsed); diff --git a/src/operator/quantization/requantize.cc b/src/operator/quantization/requantize.cc index edfb58e5cbd5..4807226e464c 100644 --- a/src/operator/quantization/requantize.cc +++ b/src/operator/quantization/requantize.cc @@ -64,6 +64,9 @@ inference accuracy. .set_attr("FInferShape", QuantizeShape) .set_attr("FInferType", RequantizeType) .set_attr("FInferStorageType", RequantizeStorageType) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) #if MXNET_USE_MKLDNN == 1 .set_attr("TIsMKLDNN", true) .set_attr("FComputeEx", MKLDNNRequantizeForward) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index e53ab2538a90..d61b4613602a 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -689,6 +689,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_conv) .set_attr("FInferStorageType", SgMKLDNNConvOpStorageType) .set_attr("FStatefulComputeEx", SgMKLDNNConvOpForward) .set_attr("TIsMKLDNN", true) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc index fc68287b039d..654f6e763972 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc @@ -107,7 +107,11 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { } } static SubgraphPropertyPtr Create() { - return std::make_shared(); + auto property = std::make_shared(); + property->SetAttr("property_name", + "MKLDNN Convolution post-quantization optimization pass"); + property->SetAttr("inference_only", true); + return property; } nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const override { diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc index e462191c2898..56ce72961e44 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc @@ -149,7 +149,10 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } } static SubgraphPropertyPtr Create() { - return std::make_shared(); + auto property = std::make_shared(); + property->SetAttr("prop_name", "MKLDNN Convolution optimization pass"); + property->SetAttr("inference_only", true); + return property; } nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const override { diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index e9fdd6619275..d115d3498e86 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -145,6 +145,13 @@ class SubgraphProperty { CHECK(it != attrs_.end()) << "Cannot find attribute " << name << " in SubgraphProperty"; return nnvm::get(*it->second); } + /*! + * \brief Check if the attr exist. + */ + bool HasAttr(const std::string& name) const { + auto it = attrs_.find(name); + return it != attrs_.end(); + } protected: std::unordered_map> attrs_; diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 313668cb56f9..8de854cc290d 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -66,15 +66,37 @@ def check_qsym_dummy_forward(qsym, batch, data_shape, label_shape): output.wait_to_read() return mod.get_outputs() -def check_quantize(sym, data_shape, out_type, check_conv=True): +def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape): + # save qsym to JSON file + qsym.save('quantized-symbol.json') + # save params + save_dict = {('arg:%s' % k): v.as_in_context(mx.current_context()) for k, v in qarg_params.items()} + save_dict.update({('aux:%s' % k): v.as_in_context(mx.current_context()) for k, v in qaux_params.items()}) + mx.nd.save('quantized-0000.params', save_dict) + # load back with SymbolBlock + net = mx.gluon.SymbolBlock.imports('quantized-symbol.json', ['data'], 'quantized-0000.params') + net.collect_params().reset_ctx(ctx = mx.current_context()) + net.hybridize() + + data = mx.random.uniform(-1.0, 1.0, shape=data_shape) + net(data) + +def check_quantize(sym, data_shape, out_type, check_conv=True, gluon_forward=False): fc = mx.sym.FullyConnected(data=sym, num_hidden=10, flatten=True, name='fc') - sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') - sym_sg = sym.get_backend_symbol("MKLDNN") - label_shape = (data_shape[0], 10) - mod = Module(symbol=sym) - mod.bind(for_training=False, - data_shapes=[('data', data_shape)], - label_shapes=[('softmax_label', label_shape)]) + if gluon_forward == True: + sym = fc + sym_sg = fc.get_backend_symbol("MKLDNN") + mod = Module(symbol=sym, label_names=[]) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)]) + else: + sym = mx.sym.SoftmaxOutput(data=fc, name='softmax') + sym_sg = sym.get_backend_symbol("MKLDNN") + label_shape = (data_shape[0], 10) + mod = Module(symbol=sym) + mod.bind(for_training=False, + data_shapes=[('data', data_shape)], + label_shapes=[('softmax_label', label_shape)]) mod.init_params(mx.init.Normal(0.5)) arg_params, aux_params = mod.get_params() @@ -107,10 +129,13 @@ def check_quantize(sym, data_shape, out_type, check_conv=True): qsym = qsym.get_backend_symbol("MKLDNN_POST_QUANTIZE") if check_conv: check_qsym_calibrated(qsym, out_type) - quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) - for i in range(len(ref_out)): - assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) - check_qsym_dummy_forward(qsym, batch, data_shape, label_shape) + if gluon_forward == True: + check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data_shape) + else: + check_qsym_dummy_forward(qsym, batch, data_shape, label_shape) + quantized_out = check_qsym_forward(qsym, qarg_params, qaux_params, batch, data_shape, label_shape) + for i in range(len(ref_out)): + assert_almost_equal(ref_out[i].asnumpy(), quantized_out[i].asnumpy(), atol = 1) @with_seed() @@ -137,6 +162,7 @@ def check_fusion(sym, data_shape, attrs_op): # fp32 to int8 for out_type in ('uint8', 'int8', 'auto'): check_quantize(sym, data_shape, out_type) + check_quantize(sym, data_shape, out_type, gluon_forward=True) def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, date_shape=(4,4,10,10)): for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): @@ -478,10 +504,13 @@ def test_pos_single_concat(): for out_type in ('uint8', 'int8', 'auto'): net = single_concat(data_shape, 2, 1) check_quantize(net, data_shape, out_type, False) + check_quantize(net, data_shape, out_type, False, True) net = single_concat(data_shape, 4, 2) check_quantize(net, data_shape, out_type, False) + check_quantize(net, data_shape, out_type, False, True) net = single_concat(data_shape, 4, 3) check_quantize(net, data_shape, out_type, False) + check_quantize(net, data_shape, out_type, False, True) @with_seed() def test_neg_conv_bn():