Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
static graph
Browse files Browse the repository at this point in the history
  • Loading branch information
antinucleon committed Aug 11, 2015
1 parent e50817d commit 6e538fc
Show file tree
Hide file tree
Showing 7 changed files with 224 additions and 155 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc
static_operator_cpu.o: src/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
static_graph.o : src/symbol/static_graph.cc
registry.o: src/registry.cc
c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class AtomicSymbol {
// default implementation returns "output"
return std::vector<std::string>(1, std::string("output"));
}
/*! \brief number of outputs of the symbol */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief set param for the symbol from string
* \param name parameter name
Expand Down
88 changes: 62 additions & 26 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
@@ -1,48 +1,84 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.h
* \brief the static graph of symbols
* \brief The static graph of symbols
*/
#ifndef MXNET_STATIC_GRAPH_H_
#define MXNET_STATIC_GRAPH_H_

#include <vector>
#include <unordered_map>
#include <string>
#include <memory>
#include "./base.h"
#include "./atomic_symbol.h"

namespace mxnet {
/*! \brief static graph interface
* static graph is an internal representation of symbol graph.
*
* The main purpose for static graph for binding a composite operator
/*!
* \brief StaticGraph is the configuration of computation graphs.
* This is the "configuration file" of mxnet.
* It can be converted to/from Symbol, and can be used to bind to operators.
*/
struct StaticGraph {
/*! \brief Node in static graph */
struct StaticNode {
class StaticGraph {
public:
/*! \brief represents a data in the graph */
struct DataEntry {
/*! \brief the source node id in the computation graph */
uint32_t source_id;
/*!
* \brief index of output from the source.
* If index == -1, it represents all the outputs.
*/
int32_t index;
};
/*! \brief Operation Node in static graph */
struct Node {
/*! \brief wrapped atomic symbol */
std::unique_ptr<AtomicSymbol> sym;
/*! \brief name of the node */
std::string name;
/*! \brief index of output from the source. */
int index;
/*! \brief output shape for node */
std::vector<TShape> in_shape;
/*! \brief output shape for node */
std::vector<TShape> out_shape;
/*! \brief input id for each node */
std::vector<int> inputs_index;
/*! \brief output id for each node */
std::vector<int> outputs_index;
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
};
/*! \brief head node (need input from outside) */
std::vector<int> in_args_node_id;
/*! \brief tail node (generate data to outside) */
std::vector<int> return_node_id;
/*! \brief node name to id dictionary */
std::unordered_map<std::string, int> name_id_map;
/*! \brief all nodes in the graph */
std::vector<StaticNode> nodes;
std::vector<Node> nodes;
/*! \brief index is nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief outputs(heads) of the graph */
std::vector<DataEntry> outputs;
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
* \return a topological order of node indices.
*/
std::vector<uint32_t> TopoSort() const;
/*!
* \brief infer the node shapes in the computation graph.
*
* When calling this function, user can setup the shape information known into right position.
* Unknown shape are indicated by shape.ndim() == 0.
*
* \param topo_order The topological order of node index, as created by TopoSort.
* \param node_out_shapes The shapes of the each outputs of nodes in the graph.
* \return if the shape inference is successful, return true, else return false.
*/
bool InferNodeShapes(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const;
};
} // namespace mxnet
#endif // MXNET_STATIC_GRAPH_H_
2 changes: 1 addition & 1 deletion include/mxnet/static_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace mxnet {
/*!
* \brief static StaticOperator interface (current interface have not yet todo with scheduler),
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
*
* This interface relies on pre-allocated memory in TBlob, the caller need to set
Expand Down
57 changes: 28 additions & 29 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,6 @@ class Symbol {
* \return the symbol corresponds to the indexed element.
*/
Symbol operator[] (int index) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape);
/*!
* \brief Compose the symbol with arguments, this changes current symbol.
*
Expand Down Expand Up @@ -122,6 +106,26 @@ class Symbol {
s.Compose(kwargs, name);
return s;
}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
inline bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
StaticGraph g;
Symbol::Convert({*this}, &g);
return g.InferShape(in_shape, out_shape);
}
/*!
* \brief create Symbol by wrapping AtomicSymbol
* This function takes the ownership of atomic_symbol.
Expand Down Expand Up @@ -219,8 +223,7 @@ class Symbol {
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const {
std::vector<DataEntry> tmp = { head_ };
DFSVisit(tmp, fvisit);
DFSVisit({*this}, fvisit);
}
/*!
* \brief Visit all the nodes in left-to-right depth first order.
Expand All @@ -232,15 +235,17 @@ class Symbol {
* \tparam FVisit visiting function type
*/
template<typename FVisit>
static inline void DFSVisit(const std::vector<DataEntry> &heads,
FVisit fvisit) {
static inline void DFSVisit(const std::vector<Symbol> &heads,
FVisit fvisit) {
std::vector<Node*> stack;
std::unordered_set<Node*> visited;
// put the head into the graph
for (auto &head : heads) {
Node *ptr = head.source.get();
stack.push_back(ptr);
visited.insert(ptr);
Node *ptr = head.head_.source.get();
if (visited.count(ptr) == 0) {
stack.push_back(ptr);
visited.insert(ptr);
}
}
while (!stack.empty()) {
Node* back = stack.back();
Expand All @@ -255,12 +260,6 @@ class Symbol {
}
}
}
/*! \brief Toposort the symbol
* \prarm heads symbol's head
* \prarm ret sorted nodes
*/
static inline void Toposort(const std::vector<DataEntry> &heads,
std::vector<Node*> *ret);
/*!
* \brief Find duplicate arguments in the composition
* \param out the map of argument-name -> occurence count
Expand Down
85 changes: 85 additions & 0 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.cc
* \brief static graph of mxnet
*/
#include <dmlc/logging.h>
#include <mxnet/static_graph.h>
#include <vector>
#include <queue>

std::vector<uint32_t> StaticGraph::TopoSort() const {
std::vector<int> out_degree(nodes.size(), 0);
for (const Node &n : nodes) {
for (const DataEntry &e : n.inputs) {
++out_degree[e.source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
auto result = ret.rbegin();
std::queue<uint32_t> queue;
for (size_t i = 0; i < nodes.size(); ++i) {
if (out_degree[i] == 0) {
queue.push(static_cast<uint32_t>(i));
}
}
while (!queue.empty()) {
uint32_t node_id = queue.front();
queue.pop();
*result = node_id;
++result;
for (const DataEntry &e : nodes[node_id].inputs) {
out_degree[e.source_id] -= 1;
if (out_degree[e.source_id] == 0) {
queue.push(e.source_id);
}
}
}
return std::move(ret);
}

bool StaticGraph::InferShape(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const {
bool success = true;
for (uint32_t nid : topo_order) {
const Node &node = nodes[nid];
if (node.sym != nullptr) {
std::vector<TShape> in_shape;
for (const DataEntry &e : node.inputs) {
in_shape.push_back(node_out_shapes[e.source_id][e.index]);
}
if (!node.sym->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false;
for (size_t i = 0; i < node.inputs.size(); ++i) {
const DataEntry &e = node.inputs[i];
node_out_shapes[e.source_id][e.index] = in_shape[i];
}
}
}
return true;
}

bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const {
std::vector<std::vector<TShape> > node_out_shapes(nodes.size());
for (size_t i = 0; i < nodes.size(); ++i) {
int nout = 1;
if (nodes[i].sym != nullptr) {
nout = nodes[i].sym->NumReturns();
}
node_out_shapes[i].resize(nout);
}
CHECK(in_shape->size() == arg_nodes.size())
<< "Wrong number of inputs to infer shape";
for (size_t i = 0; i < arg_nodes.size(); ++i) {
node_out_shapes[nid][0] = (*in_shape)[i];
}
if (!InferNodeShapes(this->TopoSort(),
&node_out_shapes)) return false;
for (size_t i = 0; i < arg_nodes.size(); ++i) {
(*in_shape)[i] = node_out_shapes[nid][0];
}
for (size_t i = 0; i < outputs.size(); ++i) {
DataEntry e = outputs[i];
(*out_shape)[i] = node_out_shapes[e.source_id][e.index];
}
}
Loading

0 comments on commit 6e538fc

Please sign in to comment.