Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
[SYMBOL] move static graph to internal
Browse files Browse the repository at this point in the history
  • Loading branch information
tqchen committed Sep 21, 2015
1 parent 38c406a commit 7f52e0d
Show file tree
Hide file tree
Showing 8 changed files with 282 additions and 245 deletions.
261 changes: 19 additions & 242 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

#include <dmlc/base.h>
#include <dmlc/json.h>
#include <algorithm>
#include <vector>
#include <memory>
#include <map>
#include <string>
#include <utility>
#include <functional>
#include <unordered_map>
#include <unordered_set>
#include "./base.h"
#include "./ndarray.h"
#include "./operator.h"
Expand All @@ -29,223 +25,16 @@

namespace mxnet {
/*!
* \brief StaticGraph is the configuration of computation graphs.
* This is the "configuration file" of mxnet.
* It can be converted to/from Symbol, and can be used to bind to operators.
* \brief Internal data structure used for
* graph serializaion and graph algorithms.
*/
class StaticGraph {
public:
/*! \brief represents a data in the graph */
struct DataEntry {
/*! \brief the source node id in the computation graph */
uint32_t source_id;
/*! \brief index of output from the source. */
uint32_t index;
/*! \brief default constructor */
DataEntry() {}
/*!
* \brief constructor with source and index
* \param source_id source id
* \param index node index
*/
DataEntry(uint32_t source_id, uint32_t index)
: source_id(source_id), index(index) {}
/*!
* \brief compare equality
* \param other the other entry to compare
* \return whether two entries equals to each other
*/
inline bool operator==(const DataEntry &other) const {
return source_id == other.source_id && index == other.index;
}
/*!
* \brief comparator, allows to use map
* \param other the other entry to compare
* \return whether two entries is smaller than the other
*/
inline bool operator<(const DataEntry &other) const {
if (source_id == other.source_id) return index < other.index;
return source_id < other.source_id;
}
/*!
* \brief interface for json serialization.
* \param writer the JSON writer to write json into.
*/
inline void Save(dmlc::JSONWriter *writer) const {
writer->BeginArray(false);
writer->WriteArrayItem(source_id);
writer->WriteArrayItem(index);
writer->EndArray();
}
/*!
* \brief interface for json serialization.
* \param reader the JSON reader to read json from.
*/
inline void Load(dmlc::JSONReader *reader) {
std::pair<uint32_t, uint32_t> p;
reader->Read(&p);
*this = DataEntry(p.first, p.second);
}
};
/*!
* \brief Operation Node in static graphs.
* There are two types of node, Forward and Backward Node.
*
* - Forward node corresponds to the op.Forward
* - Backward node corresponds to the Backward pass,
* where the corresponding forward node is indicated by backward_source_id.
* The op field in Backward node is nullptr
*
* The reason we explicit support Backward node is to allow special treatment
* such as shape inference and state sharing with Forward pass.
*/
struct Node {
/*! \brief wrapped operator property */
std::unique_ptr<OperatorProperty> op;
/*! \brief name of the node */
std::string name;
/*! \brief inputs (node_id, index) for of the nodes*/
std::vector<DataEntry> inputs;
/*!
* \brief If this field is nonnegative, this indicates this
* Node is corresponds to a Backward Operation of Operator.
* backward_source_id will points to the corresponding Forward Node.
*
* For normal node, this field is -1.
* When the node is a Backward node, the op field will be nullptr
*/
int32_t backward_source_id;
/*! \brief default constructor */
Node() : backward_source_id(-1) {}

friend void swap(Node& lhs, Node& rhs) {
std::swap(lhs.op, rhs.op);
std::swap(lhs.name, rhs.name);
std::swap(lhs.inputs, rhs.inputs);
std::swap(lhs.backward_source_id, rhs.backward_source_id);
}
/*! \brief copy constructor in favor of serialization. */
Node(const Node& another) : op(another.op.get() ? another.op.get()->Copy() : nullptr),
name(another.name),
inputs(another.inputs),
backward_source_id(another.backward_source_id) {}

inline Node& operator=(Node another) {
swap(*this, another);
return *this;
}
/*! \return whether the node is forward op node */
inline bool is_forward() const {
return op != nullptr;
}
/*! \return whether the node is backward op node */
inline bool is_backward() const {
return backward_source_id != -1;
}
/*! \return whether the node is variable node */
inline bool is_variable() const {
return op == nullptr && !is_backward();
}
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
};
/*! \brief all nodes in the graph */
std::vector<Node> nodes;
/*! \brief index of nodes that correspods to arguments */
std::vector<uint32_t> arg_nodes;
/*! \brief heads outputs of the graph */
std::vector<DataEntry> heads;
/*!
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
void Save(dmlc::JSONWriter *writer) const;
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
void Load(dmlc::JSONReader *reader);
// funtions to help inference in static graph
/*!
* \brief Perform a topological sort on the graph
* \return a topological order of node indices.
*/
std::vector<uint32_t> TopoSort() const;
/*!
* \brief infer the node shapes in the computation graph.
*
* When calling this function, user can setup the shape information known into right position.
* Unknown shape are indicated by shape.ndim() == 0.
*
* \param topo_order The topological order of node index, as created by TopoSort.
* \param node_out_shapes The shapes of the each outputs of nodes in the graph.
* \param node_aux_shapes The shapes of the each auxiliary states of nodes in the graph.
* \return if the shape inference is successful, return true, else return false.
*/
bool InferNodeShapes(const std::vector<uint32_t> &topo_order,
std::vector<std::vector<TShape> > *node_out_shapes,
std::vector<std::vector<TShape> > *node_aux_shapes) const;
/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param in_shape the shape of input arguments of the operator
* this should be of same length as the vector returned by ListArguments
* in_shape allows unknown elements, which are checked by shape.ndim() == 0.
* For unknown shapes, InferShape will try to fill in the correct Shape in in_shape
* For known shapes, InferShape will check shape consistency
*
* common practice: set the shape of data input, and usually weight's shape can be infered
*
* \param out_shape the shape of outputs of the operator
* InferShape will modify the vector to fill output TShape
* \param aux_shape the shape of auxiliary states of the operator
* InferShape will modify the vector to fill output TShape
* \return if the shape inference is successful, return true, else return false.
*/
bool InferShape(std::vector<TShape>* in_shape,
std::vector<TShape>* out_shape,
std::vector<TShape>* aux_shape) const;
/*!
* \brief Add a full backward pass in the static graph.
* This function will add gradient nodes for each heads,
* and add the backward pass to backprop the gradients all
* the way to the arguments.
*
* This will change the nodes field in the StaticGraph, but will not change other fields.
* The head and input of Backward pass will be returned by head_grad_nodes and arg_grads.
*
* \param head_grad_nodes used to store the created head gradient inputs for backward pass.
* \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator
*/
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<DataEntry> *arg_grads);

/*!
* \brief create a sum node that aggregates gradient together
* \param grad_source the source of the inputs.
* \return a created ElementWiseSum node
*/
static Node CreateSumNode(const std::vector<DataEntry> &grad_source);
};

