From 6e538fcc839a043a8544fad2659a13e7de73c1e4 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Mon, 10 Aug 2015 21:17:16 -0600 Subject: [PATCH] static graph --- Makefile | 1 + include/mxnet/atomic_symbol.h | 4 + include/mxnet/static_graph.h | 88 ++++++++++++++------ include/mxnet/static_operator.h | 2 +- include/mxnet/symbol.h | 57 +++++++------ src/symbol/static_graph.cc | 85 +++++++++++++++++++ src/symbol/symbol.cc | 142 ++++++++++---------------------- 7 files changed, 224 insertions(+), 155 deletions(-) create mode 100644 src/symbol/static_graph.cc diff --git a/Makefile b/Makefile index c94c2705f886..9d5740ff0c7e 100644 --- a/Makefile +++ b/Makefile @@ -85,6 +85,7 @@ static_operator.o: src/static_operator/static_operator.cc static_operator_cpu.o: src/static_operator/static_operator_cpu.cc static_operator_gpu.o: src/static_operator/static_operator_gpu.cu symbol.o: src/symbol/symbol.cc +static_graph.o : src/symbol/static_graph.cc registry.o: src/registry.cc c_api.o: src/c_api.cc operator.o: src/operator/static_operator_wrapper.cc diff --git a/include/mxnet/atomic_symbol.h b/include/mxnet/atomic_symbol.h index bdeac84048a6..cfc8a2eb6c20 100644 --- a/include/mxnet/atomic_symbol.h +++ b/include/mxnet/atomic_symbol.h @@ -37,6 +37,10 @@ class AtomicSymbol { // default implementation returns "output" return std::vector(1, std::string("output")); } + /*! \brief number of outputs of the symbol */ + virtual int NumReturns() const { + return 1; + } /*! * \brief set param for the symbol from string * \param name parameter name diff --git a/include/mxnet/static_graph.h b/include/mxnet/static_graph.h index 4aadbe2fb631..1e3b8352de83 100644 --- a/include/mxnet/static_graph.h +++ b/include/mxnet/static_graph.h @@ -1,48 +1,84 @@ /*! * Copyright (c) 2015 by Contributors * \file static_graph.h - * \brief the static graph of symbols + * \brief The static graph of symbols */ #ifndef MXNET_STATIC_GRAPH_H_ #define MXNET_STATIC_GRAPH_H_ #include -#include #include #include +#include "./base.h" #include "./atomic_symbol.h" + namespace mxnet { -/*! \brief static graph interface - * static graph is an internal representation of symbol graph. - * - * The main purpose for static graph for binding a composite operator +/*! + * \brief StaticGraph is the configuration of computation graphs. + * This is the "configuration file" of mxnet. + * It can be converted to/from Symbol, and can be used to bind to operators. */ -struct StaticGraph { - /*! \brief Node in static graph */ - struct StaticNode { +class StaticGraph { + public: + /*! \brief represents a data in the graph */ + struct DataEntry { + /*! \brief the source node id in the computation graph */ + uint32_t source_id; + /*! + * \brief index of output from the source. + * If index == -1, it represents all the outputs. + */ + int32_t index; + }; + /*! \brief Operation Node in static graph */ + struct Node { /*! \brief wrapped atomic symbol */ std::unique_ptr sym; /*! \brief name of the node */ std::string name; - /*! \brief index of output from the source. */ - int index; - /*! \brief output shape for node */ - std::vector in_shape; - /*! \brief output shape for node */ - std::vector out_shape; - /*! \brief input id for each node */ - std::vector inputs_index; - /*! \brief output id for each node */ - std::vector outputs_index; + /*! \brief inputs (node_id, index) for of the nodes*/ + std::vector inputs; }; - /*! \brief head node (need input from outside) */ - std::vector in_args_node_id; - /*! \brief tail node (generate data to outside) */ - std::vector return_node_id; - /*! \brief node name to id dictionary */ - std::unordered_map name_id_map; /*! \brief all nodes in the graph */ - std::vector nodes; + std::vector nodes; + /*! \brief index is nodes that correspods to arguments */ + std::vector arg_nodes; + /*! \brief outputs(heads) of the graph */ + std::vector outputs; + // funtions to help inference in static graph + /*! + * \brief Perform a topological sort on the graph + * \return a topological order of node indices. + */ + std::vector TopoSort() const; + /*! + * \brief infer the node shapes in the computation graph. + * + * When calling this function, user can setup the shape information known into right position. + * Unknown shape are indicated by shape.ndim() == 0. + * + * \param topo_order The topological order of node index, as created by TopoSort. + * \param node_out_shapes The shapes of the each outputs of nodes in the graph. + * \return if the shape inference is successful, return true, else return false. + */ + bool InferNodeShapes(const std::vector &topo_order, + std::vector > *node_out_shapes) const; + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + bool InferShape(std::vector *in_shape, + std::vector *out_shape) const; }; } // namespace mxnet #endif // MXNET_STATIC_GRAPH_H_ diff --git a/include/mxnet/static_operator.h b/include/mxnet/static_operator.h index 27efccd2ce58..e3d4d68d9d85 100644 --- a/include/mxnet/static_operator.h +++ b/include/mxnet/static_operator.h @@ -13,7 +13,7 @@ namespace mxnet { /*! - * \brief static StaticOperator interface (current interface have not yet todo with scheduler), + * \brief StaticOperator interface * StaticOperator is a stateful object that can be used to call forward and backprop * * This interface relies on pre-allocated memory in TBlob, the caller need to set diff --git a/include/mxnet/symbol.h b/include/mxnet/symbol.h index e7869e6b89d2..18b4466706d4 100644 --- a/include/mxnet/symbol.h +++ b/include/mxnet/symbol.h @@ -61,22 +61,6 @@ class Symbol { * \return the symbol corresponds to the indexed element. */ Symbol operator[] (int index) const; - /*! - * \brief infer the shapes of outputs and unknown input arguments - * \param in_shape the shape of input arguments of the operator - * this should be of same length as the vector returned by ListArguments - * in_shape allows unknown elements, which are checked by shape.ndim() == 0. - * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape - * For known shapes, InferShape will check shape consistency - * - * common practice: set the shape of data input, and usually weight's shape can be infered - * - * \param out_shape the shape of outputs of the operator - * InferShape will modify the vector to fill output TShape - * \return if the shape inference is successful, return true, else return false. - */ - bool InferShape(std::vector *in_shape, - std::vector *out_shape); /*! * \brief Compose the symbol with arguments, this changes current symbol. * @@ -122,6 +106,26 @@ class Symbol { s.Compose(kwargs, name); return s; } + /*! + * \brief infer the shapes of outputs and unknown input arguments + * \param in_shape the shape of input arguments of the operator + * this should be of same length as the vector returned by ListArguments + * in_shape allows unknown elements, which are checked by shape.ndim() == 0. + * For unknown shapes, InferShape will try to fill in the correct Shape in in_shape + * For known shapes, InferShape will check shape consistency + * + * common practice: set the shape of data input, and usually weight's shape can be infered + * + * \param out_shape the shape of outputs of the operator + * InferShape will modify the vector to fill output TShape + * \return if the shape inference is successful, return true, else return false. + */ + inline bool InferShape(std::vector *in_shape, + std::vector *out_shape) { + StaticGraph g; + Symbol::Convert({*this}, &g); + return g.InferShape(in_shape, out_shape); + } /*! * \brief create Symbol by wrapping AtomicSymbol * This function takes the ownership of atomic_symbol. @@ -219,8 +223,7 @@ class Symbol { */ template inline void DFSVisit(FVisit fvisit) const { - std::vector tmp = { head_ }; - DFSVisit(tmp, fvisit); + DFSVisit({*this}, fvisit); } /*! * \brief Visit all the nodes in left-to-right depth first order. @@ -232,15 +235,17 @@ class Symbol { * \tparam FVisit visiting function type */ template - static inline void DFSVisit(const std::vector &heads, - FVisit fvisit) { + static inline void DFSVisit(const std::vector &heads, + FVisit fvisit) { std::vector stack; std::unordered_set visited; // put the head into the graph for (auto &head : heads) { - Node *ptr = head.source.get(); - stack.push_back(ptr); - visited.insert(ptr); + Node *ptr = head.head_.source.get(); + if (visited.count(ptr) == 0) { + stack.push_back(ptr); + visited.insert(ptr); + } } while (!stack.empty()) { Node* back = stack.back(); @@ -255,12 +260,6 @@ class Symbol { } } } - /*! \brief Toposort the symbol - * \prarm heads symbol's head - * \prarm ret sorted nodes - */ - static inline void Toposort(const std::vector &heads, - std::vector *ret); /*! * \brief Find duplicate arguments in the composition * \param out the map of argument-name -> occurence count diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc new file mode 100644 index 000000000000..9175fa9d55b0 --- /dev/null +++ b/src/symbol/static_graph.cc @@ -0,0 +1,85 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file static_graph.cc + * \brief static graph of mxnet + */ +#include +#include +#include +#include + +std::vector StaticGraph::TopoSort() const { + std::vector out_degree(nodes.size(), 0); + for (const Node &n : nodes) { + for (const DataEntry &e : n.inputs) { + ++out_degree[e.source_id]; + } + } + std::vector ret(nodes.size()); + auto result = ret.rbegin(); + std::queue queue; + for (size_t i = 0; i < nodes.size(); ++i) { + if (out_degree[i] == 0) { + queue.push(static_cast(i)); + } + } + while (!queue.empty()) { + uint32_t node_id = queue.front(); + queue.pop(); + *result = node_id; + ++result; + for (const DataEntry &e : nodes[node_id].inputs) { + out_degree[e.source_id] -= 1; + if (out_degree[e.source_id] == 0) { + queue.push(e.source_id); + } + } + } + return std::move(ret); +} + +bool StaticGraph::InferShape(const std::vector &topo_order, + std::vector > *node_out_shapes) const { + bool success = true; + for (uint32_t nid : topo_order) { + const Node &node = nodes[nid]; + if (node.sym != nullptr) { + std::vector in_shape; + for (const DataEntry &e : node.inputs) { + in_shape.push_back(node_out_shapes[e.source_id][e.index]); + } + if (!node.sym->InferShape(&in_shape, &(*node_out_shapes)[nid])) return false; + for (size_t i = 0; i < node.inputs.size(); ++i) { + const DataEntry &e = node.inputs[i]; + node_out_shapes[e.source_id][e.index] = in_shape[i]; + } + } + } + return true; +} + +bool StaticGraph::InferShape(std::vector *in_shape, + std::vector *out_shape) const { + std::vector > node_out_shapes(nodes.size()); + for (size_t i = 0; i < nodes.size(); ++i) { + int nout = 1; + if (nodes[i].sym != nullptr) { + nout = nodes[i].sym->NumReturns(); + } + node_out_shapes[i].resize(nout); + } + CHECK(in_shape->size() == arg_nodes.size()) + << "Wrong number of inputs to infer shape"; + for (size_t i = 0; i < arg_nodes.size(); ++i) { + node_out_shapes[nid][0] = (*in_shape)[i]; + } + if (!InferNodeShapes(this->TopoSort(), + &node_out_shapes)) return false; + for (size_t i = 0; i < arg_nodes.size(); ++i) { + (*in_shape)[i] = node_out_shapes[nid][0]; + } + for (size_t i = 0; i < outputs.size(); ++i) { + DataEntry e = outputs[i]; + (*out_shape)[i] = node_out_shapes[e.source_id][e.index]; + } +} diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 7a1e932ca872..2a4ed45df1a8 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -250,71 +250,6 @@ std::vector Symbol::ListArguments() const { } } -inline void Symbol::Toposort(const std::vector &heads, - std::vector *ret) { - std::unordered_map out_degree; - std::queue queue; - ret->clear(); - size_t idx = 0; - DFSVisit(heads, - [&out_degree](Node* node) { - for (auto &entry : node->inputs) { - Node *ptr = entry.source.get(); - auto iter = out_degree.find(ptr); - if (iter == out_degree.end()) { - out_degree[ptr] = 0; - } else { - iter->second += 1; - } - } - }); - for (auto &entry : heads) { - queue.push(entry.source.get()); - } - idx = out_degree.size(); - ret->resize(idx); - --idx; - while (queue.size() > 0) { - Node *node = queue.front(); - queue.pop(); - ret->at(idx--) = node; - for (auto it = node->inputs.rbegin(); it != node->inputs.rend(); ++it) { - Node *ptr = it->source.get(); - out_degree[ptr] -= 1; - if (out_degree[ptr] == 0) { - queue.push(ptr); - } - } - } -} - -bool Symbol::InferShape(std::vector *in_shape, - std::vector *out_shape) { - bool success = true; - StaticGraph graph; - auto input_args = this->ListArguments(); - std::vector tmp_arg = {*this}; - CHECK(in_shape->size() == input_args.size()) << "Input shape should be same to arguments"; - out_shape->clear(); - Convert(tmp_arg, &graph); - for (size_t i = 0; i < in_shape->size(); ++i) { - graph.nodes[graph.in_args_node_id[i]].in_shape.push_back(in_shape->at(i)); - } - for (auto &nd : graph.nodes) { - success &= nd.sym->InferShape(&nd.in_shape, &nd.out_shape); - } - // copy result back - for (size_t i = 0; i < in_shape->size(); ++i) { - in_shape->at(i) = graph.nodes[graph.in_args_node_id[i]].in_shape[0]; - } - for (auto i : graph.return_node_id) { - for (auto sp : graph.nodes[i].out_shape) { - out_shape->push_back(sp); - } - } - return success; -} - std::vector Symbol::ListReturns() const { return head_.source->sym->ListReturns(); } @@ -323,49 +258,58 @@ Symbol Symbol::Create(AtomicSymbol *atomic_symbol) { // use special representation for atomic symbol Symbol s; s.head_ = DataEntry(std::make_shared(atomic_symbol, ""), - atomic_symbol->ListReturns().size() > 1 ? -1 : 0); + atomic_symbol->NumReturns() > 1 ? -1 : 0); return s; } void Symbol::Convert(const std::vector &heads, StaticGraph *out_graph) { // TODO(bing): Check unique name - std::vector nodes; - std::unordered_map node_id_dic; - std::vector arg(heads.size()); - for (size_t i = 0; i < heads.size(); ++i) { - arg[i] = heads[i].head_; - } - Toposort(arg, &nodes); - out_graph->nodes.resize(nodes.size()); - // set up dict - for (size_t i = 0; i < nodes.size(); ++i) { - node_id_dic[nodes[i]] = i; - } - // copy - for (size_t i = 0; i < nodes.size(); ++i) { - out_graph->name_id_map[nodes[i]->name] = i; - if (!nodes[i]->is_variable()) { - out_graph->nodes[i].sym.reset(nodes[i]->sym->Copy()); + std::vector node_order; + std::unordered_map node_index; + auto &arg_nodes = out_graph->arg_nodes; + arg_nodes.clear(); + + DFSVisit(heads, [&node_order, &node_index, &arg_nodes](Node *n) { + uint32_t nid = static_cast(node_index.size()); + node_index[n] = nid; + if (n->is_variable()) { + arg_nodes.push_back(nid); + } + node_order.push_back(n); + }); + // setup nodes + out_graph->nodes.resize(node_index.size()); + for (uint32_t nid = 0; nid < node_order.size(); ++nid) { + if (node_order[nid]->sym != nullptr) { + out_graph->nodes[nid].sym.reset(node_order[nid]->sym->Copy()); + } else { + out_graph->nodes[nid].sym.reset(nullptr); } - out_graph->nodes[i].name = nodes[i]->name; - for (auto &entry : nodes[i]->inputs) { - out_graph->nodes[i].inputs_index.push_back(node_id_dic[entry.source.get()]); - out_graph->nodes[node_id_dic[entry.source.get()]].outputs_index.push_back(i); + out_graph->nodes[nid].name = node_order[nid]->name; + auto &inputs = out_graph->nodes[nid].inputs; + inputs.clear(); + for (const DataEntry &src : node_order[nid]->inputs) { + StaticGraph::DataEntry e; + e.index = src.index; + e.source_id = node_index[src.source.get()]; + inputs.push_back(e); } } - // set input map - for (auto const &head : heads) { - auto input_args = head.ListArguments(); - out_graph->in_args_node_id.resize(input_args.size()); - for (size_t i = 0; i < input_args.size(); ++i) { - out_graph->in_args_node_id[i] = out_graph->name_id_map[input_args[i]]; + // setup heads + out_graph->outputs.clear(); + for (auto &head : heads) { + StaticGraph::DataEntry e; + e.source_id = node_index[head.head_.source.get()]; + if (head.head_.index == -1) { + int nout = head.head_.source->sym->NumReturns(); + for (int i = 0; i < nout; ++i) { + e.index = i; + out_graph->outputs.push_back(e); + } + } else { + e.index = head.head_.index; + out_graph->outputs.push_back(e); } } - // set output map - out_graph->return_node_id.resize(heads.size()); - for (size_t i = 0; i < heads.size(); ++i) { - out_graph->return_node_id[i] = out_graph->name_id_map[heads[i].head_.source->name]; - } } - } // namespace mxnet