Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

static graph #11

Merged
merged 1 commit into from
Aug 11, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc
static_operator_cpu.o: src/static_operator/static_operator_cpu.cc
static_operator_gpu.o: src/static_operator/static_operator_gpu.cu
symbol.o: src/symbol/symbol.cc
static_graph.o : src/symbol/static_graph.cc
registry.o: src/registry.cc
c_api.o: src/c_api.cc
operator.o: src/operator/static_operator_wrapper.cc
Expand Down
4 changes: 4 additions & 0 deletions include/mxnet/atomic_symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@ class AtomicSymbol {
// default implementation returns "output"
return std::vector<std::string>(1, std::string("output"));
}
/*! \brief number of outputs of the symbol */
virtual int NumReturns() const {
return 1;
}
/*!
* \brief set param for the symbol from string
* \param name parameter name
Expand Down
88 changes: 62 additions & 26 deletions include/mxnet/static_graph.h
Original file line number Diff line number Diff line change
@@ -1,48 +1,84 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.h
* \brief the static graph of symbols
* \brief The static graph of symbols
*/
#ifndef MXNET_STATIC_GRAPH_H_
#define MXNET_STATIC_GRAPH_H_

#include <vector>
#include <unordered_map>
#include <string>
#include <memory>
#include "./base.h"
#include "./atomic_symbol.h"

