diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h new file mode 100644 index 000000000000..747090695b99 --- /dev/null +++ b/include/mxnet/static_graph.h @@ -0,0 +1,45 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_graph.h + * \brief the static graph of symbols + */ +#ifndef MXNET_STATIC_GRAPH_H_ +#define MXNET_STATIC_GRAPH_H_ + +#include +#include +#include +#include +#include "./atomic_symbol.h" +namespace mxnet { + + struct StaticGraph { + struct StaticNode { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; + }; + std::unordered_map name_id_map; + std::vector nodes; + std::vector > output_index; + std::vector > connected_graph; + int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { + int id = 0; + if (name_id_map.find(name) == name_id_map.end()) { + name_id_map[name] = name_id_map.size(); + StaticNode static_node; + static_node.sym_ = sym->Copy(); + static_node.name_ = name; + nodes.push_back(static_node); + output_index.push_back(std::vector()); + connected_graph.push_back(std::vector()); + id = name_id_map.size(); + } else { + id = name_id_map[name]; + } + return id; + } + }; +} +#endif diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index 0b69005f7a16..d72410809731 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -16,8 +16,10 @@ #include "./base.h" #include "./tensor_blob.h" #include "./operator.h" +#include "./static_graph.h" namespace mxnet { +class CompositeOperator; /*! * \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol * should support expressions and often passed by value. While AtomicSymbol have many subclasses, @@ -27,7 +29,7 @@ namespace mxnet { * A atomic symbol can be seen as a special case of the composite symbol with only the head node. */ class Symbol { - protected: + public: /*! * \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol * with input symbols. @@ -46,12 +48,19 @@ class Symbol { /*! * \brief constructor */ - explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = ""); + explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") : + sym_(sym), name_(name) { + } /*! * \brief destructor */ - ~Node(); + ~Node() { + if (sym_) { + delete sym_; + } + } }; + protected: /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; /*! \brief if the head has multiple return values, index is used to specify */ @@ -60,7 +69,13 @@ class Symbol { std::shared_ptr > > arg_users_; /*! \brief find arg users */ void FindArgUsers(); - + /** + * @brief Recursively parse the symbol to equivalent static graph. + * + * @param node The current node in dfs + * @param graph The static graph + */ + void Dfs(const std::shared_ptr node, StaticGraph& graph); public: /*! * \brief declare virtual destructor in case it is subclassed. @@ -71,7 +86,14 @@ class Symbol { * \param ctx context of the operator * \return returns the pointer to a created operator. It is on the user to delete. */ - virtual Operator* Bind(Context ctx) const { return nullptr; } + virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } + /** + * @brief Bind the symbol to a composite operator + * + * @param in A map denotes name and corresponding NArray for binding + * @return The composite operator + */ + virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); /*! * \brief copy the symbol * \return a deep copy of the graph @@ -98,6 +120,11 @@ class Symbol { * \return the arguments list of this symbol, they can be either named or unnamed (empty string). */ virtual std::vector ListArgs(); + /** + * @brief Convert current symbol to its equivalent static graph representation. + * @return the static graph + */ + virtual StaticGraph ToStaticGraph(); /*! * \brief create Symbol by wrapping AtomicSymbol */ diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index a4d966ba422f..2506b49af65f 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -6,20 +6,11 @@ #include #include #include +#include #include namespace mxnet { -Symbol::Node::Node(AtomicSymbol* sym, const std::string& name) - : sym_(sym), name_(name) { -} - -Symbol::Node::~Node() { - if (sym_) { - delete sym_; - } -} - void Symbol::FindArgUsers() { arg_users_.reset(new std::vector >); // depth first traversing @@ -144,14 +135,14 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { std::vector args = atomic_symbol->DescribeArguments(); std::vector rets = atomic_symbol->DescribeReturns(); // set head_ - s.head_ = std::make_shared(atomic_symbol, ""); + s.head_ = std::make_shared(atomic_symbol, ""); // set index_ s.index_ = rets.size() > 1 ? -1 : 0; // set head_->in_index_ s.head_->in_index_ = std::vector(args.size(), 0); // set head_->in_symbol_ for (auto name : args) { - s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); + s.head_->in_symbol_.push_back(std::make_shared(nullptr, name)); } // set head_->out_shape_ s.head_->out_shape_ = std::vector(rets.size()); @@ -169,4 +160,29 @@ Symbol Symbol::Create(const std::string& type_name, return Create(atomic_symbol); } +StaticGraph Symbol::ToStaticGraph() { + StaticGraph graph; + Dfs(this->head_, graph); + return graph; +} + +CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { + StaticGraph graph = this->ToStaticGraph(); + return NULL; + //TODO: pass the graph and in to initlialize a composite op. +} + +void Symbol::Dfs(const std::shared_ptr node, StaticGraph& graph) { + int id = graph.FindNodeByName(node->name_, node->sym_); + for (size_t i = 0; i < node->in_symbol_.size(); ++i) { + std::shared_ptr parent = node->in_symbol_[i]; + int parent_id = graph.FindNodeByName(parent->name_, parent->sym_); + graph.connected_graph[parent_id].push_back(id); + graph.output_index[parent_id].push_back(node->in_index_[i]); + if (parent->sym_) { + Dfs(parent, graph); + } + } +} + } // namespace mxnet