Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhance PartitionGraph #14277

Merged
merged 17 commits into from
Mar 26, 2019
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 37 additions & 3 deletions docs/tutorials/c++/subgraphAPI.md
Original file line number Diff line number Diff line change
Expand Up @@ -97,23 +97,57 @@ class SgProperty : public SubgraphProperty {
return n;
}
SubgraphSelectorPtr CreateSubgraphSelector() const override {
return std::make_shared<SgSelector>();
auto property = std::make_shared<CreateSubgraphSelector>();
property->SetAttr<std::string>("property_name", "subgraph example pass"); // Optional, better to have it.
property->SetAttr<bool>("inference_only", true); // Optional, only for inference_only pass.
return property;
}
};
```
`SetAttr` is optional and developer can define their own attributes to control property behavior.
There're 2 built-in attributes that used by MXNet executor.

After defining the subgraph property, we need to register it.
`property_name` : std::string, name of this property.

`inference_only` : bool, apply this property only for inference. Property will be skiped when need_grad=True. Default `false` if this attribute isn't defined.

After defining the subgraph property, we need to register it in .cc file.

```C++
MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty);
```

After compiling this subgraph mechanism into MXNet, we can use the environment variable `MXNET_SUBGRAPH_BACKEND` to activate it.
It's possible to register multiple properties for same backend. In practice, we recommend to put each property definition into .h file, and register backend in single .cc file. Property will be executed according to the register order.

```C++
#include "SgProperty.h" // Define SgProperty class
#include "SgProperty2.h" // Define SgProperty2 class
#include "SgProperty3.h" // Define SgProperty3 class

MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty); // Execution order 1.
MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty2); // Execution order 2.
MXNET_REGISTER_SUBGRAPH_PROPERTY(SgTest, SgProperty3); // Execution order 3.
```

After compiling this subgraph mechanism into MXNet, we can use the environment variable `MXNET_SUBGRAPH_BACKEND` to activate it during symbol bind.

```bash
export MXNET_SUBGRAPH_BACKEND=SgTest
```

Or you can use python symbol API `get_backend_symbol` to run all properties registered for this backend and get returned symbol.

```Python
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)
sym = sym.get_backend_symbol('SgTest')
```

When `SgProperty` is activated, a message will be shown in terminal as

```bash
start to execute subgraph example pass.
```

This tutorial shows a simple example of how to use the subgraph API to search for patterns in an NNVM graph.
Intested users can try different pattern matching rules (i.e., define their own `SubgraphSelector`) and
attach different operators to execute the subgraphs.
Expand Down
2 changes: 0 additions & 2 deletions example/quantization/imagenet_gen_qsym_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,6 @@ def save_params(fname, arg_params, aux_params, logger=None):
sym, arg_params, aux_params = mx.model.load_checkpoint(prefix, epoch)

sym = sym.get_backend_symbol('MKLDNN')
sym = sym.get_backend_symbol('MKLDNN_FC')

# get batch size
batch_size = args.batch_size
Expand Down Expand Up @@ -303,7 +302,6 @@ def save_params(fname, arg_params, aux_params, logger=None):
% calib_mode)
sym_name = '%s-symbol.json' % (prefix + suffix)
qsym = qsym.get_backend_symbol('MKLDNN_POST_QUANTIZE')
qsym = qsym.get_backend_symbol('MKLDNN_POST_FC_QUANTIZE')
save_symbol(sym_name, qsym, logger)
param_name = '%s-%04d.params' % (prefix + '-quantized', epoch)
save_params(param_name, qarg_params, aux_params, logger)
17 changes: 9 additions & 8 deletions src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,14 +722,15 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
API_BEGIN();
nnvm::Symbol *sym = static_cast<nnvm::Symbol *>(sym_handle);
*s = sym->Copy();
nnvm::Graph g = Symbol2Graph(*s);
mxnet::op::SubgraphPropertyPtr property =
mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(
backend);
g.attrs["subgraph_property"] =
std::make_shared<nnvm::any>(std::move(property));
g = ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
std::vector<mxnet::op::SubgraphPropertyPtr> properties =
mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(backend);
for (auto property : properties) {
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
}
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
19 changes: 11 additions & 8 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,19 @@ int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
}
nnvm::Symbol* sym = static_cast<nnvm::Symbol*>(sym_handle);
*s = sym->Copy();
nnvm::Graph g;
g.outputs = s->outputs;
if (!op_name_set.empty()) {
mxnet::op::SubgraphPropertyPtr property
= mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
std::vector<mxnet::op::SubgraphPropertyPtr> properties =
mxnet::op::SubgraphPropertyRegistry::Get()->CreateSubgraphProperty(prop_name);
for (auto property : properties) {
nnvm::Graph g;
g.outputs = s->outputs;
property->SetAttr("graph", g);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
}
}
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
s->outputs = g.outputs;
*ret_sym_handle = s;
API_END_HANDLE_ERROR(delete s);
}
Expand Down
Loading