Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Merge pull request #8 from winstywang/master
Browse files Browse the repository at this point in the history
static graph
  • Loading branch information
winstywang committed Aug 6, 2015
2 parents 22ebc27 + bccfbeb commit 2c9091d
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 17 deletions.
45 changes: 45 additions & 0 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.h
* \brief the static graph of symbols
*/
#ifndef MXNET_STATIC_GRAPH_H_
#define MXNET_STATIC_GRAPH_H_

#include <vector>
#include <unordered_map>
#include <string>
#include <memory>
#include "./atomic_symbol.h"
namespace mxnet {

struct StaticGraph {
struct StaticNode {
/*! \brief wrapped atomic symbol */
AtomicSymbol* sym_;
/*! \brief name of the node */
std::string name_;
};
std::unordered_map<std::string, int> name_id_map;
std::vector<StaticNode> nodes;
std::vector<std::vector<int> > output_index;
std::vector<std::vector<int> > connected_graph;
int FindNodeByName(const std::string& name, const AtomicSymbol* sym) {
int id = 0;
if (name_id_map.find(name) == name_id_map.end()) {
name_id_map[name] = name_id_map.size();
StaticNode static_node;
static_node.sym_ = sym->Copy();
static_node.name_ = name;
nodes.push_back(static_node);
output_index.push_back(std::vector<int>());
connected_graph.push_back(std::vector<int>());
id = name_id_map.size();
} else {
id = name_id_map[name];
}
return id;
}
};
}
#endif
37 changes: 32 additions & 5 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
#include "./base.h"
#include "./tensor_blob.h"
#include "./operator.h"
#include "./static_graph.h"

namespace mxnet {
class CompositeOperator;
/*!
* \brief Symbol is the wrapper of AtomicSymbol, the reason for this indirection is that Symbol
* should support expressions and often passed by value. While AtomicSymbol have many subclasses,
Expand All @@ -27,7 +29,7 @@ namespace mxnet {
* A atomic symbol can be seen as a special case of the composite symbol with only the head node.
*/
class Symbol {
protected:
public:
/*!
* \brief Node is the container of AtomicSymbol, it also stores the connection of the AtomicSymbol
* with input symbols.
Expand All @@ -46,12 +48,19 @@ class Symbol {
/*!
* \brief constructor
*/
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "");
explicit Node(AtomicSymbol* sym = nullptr, const std::string& name = "") :
sym_(sym), name_(name) {
}
/*!
* \brief destructor
*/
~Node();
~Node() {
if (sym_) {
delete sym_;
}
}
};
protected:
/*! \brief the head node of the Symbol, it could be shared in many graphs */
std::shared_ptr<Node> head_;
/*! \brief if the head has multiple return values, index is used to specify */
Expand All @@ -60,7 +69,13 @@ class Symbol {
std::shared_ptr<std::vector<std::pair<Node*, int> > > arg_users_;
/*! \brief find arg users */
void FindArgUsers();

/**
* @brief Recursively parse the symbol to equivalent static graph.
*
* @param node The current node in dfs
* @param graph The static graph
*/
void Dfs(const std::shared_ptr<Node> node, StaticGraph& graph);
public:
/*!
* \brief declare virtual destructor in case it is subclassed.
Expand All @@ -71,7 +86,14 @@ class Symbol {
* \param ctx context of the operator
* \return returns the pointer to a created operator. It is on the user to delete.
*/
virtual Operator* Bind(Context ctx) const { return nullptr; }
virtual CompositeOperator* Bind(Context ctx) const { return nullptr; }
/**
* @brief Bind the symbol to a composite operator
*
* @param in A map denotes name and corresponding NArray for binding
* @return The composite operator
*/
virtual CompositeOperator* Bind(Context ctx, const std::unordered_map<std::string, NArray>& in);
/*!
* \brief copy the symbol
* \return a deep copy of the graph
Expand All @@ -98,6 +120,11 @@ class Symbol {
* \return the arguments list of this symbol, they can be either named or unnamed (empty string).
*/
virtual std::vector<std::string> ListArgs();
/**
* @brief Convert current symbol to its equivalent static graph representation.
* @return the static graph
*/
virtual StaticGraph ToStaticGraph();
/*!
* \brief create Symbol by wrapping AtomicSymbol
*/
Expand Down
40 changes: 28 additions & 12 deletions src/symbol/symbol.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,20 +6,11 @@
#include <dmlc/logging.h>
#include <mxnet/symbol.h>
#include <mxnet/registry.h>
#include <mxnet/static_graph.h>
#include <iterator>

namespace mxnet {

Symbol::Node::Node(AtomicSymbol* sym, const std::string& name)
: sym_(sym), name_(name) {
}

Symbol::Node::~Node() {
if (sym_) {
delete sym_;
}
}

void Symbol::FindArgUsers() {
arg_users_.reset(new std::vector<std::pair<Node*, int> >);
// depth first traversing
Expand Down Expand Up @@ -144,14 +135,14 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) {
std::vector<std::string> args = atomic_symbol->DescribeArguments();
std::vector<std::string> rets = atomic_symbol->DescribeReturns();
// set head_
s.head_ = std::make_shared<Symbol::Node>(atomic_symbol, "");
s.head_ = std::make_shared<Node>(atomic_symbol, "");
// set index_
s.index_ = rets.size() > 1 ? -1 : 0;
// set head_->in_index_
s.head_->in_index_ = std::vector<int>(args.size(), 0);
// set head_->in_symbol_
for (auto name : args) {
s.head_->in_symbol_.push_back(std::make_shared<Symbol::Node>(nullptr, name));
s.head_->in_symbol_.push_back(std::make_shared<Node>(nullptr, name));
}
// set head_->out_shape_
s.head_->out_shape_ = std::vector<TShape>(rets.size());
Expand All @@ -169,4 +160,29 @@ Symbol Symbol::Create(const std::string& type_name,
return Create(atomic_symbol);
}

StaticGraph Symbol::ToStaticGraph() {
StaticGraph graph;
Dfs(this->head_, graph);
return graph;
}

CompositeOperator* Symbol::Bind(Context ctx, const std::unordered_map<std::string, NArray>& in) {
StaticGraph graph = this->ToStaticGraph();
return NULL;
//TODO: pass the graph and in to initlialize a composite op.
}

void Symbol::Dfs(const std::shared_ptr<Node> node, StaticGraph& graph) {
int id = graph.FindNodeByName(node->name_, node->sym_);
for (size_t i = 0; i < node->in_symbol_.size(); ++i) {
std::shared_ptr<Node> parent = node->in_symbol_[i];
int parent_id = graph.FindNodeByName(parent->name_, parent->sym_);
graph.connected_graph[parent_id].push_back(id);
graph.output_index[parent_id].push_back(node->in_index_[i]);
if (parent->sym_) {
Dfs(parent, graph);
}
}
}

} // namespace mxnet

0 comments on commit 2c9091d

Please sign in to comment.