Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Enhance subgraph API #14113

Merged
merged 18 commits into from
Mar 31, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion include/mxnet/c_api_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ extern "C" {
* to the input graph for partitioning. This function should be
* used only for the testing purpose.
*/
MXNET_DLL int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
MXNET_DLL int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
Expand Down
2 changes: 1 addition & 1 deletion src/c_api/c_api_symbolic.cc
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ int MXGenBackendSubgraph(SymbolHandle sym_handle, const char *backend,
nnvm::Graph g = Symbol2Graph(*s);
property->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
g = ApplyPass(std::move(g), "BuildSubgraph");
s->outputs = g.outputs;
}
*ret_sym_handle = s;
Expand Down
4 changes: 2 additions & 2 deletions src/c_api/c_api_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
#include "./c_api_common.h"
#include "../operator/subgraph/subgraph_property.h"

int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
int MXBuildSubgraphByOpNames(SymbolHandle sym_handle,
const char* prop_name,
const mx_uint num_ops,
const char** op_names,
Expand All @@ -49,7 +49,7 @@ int MXPartitionGraphByOpNames(SymbolHandle sym_handle,
property->SetAttr("graph", g);
property->SetAttr("op_names", op_name_set);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(property));
g = nnvm::ApplyPass(std::move(g), "PartitionGraph");
g = nnvm::ApplyPass(std::move(g), "BuildSubgraph");
s->outputs = g.outputs;
}
}
Expand Down
53 changes: 26 additions & 27 deletions src/executor/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1442,30 +1442,29 @@ static nnvm::Graph InferForwardAttrs(nnvm::Graph g,

// Given input attr arrays, partition the graph using the backend name equal to prop_name.
// This is a common function for bind and simple_bind flows.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
mxnet::op::SubgraphPropertyPtr subgraph_prop,
const mxnet::ShapeVector& arg_shapes,
const nnvm::DTypeVector& arg_dtypes,
const StorageTypeVector& arg_stypes,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src,
mxnet::op::SubgraphPropertyPtr subgraph_prop,
const mxnet::ShapeVector& arg_shapes,
const nnvm::DTypeVector& arg_dtypes,
const StorageTypeVector& arg_stypes, const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
const std::vector<Context>& in_arg_ctxes,
const std::vector<Context>& aux_state_ctxes) {
nnvm::Symbol ret = src.Copy();
nnvm::Graph g;
g.outputs = ret.outputs;
g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
aux_state_ctxes);
subgraph_prop->SetAttr("graph", g);
g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(std::move(subgraph_prop));
g = ApplyPass(std::move(g), "PartitionGraph");
g = ApplyPass(std::move(g), "BuildSubgraph");
ret.outputs = g.outputs;
return ret;
}

// Given input attr dicts, partition the graph using the backend name equal to prop_name.
// This is for simple_bind flow.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src,
const std::string& prop_name,
const std::unordered_map<std::string, mxnet::TShape>
& arg_shape_map,
Expand Down Expand Up @@ -1547,7 +1546,7 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,
arg_stypes[i] = it3->second;
}
}
ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, *in_arg_ctxes, *aux_state_ctxes);
// Reorder in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types according to
// partitioned symbol input sequence
Expand All @@ -1573,13 +1572,13 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src,

// Given input ndarrays, partition the graph using the backend name equal to prop_name.
// This is for bind flow.
static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& prop_name,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grad_store,
std::vector<OpReqType>* grad_req_type,
std::vector<NDArray>* aux_states) {
static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, const std::string& prop_name,
const Context& default_ctx,
const std::map<std::string, Context>& ctx_map,
std::vector<NDArray>* in_args,
std::vector<NDArray>* arg_grad_store,
std::vector<OpReqType>* grad_req_type,
std::vector<NDArray>* aux_states) {
// setup map for in_args, arg_grad_store, grad_req_type and aux_states
std::unordered_map<std::string, NDArray> in_args_map;
std::unordered_map<std::string, NDArray> arg_grad_store_map;
Expand Down Expand Up @@ -1664,8 +1663,8 @@ static nnvm::Symbol PartitionGraph(const nnvm::Symbol& src, const std::string& p
}
}

ret = PartitionGraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, in_arg_ctxes, aux_state_ctxes);
ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
ctx_map, in_arg_ctxes, aux_state_ctxes);
}
// Reorder in_args, arg_grad_store, grad_req_type and aux_states according to partitioned symbol
// input sequence
Expand Down Expand Up @@ -1713,9 +1712,9 @@ Executor *Executor::SimpleBind(nnvm::Symbol symbol,
std::vector<Context> tmp_aux_state_ctxes = aux_state_ctxes;
std::vector<OpReqType> tmp_grad_req_types = grad_req_types;
if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes,
&tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes);
symbol = exec::BuildSubgraph(symbol, exec->subgraph_property(), arg_shape_map, arg_dtype_map,
arg_stype_map, default_ctx, group2ctx, &tmp_in_arg_ctxes,
&tmp_arg_grad_ctxes, &tmp_grad_req_types, &tmp_aux_state_ctxes);
}
exec->Init(symbol, default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes,
tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map, tmp_grad_req_types,
Expand All @@ -1738,9 +1737,9 @@ Executor *Executor::Bind(nnvm::Symbol symbol,
std::vector<NDArray> tmp_aux_states = aux_states;

if (!exec->subgraph_property().empty()) {
symbol = exec::PartitionGraph(symbol, exec->subgraph_property(), default_ctx, group2ctx,
&tmp_in_args, &tmp_arg_grad_store, &tmp_grad_req_type,
&tmp_aux_states);
symbol =
exec::BuildSubgraph(symbol, exec->subgraph_property(), default_ctx, group2ctx, &tmp_in_args,
&tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states);
}
exec->Init(symbol, default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store, tmp_grad_req_type,
tmp_aux_states, reinterpret_cast<Executor*>(shared_exec));
Expand Down
Loading