From fead856f5ed0925c687efe9b827a636f190d057f Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Thu, 6 Aug 2015 12:05:43 -0600 Subject: [PATCH] fix travis --- .gitignore | 8 ++++ include/mxnet/static_graph.h | 74 +++++++++++++++++++++--------------- include/mxnet/symbol.h | 14 ++++--- src/symbol/symbol.cc | 14 +++---- 4 files changed, 67 insertions(+), 43 deletions(-) diff --git a/.gitignore b/.gitignore index fe1264c2b747..1c1cd99ab7a4 100644 --- a/.gitignore +++ b/.gitignore @@ -27,8 +27,16 @@ *.out *.app *~ + +# doc +doc/html +doc/latex +doc/doc + +#dmlc dmlc-core mshadow +rabit config.mk *.pyc diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index 747090695b99..e06cbf4f9c42 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -12,34 +12,48 @@ #include #include "./atomic_symbol.h" namespace mxnet { - - struct StaticGraph { - struct StaticNode { - /*! \brief wrapped atomic symbol */ - AtomicSymbol* sym_; - /*! \brief name of the node */ - std::string name_; - }; - std::unordered_map name_id_map; - std::vector nodes; - std::vector > output_index; - std::vector > connected_graph; - int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { - int id = 0; - if (name_id_map.find(name) == name_id_map.end()) { - name_id_map[name] = name_id_map.size(); - StaticNode static_node; - static_node.sym_ = sym->Copy(); - static_node.name_ = name; - nodes.push_back(static_node); - output_index.push_back(std::vector()); - connected_graph.push_back(std::vector()); - id = name_id_map.size(); - } else { - id = name_id_map[name]; - } - return id; - } +/*! \brief static graph interface + * static graph is an internal representation of symbol graph. + * + * The main purpose for static graph for binding a composite operator + */ +struct StaticGraph { + /*! \brief Node in static graph */ + struct StaticNode { + /*! \brief wrapped atomic symbol */ + AtomicSymbol* sym_; + /*! \brief name of the node */ + std::string name_; }; -} -#endif + /*! \brief node name to id dictionary */ + std::unordered_map name_id_map; + /*! \brief all nodes in the graph */ + std::vector nodes; + /*! \brief output id for each node */ + std::vector > output_index; + /*! \brief connected graph for each node */ + std::vector > connected_graph; + /*! \brief find node by using name + * \param name node name + * \param sym symbol need to be copied into node + * \return node id + */ + int FindNodeByName(const std::string& name, const AtomicSymbol* sym) { + int id = 0; + if (name_id_map.find(name) == name_id_map.end()) { + name_id_map[name] = name_id_map.size(); + StaticNode static_node; + static_node.sym_ = sym->Copy(); + static_node.name_ = name; + nodes.push_back(static_node); + output_index.push_back(std::vector()); + connected_graph.push_back(std::vector()); + id = name_id_map.size(); + } else { + id = name_id_map[name]; + } + return id; + } +}; +} // namespace mxnet +#endif // MXNET_STATIC_GRAPH_H_ diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index d72410809731..df1e78438560 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -60,6 +60,7 @@ class Symbol { } } }; + protected: /*! \brief the head node of the Symbol, it could be shared in many graphs */ std::shared_ptr head_; @@ -71,11 +72,12 @@ class Symbol { void FindArgUsers(); /** * @brief Recursively parse the symbol to equivalent static graph. - * + * * @param node The current node in dfs * @param graph The static graph */ - void Dfs(const std::shared_ptr node, StaticGraph& graph); + void Dfs(const std::shared_ptr node, StaticGraph *graph); + public: /*! * \brief declare virtual destructor in case it is subclassed. @@ -88,10 +90,10 @@ class Symbol { */ virtual CompositeOperator* Bind(Context ctx) const { return nullptr; } /** - * @brief Bind the symbol to a composite operator - * - * @param in A map denotes name and corresponding NArray for binding - * @return The composite operator + * \brief Bind the symbol to a composite operator + * \param ctx context of the operator + * \param in A map denotes name and corresponding NArray for binding + * \return The composite operator */ virtual CompositeOperator* Bind(Context ctx, const std::unordered_map& in); /*! diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 2506b49af65f..a81b7ce0cccd 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -162,23 +162,23 @@ Symbol Symbol::Create(const std::string& type_name, StaticGraph Symbol::ToStaticGraph() { StaticGraph graph; - Dfs(this->head_, graph); + Dfs(this->head_, &graph); return graph; } CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map& in) { StaticGraph graph = this->ToStaticGraph(); return NULL; - //TODO: pass the graph and in to initlialize a composite op. + // TODO(bing): pass the graph and in to initlialize a composite op. } -void Symbol::Dfs(const std::shared_ptr node, StaticGraph& graph) { - int id = graph.FindNodeByName(node->name_, node->sym_); +void Symbol::Dfs(const std::shared_ptr node, StaticGraph *graph) { + int id = graph->FindNodeByName(node->name_, node->sym_); for (size_t i = 0; i < node->in_symbol_.size(); ++i) { std::shared_ptr parent = node->in_symbol_[i]; - int parent_id = graph.FindNodeByName(parent->name_, parent->sym_); - graph.connected_graph[parent_id].push_back(id); - graph.output_index[parent_id].push_back(node->in_index_[i]); + int parent_id = graph->FindNodeByName(parent->name_, parent->sym_); + graph->connected_graph[parent_id].push_back(id); + graph->output_index[parent_id].push_back(node->in_index_[i]); if (parent->sym_) { Dfs(parent, graph); }