Skip to content

Commit d26e7c0

Browse files
committed
[PASS] Add gradient pass (#28)
1 parent 112b078 commit d26e7c0

File tree

13 files changed

+357
-20
lines changed

13 files changed

+357
-20
lines changed

nnvm/example/src/operator.cc

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
1515
using nnvm::FInferShape;
1616
using nnvm::FInferType;
1717
using nnvm::FInplaceOption;
18+
using nnvm::Node;
19+
using nnvm::NodePtr;
20+
using nnvm::NodeEntry;
21+
using nnvm::FGradient;
1822
using nnvm::NodeAttrs;
1923
using nnvm::TShape;
2024
using nnvm::array_view;
@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
3741
return {{0, 0}};
3842
}
3943

44+
// quick helper to make node
45+
inline NodeEntry MakeNode(const char* op_name,
46+
std::string node_name,
47+
std::vector<NodeEntry> inputs) {
48+
NodePtr p = Node::Create();
49+
p->op = nnvm::Op::Get(op_name);
50+
p->attrs.name = std::move(node_name);
51+
p->inputs = std::move(inputs);
52+
return NodeEntry{p, 0, 0};
53+
}
54+
4055
// simple demonstration of reshape.
4156
NNVM_REGISTER_OP(reshape)
4257
.describe("reshape source to target shape")
@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
8499
return true;
85100
});
86101

102+
NNVM_REGISTER_OP(exp)
103+
.describe("take exponential")
104+
.set_num_inputs(1)
105+
.attr<FInferShape>("FInferShape", SameShape)
106+
.attr<FGradient>(
107+
"FGradient", [](const NodePtr& n,
108+
const std::vector<NodeEntry>& ograds) {
109+
return std::vector<NodeEntry>{
110+
MakeNode("mul", n->attrs.name + "_grad",
111+
{ograds[0], NodeEntry{n, 0, 0}})
112+
};
113+
});
114+
115+
NNVM_REGISTER_OP(identity)
116+
.describe("identity function")
117+
.set_num_inputs(1)
118+
.attr<FInferShape>("FInferShape", SameShape)
119+
.attr<FGradient>(
120+
"FGradient", [](const NodePtr& n,
121+
const std::vector<NodeEntry>& ograds) {
122+
return std::vector<NodeEntry>{ograds[0]};
123+
});
87124

88125
NNVM_REGISTER_OP(add)
89126
.describe("add two data together")
90127
.set_num_inputs(2)
91128
.attr<FInferShape>("FInferShape", SameShape)
92-
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0);
129+
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
130+
.attr<FGradient>(
131+
"FGradient", [](const NodePtr& n,
132+
const std::vector<NodeEntry>& ograds){
133+
return std::vector<NodeEntry>{ograds[0], ograds[0]};
134+
});
93135

94-
NNVM_REGISTER_OP(__add_symbol__)
95-
.describe("Alias of add")
96-
.set_num_inputs(2);
136+
NNVM_REGISTER_OP(mul)
137+
.describe("multiply two data together")
138+
.set_num_inputs(2)
139+
.attr<FInferShape>("FInferShape", SameShape)
140+
.attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
141+
.attr<FGradient>(
142+
"FGradient", [](const NodePtr& n,
143+
const std::vector<NodeEntry>& ograds){
144+
return std::vector<NodeEntry>{
145+
MakeNode("mul", n->attrs.name + "_grad_0",
146+
{ograds[0], n->inputs[1]}),
147+
MakeNode("mul", n->attrs.name + "_grad_1",
148+
{ograds[0], n->inputs[0]})
149+
};
150+
});
97151

98-
NNVM_REGISTER_OP(exp)
99-
.describe("take exponential")
100-
.set_num_inputs(1)
101-
.attr<FInferShape>("FInferShape", SameShape);
152+
NNVM_REGISTER_OP(__ewise_sum__)
153+
.describe("elementwise sum")
154+
.set_num_inputs(nnvm::kVarg);
155+
156+
NNVM_REGISTER_OP(__zero__)
157+
.describe("set output to zero")
158+
.set_num_inputs(0);
159+
160+
NNVM_REGISTER_OP(__one__)
161+
.describe("set output to one")
162+
.set_num_inputs(0);
102163

