diff --git a/src/c_api/c_api_symbolic.cc b/src/c_api/c_api_symbolic.cc index 19a6f35b50c2..647521e0109d 100644 --- a/src/c_api/c_api_symbolic.cc +++ b/src/c_api/c_api_symbolic.cc @@ -31,6 +31,7 @@ #include "./c_api_common.h" #include "../operator/operator_common.h" #include "../executor/exec_pass.h" +#include "../operator/subgraph/subgraph_op.h" namespace mxnet { namespace op { @@ -639,7 +640,11 @@ int MXPartitionGraph(SymbolHandle sym_handle, nnvm::Symbol* sym = static_cast(sym_handle); *s = sym->Copy(); nnvm::Graph g = Symbol2Graph(*s); - g.attrs["subgraph_op_names"] = std::make_shared(std::move(op_name_set)); + if (!op_name_set.empty()) { + mxnet::op::SubgraphPropertyPtr property + = std::make_shared(op_name_set); + g.attrs["subgraph_property"] = std::make_shared(std::move(property)); + } g = ApplyPass(std::move(g), "PartitionGraph"); s->outputs = g.outputs; *ret_sym_handle = s; diff --git a/src/operator/subgraph/partition_graph.cc b/src/operator/subgraph/partition_graph.cc index 71fe7654ab6c..4a27ba5c7bfc 100644 --- a/src/operator/subgraph/partition_graph.cc +++ b/src/operator/subgraph/partition_graph.cc @@ -28,7 +28,14 @@ #include #include +#include "./subgraph_op.h" + +namespace nnvm { +NodePtr CreateVariableNode(const std::string& name); +} + namespace mxnet { + namespace op { using nnvm::Symbol; @@ -50,24 +57,6 @@ NodePtr CloneVariableNode(const nnvm::Node& src) { namespace sg { // sg stands for subgraph -struct SimpleNode; -using SimpleNodePtr = std::shared_ptr; - -struct SimpleNode { - static SimpleNodePtr Create() { - return std::make_shared(); - } - SimpleNode() : label(-1), node(nullptr) {} - int label; - nnvm::Node* node; - // key is node ptr - // value is the index array standing for the entry indices - // in key->inputs that use this->node as input node - std::unordered_map> outputs; - //std::unordered_map inputs; - //std::unordered_map outputs; -}; - void CreateSimpleGraph(const Graph& g, std::vector* simple_nodes) { const auto& indexed_graph = g.indexed_graph(); @@ -111,8 +100,13 @@ void PrintSubgraph(const std::vector& simple_nodes) { LOG(INFO) << "Subgraph node names: " << op_names; } +/* + * This function traverses the nodes in a computation graph from a starting + * node following the input links and output links, and marks all nodes that + * can be accessed from the starting node. + */ void LabelSubgraph(const Graph&g, - const std::unordered_set& op_names, + SubgraphSelectorPtr select_func, const int label, const size_t snid, // simple node id const std::vector& simple_nodes, @@ -126,36 +120,40 @@ void LabelSubgraph(const Graph&g, cur_node->label = label; subgraph_nodes->push_back(cur_node); // get qualified adjacent input nodes - for (auto& e : cur_node->node->inputs) { - if (!e.node->is_variable() && op_names.count(e.node->op()->name)) { - const auto nid = indexed_graph.node_id(e.node.get()); - CHECK_LT(nid, simple_nodes.size()); - if (simple_nodes[nid]->label == -1) { // this node has not been visited yet - node_queue.push(simple_nodes[nid].get()); - } else { - CHECK_EQ(simple_nodes[nid]->label, label); + if (select_func->UseIncomingEdges()) { + for (auto& e : cur_node->node->inputs) { + if (select_func->Select(*e.node)) { + const auto nid = indexed_graph.node_id(e.node.get()); + CHECK_LT(nid, simple_nodes.size()); + // this node has not been visited yet + if (simple_nodes[nid]->label == -1) + node_queue.push(simple_nodes[nid].get()); } } } // get qualified output nodes - for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) { - CHECK(!it->first->is_variable()); - if (op_names.count(it->first->op()->name)) { - const auto nid = indexed_graph.node_id(it->first); - CHECK_LT(nid, simple_nodes.size()); - if (simple_nodes[nid]->label == -1) { // this node has not been visited yet - node_queue.push(simple_nodes[nid].get()); - } else { - CHECK_EQ(simple_nodes[nid]->label, label); + if (select_func->UseOutgoingEdges()) { + for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) { + if (select_func->Select(*it->first)) { + const auto nid = indexed_graph.node_id(it->first); + CHECK_LT(nid, simple_nodes.size()); + // this node has not been visited yet + if (simple_nodes[nid]->label == -1) + node_queue.push(simple_nodes[nid].get()); } } } } } -// number of subgraphs found +/* + * This function finds subgraphs with all nodes that meet certain criteria. + * All nodes in a subgraph are marked with the same label. + * All nodes in a subgraph have to be connected with each other. If a node + * doesn't meet the given criteria, it will be marked with a separate label. + */ void FindSubgraphs(const Graph& g, - const std::unordered_set& op_names, + const SubgraphProperty &subg_prop, const std::vector& simple_nodes, std::vector>* subgraph_nodes) { //CHECK(simple_nodes != nullptr); @@ -163,9 +161,11 @@ void FindSubgraphs(const Graph& g, CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size()); for (size_t i = 0; i < simple_nodes.size(); ++i) { nnvm::Node* node = simple_nodes[i]->node; - if (!node->is_variable() && simple_nodes[i]->label == -1 && op_names.count(node->op()->name)) { + auto select_func = subg_prop.CreateSubgraphSelector(); + if (select_func->Select(*node) && simple_nodes[i]->label == -1) { subgraph_nodes->emplace_back(); - LabelSubgraph(g, op_names, subgraph_nodes->size() - 1, i, simple_nodes, &subgraph_nodes->back()); + LabelSubgraph(g, select_func, subgraph_nodes->size() - 1, i, simple_nodes, + &subgraph_nodes->back()); } } } @@ -185,11 +185,9 @@ void FindInputEntries(const Graph& g, } for (auto& e : subgraph_nodes[i]->node->inputs) { const auto nid = indexed_graph.node_id(e.node.get()); - if (simple_nodes[nid]->label == -1) { // this is a node not belonging to the subgraph + // this is a node not belonging to the subgraph + if (simple_nodes[nid]->label != label) input_entries->push_back(&e); - } else { - CHECK_EQ(simple_nodes[nid]->label, label); - } } } } @@ -208,14 +206,15 @@ void FindOutputEntries(Graph* g, } else { CHECK_EQ(subgraph_nodes[i]->label, label); } - for (auto it = subgraph_nodes[i]->outputs.begin(); it != subgraph_nodes[i]->outputs.end(); ++it) { + for (auto it = subgraph_nodes[i]->outputs.begin(); + it != subgraph_nodes[i]->outputs.end(); ++it) { const auto nid = indexed_graph.node_id(it->first); - if (simple_nodes[nid]->label == -1) { // this is a node not belonging to the subgraph + // this is a node not belonging to the subgraph + if (simple_nodes[nid]->label != label) { + // TODO(zhengda) I need to test this. for (int idx : it->second) { output_entries->push_back(&simple_nodes[nid]->node->inputs[idx]); } - } else { - CHECK_EQ(simple_nodes[nid]->label, label); } } } @@ -241,6 +240,29 @@ void PrintNodeEntries(const std::vector& entries) { } } +/* + * Given a computation graph and a set of input node entries, this function cuts + * the node entries and creates new variable nodes as the input nodes of the + * subgraph. It returns the nodes that connect to the subgraph directly and + * the names of the new variable nodes. + */ +void CutGraphInputs(const std::vector &input_entries, + bool skip_var, std::vector *orig_entries) { + orig_entries->reserve(input_entries.size()); + for (size_t i = 0; i < input_entries.size(); i++) { + nnvm::NodeEntry *e = input_entries[i]; + // If the node is a variable itself, we may want to skip the node. + if (e->node->is_variable() && skip_var) + continue; + + orig_entries->push_back(*e); + nnvm::Symbol sym; + sym.outputs.push_back(*e); + nnvm::NodePtr n = nnvm::CreateVariableNode(sym.ListOutputNames()[0]); + *e = nnvm::NodeEntry{n, 0, 0}; + } +} + } // namespace sg Graph PartitionGraph(Graph&& g) { @@ -252,8 +274,7 @@ Graph PartitionGraph(Graph&& g) { } }); #endif - const std::unordered_set& op_names = g.GetAttr>("subgraph_op_names"); - if (op_names.empty()) { // treat the whole graph as a subgraph + if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a subgraph Symbol whole_graph_sym; whole_graph_sym.outputs = g.outputs; // DO NOT define node name for subgraph op because it would serve @@ -283,23 +304,40 @@ Graph PartitionGraph(Graph&& g) { return ret; } else { using namespace sg; + SubgraphPropertyPtr subg_prop = g.GetAttr("subgraph_property"); std::vector simple_nodes; CreateSimpleGraph(g, &simple_nodes); std::vector> subgraph_nodes; - FindSubgraphs(g, op_names, simple_nodes, &subgraph_nodes); + FindSubgraphs(g, *subg_prop, simple_nodes, &subgraph_nodes); std::vector entries; // TODO(junwu): take care of the situation when the op is the last op for (size_t i = 0; i < subgraph_nodes.size(); ++i) { PrintSubgraph(subgraph_nodes[i]); + // Break the input links. LOG(INFO) << "Searching for input entries..."; entries.clear(); FindInputEntries(g, simple_nodes, subgraph_nodes[i], &entries); + std::vector orig_input_entries; + sg::CutGraphInputs(entries, false, &orig_input_entries); PrintNodeEntries(entries); LOG(INFO) << "Searching for output entries..."; entries.clear(); FindOutputEntries(&g, simple_nodes, subgraph_nodes[i], &entries); + + // Create a subgraph. + nnvm::Symbol sym; + sym.outputs.resize(entries.size()); + for (size_t i = 0; i < entries.size(); i++) + sym.outputs[i] = *entries[i]; + nnvm::NodePtr n = subg_prop->CreateSubgraphNode(sym); + + // Connect the external nodes to the subgraph node. + for (uint32_t i = 0; i < entries.size(); i++) + *entries[i] = nnvm::NodeEntry{n, i, 0}; + // TODO(zhengda) this may not be the right order for input entries of a subgraph? + n->inputs = orig_input_entries; PrintNodeEntries(entries); } return g; diff --git a/src/operator/subgraph/subgraph_op.cc b/src/operator/subgraph/subgraph_op.cc index 637e3326eda9..326ee34cfe8c 100644 --- a/src/operator/subgraph/subgraph_op.cc +++ b/src/operator/subgraph/subgraph_op.cc @@ -27,20 +27,29 @@ #include "../operator_common.h" #include "../elemwise_op_common.h" #include "../../imperative/imperative_utils.h" +#include "./subgraph_op.h" namespace mxnet { namespace op { -struct SubgraphOpState { +static std::unordered_map subg_props; + +void RegisterSubgraphProperty(SubgraphPropertyPtr property) { + auto ret = subg_props.insert(std::pair( + property->GetType(), property)); + CHECK(!ret.second) << "The subgraph property for " << property->GetType() + << " has been registered"; +} + +class DefaultSubgraphOperator: public SubgraphOperator { + public: // TODO: initialize uuid - SubgraphOpState(const Symbol& sym) : - subgraph_sym_(&sym), + DefaultSubgraphOperator(const Symbol& sym) : SubgraphOperator(sym), subgraph_uuid_("dfasdfadsmxdfw324"), immutable_data_names_(sym.ListInputNames(Symbol::kReadOnlyArgs)), mutable_data_names_(sym.ListInputNames(Symbol::kAuxiliaryStates)), //input_data_names_(sym.ListInputNames(Symbol::kAll)), - output_data_names_(sym.ListOutputNames()), - subgraph_executor_(nullptr) { + output_data_names_(sym.ListOutputNames()) { const std::vector input_data_names = sym.ListInputNames(Symbol::kAll); //const std::vector immutable_data_names = sym.ListInputNames(Symbol::kReadOnlyArgs); //const std::vector mutable_data_names = sym.ListInputNames(Symbol::kAuxiliaryStates); @@ -60,8 +69,20 @@ struct SubgraphOpState { ndarray_var_versions_.resize(input_data_names.size(), -1); } - // arguments should have the same order as NDArrays in FCompute - const Symbol* subgraph_sym_; + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + void Backward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // TODO we don't support backward yet. + // Many of the subgraphs don't need backward computation. + } + + private: + std::string subgraph_uuid_; // this variable records the NDArrays' var versions of the last run. std::vector ndarray_var_versions_; std::vector immutable_data_indices_; @@ -70,16 +91,88 @@ struct SubgraphOpState { std::vector mutable_data_names_; //std::vector input_data_names_; std::vector output_data_names_; - std::string subgraph_uuid_; std::shared_ptr subgraph_executor_; -}; // SubgraphOpState +}; + +SubgraphOperatorPtr SimpleSubgraphProperty::CreateSubgraphOperator(const nnvm::Symbol &sym) const { + return std::make_shared(sym); +} + +void DefaultSubgraphOperator::Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + // We can create an executor to run this subgraph op + if (this->subgraph_executor_.get() == nullptr) { + std::vector arg_arrays; + std::vector aux_arrays; + for (size_t i = 0; i < this->immutable_data_indices_.size(); ++i) { + arg_arrays.push_back(inputs[this->immutable_data_indices_[i]]); + } + for (size_t i = 0; i < this->mutable_data_indices_.size(); ++i) { + aux_arrays.push_back(inputs[this->mutable_data_indices_[i]]); + } + std::vector grad_store(arg_arrays.size()); + std::vector grad_req(arg_arrays.size(), kNullOp); + this->subgraph_executor_.reset(Executor::Bind(this->GetSubgraph(), + ctx.run_ctx.ctx, std::map(), arg_arrays, grad_store, + grad_req, aux_arrays)); + } + // TODO: replace the hard-coded integer with inputs[i].var().version + // If var version is old, need to update ndarray in the executor. + const int64_t max_var_version = 12034324324; + for (size_t i = 0; i < this->immutable_data_names_.size(); ++i) { + if (this->ndarray_var_versions_[this->immutable_data_indices_[i]] < max_var_version) { + auto it = this->subgraph_executor_->in_arg_map().find(this->immutable_data_names_[i]); + CHECK(it != this->subgraph_executor_->in_arg_map().end()); + // Commented out because we don't have interface to do it yet + it->second = inputs[this->immutable_data_indices_[i]]; + ++this->ndarray_var_versions_[this->immutable_data_indices_[i]]; + } + } + for (size_t i = 0; i < this->mutable_data_names_.size(); ++i) { + if (this->ndarray_var_versions_[this->mutable_data_indices_[i]] < max_var_version) { + auto it = this->subgraph_executor_->aux_state_map().find(this->mutable_data_names_[i]); + CHECK(it != this->subgraph_executor_->aux_state_map().end()); + // Commented out because we don't have interface to do it yet + it->second = inputs[this->mutable_data_indices_[i]]; + ++this->ndarray_var_versions_[this->mutable_data_indices_[i]]; + } + } + this->subgraph_executor_->Forward(false); + for (size_t i = 0; i < outputs.size(); ++i) { + //NDArray tmp = outputs[i]; + //this->subgraph_executor_->output_arrays()[i].WaitToRead(); + CopyFromTo(this->subgraph_executor_->output_arrays()[i], &outputs[i]); + //tmp = this->subgraph_executor_->output_arrays()[i]; + } +} + +struct SubgraphOpState { + SubgraphOperatorPtr op; + + SubgraphOpState(SubgraphOperatorPtr op) { + this->op = op; + } +}; OpStatePtr CreateSubgraphOpState(const NodeAttrs& attrs, Context ctx, const std::vector& in_shapes, const std::vector& in_types) { const Symbol& subgraph_sym = nnvm::get(attrs.parsed); - return OpStatePtr::Create(subgraph_sym); + auto it = attrs.dict.find("exec_type"); + if (it == attrs.dict.end()) { + auto op = std::make_shared(subgraph_sym); + return OpStatePtr::Create(op); + } + + std::string exec_name = it->second; + auto prop_iter = subg_props.find(exec_name); + CHECK(prop_iter != subg_props.end()) << "We don't support the execution type: " + << exec_name; + auto op = prop_iter->second->CreateSubgraphOperator(subgraph_sym); + return OpStatePtr::Create(op); } bool SubgraphOpShape(const nnvm::NodeAttrs& attrs, @@ -175,62 +268,7 @@ void SubgraphOpForward(const OpStatePtr& state_ptr, const std::vector& req, const std::vector& outputs) { SubgraphOpState& state = state_ptr.get_state(); - // We can create an executor to run this subgraph op - if (state.subgraph_executor_.get() == nullptr) { - std::vector arg_arrays; - std::vector aux_arrays; - for (size_t i = 0; i < state.immutable_data_indices_.size(); ++i) { - arg_arrays.push_back(inputs[state.immutable_data_indices_[i]]); - } - for (size_t i = 0; i < state.mutable_data_indices_.size(); ++i) { - aux_arrays.push_back(inputs[state.mutable_data_indices_[i]]); - } - std::vector grad_store(arg_arrays.size()); - std::vector grad_req(arg_arrays.size(), kNullOp); - state.subgraph_executor_.reset(Executor::Bind(*state.subgraph_sym_, - ctx.run_ctx.ctx, std::map(), arg_arrays, grad_store, - grad_req, aux_arrays)); - } - // TODO: replace the hard-coded integer with inputs[i].var().version - // If var version is old, need to update ndarray in the executor. - const int64_t max_var_version = 12034324324; - for (size_t i = 0; i < state.immutable_data_names_.size(); ++i) { - if (state.ndarray_var_versions_[state.immutable_data_indices_[i]] < max_var_version) { - auto it = state.subgraph_executor_->in_arg_map().find(state.immutable_data_names_[i]); - CHECK(it != state.subgraph_executor_->in_arg_map().end()); - // Commented out because we don't have interface to do it yet - it->second = inputs[state.immutable_data_indices_[i]]; - ++state.ndarray_var_versions_[state.immutable_data_indices_[i]]; - } - } - for (size_t i = 0; i < state.mutable_data_names_.size(); ++i) { - if (state.ndarray_var_versions_[state.mutable_data_indices_[i]] < max_var_version) { - auto it = state.subgraph_executor_->aux_state_map().find(state.mutable_data_names_[i]); - CHECK(it != state.subgraph_executor_->aux_state_map().end()); - // Commented out because we don't have interface to do it yet - it->second = inputs[state.mutable_data_indices_[i]]; - ++state.ndarray_var_versions_[state.mutable_data_indices_[i]]; - } - } - state.subgraph_executor_->Forward(false); - for (size_t i = 0; i < outputs.size(); ++i) { - //NDArray tmp = outputs[i]; - //state.subgraph_executor_->output_arrays()[i].WaitToRead(); - CopyFromTo(state.subgraph_executor_->output_arrays()[i], &outputs[i]); - //tmp = state.subgraph_executor_->output_arrays()[i]; - } - -#if 0 - for (const auto& name : state.immutable_data_names_) { - LOG(INFO) << "SubgraphOpForward: input_data_name = " << name; - } - for (const auto& name : state.output_data_names_) { - LOG(INFO) << "SubgraphOpForward: output_data_name = " << name; - } - for (const auto v : state.ndarray_var_versions_) { - LOG(INFO) << "SubgraphOpForward: var_version = " << v; - } -#endif + state.op->Forward(ctx, inputs, req, outputs); } NNVM_REGISTER_OP(_subgraph_op) diff --git a/src/operator/subgraph/subgraph_op.h b/src/operator/subgraph/subgraph_op.h new file mode 100644 index 000000000000..a6621509f5c3 --- /dev/null +++ b/src/operator/subgraph/subgraph_op.h @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_OP_H_ + +#include +#include + +namespace mxnet { + +namespace op { + +namespace sg { // sg stands for subgraph + +struct SimpleNode; +using SimpleNodePtr = std::shared_ptr; + +struct SimpleNode { + static SimpleNodePtr Create() { + return std::make_shared(); + } + SimpleNode() : label(-1), node(nullptr) {} + int label; + nnvm::Node* node; + // key is node ptr + // value is the index array standing for the entry indices + // in key->inputs that use this->node as input node + std::unordered_map> outputs; +}; + +} + +/* + * This provides criteria for selecting nodes in a subgraph. + * When a node is passed to this object, the selection criteria may be changed. + * We can also specify what links we should use when traversing the neighbor + * nodes. + */ +class SubgraphSelector { + public: + virtual ~SubgraphSelector() { + } + /* + * Given a set of nodes that have been selected so far for a subgraph, determine + * if the input node should be selected for a subgraph. + */ + virtual bool Select(const nnvm::Node &n) = 0; + virtual bool UseIncomingEdges() const = 0; + virtual bool UseOutgoingEdges() const = 0; +}; + +using SubgraphSelectorPtr = std::shared_ptr; + +/* + * This is the interface of the subgraph operator that executes the computation + * in the subgraph. + */ +class SubgraphOperator { +public: + SubgraphOperator(const nnvm::Symbol &sym) { + this->subgraph_sym_ = sym; + } + + virtual ~SubgraphOperator() { + } + + const nnvm::Symbol &GetSubgraph() const { + return subgraph_sym_; + } + + virtual void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) = 0; + virtual void Backward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) = 0; +private: + nnvm::Symbol subgraph_sym_; +}; + +using SubgraphOperatorPtr = std::shared_ptr; + +/* + * This provides a set of properties for partitioning a graph into subgraphs, + * reconstructing a new graph from the subgraphs and creating a subgraph + * operator to execute the subgraph. + */ +class SubgraphProperty { + public: + // the criteria of selecting the subgraph nodes. + virtual SubgraphSelectorPtr CreateSubgraphSelector() const = 0; + // create an nnvm node for a given subgraph. Here users can customize how to + // execute the operators in the subgraph. + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &s) const = 0; + // Create a subgraph operator for execution. + virtual SubgraphOperatorPtr CreateSubgraphOperator(const nnvm::Symbol &sym) const = 0; + // The type of the subgraph. + virtual std::string GetType() const = 0; +}; + +using SubgraphPropertyPtr = std::shared_ptr; + +void RegisterSubgraphProperty(SubgraphPropertyPtr property); + +/* + * This selects nodes for a subgraph that only contains operators + * in a given set and it visits nodes via both input and output links. + */ +class ContainOpSelector: public SubgraphSelector { + std::shared_ptr> op_names; + + public: + ContainOpSelector(std::shared_ptr> op_names) { + this->op_names = op_names; + } + + virtual bool UseIncomingEdges() const { + return true; + } + + virtual bool UseOutgoingEdges() const { + return true; + } + + virtual bool Select(const nnvm::Node &n) { + return !n.is_variable() && op_names->count(n.op()->name); + } +}; + +/* + * This subgraph property finds a subgraph whose nodes have only operators + * within a set. The operators in the subgraph will be executed by _subgraph_op. + */ +class SimpleSubgraphProperty: public SubgraphProperty { + public: + SimpleSubgraphProperty(const std::unordered_set &op_names) { + this->op_names = std::make_shared>(op_names); + } + virtual nnvm::NodePtr CreateSubgraphNode(const nnvm::Symbol &sym) const { + nnvm::NodePtr n = nnvm::Node::Create(); + n->attrs.op = Op::Get("_subgraph_op"); + n->attrs.name = "_subgraph_op"; + n->attrs.dict.insert(std::pair("exec_type", GetType())); + n->attrs.parsed = sym; + return n; + } + virtual SubgraphSelectorPtr CreateSubgraphSelector() const { + return std::make_shared(op_names); + } + + virtual SubgraphOperatorPtr CreateSubgraphOperator(const nnvm::Symbol &sym) const; + virtual std::string GetType() const { + return "default"; + } + + private: + std::shared_ptr> op_names; +}; + +} +} + +#endif // MXNET_OPERATOR_SUBGRAPH_SUBGRAPH_OP_H_