namespace mxnet {
/*! \brief static graph interface
* static graph is an internal representation of symbol graph.
*
* The main purpose for static graph for binding a composite operator
/*!
* \brief StaticGraph is the configuration of computation graphs.
* This is the "configuration file" of mxnet.
* It can be converted to/from Symbol, and can be used to bind to operators.
*/
struct StaticGraph {
/*! \brief Node in static graph */
struct StaticNode {
class StaticGraph {
public:
/*! \brief represents a data in the graph */
struct DataEntry {
/*! \brief the source node id in the computation graph */
uint32_t source_id;
/*!
* \brief index of output from the source.
* If index == -1, it represents all the outputs.
*/
int32_t index;
};
/*! \brief Operation Node in static graph */
struct Node {
/*! \brief wrapped atomic symbol */
std::unique_ptr<AtomicSymbol> sym;
/*! \brief name of the node */
std::string name;
/*! \brief index of output from the source. */
int index;
/*! \brief output shape for node */
std::vector<TShape> in_shape;
/*! \brief output shape for node */
std::vector<TShape> out_shape;
/*! \brief input id for each node */
std::vector<int> inputs_index;
/*! \brief output id for each node */
std::vector<int> outputs_index;
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
};
/*! \brief head node (need input from outside) */
std::vector<int> in_args_node_id;
/*! \brief tail node (generate data to outside) */
std::vector<int> return_node_id;
/*! \brief node name to id dictionary */
std::unordered_map<std::string, int> name_id_map;
/*! \brief all nodes in the graph */
std::vector<StaticNode> nodes;
std::vector<Node> nodes;
/*! \brief index is nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief outputs(heads) of the graph */
std::vector<DataEntry> outputs;
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
* \return a topological order of node indices.
*/
std::vector<uint32_t> TopoSort() const;
/*!
* \brief infer the node shapes in the computation graph.
*
* When calling this function, user can setup the shape information known into right position.
* Unknown shape are indicated by shape.ndim() == 0.
*
* \param topo_order The topological order of node index, as created by TopoSort.
* \param node_out_shapes The shapes of the each outputs of nodes in the graph.
* \return if the shape inference is successful, return true, else return false.
*/
bool InferNodeShapes(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const;
};
} // namespace mxnet
#endif // MXNET_STATIC_GRAPH_H_
2 changes: 1 addition & 1 deletion include/mxnet/static_operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

namespace mxnet {
/*!
* \brief static StaticOperator interface (current interface have not yet todo with scheduler),
* \brief StaticOperator interface
* StaticOperator is a stateful object that can be used to call forward and backprop
*
* This interface relies on pre-allocated memory in TBlob, the caller need to set
Expand Down
57 changes: 28 additions & 29 deletions include/mxnet/symbol.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,6 @@ class Symbol {
* \return the symbol corresponds to the indexed element.
*/
Symbol operator[] (int index) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape);
/*!
* \brief Compose the symbol with arguments, this changes current symbol.
*
Expand Down Expand Up @@ -122,6 +106,26 @@ class Symbol {
s.Compose(kwargs, name);
return s;
}
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
inline bool InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) {
StaticGraph g;
Symbol::Convert({*this}, &g);
return g.InferShape(in_shape, out_shape);
}
/*!
* \brief create Symbol by wrapping AtomicSymbol
* This function takes the ownership of atomic_symbol.
Expand Down Expand Up @@ -219,8 +223,7 @@ class Symbol {
*/
template<typename FVisit>
inline void DFSVisit(FVisit fvisit) const {
std::vector<DataEntry> tmp = { head_ };
DFSVisit(tmp, fvisit);
DFSVisit({*this}, fvisit);
}
/*!
* \brief Visit all the nodes in left-to-right depth first order.
Expand All @@ -232,15 +235,17 @@ class Symbol {
* \tparam FVisit visiting function type
*/
template<typename FVisit>
static inline void DFSVisit(const std::vector<DataEntry> &heads,
FVisit fvisit) {
static inline void DFSVisit(const std::vector<Symbol> &heads,
FVisit fvisit) {
std::vector<Node*> stack;
std::unordered_set<Node*> visited;
// put the head into the graph
for (auto &head : heads) {
Node *ptr = head.source.get();
stack.push_back(ptr);
visited.insert(ptr);
Node *ptr = head.head_.source.get();
if (visited.count(ptr) == 0) {
stack.push_back(ptr);
visited.insert(ptr);
}
}
while (!stack.empty()) {
Node* back = stack.back();
Expand All @@ -255,12 +260,6 @@ class Symbol {
}
}
}
/*! \brief Toposort the symbol
* \prarm heads symbol's head
* \prarm ret sorted nodes
*/
static inline void Toposort(const std::vector<DataEntry> &heads,
std::vector<Node*> *ret);
/*!
* \brief Find duplicate arguments in the composition
* \param out the map of argument-name -> occurence count
Expand Down
85 changes: 85 additions & 0 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
/*!
* Copyright (c) 2015 by Contributors
* \file static_graph.cc
* \brief static graph of mxnet
*/
#include <dmlc/logging.h>
#include <mxnet/static_graph.h>
#include <vector>
#include <queue>

std::vector<uint32_t> StaticGraph::TopoSort() const {
std::vector<int> out_degree(nodes.size(), 0);
for (const Node &n : nodes) {
for (const DataEntry &e : n.inputs) {
++out_degree[e.source_id];
}
}
std::vector<uint32_t> ret(nodes.size());
auto result = ret.rbegin();
std::queue<uint32_t> queue;
for (size_t i = 0; i < nodes.size(); ++i) {
if (out_degree[i] == 0) {
queue.push(static_cast<uint32_t>(i));
}
}
while (!queue.empty()) {
uint32_t node_id = queue.front();
queue.pop();
*result = node_id;
++result;
for (const DataEntry &e : nodes[node_id].inputs) {
out_degree[e.source_id] -= 1;
if (out_degree[e.source_id] == 0) {
queue.push(e.source_id);
}
}
}
return std::move(ret);
}

bool StaticGraph::InferShape(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes) const {
bool success = true;
for (uint32_t nid : topo_order) {
const Node &node = nodes[nid];
if (node.sym != nullptr) {
std::vector<TShape> in_shape;
for (const DataEntry &e : node.inputs) {
in_shape.push_back(node_out_shapes[e.source_id][e.index]);
}
if (!node.sym->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false;
for (size_t i = 0; i < node.inputs.size(); ++i) {
const DataEntry &e = node.inputs[i];
node_out_shapes[e.source_id][e.index] = in_shape[i];
}
}
}
return true;
}

bool StaticGraph::InferShape(std::vector<TShape> *in_shape,
std::vector<TShape> *out_shape) const {
std::vector<std::vector<TShape> > node_out_shapes(nodes.size());
for (size_t i = 0; i < nodes.size(); ++i) {
int nout = 1;
if (nodes[i].sym != nullptr) {
nout = nodes[i].sym->NumReturns();
}
node_out_shapes[i].resize(nout);
}
CHECK(in_shape->size() == arg_nodes.size())
<< "Wrong number of inputs to infer shape";
for (size_t i = 0; i < arg_nodes.size(); ++i) {
node_out_shapes[nid][0] = (*in_shape)[i];
}
if (!InferNodeShapes(this->TopoSort(),
&node_out_shapes)) return false;
for (size_t i = 0; i < arg_nodes.size(); ++i) {
(*in_shape)[i] = node_out_shapes[nid][0];
}
for (size_t i = 0; i < outputs.size(); ++i) {
DataEntry e = outputs[i];
(*out_shape)[i] = node_out_shapes[e.source_id][e.index];
}
}
Loading