103164
NNVM_REGISTER_OP(cross_device_copy)
104165
.describe("Copy data across device.")

nnvm/include/dmlc/base.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,11 @@
5858
__cplusplus >= 201103L || defined(_MSC_VER))
5959
#endif
6060

61+
/*! \brief strict CXX11 support */
62+
#ifndef DMLC_STRICT_CXX11
63+
#define DMLC_STRICT_CXX11 (__cplusplus >= 201103L || defined(_MSC_VER))
64+
#endif
65+
6166
/// check if g++ is before 4.6
6267
#if DMLC_USE_CXX11 && defined(__GNUC__) && !defined(__clang_version__)
6368
#if __GNUC__ == 4 && __GNUC_MINOR__ < 6
@@ -69,6 +74,7 @@
6974
#endif
7075
#endif
7176

77+
7278
/*!
7379
* \brief Enable std::thread related modules,
7480
* Used to disable some module in mingw compile.
@@ -82,6 +88,13 @@
8288
#define DMLC_USE_REGEX (__cplusplus >= 201103L || defined(_MSC_VER))
8389
#endif
8490

91+
/*! \brief helper macro to supress unused warning */
92+
#if defined(__GNUC__)
93+
#define DMLC_ATTRIBUTE_UNUSED __attribute__((unused))
94+
#else
95+
#define DMLC_ATTRIBUTE_UNUSED
96+
#endif
97+
8598
/*! \brief helper macro to generate string concat */
8699
#define DMLC_STR_CONCAT_(__x, __y) __x##__y
87100
#define DMLC_STR_CONCAT(__x, __y) DMLC_STR_CONCAT_(__x, __y)

nnvm/include/dmlc/json.h

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,9 @@
2525
#include <typeindex>
2626
#include <typeinfo>
2727
#include <unordered_map>
28+
#if DMLC_STRICT_CXX11
2829
#include "./any.h"
30+
#endif // DMLC_STRICT_CXX11
2931
#endif // DMLC_USE_CXX11
3032

