Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
fix static graph
Browse files Browse the repository at this point in the history
  • Loading branch information
winstywang committed Aug 6, 2015
1 parent d75d0ef commit bccfbeb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 56 deletions.
54 changes: 13 additions & 41 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,54 +12,26 @@
#include <memory>
#include "./atomic_symbol.h"
namespace mxnet {
struct NodeMetaInfo{
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};

/*!
* \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol
* with input symbols.
*/
struct Node {

NodeMetaInfo info_;
/*! \brief inputs to this node */
std::vector<std::shared_ptr<Node> > in_symbol_;
/*! \brief index of the inputs if the inputs are tuple */
std::vector<int> in_index_;
/*! \brief the output shape of the wrapped symbol */
std::vector<TShape> out_shape_;
/*!
* \brief constructor
*/
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") {
info_.sym_ = sym;
info_.name_ = name;
}
/*!
* \brief destructor
*/
~Node() {
if (info_.sym_) {
delete info_.sym_;
}
}
};


struct StaticGraph {
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};
std::unordered_map<std::string, int> name_id_map;
std::vector<NodeMetaInfo> nodes;
std::vector<StaticNode> nodes;
std::vector<std::vector<int> > output_index;
std::vector<std::vector<int> > connected_graph;

int FindNodeByName(const std::string& name, const std::shared_ptr<Node> node) {
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
nodes.push_back(node->info_);
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
Expand Down
54 changes: 51 additions & 3 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "./static_graph.h"

namespace mxnet {
class CompositeOperator;
/*!
* \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol
* should support expressions and often passed by value. While AtomicSymbol have many subclasses,
Expand All @@ -28,6 +29,37 @@ namespace mxnet {
* A atomic symbol can be seen as a special case of the composite symbol with only the head node.
*/
class Symbol {
public:
/*!
* \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol
* with input symbols.
*/
struct Node {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
/*! \brief inputs to this node */
std::vector<std::shared_ptr<Node> > in_symbol_;
/*! \brief index of the inputs if the inputs are tuple */
std::vector<int> in_index_;
/*! \brief the output shape of the wrapped symbol */
std::vector<TShape> out_shape_;
/*!
* \brief constructor
*/
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") :
sym_(sym), name_(name) {
}
/*!
* \brief destructor
*/
~Node() {
if (sym_) {
delete sym_;
}
}
};
protected:
/*! \brief the head node of the Symbol, it could be shared in many graphs */
std::shared_ptr<Node> head_;
Expand All @@ -37,7 +69,13 @@ class Symbol {
std::shared_ptr<std::vector<std::pair<Node*, int> > > arg_users_;
/*! \brief find arg users */
void FindArgUsers();
void dfs_(const std::shared_ptr<Node> node, StaticGraph& graph);
/**
* @brief Recursively parse the symbol to equivalent static graph.
*
* @param node The current node in dfs
* @param graph The static graph
*/
void Dfs(const std::shared_ptr<Node> node, StaticGraph& graph);
public:
/*!
* \brief declare virtual destructor in case it is subclassed.
Expand All @@ -48,7 +86,14 @@ class Symbol {
* \param ctx context of the operator
* \return returns the pointer to a created operator. It is on the user to delete.
*/
virtual Operator* Bind(Context ctx) const { return nullptr; }
virtual CompositeOperator* Bind(Context ctx) const { return nullptr; }
/**
* @brief Bind the symbol to a composite operator
*
* @param in A map denotes name and corresponding NArray for binding
* @return The composite operator
*/
virtual CompositeOperator* Bind(Context ctx, const std::unordered_map<std::string, NArray>& in);
/*!
* \brief copy the symbol
* \return a deep copy of the graph
Expand All @@ -75,7 +120,10 @@ class Symbol {
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*/
virtual std::vector<std::string> ListArgs();

/**
* @brief Convert current symbol to its equivalent static graph representation.
* @return the static graph
*/
virtual StaticGraph ToStaticGraph();
/*!
* \brief create Symbol by wrapping AtomicSymbol
Expand Down
29 changes: 17 additions & 12 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ void Symbol::FindArgUsers() {
stk.pop_back();
} else {
Node* next_level = back.first->in_symbol_[back.second].get();
if (next_level->info_.sym_) {
if (next_level->sym_) {
stk.push_back({next_level, 0});
} else { // back uses next_level which is a placeholder
arg_users_->push_back({back.first, back.second});
Expand All @@ -42,10 +42,10 @@ Symbol Symbol::Copy() const {
Node* back = stk.back();
stk.pop_back();
if (old_new.count(back) == 0) {
if (back->info_.sym_) {
old_new[back] = std::make_shared<Node>(back->info_.sym_->Copy(), back->info_.name_);
if (back->sym_) {
old_new[back] = std::make_shared<Node>(back->sym_->Copy(), back->name_);
} else {
old_new[back] = std::make_shared<Node>(nullptr, back->info_.name_);
old_new[back] = std::make_shared<Node>(nullptr, back->name_);
}
}
for (const std::shared_ptr<Node>& n : back->in_symbol_) {
Expand Down Expand Up @@ -98,7 +98,7 @@ Symbol Symbol::operator () (const std::unordered_map<std::string, Symbol>& kwarg
<< s.arg_users_->size() << " provided " << kwargs.size();
for (size_t i = 0; i < s.arg_users_->size(); ++i) {
const std::pair<Node*, int>& arg_user = (*s.arg_users_)[i];
const std::string& name = arg_user.first->info_.name_;
const std::string& name = arg_user.first->name_;
if (!(name == "") && kwargs.count(name) != 0) {
const Symbol& bind = kwargs.at(name);
arg_user.first->in_symbol_[arg_user.second] = bind.head_;
Expand All @@ -125,7 +125,7 @@ std::vector<std::string> Symbol::ListArgs() {
}
std::transform(arg_users_->begin(), arg_users_->end(), std::back_inserter(ret),
[&](const std::pair<Node*, int>& n) -> std::string {
return n.first->in_symbol_[n.second]->info_.name_;
return n.first->in_symbol_[n.second]->name_;
});
return ret;
}
Expand Down Expand Up @@ -162,20 +162,25 @@ Symbol Symbol::Create(const std::string& type_name,

StaticGraph Symbol::ToStaticGraph() {
StaticGraph graph;
dfs_(this->head_, graph);
Dfs(this->head_, graph);
return graph;
}

CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map<std::string, NArray>& in) {
StaticGraph graph = this->ToStaticGraph();
return NULL;
//TODO: pass the graph and in to initlialize a composite op.
}

void Symbol::dfs_(const std::shared_ptr<Node> node, StaticGraph& graph) {
int id = graph.FindNodeByName(node->info_.name_, node);
void Symbol::Dfs(const std::shared_ptr<Node> node, StaticGraph& graph) {
int id = graph.FindNodeByName(node->name_, node->sym_);
for (size_t i = 0; i < node->in_symbol_.size(); ++i) {
std::shared_ptr<Node> parent = node->in_symbol_[i];
int parent_id = graph.FindNodeByName(parent->info_.name_, node);
int parent_id = graph.FindNodeByName(parent->name_, parent->sym_);
graph.connected_graph[parent_id].push_back(id);
graph.output_index[parent_id].push_back(node->in_index_[i]);
if (parent->info_.sym_) {
dfs_(parent, graph);
if (parent->sym_) {
Dfs(parent, graph);
}
}
}
Expand Down

0 comments on commit bccfbeb

Please sign in to comment.