class StaticGraph;
/*!
* \brief Symbol is used to represent dynamically generated symbolic computation graph.
*
* This class is used as a tool to generate computation graphs(aka. configuration) of the network.
* Symbol is always composite, the head Node is the output node of the symbol.
* An atomic symbol can be seen as a special case of the composite symbol with only the head node.
*
* The symbol can be converted from/to StaticGraph, the actual configuration used by mxnet.
* Symbol offers more flexible way to composite nodes than StaticGraph, which makes it good
* tool to generate configurations from language bindings such as python.
* \sa StaticGraph
*/
class Symbol {
public:
Expand Down Expand Up @@ -297,20 +86,6 @@ class Symbol {
*/
void Compose(const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name);
/*!
* \brief Convert a list of symbols into static graph
*
* The user can go further to call bind function on static graph
*
* \param out_graph the pointer holder of the output graph
*/
void ToStaticGraph(StaticGraph *out_graph) const;
/*!
* \brief create equivalence of symbol from static graphs.
* This operation will change the content of current symbol.
* \param graph the static graph
*/
void FromStaticGraph(const StaticGraph &graph);
/*!
* \brief Apply the symbol as a function, compose with arguments
* \param args positional arguments for the symbol
Expand Down Expand Up @@ -367,20 +142,12 @@ class Symbol {
* \brief interface for json serialization.
* \param writer the JSON writer write json.
*/
inline void Save(dmlc::JSONWriter *writer) const {
StaticGraph g;
this->ToStaticGraph(&g);
g.Save(writer);
}
void Save(dmlc::JSONWriter *writer) const;
/*!
* \brief interface for json serialization.
* \param reader the JSON read to read json.
*/
inline void Load(dmlc::JSONReader *reader) {
StaticGraph g;
g.Load(reader);
this->FromStaticGraph(g);
}
void Load(dmlc::JSONReader *reader);
/*!
* \brief get number of outputs of this symbol
* \return number of outputs
Expand Down Expand Up @@ -451,6 +218,20 @@ class Symbol {
* \return maximum number of duplication factor
*/
int FindDuplicateArgs(std::unordered_map<std::string, int> *out) const;
/*!
* \brief Convert symbol into internal static graph
*
* \param out_graph the pointer holder of the output graph
*/
void ToStaticGraph(StaticGraph *out_graph) const;
/*!
* \brief create equivalence of symbol from static graphs.
* This operation will change the content of current symbol.
* \param graph the static graph
*/
void FromStaticGraph(const StaticGraph &graph);
/*! \brief let static graph know the contents */
friend class StaticGraph;
};

/*!
Expand Down Expand Up @@ -506,8 +287,4 @@ class Executor {
const std::vector<NDArray> &aux_states);
}; // class operator
} // namespace mxnet

namespace dmlc {
DMLC_DECLARE_TRAITS(is_pod, ::mxnet::StaticGraph::DataEntry, true);
}
#endif // MXNET_SYMBOLIC_H_
1 change: 1 addition & 0 deletions src/symbol/graph_algorithm.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
#include <vector>
#include "./static_graph.h"

namespace mxnet {
namespace graph {
Expand Down
4 changes: 2 additions & 2 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -239,9 +239,9 @@ GraphExecutor::~GraphExecutor() {
Engine::Get()->WaitForAll();
}

void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) {
void GraphExecutor::InitGraph(const Symbol &symbol, Context ctx, bool need_backward) {
// initialize all internal data structures
symbol.ToStaticGraph(&graph_);
graph_.FromSymbol(symbol);
num_forward_nodes_ = graph_.nodes.size();
if (need_backward) {
graph_.MakeBackwardPass(&head_grad_nodes_, &arg_grads_);
Expand Down
3 changes: 2 additions & 1 deletion src/symbol/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <string>
#include <vector>
#include <utility>
#include "./static_graph.h"
#include "./graph_memory_allocator.h"

namespace mxnet {
Expand Down Expand Up @@ -159,7 +160,7 @@ class GraphExecutor : public Executor {
*/
inline OpExecEntry GetOpExecEntry(uint32_t node_id);
// initialize the internal graph structure
void InitGraph(Symbol symbol, Context ctx, bool need_backward);
void InitGraph(const Symbol &symbol, Context ctx, bool need_backward);
// initialize internal DataEntryInfo, reference counting
void InitDataEntryInfo(const std::vector<NDArray> &in_args,
const std::vector<NDArray> &arg_grad_store,
Expand Down
1 change: 1 addition & 0 deletions src/symbol/graph_memory_allocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <map>
#include <vector>
#include <algorithm>
#include "./static_graph.h"
#include "./graph_algorithm.h"

namespace mxnet {
Expand Down
1 change: 1 addition & 0 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <vector>
#include <queue>
#include <map>
#include "./static_graph.h"
#include "../operator/operator_common.h"

namespace mxnet {
Expand Down
Loading

0 comments on commit 7f52e0d

Please sign in to comment.