3133
namespace dmlc {
@@ -320,7 +322,8 @@ class JSONObjectReadHelper {
320322
};
321323

322324
#define DMLC_JSON_ENABLE_ANY_VAR_DEF(KeyName) \
323-
static ::dmlc::json::AnyJSONManager& __make_AnyJSONType ## _ ## KeyName ## __
325+
static DMLC_ATTRIBUTE_UNUSED ::dmlc::json::AnyJSONManager& \
326+
__make_AnyJSONType ## _ ## KeyName ## __
324327

325328
/*!
326329
* \def DMLC_JSON_ENABLE_ANY
@@ -475,7 +478,7 @@ struct Handler {
475478
}
476479
};
477480

478-
#if DMLC_USE_CXX11
481+
#if DMLC_STRICT_CXX11
479482
// Manager to store json serialization strategy.
480483
class AnyJSONManager {
481484
public:
@@ -561,7 +564,7 @@ struct Handler<any> {
561564
CHECK(!reader->NextArrayItem()) << "invalid any json format";
562565
}
563566
};
564-
#endif // DMLC_USE_CXX11
567+
#endif // DMLC_STRICT_CXX11
565568

566569
} // namespace json
567570

nnvm/include/dmlc/parameter.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ struct Parameter {
251251
static ::dmlc::parameter::ParamManagerSingleton<PType> inst(#PType); \
252252
return &inst.manager; \
253253
} \
254-
static ::dmlc::parameter::ParamManager &__make__ ## PType ## ParamManager__ = \
254+
static DMLC_ATTRIBUTE_UNUSED ::dmlc::parameter::ParamManager& \
255+
__make__ ## PType ## ParamManager__ = \
255256
(*PType::__MANAGER__()) \
256257

257258
//! \endcond

nnvm/include/dmlc/registry.h

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ class FunctionRegEntryBase {
216216
* \sa FactoryRegistryEntryBase
217217
*/
218218
#define DMLC_REGISTRY_REGISTER(EntryType, EntryTypeName, Name) \
219-
static EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
219+
static DMLC_ATTRIBUTE_UNUSED EntryType & __make_ ## EntryTypeName ## _ ## Name ## __ = \
220220
::dmlc::Registry<EntryType>::Get()->__REGISTER__(#Name) \
221221

222222
/*!
@@ -272,6 +272,7 @@ class FunctionRegEntryBase {
272272
*/
273273
#define DMLC_REGISTRY_LINK_TAG(UniqueTag) \
274274
int __dmlc_registry_file_tag_ ## UniqueTag ## __(); \
275-
static int __reg_file_tag_ ## UniqueTag ## __ = __dmlc_registry_file_tag_ ## UniqueTag ## __();
275+
static int DMLC_ATTRIBUTE_UNUSED __reg_file_tag_ ## UniqueTag ## __ = \
276+
__dmlc_registry_file_tag_ ## UniqueTag ## __();
276277
} // namespace dmlc
277278
#endif // DMLC_REGISTRY_H_

nnvm/include/nnvm/c_api.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,6 +260,7 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
260260
* \return 0 when success, -1 when failure happens
261261
*/
262262
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
263+
263264
/*!
264265
* \brief Get Set a attribute in json format.
265266
* This feature allows pass graph attributes back and forth in reasonable speed.
@@ -273,6 +274,7 @@ NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
273274
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
274275
const char* key,
275276
const char* json_value);
277+
276278
/*!
277279
* \brief Get a serialized attrirbute from graph.
278280
* This feature allows pass graph attributes back and forth in reasonable speed.
@@ -289,6 +291,21 @@ NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
289291
const char* key,
290292
const char** json_out,
291293
int *success);
294+
295+
/*!
296+
* \brief Set a attribute whose type is std::vector<NodeEntry> in c++
297+
* This feature allows pass List of symbolic variables for gradient request.
298+
*
299+
* \note This is beta feature only used for test purpos
300+
*
301+
* \param handle The graph handle.
302+
* \param key The key to the attribute.
303+
* \param list The symbol whose outputs represents the list of NodeEntry to be passed.
304+
* \return 0 when success, -1 when failure happens
305+
*/
306+
NNVM_DLL int NNGraphSetNodeEntryListAttr_(GraphHandle handle,
307+
const char* key,
308+
SymbolHandle list);
292309
/*!
293310
* \brief Apply pass on the src graph.
294311
* \param src The source graph handle.

nnvm/include/nnvm/op.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -279,10 +279,8 @@ class OpMap {
279279
};
280280

281281
// internal macros to make
282-
#define NNVM_STR_CONCAT_(__x, __y) __x##__y
283-
#define NNVM_STR_CONCAT(__x, __y) NNVM_STR_CONCAT_(__x, __y)
284282
#define NNVM_REGISTER_VAR_DEF(OpName) \
285-
static ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
283+
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
286284

287285
/*!
288286
* \def NNVM_REGISTER_OP
@@ -300,7 +298,7 @@ class OpMap {
300298
* \endcode
301299
*/
302300
#define NNVM_REGISTER_OP(OpName) \
303-
NNVM_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
301+
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
304302
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
305303

306304
// implementations of template functions after this.

nnvm/include/nnvm/op_attr_types.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include <utility>
1212
#include <functional>
1313
#include "./base.h"
14+
#include "./node.h"
1415
#include "./tuple.h"
1516

1617
namespace nnvm {
@@ -107,6 +108,19 @@ using TIsBackwardOp = bool;
107108
using FInplaceOption = std::function<
108109
std::vector<std::pair<int, int> > (const NodeAttrs& attrs)>;
109110

111+
/*!
112+
* \brief Get the gradient node of the op node
113+
* This function generates the backward graph of the node
114+
* \param nodeptr The node to take gradient
115+
* \param out_grads Gradient of current node's outputs
116+
* \return gradients of the inputs
117+
*
118+
* \note Register under "FGradient"
119+
*/
120+
using FGradient = std::function<std::vector<NodeEntry>(
121+
const NodePtr& nodeptr,
122+
const std::vector<NodeEntry>& out_grads)>;
123+
110124
} // namespace nnvm
111125

112126
#endif // NNVM_OP_ATTR_TYPES_H_

nnvm/include/nnvm/pass_functions.h

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,37 @@ inline Graph PlaceDevice(Graph graph,
109109
return ApplyPass(std::move(graph), {"PlaceDevice"});
110110
}
111111

112+
/*!
113+
* \brief Get the gradient graph whose outputs are gradients of xs wrt to ys.
114+
* \param graph source graph
115+
* \param ys The entries we want to take gradient from.
116+
* \param xs The input to take gradient with respect to.
117+
* \param ys_out_grad The symbol for additional gradient to be propagate back to y.
118+
* \param aggregate_fun aggregation function applied to aggregate the inputs
119+
* \param mirror_fun Optional mirror function to do mirror optimization and save memory.
120+
* \return A new graph, whose outputs corresponds to inputs of xs.
121+
*/
122+
inline Graph Gradient(
123+
Graph graph,
124+
std::vector<NodeEntry> ys,
125+
std::vector<NodeEntry> xs,
126+
std::vector<NodeEntry> ys_out_grad,
127+
std::function<NodeEntry(std::vector<NodeEntry>&& inputs)> aggregate_fun = nullptr,
128+
std::function<int(const Node& node)> mirror_fun = nullptr) {
129+
graph.attrs["grad_ys"] = std::make_shared<any>(std::move(ys));
130+
131+
graph.attrs["grad_xs"] = std::make_shared<any>(std::move(xs));
132+
graph.attrs["grad_ys_out_grad"] = std::make_shared<any>(std::move(ys_out_grad));
133+
if (aggregate_fun != nullptr) {
134+
graph.attrs["grad_aggregate_fun"] = std::make_shared<any>(aggregate_fun);
135+
}
136+
if (mirror_fun != nullptr) {
137+
graph.attrs["grad_mirror_fun"] = std::make_shared<any>(mirror_fun);
138+
}
139+
140+
return ApplyPass(std::move(graph), {"Gradient"});
141+
}
142+
112143
} // namespace pass
113144
} // namespace nnvm
114145
#endif // NNVM_PASS_FUNCTIONS_H_

