Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

symbolic autodiff #27

Merged
merged 6 commits into from
Aug 23, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@ include $(DMLC_CORE)/make/dmlc.mk
# all tge possible warning tread
WARNFLAGS= -Wall
CFLAGS = -DMSHADOW_FORCE_STREAM $(WARNFLAGS)
CFLAGS += -g -O3 -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS)

# CFLAGS for debug
ifeq ($(DEBUG),0)
CFLAGS += -O3
else
CFLAGS += -g -O0
endif
CFLAGS += -I./mshadow/ -I./dmlc-core/include -fPIC -Iinclude $(MSHADOW_CFLAGS)
LDFLAGS = -pthread $(MSHADOW_LDFLAGS) $(DMLC_LDFLAGS)
NVCCFLAGS = --use_fast_math -g -O3 -ccbin $(CXX) $(MSHADOW_NVCCFLAGS)
ROOTDIR = $(CURDIR)
Expand Down
13 changes: 13 additions & 0 deletions include/mxnet/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,19 @@ MXNET_DLL int MXSymbolCompose(SymbolHandle sym,
mx_uint num_args,
const char** keys,
SymbolHandle* args);
/*!
* \brief Get the gradient graph of the symbol
*
* \param sym the symbol to get gradient
* \param num_wrt number of arguments to get gradient
* \param wrt the name of the arguments to get gradient
* \param out the returned symbol that has gradient
* \return 0 when success, -1 when failure happens
*/
MXNET_DLL int MXSymbolGrad(SymbolHandle sym,
mx_uint num_wrt,
const char** wrt,
SymbolHandle* out);
/*!
* \brief infer shape of unknown input shapes given the known one.
* The shapes are packed into a CSR matrix represented by arg_ind_ptr and arg_shape_data
Expand Down
11 changes: 9 additions & 2 deletions include/mxnet/symbolic.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
* \file symbolic.h
* \brief Symbolic interface of mxnet.
* \author Min Lin, Bing Xu
*/
*/
#ifndef MXNET_SYMBOLIC_H_
#define MXNET_SYMBOLIC_H_

Expand Down Expand Up @@ -161,7 +161,7 @@ class StaticGraph {
* \param arg_grads used to store gradients to args, can be multiple one if an argument is used by operator
*/
void MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads);
std::vector<DataEntry> *arg_grads);

/*!
* \brief create a sum node that aggregates gradient together
Expand Down Expand Up @@ -254,6 +254,13 @@ class Symbol {
*/
Symbol operator () (const std::unordered_map<std::string, Symbol>& kwargs,
const std::string& name) const;
/*!
* \brief get the gradient graph
* \param wrt with respect to the input
* \return the new symbol with gradient graph
*/
Symbol Grad(const std::vector<std::string>& wrt) const;

/*!
* \brief infer the shapes of outputs and unknown input arguments
* \param arg_shapes the shape of input arguments of the operator
Expand Down
3 changes: 3 additions & 0 deletions make/config.mk
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,9 @@ export CC = gcc
export CXX = g++
export NVCC = nvcc

# whether compile with debug
DEBUG = 0

# whether use CUDA during compile
USE_CUDA = 0

Expand Down
15 changes: 15 additions & 0 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,21 @@ def bind(self, ctx, args, args_grad, reqs):
ctypes.byref(handle)))
return Executor(handle)

def grad(self, wrt):
"""get the autodiff of current symbol.
Parameters
----------
wrt: Array of String
keyword arguments of the symbol that the gradients are taken.
"""
handle = SymbolHandle()
c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt])
check_call(_LIB.MXSymbolGrad(self.handle,
mx_uint(len(wrt)),
c_wrt,
ctypes.byref(handle)))
return Symbol(handle)

def Variable(name):
"""Create a symbolic variable with specified name.
Expand Down
14 changes: 13 additions & 1 deletion src/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ int MXSymbolCompose(SymbolHandle sym,
API_END();
}

int MXSymbolGrad(SymbolHandle sym, mx_uint num_wrt, const char** wrt, SymbolHandle* out) {
API_BEGIN();
Symbol* s = static_cast<Symbol*>(sym);
std::vector<std::string> wrts(num_wrt);
for (mx_uint i = 0; i < num_wrt; ++i) {
wrts[i] = wrt[i];
}
Symbol* ret = new Symbol;
*ret = s->Grad(wrts);
*out = ret;
API_END();
}

int MXSymbolInferShape(SymbolHandle sym,
mx_uint num_args,
const char** keys,
Expand Down Expand Up @@ -596,4 +609,3 @@ int MXExecutorBind(SymbolHandle symbol_handle,
*out = Executor::Bind(*symb, ctx, in_args_vec, arg_grad_vec, grad_req_vec);
API_END();
}

26 changes: 9 additions & 17 deletions src/symbol/graph_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
* Copyright (c) 2015 by Contributors
* \file graph_executor.cc
* \brief Executor to execute the Graph.
*/
*/
#include <dmlc/logging.h>
#include <mxnet/symbolic.h>
#include <memory>
Expand Down Expand Up @@ -200,7 +200,7 @@ GraphExecutor::GetOpExecEntry(uint32_t nid) {
}

