diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index 86db91546d32..747090695b99 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -12,54 +12,26 @@ #include #include "./atomic_symbol.h" namespace mxnet { - struct NodeMetaInfo{ - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - }; - - /*! - * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol - * with input symbols. - */ - struct Node { - - NodeMetaInfo info_; - /*! \brief inputs to this node */ - std::vector > in_symbol_; - /*! \brief index of the inputs if the inputs are tuple */ - std::vector in_index_; - /*! \brief the output shape of the wrapped symbol */ - std::vector out_shape_; - /*! - * \brief constructor - */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") { - info_.sym_ = sym; - info_.name_ = name; - } - /*! - * \brief destructor - */ - ~Node() { - if (info_.sym_) { - delete info_.sym_; - } - } - }; - + struct StaticGraph { + struct StaticNode { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + }; std::unordered_map name_id_map; - std::vector nodes; + std::vector nodes; std::vector > output_index; std::vector > connected_graph; - - int FindNodeByName(const std::string& name, const std::shared_ptr node) { + int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { int id = 0; if (name_id_map.find(name) == name_id_map.end()) { name_id_map[name] = name_id_map.size(); - nodes.push_back(node->info_); + StaticNode static_node; + static_node.sym_ = sym->Copy(); + static_node.name_ = name; + nodes.push_back(static_node); output_index.push_back(std::vector()); connected_graph.push_back(std::vector()); id = name_id_map.size(); diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 71c73f54a4c4..d72410809731 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -19,6 +19,7 @@ #include "./static_graph.h" namespace mxnet { +class CompositeOperator; /*! * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol * should support expressions and often passed by value. While AtomicSymbol have many subclasses, @@ -28,6 +29,37 @@ namespace mxnet { * A atomic symbol can be seen as a special case of the composite symbol with only the head node. */ class Symbol { + public: + /*! + * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol + * with input symbols. + */ + struct Node { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + /*! \brief inputs to this node */ + std::vector > in_symbol_; + /*! \brief index of the inputs if the inputs are tuple */ + std::vector in_index_; + /*! \brief the output shape of the wrapped symbol */ + std::vector out_shape_; + /*! + * \brief constructor + */ + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : + sym_(sym), name_(name) { + } + /*! + * \brief destructor + */ + ~Node() { + if (sym_) { + delete sym_; + } + } + }; protected: /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; @@ -37,7 +69,13 @@ class Symbol { std::shared_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); - void dfs_(const std::shared_ptr node, StaticGraph& graph); + /** + * @brief Recursively parse the symbol to equivalent static graph. + * + * @param node The current node in dfs + * @param graph The static graph + */ + void Dfs(const std::shared_ptr node, StaticGraph& graph); public: /*! * \brief declare virtual destructor in case it is subclassed. @@ -48,7 +86,14 @@ class Symbol { * \param ctx context of the operator * \return returns the pointer to a created operator. It is on the user to delete. */ - virtual Operator* Bind(Context ctx) const { return nullptr; } + virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } + /** + * @brief Bind the symbol to a composite operator + * + * @param in A map denotes name and corresponding NArray for binding + * @return The composite operator + */ + virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); /*! * \brief copy the symbol * \return a deep copy of the graph @@ -75,7 +120,10 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); - + /** + * @brief Convert current symbol to its equivalent static graph representation. + * @return the static graph + */ virtual StaticGraph ToStaticGraph(); /*! * \brief create Symbol by wrapping AtomicSymbol diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 7c35804e7f38..2506b49af65f 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -22,7 +22,7 @@ void Symbol::FindArgUsers() { stk.pop_back(); } else { Node* next_level = back.first->in_symbol_[back.second].get(); - if (next_level->info_.sym_) { + if (next_level->sym_) { stk.push_back({next_level, 0}); } else { // back uses next_level which is a placeholder arg_users_->push_back({back.first, back.second}); @@ -42,10 +42,10 @@ Symbol Symbol::Copy() const { Node* back = stk.back(); stk.pop_back(); if (old_new.count(back) == 0) { - if (back->info_.sym_) { - old_new[back] = std::make_shared(back->info_.sym_->Copy(), back->info_.name_); + if (back->sym_) { + old_new[back] = std::make_shared(back->sym_->Copy(), back->name_); } else { - old_new[back] = std::make_shared(nullptr, back->info_.name_); + old_new[back] = std::make_shared(nullptr, back->name_); } } for (const std::shared_ptr& n : back->in_symbol_) { @@ -98,7 +98,7 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg << s.arg_users_->size() << " provided " << kwargs.size(); for (size_t i = 0; i < s.arg_users_->size(); ++i) { const std::pair& arg_user = (*s.arg_users_)[i]; - const std::string& name = arg_user.first->info_.name_; + const std::string& name = arg_user.first->name_; if (!(name == "") && kwargs.count(name) != 0) { const Symbol& bind = kwargs.at(name); arg_user.first->in_symbol_[arg_user.second] = bind.head_; @@ -125,7 +125,7 @@ std::vector Symbol::ListArgs() { } std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret), [&](const std::pair& n) -> std::string { - return n.first->in_symbol_[n.second]->info_.name_; + return n.first->in_symbol_[n.second]->name_; }); return ret; } @@ -162,20 +162,25 @@ Symbol Symbol::Create(const std::string& type_name, StaticGraph Symbol::ToStaticGraph() { StaticGraph graph; - dfs_(this->head_, graph); + Dfs(this->head_, graph); return graph; } +CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { + StaticGraph graph = this->ToStaticGraph(); + return NULL; + //TODO: pass the graph and in to initlialize a composite op. +} -void Symbol::dfs_(const std::shared_ptr node, StaticGraph& graph) { - int id = graph.FindNodeByName(node->info_.name_, node); +void Symbol::Dfs(const std::shared_ptr node, StaticGraph& graph) { + int id = graph.FindNodeByName(node->name_, node->sym_); for (size_t i = 0; i < node->in_symbol_.size(); ++i) { std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph.FindNodeByName(parent->info_.name_, node); + int parent_id = graph.FindNodeByName(parent->name_, parent->sym_); graph.connected_graph[parent_id].push_back(id); graph.output_index[parent_id].push_back(node->in_index_[i]); - if (parent->info_.sym_) { - dfs_(parent, graph); + if (parent->sym_) { + Dfs(parent, graph); } } }