nnvm/python/nnvm/graph.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from ._base import c_array, c_str, nn_uint, py_str, string_types
1111
from ._base import GraphHandle, SymbolHandle
1212
from ._base import check_call
13-
from .symbol import Symbol
13+
from .symbol import Symbol, Group as _Group
1414

1515

1616
class Graph(object):
@@ -56,8 +56,27 @@ def json_attr(self, key):
5656
else:
5757
return None
5858

59+
def _set_symbol_list_attr(self, key, value):
60+
"""Set the attribute of the graph.
61+
62+
Parameters
63+
----------
64+
key : string
65+
The key of the attribute
66+
value : value
67+
The any type that can be dumped to json
68+
type_name : string
69+
The typename registered on c++ side.
70+
"""
71+
if isinstance(value, list):
72+
value = _Group(value)
73+
if not isinstance(value, Symbol):
74+
raise ValueError("value need to be grouped symbol")
75+
check_call(_LIB.NNGraphSetNodeEntryListAttr_(
76+
self.handle, c_str(key), value.handle))
77+
5978
def _set_json_attr(self, key, value, type_name=None):
60-
"""Set the attribute of the symbol.
79+
"""Set the attribute of the graph.
6180
6281
Parameters
6382
----------

0 commit comments

Comments
 (0)