void GraphExecutor::InitGraph(Symbol symbol, Context ctx, bool need_backward) {
// initialize all internal daa structures
// initialize all internal data structures
symbol.ToStaticGraph(&graph_);
num_forward_nodes_ = graph_.nodes.size();
if (need_backward) {
Expand Down Expand Up @@ -252,21 +252,13 @@ void GraphExecutor::InitDataEntryInfo(const std::vector<NArray> &in_args,
if (grad_req_type[i] == kNullOp) continue;
CHECK_NE(grad_req_type[i], kWriteInplace)
<< "Gradient request can only be nullop, add, write";
std::vector<StaticGraph::DataEntry> &grad_source = arg_grads_[i];
CHECK_GE(grad_source.size(), 1);
// TODO(bing) add a aggregation node here
if (grad_source.size() > 1) {
CHECK_EQ(grad_req_type[i], kAddTo)
<< "The gradient contains multiple variables,";
}
for (StaticGraph::DataEntry e : grad_source) {
DataEntryInfo &info = op_nodes_[e.source_id].outputs[e.index];
info.type = kBindByExternal;
info.op_req = grad_req_type[i];
info.data = arg_grad_store[i];
++info.ref_count;
op_nodes_[e.source_id].activated = true;
}
StaticGraph::DataEntry &grad_source = arg_grads_[i];
DataEntryInfo &info = op_nodes_[grad_source.source_id].outputs[grad_source.index];
info.type = kBindByExternal;
info.op_req = grad_req_type[i];
info.data = arg_grad_store[i];
++info.ref_count;
op_nodes_[grad_source.source_id].activated = true;
}
// setup head gradient
for (uint32_t nid : head_grad_nodes_) {
Expand Down
2 changes: 1 addition & 1 deletion src/symbol/graph_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class GraphExecutor : public Executor {
// head gradient node in the graph, if there is backward pass
std::vector<uint32_t> head_grad_nodes_;
// argument node in the graph, if there is backward pass
std::vector<std::vector<StaticGraph::DataEntry> > arg_grads_;
std::vector<StaticGraph::DataEntry> arg_grads_;
// operational nodes
std::vector<OpNode> op_nodes_;
// head NArrays
Expand Down
14 changes: 12 additions & 2 deletions src/symbol/static_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ StaticGraph::Node StaticGraph::CreateSumNode(
}

void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
std::vector<std::vector<DataEntry> > *arg_grads) {
std::vector<DataEntry> *arg_grads) {
arg_grads->clear();
head_grad_nodes->clear();
// get topo order of nodes, before new nodes are added
Expand Down Expand Up @@ -254,7 +254,17 @@ void StaticGraph::MakeBackwardPass(std::vector<uint32_t> *head_grad_nodes,
DataEntry odata(arg_nodes[i], 0);
auto it = grad_map.find(odata);
CHECK(it != grad_map.end()) << "bad graph";
arg_grads->at(i) = it->second;
if (it->second.size() == 1) {
arg_grads->at(i) = it->second[0];
} else {
std::ostringstream os_name;
Node agg_node = StaticGraph::CreateSumNode(it->second);
os_name << nodes[arg_nodes[i]].name << "_grad_agg";
agg_node.name = os_name.str();
uint32_t agg_node_id = static_cast<uint32_t>(nodes.size());
nodes.push_back(std::move(agg_node));
arg_grads->at(i) = DataEntry(agg_node_id, 0);
}
}
}
} // namespace mxnet
Loading