From f8a0dbc1cc8eb694b5959f2d2dee310e571b20ad Mon Sep 17 00:00:00 2001 From: Zhennan Qin Date: Tue, 26 Mar 2019 15:32:49 +0800 Subject: [PATCH] Enhance PartitionGraph (#14277) * Enhance PartitionGraph * Fix lint * Fix test * Run CI * Change subgraph property register * Change doc * Fix name * Run CI * Add env var in doc * Address comments. * run CI --- docs/tutorials/c++/subgraphAPI.md | 40 ++- docs/tutorials/mkldnn/MKLDNN_README.md | 8 + .../quantization/imagenet_gen_qsym_mkldnn.py | 2 - src/c_api/c_api_symbolic.cc | 17 +- src/c_api/c_api_test.cc | 19 +- src/executor/graph_executor.cc | 324 ++++++++++++------ ...c => mkldnn_conv_post_quantize_property.h} | 35 +- ...onv_property.cc => mkldnn_conv_property.h} | 22 +- src/operator/subgraph/mkldnn/mkldnn_fc.cc | 3 + ....cc => mkldnn_fc_post_quantize_property.h} | 28 +- ...nn_fc_property.cc => mkldnn_fc_property.h} | 39 +-- .../mkldnn/mkldnn_subgraph_property.cc | 38 ++ src/operator/subgraph/partition_graph.cc | 4 + src/operator/subgraph/subgraph_property.h | 42 ++- tests/python/mkl/test_subgraph.py | 4 +- 15 files changed, 407 insertions(+), 218 deletions(-) rename src/operator/subgraph/mkldnn/{mkldnn_conv_post_quantize_property.cc => mkldnn_conv_post_quantize_property.h} (84%) rename src/operator/subgraph/mkldnn/{mkldnn_conv_property.cc => mkldnn_conv_property.h} (93%) rename src/operator/subgraph/mkldnn/{mkldnn_fc_post_quantize_property.cc => mkldnn_fc_post_quantize_property.h} (91%) rename src/operator/subgraph/mkldnn/{mkldnn_fc_property.cc => mkldnn_fc_property.h} (85%) create mode 100644 src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc diff --git a/docs/tutorials/c++/subgraphAPI.md b/docs/tutorials/c++/subgraphAPI.md index b834df8741b5..6b1b477b8021 100644 --- a/docs/tutorials/c++/subgraphAPI.md +++ b/docs/tutorials/c++/subgraphAPI.md @@ -97,23 +97,57 @@ class SgProperty : public SubgraphProperty { return n; } SubgraphSelectorPtr CreateSubgraphSelector() const override { - return std::make_shared(); + auto property = std::make_shared(); + property->SetAttr("property_name", "subgraph example pass"); // Optional, better to have it. + property->SetAttr("inference_only", true); // Optional, only for inference_only pass. + return property; } }; ``` +`SetAttr` is optional and developer can define their own attributes to control property behavior. +There're 2 built-in attributes that used by MXNet executor. -After defining the subgraph property, we need to register it. +`property_name` : std::string, name of this property. + +`inference_only` : bool, apply this property only for inference. Property will be skiped when need_grad=True. Default `false` if this attribute isn't defined. + +After defining the subgraph property, we need to register it in .cc file. ```C++ MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty); ``` -After compiling this subgraph mechanism into MXNet, we can use the environment variable `MXNET_SUBGRAPH_BACKEND` to activate it. +It's possible to register multiple properties for same backend. In practice, we recommend to put each property definition into .h file, and register backend in single .cc file. Property will be executed according to the register order. + +```C++ +#include "SgProperty.h" // Define SgProperty class +#include "SgProperty2.h" // Define SgProperty2 class +#include "SgProperty3.h" // Define SgProperty3 class + +MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty); // Execution order 1. +MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty2); // Execution order 2. +MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty3); // Execution order 3. +``` + +After compiling this subgraph mechanism into MXNet, we can use the environment variable `MXNET_SUBGRAPH_BACKEND` to activate it during symbol bind. ```bash export MXNET_SUBGRAPH_BACKEND=SgTest ``` +Or you can use python symbol API `get_backend_symbol` to run all properties registered for this backend and get returned symbol. + +```Python +sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) +sym = sym.get_backend_symbol('SgTest') +``` + +When `SgProperty` is activated, a message will be shown in terminal as + +```bash +start to execute subgraph example pass. +``` + This tutorial shows a simple example of how to use the subgraph API to search for patterns in an NNVM graph. Intested users can try different pattern matching rules (i.e., define their own `SubgraphSelector`) and attach different operators to execute the subgraphs. diff --git a/docs/tutorials/mkldnn/MKLDNN_README.md b/docs/tutorials/mkldnn/MKLDNN_README.md index 189de3f3a0d4..c5779670cd87 100644 --- a/docs/tutorials/mkldnn/MKLDNN_README.md +++ b/docs/tutorials/mkldnn/MKLDNN_README.md @@ -306,6 +306,14 @@ Graph optimization by subgraph feature are available in master branch. You can b ``` export MXNET_SUBGRAPH_BACKEND=MKLDNN ``` + +When `MKLDNN` backend is enabled, advanced control options are avaliable: + +``` +export MXNET_DISABLE_MKLDNN_CONV_OPT=1 # disable MKLDNN convolution optimization pass +export MXNET_DISABLE_MKLDNN_FC_OPT=1 # disable MKLDNN FullyConnected optimization pass +``` + This limitations of this experimental feature are: diff --git a/example/quantization/imagenet_gen_qsym_mkldnn.py b/example/quantization/imagenet_gen_qsym_mkldnn.py index 3f644fc771a7..2ef137273cca 100644 --- a/example/quantization/imagenet_gen_qsym_mkldnn.py +++ b/example/quantization/imagenet_gen_qsym_mkldnn.py @@ -180,7 +180,6 @@ def save_params(fname, arg_params, aux_params, logger=None): sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch) sym = sym.get_backend_symbol('MKLDNN') - sym = sym.get_backend_symbol('MKLDNN_FC') # get batch size batch_size = args.batch_size @@ -303,7 +302,6 @@ def save_params(fname, arg_params, aux_params, logger=None): % calib_mode) sym_name = '%s-symbol.json' % (prefix + suffix) qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE') - qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE') save_symbol(sym_name, qsym, logger) param_name = '%s-%04d.params' % (prefix + '-quantized', epoch) save_params(param_name, qarg_params, aux_params, logger) diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index e07716267288..a3a0b0ca16f9 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -722,14 +722,15 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend, API_BEGIN(); nnvm::Symbol *sym = static_cast(sym_handle); *s = sym->Copy(); - nnvm::Graph g = Symbol2Graph(*s); - mxnet::op::SubgraphPropertyPtr property = - mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty( - backend); - g.attrs["subgraph_property"] = - std::make_shared(std::move(property)); - g = ApplyPass(std::move(g), "PartitionGraph"); - s->outputs = g.outputs; + std::vector properties = + mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(backend); + for (auto property : properties) { + nnvm::Graph g = Symbol2Graph(*s); + property->SetAttr("graph", g); + g.attrs["subgraph_property"] = std::make_shared(std::move(property)); + g = nnvm::ApplyPass(std::move(g), "PartitionGraph"); + s->outputs = g.outputs; + } *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } diff --git a/src/c_api/c_api_test.cc b/src/c_api/c_api_test.cc index 623faa71adc9..70829db3d4a5 100644 --- a/src/c_api/c_api_test.cc +++ b/src/c_api/c_api_test.cc @@ -40,16 +40,19 @@ int MXPartitionGraphByOpNames(SymbolHandle sym_handle, } nnvm::Symbol* sym = static_cast(sym_handle); *s = sym->Copy(); - nnvm::Graph g; - g.outputs = s->outputs; if (!op_name_set.empty()) { - mxnet::op::SubgraphPropertyPtr property - = mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); - property->SetAttr("op_names", op_name_set); - g.attrs["subgraph_property"] = std::make_shared(std::move(property)); + std::vector properties = + mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); + for (auto property : properties) { + nnvm::Graph g; + g.outputs = s->outputs; + property->SetAttr("graph", g); + property->SetAttr("op_names", op_name_set); + g.attrs["subgraph_property"] = std::make_shared(std::move(property)); + g = nnvm::ApplyPass(std::move(g), "PartitionGraph"); + s->outputs = g.outputs; + } } - g = nnvm::ApplyPass(std::move(g), "PartitionGraph"); - s->outputs = g.outputs; *ret_sym_handle = s; API_END_HANDLE_ERROR(delete s); } diff --git a/src/executor/graph_executor.cc b/src/executor/graph_executor.cc index 3d74bfb9a663..28862a49fae1 100644 --- a/src/executor/graph_executor.cc +++ b/src/executor/graph_executor.cc @@ -1443,47 +1443,20 @@ static nnvm::Graph InferForwardAttrs(nnvm::Graph g, // Given input attr arrays, partition the graph using the backend name equal to prop_name. // This is a common function for bind and simple_bind flows. static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, - const std::string& prop_name, + mxnet::op::SubgraphPropertyPtr subgraph_prop, const mxnet::ShapeVector& arg_shapes, const nnvm::DTypeVector& arg_dtypes, const StorageTypeVector& arg_stypes, const Context& default_ctx, const std::map& ctx_map, const std::vector& in_arg_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types) { - auto subgraph_prop = op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); - bool need_grad = false; - for (OpReqType req : grad_req_types) { - if (req != kNullOp) { - need_grad = true; - break; - } - } - if (subgraph_prop->HasAttr("inference_only") && - subgraph_prop->GetAttr("inference_only") == true) { - if (need_grad) { - auto full_name = subgraph_prop->HasAttr("prop_name") - ? subgraph_prop->GetAttr("prop_name") - : prop_name; - LOG(INFO) << "Skip subgraph " << full_name << " as it requires `grad_req=null`."; - return src; - } - } + const std::vector& aux_state_ctxes) { nnvm::Symbol ret = src.Copy(); nnvm::Graph g; g.outputs = ret.outputs; - g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, - ctx_map, in_arg_ctxes, aux_state_ctxes); + g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes, + aux_state_ctxes); subgraph_prop->SetAttr("graph", g); - auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name); - // assign a op name set to the subgraph property if it has been provided by users - if (it != op::SubgraphPropertyOpNameSet::Get()->end()) { - LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name - << " has been assigned a value. Please make sure it is initialized" - " only for the testing purpose."; - subgraph_prop->SetAttr("op_names", it->second); - } g.attrs["subgraph_property"] = std::make_shared(std::move(subgraph_prop)); g = ApplyPass(std::move(g), "PartitionGraph"); ret.outputs = g.outputs; @@ -1500,91 +1473,221 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::unordered_map& arg_stype_map, const Context& default_ctx, const std::map& ctx_map, - const std::vector& in_arg_ctxes, - const std::vector& aux_state_ctxes, - const std::vector& grad_req_types) { - const std::vector input_names = src.ListInputNames(Symbol::kAll); - mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape()); - nnvm::DTypeVector arg_dtypes(input_names.size(), -1); - StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage); - for (size_t i = 0; i < input_names.size(); ++i) { - auto it1 = arg_shape_map.find(input_names[i]); - if (arg_shape_map.end() != it1) { - arg_shapes[i] = it1->second; + std::vector* in_arg_ctxes, + std::vector* arg_grad_ctxes, + std::vector* grad_req_types, + std::vector* aux_state_ctxes) { + // setup map for in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types + std::unordered_map in_arg_ctx_map; + std::unordered_map arg_grad_ctx_map; + std::unordered_map aux_state_ctx_map; + std::unordered_map grad_req_type_map; + + auto arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + auto aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + for (size_t i = 0; i < arg_names.size(); ++i) { + auto name = arg_names[i]; + in_arg_ctx_map[name] = in_arg_ctxes->at(i); + arg_grad_ctx_map[name] = arg_grad_ctxes->at(i); + grad_req_type_map[name] = grad_req_types->at(i); + } + + for (size_t i = 0; i < aux_names.size(); ++i) { + aux_state_ctx_map[aux_names[i]] = aux_state_ctxes->at(i); + } + + bool need_grad = false; + for (OpReqType req : *grad_req_types) { + if (req != kNullOp) { + need_grad = true; + break; } - auto it2 = arg_dtype_map.find(input_names[i]); - if (arg_dtype_map.end() != it2) { - arg_dtypes[i] = it2->second; + } + nnvm::Symbol ret = src.Copy(); + std::unordered_set op_names_set; + auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name); + // assign a op name set to the subgraph property if it has been provided by users + if (it != op::SubgraphPropertyOpNameSet::Get()->end()) { + LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name + << " has been assigned a value. Please make sure it is initialized" + " only for the testing purpose."; + op_names_set = it->second; + } + std::vector properties = + op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); + for (auto subgraph_prop : properties) { + if (subgraph_prop->HasAttr("inference_only") && + subgraph_prop->GetAttr("inference_only") == true) { + if (need_grad) { + auto full_name = subgraph_prop->HasAttr("property_name") + ? subgraph_prop->GetAttr("property_name") + : prop_name; + LOG(INFO) << "skip partitioning graph with subgraph property " << full_name + << " as it requires `grad_req=null`."; + continue; + } } - auto it3 = arg_stype_map.find(input_names[i]); - if (arg_stype_map.end() != it3) { - arg_stypes[i] = it3->second; + subgraph_prop->SetAttr("op_names", op_names_set); + const std::vector input_names = ret.ListInputNames(Symbol::kAll); + mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape()); + nnvm::DTypeVector arg_dtypes(input_names.size(), -1); + StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage); + for (size_t i = 0; i < input_names.size(); ++i) { + const auto& input_name = input_names[i]; + auto it1 = arg_shape_map.find(input_name); + if (arg_shape_map.end() != it1) { + arg_shapes[i] = it1->second; + } + auto it2 = arg_dtype_map.find(input_name); + if (arg_dtype_map.end() != it2) { + arg_dtypes[i] = it2->second; + } + auto it3 = arg_stype_map.find(input_name); + if (arg_stype_map.end() != it3) { + arg_stypes[i] = it3->second; + } + } + ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, *in_arg_ctxes, *aux_state_ctxes); + // Reorder in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types according to + // partitioned symbol input sequence + in_arg_ctxes->clear(); + arg_grad_ctxes->clear(); + aux_state_ctxes->clear(); + grad_req_types->clear(); + auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + for (auto arg_name : new_arg_names) { + CHECK(in_arg_ctx_map.count(arg_name)); + in_arg_ctxes->push_back(in_arg_ctx_map[arg_name]); + arg_grad_ctxes->push_back(arg_grad_ctx_map[arg_name]); + grad_req_types->push_back(grad_req_type_map[arg_name]); + } + for (auto arg_name : new_aux_names) { + CHECK(aux_state_ctx_map.count(arg_name)); + aux_state_ctxes->push_back(aux_state_ctx_map[arg_name]); } } - return PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, - default_ctx, ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types); + return ret; } // Given input ndarrays, partition the graph using the backend name equal to prop_name. // This is for bind flow. -static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, - const std::string& prop_name, - std::vector *in_args, - const std::vector &aux_states, +static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name, const Context& default_ctx, const std::map& ctx_map, - const std::vector& grad_req_types) { - const std::vector input_names = src.ListInputNames(Symbol::kAll); + std::vector* in_args, + std::vector* arg_grad_store, + std::vector* grad_req_type, + std::vector* aux_states) { + // setup map for in_args, arg_grad_store, grad_req_type and aux_states + std::unordered_map in_args_map; + std::unordered_map arg_grad_store_map; + std::unordered_map grad_req_type_map; + std::unordered_map aux_states_map; const std::vector arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs); const std::vector aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates); - CHECK_EQ(arg_names.size(), in_args->size()); - CHECK_EQ(aux_names.size(), aux_states.size()); - mxnet::ShapeVector arg_shapes; // all input shapes - arg_shapes.reserve(input_names.size()); - nnvm::DTypeVector arg_dtypes; // all input dtypes - arg_dtypes.reserve(input_names.size()); - StorageTypeVector arg_stypes; // all input stypes - arg_stypes.reserve(input_names.size()); - std::vector in_arg_ctxes(in_args->size()); - std::vector aux_state_ctxes(aux_states.size()); + for (size_t i = 0; i < arg_names.size(); ++i) { + auto name = arg_names[i]; + in_args_map[name] = in_args->at(i); + arg_grad_store_map[name] = arg_grad_store->at(i); + grad_req_type_map[name] = grad_req_type->at(i); + } - size_t i1 = 0, i2 = 0; - for (const auto& input_name : input_names) { - if (i2 < aux_names.size() && aux_names[i2] == input_name) { - arg_shapes.push_back(aux_states[i2].shape()); - arg_dtypes.push_back(aux_states[i2].dtype()); - arg_stypes.push_back(aux_states[i2].storage_type()); - aux_state_ctxes[i2] = aux_states[i2].ctx(); - ++i2; - } else { - CHECK(i1 < arg_names.size()); - CHECK_EQ(arg_names[i1], input_name); - arg_shapes.push_back(in_args->at(i1).shape()); - arg_dtypes.push_back(in_args->at(i1).dtype()); - arg_stypes.push_back(in_args->at(i1).storage_type()); - in_arg_ctxes[i1] = in_args->at(i1).ctx(); - ++i1; + for (size_t i = 0; i < aux_names.size(); ++i) { + aux_states_map[aux_names[i]] = aux_states->at(i); + } + + bool need_grad = false; + for (OpReqType req : *grad_req_type) { + if (req != kNullOp) { + need_grad = true; + break; } } + nnvm::Symbol ret = src.Copy(); + std::unordered_set op_names_set; + auto it = op::SubgraphPropertyOpNameSet::Get()->find(prop_name); + // assign a op name set to the subgraph property if it has been provided by users + if (it != op::SubgraphPropertyOpNameSet::Get()->end()) { + LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << prop_name + << " has been assigned a value. Please make sure it is initialized" + " only for the testing purpose."; + op_names_set = it->second; + } + std::vector properties = + op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name); + for (auto subgraph_prop : properties) { + if (subgraph_prop->HasAttr("inference_only") && + subgraph_prop->GetAttr("inference_only") == true) { + if (need_grad) { + auto full_name = subgraph_prop->HasAttr("property_name") + ? subgraph_prop->GetAttr("property_name") + : prop_name; + LOG(INFO) << "Skip subgraph " << full_name << " as it requires `grad_req=null`."; + continue; + } + } + subgraph_prop->SetAttr("op_names", op_names_set); + const std::vector input_names = ret.ListInputNames(Symbol::kAll); + const std::vector arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const std::vector aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + CHECK_EQ(arg_names.size(), in_args_map.size()); + CHECK_EQ(aux_names.size(), aux_states_map.size()); + mxnet::ShapeVector arg_shapes; // all input shapes + arg_shapes.reserve(input_names.size()); + nnvm::DTypeVector arg_dtypes; // all input dtypes + arg_dtypes.reserve(input_names.size()); + StorageTypeVector arg_stypes; // all input stypes + arg_stypes.reserve(input_names.size()); + std::vector in_arg_ctxes(in_args_map.size()); + std::vector aux_state_ctxes(aux_states_map.size()); + + size_t i1 = 0, i2 = 0; + for (const auto& input_name : input_names) { + if (i2 < aux_names.size() && aux_names[i2] == input_name) { + const auto &aux_st = aux_states_map[input_name]; + arg_shapes.push_back(aux_st.shape()); + arg_dtypes.push_back(aux_st.dtype()); + arg_stypes.push_back(aux_st.storage_type()); + aux_state_ctxes[i2] = aux_st.ctx(); + ++i2; + } else { + CHECK(i1 < arg_names.size()); + CHECK_EQ(arg_names[i1], input_name); + const auto &in_arg = in_args_map[input_name]; + arg_shapes.push_back(in_arg.shape()); + arg_dtypes.push_back(in_arg.dtype()); + arg_stypes.push_back(in_arg.storage_type()); + in_arg_ctxes[i1] = in_arg.ctx(); + ++i1; + } + } - // setup in_args_map - std::unordered_map in_args_map; - for (size_t i = 0; i < in_args->size(); ++i) { - in_args_map[arg_names[i]] = in_args->at(i); - } - auto result = PartitionGraph(src, prop_name, arg_shapes, arg_dtypes, arg_stypes, default_ctx, - ctx_map, in_arg_ctxes, aux_state_ctxes, grad_req_types); - // Reorder in_args into new_in_args according to partitioned symbol input sequence - std::vector new_in_args(in_args->size()); - // get new symbol in_arg names - std::vector new_arg_names = result.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx, + ctx_map, in_arg_ctxes, aux_state_ctxes); + } + // Reorder in_args, arg_grad_store, grad_req_type and aux_states according to partitioned symbol + // input sequence + const auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs); + const auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates); + CHECK_EQ(arg_names.size(), new_arg_names.size()); CHECK_EQ(arg_names.size(), new_arg_names.size()); in_args->clear(); + arg_grad_store->clear(); + grad_req_type->clear(); + aux_states->clear(); for (auto arg_name : new_arg_names) { CHECK(in_args_map.count(arg_name)); in_args->push_back(in_args_map[arg_name]); + arg_grad_store->push_back(arg_grad_store_map[arg_name]); + grad_req_type->push_back(grad_req_type_map[arg_name]); } - return result; + for (auto arg_name : new_aux_names) { + CHECK(aux_states_map.count(arg_name)); + aux_states->push_back(aux_states_map[arg_name]); + } + return ret; } } // namespace exec @@ -1605,17 +1708,18 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol, std::unordered_map* shared_buffer, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); + std::vector tmp_in_arg_ctxes = in_arg_ctxes; + std::vector tmp_arg_grad_ctxes = arg_grad_ctxes; + std::vector tmp_aux_state_ctxes = aux_state_ctxes; + std::vector tmp_grad_req_types = grad_req_types; if (!exec->subgraph_property().empty()) { symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map, - arg_stype_map, default_ctx, group2ctx, in_arg_ctxes, - aux_state_ctxes, grad_req_types); - } - exec->Init(symbol, default_ctx, group2ctx, - in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes, - arg_shape_map, arg_dtype_map, arg_stype_map, - grad_req_types, shared_arg_names, - in_args, arg_grads, aux_states, - shared_buffer, shared_exec); + arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes, + &tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes); + } + exec->Init(symbol, default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes, + tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types, + shared_arg_names, in_args, arg_grads, aux_states, shared_buffer, shared_exec); return exec; } @@ -1629,13 +1733,17 @@ Executor *Executor::Bind(nnvm::Symbol symbol, Executor* shared_exec) { auto exec = new exec::GraphExecutor(); std::vector tmp_in_args = in_args; + std::vector tmp_arg_grad_store = arg_grad_store; + std::vector tmp_grad_req_type = grad_req_type; + std::vector tmp_aux_states = aux_states; + if (!exec->subgraph_property().empty()) { - symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), &tmp_in_args, aux_states, - default_ctx, group2ctx, grad_req_type); + symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), default_ctx, group2ctx, + &tmp_in_args, &tmp_arg_grad_store, &tmp_grad_req_type, + &tmp_aux_states); } - exec->Init(symbol, default_ctx, group2ctx, - tmp_in_args, arg_grad_store, grad_req_type, aux_states, - reinterpret_cast(shared_exec)); + exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, tmp_grad_req_type, + tmp_aux_states, reinterpret_cast(shared_exec)); return exec; } } // namespace mxnet diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.h similarity index 84% rename from src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc rename to src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.h index 654f6e763972..f9033f48d413 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_post_quantize_property.h @@ -16,9 +16,12 @@ * specific language governing permissions and limitations * under the License. */ - +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 +#include +#include #include "../common.h" #include "../subgraph_property.h" #include "../../nn/mkldnn/mkldnn_convolution-inl.h" @@ -38,16 +41,14 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { }; private: - bool disable_all; SelectStatus status; std::vector matched_list; public: - explicit SgMKLDNNConvPostQuantizeSelector(int dis_all) - : disable_all(dis_all) {} + SgMKLDNNConvPostQuantizeSelector() {} bool Select(const nnvm::Node &n) override { - if ((!disable_all) && n.op() && n.op()->name == "_sg_mkldnn_conv") { + if (n.op() && n.op()->name == "_sg_mkldnn_conv") { auto const ¶m = nnvm::get(n.attrs.parsed); if (param.full_conv_param.mkldnn_param.quantized) { status = kStart; @@ -98,21 +99,16 @@ class SgMKLDNNConvPostQuantizeSelector : public SubgraphSelector { class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { public: - SgMKLDNNConvPostQuantizeProperty() { - disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); - if (disable_all) { - LOG(INFO) << "MKLDNN Convolution post-quantization optimization pass is disabled."; - } else { - LOG(INFO) << "Start to execute MKLDNN Convolution post-quantization optimization pass."; - } - } + SgMKLDNNConvPostQuantizeProperty() {} + static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN Convolution post-quantization optimization pass"; auto property = std::make_shared(); - property->SetAttr("property_name", - "MKLDNN Convolution post-quantization optimization pass"); + property->SetAttr("property_name", name); property->SetAttr("inference_only", true); return property; } + nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, const int subgraph_id = 0) const override { nnvm::NodePtr conv_node = nullptr; @@ -141,8 +137,7 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { } SubgraphSelectorPtr CreateSubgraphSelector() const override { - auto selector = - std::make_shared(disable_all); + auto selector = std::make_shared(); return selector; } @@ -154,14 +149,10 @@ class SgMKLDNNConvPostQuantizeProperty : public SubgraphProperty { *entry_ptr = nnvm::NodeEntry{n, entry_ptr->index, 0}; } } - - private: - int disable_all; }; -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNConvPostQuantizeProperty); - } // namespace op } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h similarity index 93% rename from src/operator/subgraph/mkldnn/mkldnn_conv_property.cc rename to src/operator/subgraph/mkldnn/mkldnn_conv_property.h index 56ce72961e44..7fe4727a4990 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -17,8 +17,12 @@ * under the License. */ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 +#include +#include #include "../common.h" #include "../subgraph_property.h" #include "../../nn/activation-inl.h" @@ -136,21 +140,20 @@ class SgMKLDNNConvSelector : public SubgraphSelector { class SgMKLDNNConvProperty : public SubgraphProperty { public: SgMKLDNNConvProperty() { - disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", 0); disable_conv_bn = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_BN", 0); disable_conv_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_RELU", 0); disable_conv_sum = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_CONV_SUM", 0); - disable_all = disable_all || (disable_conv_bn && disable_conv_relu && disable_conv_sum); - if (disable_all) { - LOG(INFO) << "MKLDNN Convolution optimization pass is disabled."; - } else { - LOG(INFO) << "Start to execute MKLDNN Convolution optimization pass."; - } + disable_all = disable_conv_bn && disable_conv_relu && disable_conv_sum; } static SubgraphPropertyPtr Create() { + static const std::string &name = "MKLDNN convolution optimization pass"; + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_CONV_OPT", 0)) { + LOG(INFO) << name << " is disabled."; + return nullptr; + } auto property = std::make_shared(); - property->SetAttr("prop_name", "MKLDNN Convolution optimization pass"); + property->SetAttr("property_name", name); property->SetAttr("inference_only", true); return property; } @@ -244,9 +247,8 @@ class SgMKLDNNConvProperty : public SubgraphProperty { int disable_conv_sum; }; -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); - } // namespace op } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_CONV_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc.cc b/src/operator/subgraph/mkldnn/mkldnn_fc.cc index 8829404b9576..c9e1e1c79244 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc.cc @@ -423,6 +423,9 @@ NNVM_REGISTER_OP(_sg_mkldnn_fully_connected) .set_attr("FCreateOpState", CreateSgMKLDNNFCState) .set_attr("FStatefulComputeEx", SgMKLDNNFCForward) .set_attr("TIsMKLDNN", true) +// TODO(Xinyu): a temp solution to enable GluonCV INT8 flow, +// will be reverted after the improvement of CachedOP is done. +.set_attr("FGradient", MakeZeroGradNodes) .set_attr("FResourceRequest", [](const NodeAttrs& n) { return std::vector{ResourceRequest::kTempSpace}; }) diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h similarity index 91% rename from src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc rename to src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h index d2d176fadbb6..f8d7ee1da6c9 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_post_quantize_property.h @@ -24,12 +24,16 @@ * \author Ciyong Chen */ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 -#include "../common.h" -#include "../subgraph_property.h" +#include +#include #include "../../nn/fully_connected-inl.h" #include "../../quantization/requantize-inl.h" +#include "../common.h" +#include "../subgraph_property.h" namespace mxnet { namespace op { @@ -133,20 +137,16 @@ class SgMKLDNNFCPostQuantizeSelector : public SubgraphSelector { class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { public: SgMKLDNNFCPostQuantizeProperty() { - disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_POST_OPT", false); disable_fuse_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FUSE_ALL", false); disable_float_output = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_QFC_FLOAT_OUTPUT", false); - - disable_all = disable_all || disable_fuse_all; - if (disable_all) { - LOG(INFO) << "MKLDNN FullyConnected post-quantization optimization pass is disabled."; - } else { - LOG(INFO) << "Start to execute MKLDNN FullyConected post-quantization optimization pass."; - } } static SubgraphPropertyPtr Create() { - return std::make_shared(); + static const std::string &name = "MKLDNN FullyConected post-quantization optimization pass"; + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + return property; } nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, @@ -189,7 +189,7 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { SubgraphSelectorPtr CreateSubgraphSelector() const override { auto selector = - std::make_shared(disable_all, + std::make_shared(disable_fuse_all, disable_float_output); return selector; } @@ -204,14 +204,12 @@ class SgMKLDNNFCPostQuantizeProperty : public SubgraphProperty { } private: - bool disable_all; bool disable_fuse_all; bool disable_float_output; }; -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_FC_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); - } // namespace op } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_POST_QUANTIZE_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h similarity index 85% rename from src/operator/subgraph/mkldnn/mkldnn_fc_property.cc rename to src/operator/subgraph/mkldnn/mkldnn_fc_property.h index e4fa02d4e713..04e140c72d86 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_fc_property.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_fc_property.h @@ -24,8 +24,12 @@ * \author Ciyong Chen */ +#ifndef MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_PROPERTY_H_ +#define MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_PROPERTY_H_ #if MXNET_USE_MKLDNN == 1 +#include +#include #include "../common.h" #include "../subgraph_property.h" @@ -42,19 +46,16 @@ class SgMKLDNNFCSelector : public SubgraphSelector { }; private: - bool disable_all; bool disable_fc_relu; SelectStatus status; std::vector matched_list; public: - SgMKLDNNFCSelector(const bool dis_all, const bool dis_fc_relu) - : disable_all(dis_all), - disable_fc_relu(dis_fc_relu) {} + explicit SgMKLDNNFCSelector(const bool dis_fc_relu) : disable_fc_relu(dis_fc_relu) {} bool Select(const nnvm::Node &n) override { if (n.op() == Op::Get("FullyConnected")) { - status = disable_all ? kSuccess : kStart; + status = disable_fc_relu ? kSuccess : kStart; matched_list.clear(); matched_list.push_back(&n); return true; @@ -86,8 +87,7 @@ class SgMKLDNNFCSelector : public SubgraphSelector { switch (status) { case kStart: - if ((!disable_fc_relu) && - new_node.op() == Op::Get("Activation") && + if (new_node.op() == Op::Get("Activation") && new_node.attrs.dict.at("act_type") == "relu") { matched_list.push_back(&new_node); status = kSuccess; @@ -120,19 +120,19 @@ class SgMKLDNNFCSelector : public SubgraphSelector { class SgMKLDNNFCProperty : public SubgraphProperty { public: SgMKLDNNFCProperty() { - disable_all = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_OPT", false); disable_fc_relu = dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FUSE_FC_RELU", false); - - disable_all = disable_all || disable_fc_relu; - if (disable_all) { - LOG(INFO) << "MKLDNN FullyConnected optimization pass is disabled."; - } else { - LOG(INFO) << "Start to execute MKLDNN FullyConnected optimization pass."; - } } static SubgraphPropertyPtr Create() { - return std::make_shared(); + static const std::string &name = "MKLDNN FullyConnected optimization pass"; + if (dmlc::GetEnv("MXNET_DISABLE_MKLDNN_FC_OPT", 0)) { + LOG(INFO) << name << " is disabled."; + return nullptr; + } + auto property = std::make_shared(); + property->SetAttr("property_name", name); + property->SetAttr("inference_only", true); + return property; } nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym, @@ -165,8 +165,7 @@ class SgMKLDNNFCProperty : public SubgraphProperty { } SubgraphSelectorPtr CreateSubgraphSelector() const override { - auto selector = std::make_shared( - disable_all, disable_fc_relu); + auto selector = std::make_shared(disable_fc_relu); return selector; } @@ -181,13 +180,11 @@ class SgMKLDNNFCProperty : public SubgraphProperty { } private: - bool disable_all; bool disable_fc_relu; }; -MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_FC, SgMKLDNNFCProperty); - } // namespace op } // namespace mxnet #endif // if MXNET_USE_MKLDNN == 1 +#endif // MXNET_OPERATOR_SUBGRAPH_MKLDNN_MKLDNN_FC_PROPERTY_H_ diff --git a/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc new file mode 100644 index 000000000000..4d54f02a5844 --- /dev/null +++ b/src/operator/subgraph/mkldnn/mkldnn_subgraph_property.cc @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#if MXNET_USE_MKLDNN == 1 + +#include "mkldnn_conv_property.h" +#include "mkldnn_fc_property.h" +#include "mkldnn_conv_post_quantize_property.h" +#include "mkldnn_fc_post_quantize_property.h" + +namespace mxnet { +namespace op { + +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNConvProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN, SgMKLDNNFCProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNConvPostQuantizeProperty); +MXNET_REGISTER_SUBGRAPH_PROPERTY(MKLDNN_POST_QUANTIZE, SgMKLDNNFCPostQuantizeProperty); + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_MKLDNN == 1 diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 90a14caa510b..4c7552ec9bfe 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -740,6 +740,10 @@ Graph PartitionGraph(Graph&& g) { } using namespace sg; const SubgraphPropertyPtr& subg_prop = g.GetAttr("subgraph_property"); + const std::string& prop_name = subg_prop->HasAttr("property_name") + ? subg_prop->GetAttr("property_name") + : "partition graph"; + LOG(INFO) << "start to execute " << prop_name << "."; // top sort NodeEntry of all the nodes' inputs std::unordered_map entry_top_order_map; TopSortEntries(g, &entry_top_order_map); diff --git a/src/operator/subgraph/subgraph_property.h b/src/operator/subgraph/subgraph_property.h index d115d3498e86..4c05c089426f 100644 --- a/src/operator/subgraph/subgraph_property.h +++ b/src/operator/subgraph/subgraph_property.h @@ -146,7 +146,7 @@ class SubgraphProperty { return nnvm::get(*it->second); } /*! - * \brief Check if the attr exist. + * \brief Check if the attr exists. */ bool HasAttr(const std::string& name) const { auto it = attrs_.find(name); @@ -167,35 +167,34 @@ class SubgraphPropertyRegistry { return &inst; } - SubgraphPropertyPtr CreateSubgraphProperty(const std::string& name) { + std::vector CreateSubgraphProperty(const std::string& name) { auto it = prop_fn_map_.find(name); CHECK(it != prop_fn_map_.end()) << "SubgraphProperty " << name << " is not found in SubgraphPropertyRegistry"; - return it->second(); - } - - SubgraphPropertyCreateFn __REGISTER_OR_GET__(const std::string& name, - SubgraphPropertyCreateFn fn) { - if (prop_fn_map_.count(name) == 0U) { - return __REGISTER__(name, fn); - } else { - return prop_fn_map_.at(name); + std::vector ret; + ret.reserve(it->second.size()); + for (auto i : it->second) { + auto ptr_it = prop_ptr_map_.find(i); + if (ptr_it == prop_ptr_map_.end()) { + prop_ptr_map_[i] = i(); + ptr_it = prop_ptr_map_.find(i); + } + if (ptr_it->second) ret.emplace_back(ptr_it->second); } + return ret; } - private: SubgraphPropertyCreateFn __REGISTER__(const std::string& name, SubgraphPropertyCreateFn fn) { - CHECK_EQ(prop_fn_map_.count(name), 0U) << "Subgraph property " << name - << " has been registered"; - prop_fn_map_[name] = fn; - return prop_fn_map_[name]; + prop_fn_map_[name].push_back(fn); + return fn; } SubgraphPropertyRegistry() = default; SubgraphPropertyRegistry(const SubgraphPropertyRegistry&) = delete; SubgraphPropertyRegistry(SubgraphPropertyRegistry&&) = delete; SubgraphPropertyRegistry& operator=(const SubgraphPropertyRegistry&) = delete; - std::unordered_map prop_fn_map_; + std::unordered_map> prop_fn_map_; + std::unordered_map prop_ptr_map_; }; // This op name set is for setting the names of operators that should be grouped into @@ -205,9 +204,14 @@ class SubgraphPropertyRegistry { typedef dmlc::ThreadLocalStore>> SubgraphPropertyOpNameSet; +#define DECLARE_PROPERTY_EX(NAME, SubgraphPropertyType, X) \ + static const DMLC_ATTRIBUTE_UNUSED auto __make_##SubgraphPropertyType##_##Name##_##X##__ +#define DECLARE_PROPERTY(NAME, SubgraphPropertyType, X) \ + DECLARE_PROPERTY_EX(NAME, SubgraphPropertyType, X) + #define MXNET_REGISTER_SUBGRAPH_PROPERTY(Name, SubgraphPropertyType) \ - static DMLC_ATTRIBUTE_UNUSED auto __make_ ## SubgraphPropertyType ## _ ## Name ## __ = \ - SubgraphPropertyRegistry::Get()->__REGISTER_OR_GET__(#Name, &SubgraphPropertyType::Create) + DECLARE_PROPERTY(Name, SubgraphPropertyType, __LINE__) = \ + SubgraphPropertyRegistry::Get()->__REGISTER__(#Name, &SubgraphPropertyType::Create) } // namespace op } // namespace mxnet diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 871c1e3d566a..3213fb13a218 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -48,8 +48,8 @@ 'fc': { OP_NAME: 'sg_mkldnn_fully_connected', QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected', - SG_PASS_NAME: 'MKLDNN_FC', - POST_SG_PASS_NAME: 'MKLDNN_POST_FC_QUANTIZE' + SG_PASS_NAME: 'MKLDNN', + POST_SG_PASS_NAME: 'MKLDNN_POST_QUANTIZE' } }