From 2dbdc0fdca8c785b8036366f6d63a463271bd837 Mon Sep 17 00:00:00 2001 From: linmin Date: Sat, 22 Aug 2015 10:48:06 +0800 Subject: [PATCH 1/6] dfs visit use const shared_ptr& --- src/symbol/symbol.cc | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index aecac3dda487..8bf36d309018 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -52,24 +52,24 @@ inline bool Symbol::is_atomic() const { // implementation of template functions template inline void Symbol::DFSVisit(FVisit fvisit) const { - std::vector stack; + std::vector*> stack; std::unordered_set visited; // put the head into the graph for (auto &head : heads_) { Node *ptr = head.source.get(); if (visited.count(ptr) == 0) { - stack.push_back(ptr); + stack.push_back(&head.source); visited.insert(ptr); } } while (!stack.empty()) { - Node *back = stack.back(); + const std::shared_ptr *back = stack.back(); stack.pop_back(); - fvisit(back); - for (auto it = back->inputs.rbegin(); it != back->inputs.rend(); ++it) { + fvisit(*back); + for (auto it = back->get()->inputs.rbegin(); it != back->get()->inputs.rend(); ++it) { Node *ptr = it->source.get(); if (visited.count(ptr) == 0) { - stack.push_back(ptr); + stack.push_back(&it->source); visited.insert(ptr); } } @@ -101,7 +101,7 @@ inline void KeywordArgumentMismatch(const char *source, int Symbol::FindDuplicateArgs(std::unordered_map *out) const { out->clear(); int max_dup = 1; - this->DFSVisit([out, &max_dup](Node *node) { + this->DFSVisit([out, &max_dup](const std::shared_ptr &node) { if (node->is_variable()) { auto iter = out->find(node->name); if (iter == out->end()) { @@ -119,11 +119,11 @@ int Symbol::FindDuplicateArgs(std::unordered_map *out) const { Symbol Symbol::Copy() const { std::unordered_map > old_new; // use DFSVisit to copy all the nodes - this->DFSVisit([&old_new](Node *node) { + this->DFSVisit([&old_new](const std::shared_ptr &node) { if (node->op == nullptr) { - old_new[node] = std::make_shared(nullptr, node->name); + old_new[node.get()] = std::make_shared(nullptr, node->name); } else { - old_new[node] = std::make_shared(node->op->Copy(), node->name); + old_new[node.get()] = std::make_shared(node->op->Copy(), node->name); } }); // connect nodes of new graph @@ -156,7 +156,7 @@ void Symbol::Print(std::ostream &os) const { os << "\toutput[" << i << "]=" << heads_[i].source->name << '(' << heads_[i].index << ")\n"; } - this->DFSVisit([&os](Node *node) { + this->DFSVisit([&os](const std::shared_ptr &node) { if (node->is_variable()) { os << "Variable:" << node->name << '\n'; } else { @@ -176,7 +176,7 @@ std::vector Symbol::ListArguments() const { if (this->is_atomic()) { return heads_[0].source->op->ListArguments(); } else { - this->DFSVisit([&ret](Node *node) { + this->DFSVisit([&ret](const std::shared_ptr &node) { if (node->is_variable()) { ret.push_back(node->name); } @@ -243,7 +243,8 @@ void Symbol::Compose(const std::vector& args, std::unordered_map replace_map; std::vector > replace_plan; // replace map stores the existing replacement plan for arguments node - this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args](Node *node) { + this->DFSVisit([&arg_counter, &replace_map, &replace_plan, &args] + (const std::shared_ptr &node) { // visit all the childs, find possible replacement for (size_t i = 0; i < node->inputs.size(); ++i) { DataEntry *e = &(node->inputs[i]); @@ -324,7 +325,8 @@ void Symbol::Compose(const std::unordered_map& kwargs, std::vector > replace_plan; std::unordered_set visited; // replace map stores the existing replacement plan for arguments node - this->DFSVisit([&nmatched, &visited, &kwargs, &replace_plan](Node *node) { + this->DFSVisit([&nmatched, &visited, &kwargs, &replace_plan] + (const std::shared_ptr &node) { // visit all the childs, find possible replacement for (size_t i = 0; i < node->inputs.size(); ++i) { DataEntry *e = &(node->inputs[i]); @@ -431,13 +433,13 @@ void Symbol::ToStaticGraph(StaticGraph *out_graph) const { auto &arg_nodes = out_graph->arg_nodes; arg_nodes.clear(); - this->DFSVisit([&node_order, &node_index, &arg_nodes](Node *n) { + this->DFSVisit([&node_order, &node_index, &arg_nodes](const std::shared_ptr &n) { uint32_t nid = static_cast(node_index.size()); - node_index[n] = nid; + node_index[n.get()] = nid; if (n->is_variable()) { arg_nodes.push_back(nid); } - node_order.push_back(n); + node_order.push_back(n.get()); }); // setup nodes out_graph->nodes.resize(node_index.size()); From 2cc53b48c61103e1a5875a35659a92c51c2f2bb1 Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 23 Aug 2015 21:15:59 +0800 Subject: [PATCH 2/6] arg mismatch code --- src/symbol/symbol.cc | 23 +++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 8bf36d309018..3d3de8dfb697 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -78,10 +78,9 @@ inline void Symbol::DFSVisit(FVisit fvisit) const { // helper function to handle keyword argument mismatch // throw approperiate messages -template inline void KeywordArgumentMismatch(const char *source, - const TMap &kwargs, - const std::vector args) { + const std::vector &user_args, + const std::vector &args) { std::unordered_set keys(args.begin(), args.end()); std::ostringstream head, msg; msg << "\nCandidate arguments:\n"; @@ -89,10 +88,10 @@ inline void KeywordArgumentMismatch(const char *source, msg << "\t[" << i << ']' << args[i] << '\n'; } - for (const auto& kv : kwargs) { - if (keys.count(kv.first) == 0) { + for (const auto& key : user_args) { + if (keys.count(key) == 0) { LOG(FATAL) << source - << "Keyword argument name " << kv.first << " not found." + << "Keyword argument name " << key << " not found." << msg.str(); } } @@ -352,8 +351,10 @@ void Symbol::Compose(const std::unordered_map& kwargs, } } if (nmatched != kwargs.size()) { - KeywordArgumentMismatch( - "Symbol.Compose", kwargs, ListArguments()); + std::vector keys(kwargs.size()); + std::transform(kwargs.begin(), kwargs.end(), keys.begin(), + [](decltype(*kwargs.begin())& kv)->std::string { return kv.first; }); + KeywordArgumentMismatch("Symbol.Compose", keys, ListArguments()); } } @@ -395,8 +396,10 @@ bool Symbol::InferShape(const std::unordered_map& known_arg } } if (nmatched != known_arg_shapes.size()) { - KeywordArgumentMismatch( - "Symbol.InterShape", known_arg_shapes, ListArguments()); + std::vector keys(known_arg_shapes.size()); + std::transform(known_arg_shapes.begin(), known_arg_shapes.end(), keys.begin(), + [](decltype(*known_arg_shapes.begin())& kv)->std::string { return kv.first; }); + KeywordArgumentMismatch("Symbol.InterShape", keys, ListArguments()); } return g.InferShape(arg_shapes, out_shapes); } From d8bcd68fc0b575e6ea89d9e8469905d48026ccd8 Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 23 Aug 2015 21:18:23 +0800 Subject: [PATCH 3/6] agg arg_grads --- include/mxnet/symbolic.h | 2 +- src/symbol/graph_executor.cc | 22 +++++++--------------- src/symbol/graph_executor.h | 2 +- src/symbol/static_graph.cc | 14 ++++++++++++-- 4 files changed, 21 insertions(+), 19 deletions(-) diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index df06c4913de8..8c3a6dfe7426 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -161,7 +161,7 @@ class StaticGraph { * \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator */ void MakeBackwardPass(std::vector *head_grad_nodes, - std::vector > *arg_grads); + std::vector *arg_grads); /*! * \brief create a sum node that aggregates gradient together diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 8dbadb34e24e..9cd01ff8756c 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -252,21 +252,13 @@ void GraphExecutor::InitDataEntryInfo(const std::vector &in_args, if (grad_req_type[i] == kNullOp) continue; CHECK_NE(grad_req_type[i], kWriteInplace) << "Gradient request can only be nullop, add, write"; - std::vector &grad_source = arg_grads_[i]; - CHECK_GE(grad_source.size(), 1); - // TODO(bing) add a aggregation node here - if (grad_source.size() > 1) { - CHECK_EQ(grad_req_type[i], kAddTo) - << "The gradient contains multiple variables,"; - } - for (StaticGraph::DataEntry e : grad_source) { - DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index]; - info.type = kBindByExternal; - info.op_req = grad_req_type[i]; - info.data = arg_grad_store[i]; - ++info.ref_count; - op_nodes_[e.source_id].activated = true; - } + StaticGraph::DataEntry &grad_source = arg_grads_[i]; + DataEntryInfo &info = op_nodes_[grad_source.source_id].outputs[grad_source.index]; + info.type = kBindByExternal; + info.op_req = grad_req_type[i]; + info.data = arg_grad_store[i]; + ++info.ref_count; + op_nodes_[grad_source.source_id].activated = true; } // setup head gradient for (uint32_t nid : head_grad_nodes_) { diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h index d2bc84d0733d..ff4eb19dc410 100644 --- a/src/symbol/graph_executor.h +++ b/src/symbol/graph_executor.h @@ -177,7 +177,7 @@ class GraphExecutor : public Executor { // head gradient node in the graph, if there is backward pass std::vector head_grad_nodes_; // argument node in the graph, if there is backward pass - std::vector > arg_grads_; + std::vector arg_grads_; // operational nodes std::vector op_nodes_; // head NArrays diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc index c9ed278b8f7e..53213e2aced7 100644 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -165,7 +165,7 @@ StaticGraph::Node StaticGraph::CreateSumNode( } void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, - std::vector > *arg_grads) { + std::vector *arg_grads) { arg_grads->clear(); head_grad_nodes->clear(); // get topo order of nodes, before new nodes are added @@ -254,7 +254,17 @@ void StaticGraph::MakeBackwardPass(std::vector *head_grad_nodes, DataEntry odata(arg_nodes[i], 0); auto it = grad_map.find(odata); CHECK(it != grad_map.end()) << "bad graph"; - arg_grads->at(i) = it->second; + if (it->second.size() == 1) { + arg_grads->at(i) = it->second[0]; + } else { + std::ostringstream os_name; + Node agg_node = StaticGraph::CreateSumNode(it->second); + os_name << nodes[arg_nodes[i]].name << "_grad_agg"; + agg_node.name = os_name.str(); + uint32_t agg_node_id = static_cast(nodes.size()); + nodes.push_back(std::move(agg_node)); + arg_grads->at(i) = DataEntry(agg_node_id, 0); + } } } } // namespace mxnet From 686c61908312a93f09a337748173dbc168d050e8 Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 23 Aug 2015 21:42:33 +0800 Subject: [PATCH 4/6] symbolic gradient --- include/mxnet/c_api.h | 13 +++++++++ include/mxnet/symbolic.h | 9 +++++- python/mxnet/symbol.py | 8 ++++++ src/c_api.cc | 14 ++++++++- src/symbol/graph_executor.cc | 4 +-- src/symbol/symbol.cc | 56 ++++++++++++++++++++++++++++++++++-- 6 files changed, 98 insertions(+), 6 deletions(-) diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index 1178af9db5fd..5802c32cf75c 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -347,6 +347,19 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym, mx_uint num_args, const char** keys, SymbolHandle* args); +/*! + * \brief Get the gradient graph of the symbol + * + * \param sym the symbol to get gradient + * \param num_wrt number of arguments to get gradient + * \param wrt the name of the arguments to get gradient + * \param out the returned symbol that has gradient + * \return 0 when success, -1 when failure happens + */ +MXNET_DLL int MXSymbolGrad(SymbolHandle sym, + mx_uint num_wrt, + const char** wrt, + SymbolHandle* out); /*! * \brief infer shape of unknown input shapes given the known one. * The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h index 8c3a6dfe7426..15b143c87bf7 100644 --- a/include/mxnet/symbolic.h +++ b/include/mxnet/symbolic.h @@ -3,7 +3,7 @@ * \file symbolic.h * \brief Symbolic interface of mxnet. * \author Min Lin, Bing Xu -*/ + */ #ifndef MXNET_SYMBOLIC_H_ #define MXNET_SYMBOLIC_H_ @@ -254,6 +254,13 @@ class Symbol { */ Symbol operator () (const std::unordered_map& kwargs, const std::string& name) const; + /*! + * \brief get the gradient graph + * \param wrt with respect to the input + * \return the new symbol with gradient graph + */ + Symbol Grad(const std::vector& wrt) const; + /*! * \brief infer the shapes of outputs and unknown input arguments * \param arg_shapes the shape of input arguments of the operator diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 5818d0f53305..0f3acf2afddd 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -243,6 +243,14 @@ def bind(self, ctx, args, args_grad, reqs): ctypes.byref(handle))) return Executor(handle) + def grad(self, wrt): + handle = SymbolHandle() + c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt]) + check_call(_LIB.MXSymbolGrad(self.handle, + mx_uint(len(wrt)), + c_wrt, + ctypes.byref(handle))) + return Symbol(handle) def Variable(name): """Create a symbolic variable with specified name. diff --git a/src/c_api.cc b/src/c_api.cc index a5ed648469e1..b251ba578743 100644 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -484,6 +484,19 @@ int MXSymbolCompose(SymbolHandle sym, API_END(); } +int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) { + API_BEGIN(); + Symbol* s = static_cast(sym); + std::vector wrts(num_wrt); + for (mx_uint i = 0; i < num_wrt; ++i) { + wrts[i] = wrt[i]; + } + Symbol* ret = new Symbol; + *ret = s->Grad(wrts); + *out = ret; + API_END(); +} + int MXSymbolInferShape(SymbolHandle sym, mx_uint num_args, const char** keys, @@ -596,4 +609,3 @@ int MXExecutorBind(SymbolHandle symbol_handle, *out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec); API_END(); } - diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc index 9cd01ff8756c..68de552e7f21 100644 --- a/src/symbol/graph_executor.cc +++ b/src/symbol/graph_executor.cc @@ -2,7 +2,7 @@ * Copyright (c) 2015 by Contributors * \file graph_executor.cc * \brief Executor to execute the Graph. -*/ + */ #include #include #include @@ -200,7 +200,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) { } void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) { - // initialize all internal daa structures + // initialize all internal data structures symbol.ToStaticGraph(&graph_); num_forward_nodes_ = graph_.nodes.size(); if (need_backward) { diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index 3d3de8dfb697..897b22663730 100644 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -21,6 +21,8 @@ namespace mxnet { * - Variable: the sym_ is nullptr, represents an named Variable of tensors that can be composed. */ struct Symbol::Node { + /*! \brief source node of the current node */ + std::shared_ptr backward_source_node; /*! \brief Operator of this node */ std::unique_ptr op; /*! \brief name of the node */ @@ -41,7 +43,7 @@ struct Symbol::Node { } /*! \return Whether it is unit variable */ inline bool is_variable() const { - return op == nullptr; + return op == nullptr && !backward_source_node; } }; @@ -159,7 +161,13 @@ void Symbol::Print(std::ostream &os) const { if (node->is_variable()) { os << "Variable:" << node->name << '\n'; } else { - os << "Name: " << node->name << " Type:" << node->op->TypeString() << '\n' + std::string type_string; + if (!node->backward_source_node) { + type_string = node->op->TypeString(); + } else { + type_string = node->backward_source_node->op->TypeString(); + } + os << "Name: " << node->name << " Type:" << type_string << '\n' << "Inputs:\n"; for (size_t i = 0; i < node->inputs.size(); ++i) { os << "\targ[" << i << "]=" << node->inputs[i].source->name @@ -372,6 +380,50 @@ Symbol Symbol::operator () (const std::unordered_map& kwarg return s; } +Symbol Symbol::Grad(const std::vector& wrt) const { + StaticGraph g; + this->ToStaticGraph(&g); + uint32_t num_nodes = g.nodes.size(); + std::vector head_grad_nodes; + std::vector arg_grads; + g.MakeBackwardPass(&head_grad_nodes, &arg_grads); + std::vector > shared_node; + this->DFSVisit([&shared_node](const std::shared_ptr &n) { + shared_node.push_back(n); + }); + for (std::vector::const_iterator it = g.nodes.begin() + num_nodes; + it != g.nodes.end(); ++it) { + auto sym_node = std::make_shared(); + sym_node->name = it->name; + if (it->backward_source_id != -1) { + sym_node->backward_source_node = shared_node[it->backward_source_id]; + } + shared_node.push_back(sym_node); + for (auto e : it->inputs) { + Symbol::DataEntry entry(shared_node[e.source_id], e.index); + sym_node->inputs.push_back(std::move(entry)); + } + } + // make arg lookup dict + auto arg_list = ListArguments(); + std::unordered_map arg_index; + for (uint32_t i = 0; i < arg_list.size(); ++i) { + arg_index[arg_list[i]] = i; + } + // generate the heads + Symbol ret; + for (const std::string& name : wrt) { + if (arg_index.find(name) != arg_index.end()) { + uint32_t index = arg_index[name]; + Symbol::DataEntry entry(shared_node[arg_grads[index].source_id], arg_grads[index].index); + ret.heads_.push_back(entry); + } else { + KeywordArgumentMismatch("Symbol.Grad ", wrt, arg_list); + } + } + return ret; +} + bool Symbol::InferShape(std::vector *arg_shapes, std::vector *out_shapes) const { StaticGraph g; From 01362f8629fe29f428462b79e4dbd2982067a37d Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 23 Aug 2015 21:26:16 +0800 Subject: [PATCH 5/6] compile debug --- Makefile | 9 ++++++++- make/config.mk | 3 +++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Makefile b/Makefile index fe906dcd1a1e..e95ee067980f 100644 --- a/Makefile +++ b/Makefile @@ -24,7 +24,14 @@ include $(DMLC_CORE)/make/dmlc.mk # all tge possible warning tread WARNFLAGS= -Wall CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS) -CFLAGS += -g -O3 -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS) + +# CFLAGS for debug +ifeq ($(DEBUG),0) + CFLAGS += -O3 +else + CFLAGS += -g -O0 +endif +CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS) LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS) NVCCFLAGS = --use_fast_math -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS) ROOTDIR = $(CURDIR) diff --git a/make/config.mk b/make/config.mk index a23ff2147612..cd04b146180c 100644 --- a/make/config.mk +++ b/make/config.mk @@ -14,6 +14,9 @@ export CC = gcc export CXX = g++ export NVCC = nvcc +# whether compile with debug +DEBUG = 0 + # whether use CUDA during compile USE_CUDA = 0 From 1e0459c1e3223aed3cc18ff454c80732aac39948 Mon Sep 17 00:00:00 2001 From: linmin Date: Sun, 23 Aug 2015 21:58:29 +0800 Subject: [PATCH 6/6] fix docstring --- python/mxnet/symbol.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 0f3acf2afddd..899f78ec22b8 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -244,6 +244,13 @@ def bind(self, ctx, args, args_grad, reqs): return Executor(handle) def grad(self, wrt): + """get the autodiff of current symbol. + + Parameters + ---------- + wrt: Array of String + keyword arguments of the symbol that the gradients are taken. + """ handle = SymbolHandle() c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt]) check_call(_LIB.MXSymbolGrad(self.handle,