Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix travis
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 6, 2015
1 parent 2c9091d commit fead856
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 43 deletions.
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@
*.out
*.app
*~

# doc
doc/html
doc/latex
doc/doc

#dmlc
dmlc-core
mshadow
rabit
config.mk

*.pyc
Expand Down
74 changes: 44 additions & 30 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,48 @@
#include <memory>
#include "./atomic_symbol.h"
namespace mxnet {

struct StaticGraph {
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};
std::unordered_map<std::string, int> name_id_map;
std::vector<StaticNode> nodes;
std::vector<std::vector<int> > output_index;
std::vector<std::vector<int> > connected_graph;
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
/*! \brief static graph interface
* static graph is an internal representation of symbol graph.
*
* The main purpose for static graph for binding a composite operator
*/
struct StaticGraph {
/*! \brief Node in static graph */
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};
}
#endif
/*! \brief node name to id dictionary */
std::unordered_map<std::string, int> name_id_map;
/*! \brief all nodes in the graph */
std::vector<StaticNode> nodes;
/*! \brief output id for each node */
std::vector<std::vector<int> > output_index;
/*! \brief connected graph for each node */
std::vector<std::vector<int> > connected_graph;
/*! \brief find node by using name
* \param name node name
* \param sym symbol need to be copied into node
* \return node id
*/
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
};
} // namespace mxnet
#endif // MXNET_STATIC_GRAPH_H_
14 changes: 8 additions & 6 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class Symbol {
}
}
};

protected:
/*! \brief the head node of the Symbol, it could be shared in many graphs */
std::shared_ptr<Node> head_;
Expand All @@ -71,11 +72,12 @@ class Symbol {
void FindArgUsers();
/**
* @brief Recursively parse the symbol to equivalent static graph.
*
*
* @param node The current node in dfs
* @param graph The static graph
*/
void Dfs(const std::shared_ptr<Node> node, StaticGraph& graph);
void Dfs(const std::shared_ptr<Node> node, StaticGraph *graph);

public:
/*!
* \brief declare virtual destructor in case it is subclassed.
Expand All @@ -88,10 +90,10 @@ class Symbol {
*/
virtual CompositeOperator* Bind(Context ctx) const { return nullptr; }
/**
* @brief Bind the symbol to a composite operator
*
* @param in A map denotes name and corresponding NArray for binding
* @return The composite operator
* \brief Bind the symbol to a composite operator
* \param ctx context of the operator
* \param in A map denotes name and corresponding NArray for binding
* \return The composite operator
*/
virtual CompositeOperator* Bind(Context ctx, const std::unordered_map<std::string, NArray>& in);
/*!
Expand Down
14 changes: 7 additions & 7 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,23 +162,23 @@ Symbol Symbol::Create(const std::string& type_name,

StaticGraph Symbol::ToStaticGraph() {
StaticGraph graph;
Dfs(this->head_, graph);
Dfs(this->head_, &graph);
return graph;
}

CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map<std::string, NArray>& in) {
StaticGraph graph = this->ToStaticGraph();
return NULL;
//TODO: pass the graph and in to initlialize a composite op.
// TODO(bing): pass the graph and in to initlialize a composite op.
}

void Symbol::Dfs(const std::shared_ptr<Node> node, StaticGraph& graph) {
int id = graph.FindNodeByName(node->name_, node->sym_);
void Symbol::Dfs(const std::shared_ptr<Node> node, StaticGraph *graph) {
int id = graph->FindNodeByName(node->name_, node->sym_);
for (size_t i = 0; i < node->in_symbol_.size(); ++i) {
std::shared_ptr<Node> parent = node->in_symbol_[i];
int parent_id = graph.FindNodeByName(parent->name_, parent->sym_);
graph.connected_graph[parent_id].push_back(id);
graph.output_index[parent_id].push_back(node->in_index_[i]);
int parent_id = graph->FindNodeByName(parent->name_, parent->sym_);
graph->connected_graph[parent_id].push_back(id);
graph->output_index[parent_id].push_back(node->in_index_[i]);
if (parent->sym_) {
Dfs(parent, graph);
}
Expand Down

0 comments on commit fead856

Please sign in to comment.