Skip to content

Commit 869a953

Browse files
committed
[OP] Enable register via match tag (apache#57)
* [OP] Enable register via match tag * more docs on usage
1 parent fa5c588 commit 869a953

File tree

3 files changed

+187
-37
lines changed

3 files changed

+187
-37
lines changed

nnvm/example/src/operator.cc

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ NNVM_REGISTER_OP(reshape)
8484
NNVM_REGISTER_OP(cast)
8585
.describe("cast source type to target")
8686
.set_num_inputs(1)
87+
.include("ElementwiseOpAttr")
8788
.set_attr_parser(
8889
[](NodeAttrs* attrs) {
8990
// parse attr parser to get target attribute
@@ -92,7 +93,6 @@ NNVM_REGISTER_OP(cast)
9293
CHECK(is >> dtype);
9394
attrs->parsed = std::move(dtype);
9495
})
95-
.set_attr<FInferShape>("FInferShape", SameShape)
9696
.set_attr<FInferType>(
9797
"FInferType", [](const NodeAttrs& attrs,
9898
std::vector<int> *itype,
@@ -101,23 +101,10 @@ NNVM_REGISTER_OP(cast)
101101
return true;
102102
});
103103

104-
NNVM_REGISTER_OP(exp)
105-
.describe("take exponential")
106-
.set_num_inputs(1)
107-
.set_attr<FInferShape>("FInferShape", SameShape)
108-
.set_attr<FGradient>(
109-
"FGradient", [](const NodePtr& n,
110-
const std::vector<NodeEntry>& ograds) {
111-
return std::vector<NodeEntry>{
112-
MakeNode("mul", n->attrs.name + "_grad",
113-
{ograds[0], NodeEntry{n, 0, 0}})
114-
};
115-
});
116-
117104
NNVM_REGISTER_OP(identity)
118105
.describe("identity function")
119106
.set_num_inputs(1)
120-
.set_attr<FInferShape>("FInferShape", SameShape)
107+
.include("ElementwiseOpAttr")
121108
.set_attr<FGradient>(
122109
"FGradient", [](const NodePtr& n,
123110
const std::vector<NodeEntry>& ograds) {
@@ -128,7 +115,7 @@ NNVM_REGISTER_OP(add)
128115
.describe("add two data together")
129116
.set_num_inputs(2)
130117
.add_alias("__add_symbol__")
131-
.set_attr<FInferShape>("FInferShape", SameShape)
118+
.include("ElementwiseOpAttr")
132119
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
133120
.set_attr<FGradient>(
134121
"FGradient", [](const NodePtr& n,
@@ -139,6 +126,7 @@ NNVM_REGISTER_OP(add)
139126
NNVM_REGISTER_OP(mul)
140127
.describe("multiply two data together")
141128
.set_num_inputs(2)
129+
.include("ElementwiseOpAttr")
142130
.set_attr<FInferShape>("FInferShape", SameShape)
143131
.set_attr<FInplaceOption>("FInplaceOption", InplaceIn0Out0)
144132
.set_attr<FGradient>(
@@ -187,4 +175,22 @@ NNVM_REGISTER_OP(assign)
187175
return std::vector<uint32_t>{0};
188176
});
189177

178+
NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
179+
.set_attr<FInferShape>("FInferShape", SameShape);
180+
181+
182+
NNVM_REGISTER_OP(exp)
183+
.describe("take exponential")
184+
.set_num_inputs(1)
185+
.include("ElementwiseOpAttr")
186+
.set_attr<FGradient>(
187+
"FGradient", [](const NodePtr& n,
188+
const std::vector<NodeEntry>& ograds) {
189+
return std::vector<NodeEntry>{
190+
MakeNode("mul", n->attrs.name + "_grad",
191+
{ograds[0], NodeEntry{n, 0, 0}})
192+
};
193+
});
194+
195+
190196
} // namespace myproject

nnvm/include/nnvm/op.h

Lines changed: 125 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Node;
2222
struct NodeAttrs;
2323
template<typename ValueType>
2424
class OpMap;
25+
class OpGroup;
2526
class OpRegistryEntry;
2627
using dmlc::ParamFieldInfo;
2728

@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
4445
* NNVM_REGISTER_OP(add)
4546
* .describe("add two inputs together")
4647
* .set_num_inputs(2)
47-
* .set_attr<OpKernel>("gpu_kernel", AddKernel);
48+
* .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
49+
* .include("ElementwiseOpAttr");
50+
*
51+
* // can register attribute by group
52+
* // all the ops that include the group get the attribute.
53+
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
54+
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
4855
*
4956
* NNVM_REGISTER_OP(sub)
5057
* .describe("substract one tensor from another")
@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
5360
* // Can call regster multiple times in different files
5461
* // to register different part of information
5562
* NNVM_REGISTER_OP(sub)
56-
* .set_attr<OpKernel>("gpu_kernel", SubKernel);
63+
* .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
64+
* .include("ElementwiseOpAttr");
5765
*
5866
* // get operators from registry.
5967
* void my_function() {
@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
6573
*
6674
* // get additional registered information,
6775
* // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
68-
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel");
76+
* const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu>");
6977
* // we can get the kernel functions by using operator as key.
7078
* auto add_kernel = kernel[add];
7179
* auto sub_kernel = kernel[sub];
@@ -199,6 +207,23 @@ class Op {
199207
* \return reference to self.
200208
*/
201209
inline Op& set_attr_parser(std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
210+
/*!
211+
* \brief Register additional attributes to operator.
212+
* \param attr_name The name of the attribute.
213+
* \param value The value to be set.
214+
* \param plevel The priority level of this set,
215+
* an higher priority level attribute
216+
* will replace lower priority level attribute.
217+
* Must be bigger than 0.
218+
*
219+
* Cannot set with same plevel twice in the code.
220+
*
221+
* \tparam ValueType The type of the value to be set.
222+
*/
223+
template<typename ValueType>
224+
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
225+
const ValueType& value,
226+
int plevel = 10);
202227
/*!
203228
* \brief Add another alias to this operator.
204229
* The same Op can be queried with Op::Get(alias)
@@ -207,14 +232,13 @@ class Op {
207232
*/
208233
Op& add_alias(const std::string& alias); // NOLINT(*)
209234
/*!
210-
* \brief Register additional attributes to operator.
211-
* \param attr_name The name of the attribute.
212-
* \param value The value to be set.
213-
* \tparam ValueType The type of the value to be set.
235+
* \brief Include all the attributes from an registered op group.
236+
* \param group_name The name of the group.
237+
* \return reference to self.
238+
*
239+
* \sa NNVM_REGISTER_OP_GROUP
214240
*/
215-
template<typename ValueType>
216-
inline Op& set_attr(const std::string& attr_name, // NOLINT(*)
217-
const ValueType& value);
241+
Op& include(const std::string& group_name);
218242
/*!
219243
* \brief Get an Op for a given operator name.
220244
* Will raise an error if the op has not been registered.
@@ -235,6 +259,7 @@ class Op {
235259
private:
236260
template<typename ValueType>
237261
friend class OpMap;
262+
friend class OpGroup;
238263
friend class dmlc::Registry<Op>;
239264
// Program internal unique index of operator.
240265
// Used to help index the program.
@@ -246,6 +271,13 @@ class Op {
246271
// update the attribute OpMap
247272
static void UpdateAttrMap(const std::string& key,
248273
std::function<void(any*)> updater);
274+
// add a trigger based on tag matching on certain tag attribute
275+
// This will apply trigger on all the op such that
276+
// include the corresponding group.
277+
// The trigger will also be applied to all future registrations
278+
// that calls include
279+
static void AddGroupTrigger(const std::string& group_name,
280+
std::function<void(Op*)> trigger);
249281
};
250282

251283
/*!
@@ -285,14 +317,44 @@ class OpMap {
285317
OpMap() = default;
286318
};
287319

320+
/*!
321+
* \brief auxiliary data structure used to
322+
* set attributes to a group of operators
323+
*/
324+
class OpGroup {
325+
public:
326+
/*! \brief the tag key to be matched */
327+
std::string group_name;
328+
/*!
329+
* \brief Register additional attributes to operator group.
330+
* \param attr_name The name of the attribute.
331+
* \param value The value to be set.
332+
* \param plevel The priority level of this set,
333+
* an higher priority level attribute
334+
* will replace lower priority level attribute.
335+
* Must be bigger than 0.
336+
*
337+
* Cannot set with same plevel twice in the code.
338+
*
339+
* \tparam ValueType The type of the value to be set.
340+
*/
341+
template<typename ValueType>
342+
inline OpGroup& set_attr(const std::string& attr_name, // NOLINT(*)
343+
const ValueType& value,
344+
int plevel = 1);
345+
};
346+
288347
// internal macros to make
289-
#define NNVM_REGISTER_VAR_DEF(OpName) \
348+
#define NNVM_REGISTER_VAR_DEF(OpName) \
290349
static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
291350

351+
#define NNVM_REGISTER_GVAR_DEF(TagName) \
352+
static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
353+
292354
/*!
293355
* \def NNVM_REGISTER_OP
294-
* \brief Register
295-
* This macro must be used under namespace dmlc, and only used once in cc file.
356+
* \brief Register a new operator, or set attribute of the corresponding op.
357+
*
296358
* \param OpName The name of registry
297359
*
298360
* \code
@@ -308,6 +370,31 @@ class OpMap {
308370
DMLC_STR_CONCAT(NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
309371
::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
310372

373+
/*!
374+
* \def NNVM_REGISTER_OP_GROUP
375+
* \brief Register attribute to a group of operators.
376+
* These attributes will be registered to Op that include the group.
377+
*
378+
* \param GroupName The name of the group.
379+
*
380+
* \code
381+
*
382+
* NNVM_REGISTER_OP(add)
383+
* .include("ElementwiseOpAttr");
384+
*
385+
* // register same attributes to all the ops that include the group
386+
* NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
387+
* .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
388+
*
389+
* NNVM_REGISTER_OP(mul)
390+
* .include("ElementwiseOpAttr");
391+
*
392+
* \endcode
393+
*/
394+
#define NNVM_REGISTER_OP_GROUP(GroupName) \
395+
DMLC_STR_CONCAT(NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
396+
::nnvm::OpGroup {#GroupName}
397+
311398
// implementations of template functions after this.
312399
// member function of Op
313400
template<typename ValueType>
@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
330417

331418
template<typename ValueType>
332419
inline Op& Op::set_attr( // NOLINT(*)
333-
const std::string& attr_name, const ValueType& value) {
420+
const std::string& attr_name,
421+
const ValueType& value,
422+
int plevel) {
423+
CHECK_GT(plevel, 0)
424+
<< "plevel in set_attr must be greater than 0";
334425
// update the attribute map of the key by creating new empty if needed.
335-
UpdateAttrMap(attr_name, [this, attr_name, value](any* pmap) {
426+
UpdateAttrMap(attr_name,
427+
[this, attr_name, value, plevel](any* pmap) {
336428
// the callback is in lockscope so is threadsafe.
337429
if (pmap->empty()) {
338430
OpMap<ValueType> pm;
@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
353445
std::make_pair(ValueType(), 0));
354446
}
355447
std::pair<ValueType, int>& p = vec[index_];
356-
CHECK(p.second == 0)
448+
CHECK(p.second != plevel)
357449
<< "Attribute " << attr_name
358450
<< " of operator " << this->name
359-
<< " is already registered.";
360-
vec[index_] = std::make_pair(value, 1);
451+
<< " is already registered with same plevel=" << plevel;
452+
if (p.second < plevel) {
453+
vec[index_] = std::make_pair(value, plevel);
454+
}
361455
});
362456
return *this;
363457
}
364458

459+
365460
inline Op& Op::describe(const std::string& descr) { // NOLINT(*)
366461
this->description = descr;
367462
return *this;
@@ -409,7 +504,7 @@ template<typename ValueType>
409504
inline int OpMap<ValueType>::count(const Op* op) const {
410505
if (op == nullptr) return 0;
411506
const uint32_t idx = op->index_;
412-
return idx < data_.size() ? data_[idx].second : 0;
507+
return idx < data_.size() ? (data_[idx].second != 0) : 0;
413508
}
414509

415510
template<typename ValueType>
@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
433528
}
434529
}
435530

531+
template<typename ValueType>
532+
inline OpGroup& OpGroup::set_attr(const std::string& attr_name,
533+
const ValueType& value,
534+
int plevel) {
535+
auto trigger = [attr_name, value, plevel](Op* op) {
536+
op->set_attr<ValueType>(attr_name, value, plevel);
537+
};
538+
Op::AddGroupTrigger(group_name, trigger);
539+
return *this;
540+
}
541+
436542
} // namespace nnvm
437543

438544
#endif // NNVM_OP_H_

0 commit comments

Comments
 (0)