@@ -22,6 +22,7 @@ class Node;
2222struct NodeAttrs ;
2323template <typename ValueType>
2424class OpMap ;
25+ class OpGroup ;
2526class OpRegistryEntry ;
2627using dmlc::ParamFieldInfo;
2728
@@ -44,7 +45,13 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
4445 * NNVM_REGISTER_OP(add)
4546 * .describe("add two inputs together")
4647 * .set_num_inputs(2)
47- * .set_attr<OpKernel>("gpu_kernel", AddKernel);
48+ * .set_attr<OpKernel>("OpKernel<gpu>", AddKernel)
49+ * .include("ElementwiseOpAttr");
50+ *
51+ * // can register attribute by group
52+ * // all the ops that include the group get the attribute.
53+ * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
54+ * .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
4855 *
4956 * NNVM_REGISTER_OP(sub)
5057 * .describe("substract one tensor from another")
@@ -53,7 +60,8 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
5360 * // Can call regster multiple times in different files
5461 * // to register different part of information
5562 * NNVM_REGISTER_OP(sub)
56- * .set_attr<OpKernel>("gpu_kernel", SubKernel);
63+ * .set_attr<OpKernel>("OpKernel<gpu>", SubKernel);
64+ * .include("ElementwiseOpAttr");
5765 *
5866 * // get operators from registry.
5967 * void my_function() {
@@ -65,7 +73,7 @@ static const uint32_t kVarg = std::numeric_limits<uint32_t>::max();
6573 *
6674 * // get additional registered information,
6775 * // Assume user registered a OpKernel type attribute as gpu_kernel on each operator.
68- * const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("gpu_kernel ");
76+ * const OpMap<OpKernel>& kernel = Op::GetAttr<OpKernel>("OpKernel<gpu> ");
6977 * // we can get the kernel functions by using operator as key.
7078 * auto add_kernel = kernel[add];
7179 * auto sub_kernel = kernel[sub];
@@ -199,6 +207,23 @@ class Op {
199207 * \return reference to self.
200208 */
201209 inline Op& set_attr_parser (std::function<void (NodeAttrs* attrs)> fn); // NOLINT(*)
210+ /* !
211+ * \brief Register additional attributes to operator.
212+ * \param attr_name The name of the attribute.
213+ * \param value The value to be set.
214+ * \param plevel The priority level of this set,
215+ * an higher priority level attribute
216+ * will replace lower priority level attribute.
217+ * Must be bigger than 0.
218+ *
219+ * Cannot set with same plevel twice in the code.
220+ *
221+ * \tparam ValueType The type of the value to be set.
222+ */
223+ template <typename ValueType>
224+ inline Op& set_attr (const std::string& attr_name, // NOLINT(*)
225+ const ValueType& value,
226+ int plevel = 10 );
202227 /* !
203228 * \brief Add another alias to this operator.
204229 * The same Op can be queried with Op::Get(alias)
@@ -207,14 +232,13 @@ class Op {
207232 */
208233 Op& add_alias (const std::string& alias); // NOLINT(*)
209234 /* !
210- * \brief Register additional attributes to operator.
211- * \param attr_name The name of the attribute.
212- * \param value The value to be set.
213- * \tparam ValueType The type of the value to be set.
235+ * \brief Include all the attributes from an registered op group.
236+ * \param group_name The name of the group.
237+ * \return reference to self.
238+ *
239+ * \sa NNVM_REGISTER_OP_GROUP
214240 */
215- template <typename ValueType>
216- inline Op& set_attr (const std::string& attr_name, // NOLINT(*)
217- const ValueType& value);
241+ Op& include (const std::string& group_name);
218242 /* !
219243 * \brief Get an Op for a given operator name.
220244 * Will raise an error if the op has not been registered.
@@ -235,6 +259,7 @@ class Op {
235259 private:
236260 template <typename ValueType>
237261 friend class OpMap ;
262+ friend class OpGroup ;
238263 friend class dmlc ::Registry<Op>;
239264 // Program internal unique index of operator.
240265 // Used to help index the program.
@@ -246,6 +271,13 @@ class Op {
246271 // update the attribute OpMap
247272 static void UpdateAttrMap (const std::string& key,
248273 std::function<void (any*)> updater);
274+ // add a trigger based on tag matching on certain tag attribute
275+ // This will apply trigger on all the op such that
276+ // include the corresponding group.
277+ // The trigger will also be applied to all future registrations
278+ // that calls include
279+ static void AddGroupTrigger (const std::string& group_name,
280+ std::function<void (Op*)> trigger);
249281};
250282
251283/* !
@@ -285,14 +317,44 @@ class OpMap {
285317 OpMap () = default ;
286318};
287319
320+ /* !
321+ * \brief auxiliary data structure used to
322+ * set attributes to a group of operators
323+ */
324+ class OpGroup {
325+ public:
326+ /* ! \brief the tag key to be matched */
327+ std::string group_name;
328+ /* !
329+ * \brief Register additional attributes to operator group.
330+ * \param attr_name The name of the attribute.
331+ * \param value The value to be set.
332+ * \param plevel The priority level of this set,
333+ * an higher priority level attribute
334+ * will replace lower priority level attribute.
335+ * Must be bigger than 0.
336+ *
337+ * Cannot set with same plevel twice in the code.
338+ *
339+ * \tparam ValueType The type of the value to be set.
340+ */
341+ template <typename ValueType>
342+ inline OpGroup& set_attr (const std::string& attr_name, // NOLINT(*)
343+ const ValueType& value,
344+ int plevel = 1 );
345+ };
346+
288347// internal macros to make
289- #define NNVM_REGISTER_VAR_DEF (OpName ) \
348+ #define NNVM_REGISTER_VAR_DEF (OpName ) \
290349 static DMLC_ATTRIBUTE_UNUSED ::nnvm::Op & __make_ ## NnvmOp ## _ ## OpName
291350
351+ #define NNVM_REGISTER_GVAR_DEF (TagName ) \
352+ static DMLC_ATTRIBUTE_UNUSED ::nnvm::OpGroup __make_ ## NnvmOpGroup ## _ ## TagName
353+
292354/* !
293355 * \def NNVM_REGISTER_OP
294- * \brief Register
295- * This macro must be used under namespace dmlc, and only used once in cc file.
356+ * \brief Register a new operator, or set attribute of the corresponding op.
357+ *
296358 * \param OpName The name of registry
297359 *
298360 * \code
@@ -308,6 +370,31 @@ class OpMap {
308370 DMLC_STR_CONCAT (NNVM_REGISTER_VAR_DEF(OpName), __COUNTER__) = \
309371 ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(#OpName)
310372
373+ /* !
374+ * \def NNVM_REGISTER_OP_GROUP
375+ * \brief Register attribute to a group of operators.
376+ * These attributes will be registered to Op that include the group.
377+ *
378+ * \param GroupName The name of the group.
379+ *
380+ * \code
381+ *
382+ * NNVM_REGISTER_OP(add)
383+ * .include("ElementwiseOpAttr");
384+ *
385+ * // register same attributes to all the ops that include the group
386+ * NNVM_REGISTER_OP_GROUP(ElementwiseOpAttr)
387+ * .set_attr<FInferShape>("FInferShape", ElementwiseInferShape);
388+ *
389+ * NNVM_REGISTER_OP(mul)
390+ * .include("ElementwiseOpAttr");
391+ *
392+ * \endcode
393+ */
394+ #define NNVM_REGISTER_OP_GROUP (GroupName ) \
395+ DMLC_STR_CONCAT (NNVM_REGISTER_GVAR_DEF(GroupName), __COUNTER__) = \
396+ ::nnvm::OpGroup {#GroupName}
397+
311398// implementations of template functions after this.
312399// member function of Op
313400template <typename ValueType>
@@ -330,9 +417,14 @@ inline const OpMap<ValueType>& Op::GetAttr(const std::string& key) {
330417
331418template <typename ValueType>
332419inline Op& Op::set_attr ( // NOLINT(*)
333- const std::string& attr_name, const ValueType& value) {
420+ const std::string& attr_name,
421+ const ValueType& value,
422+ int plevel) {
423+ CHECK_GT (plevel, 0 )
424+ << " plevel in set_attr must be greater than 0" ;
334425 // update the attribute map of the key by creating new empty if needed.
335- UpdateAttrMap (attr_name, [this , attr_name, value](any* pmap) {
426+ UpdateAttrMap (attr_name,
427+ [this , attr_name, value, plevel](any* pmap) {
336428 // the callback is in lockscope so is threadsafe.
337429 if (pmap->empty ()) {
338430 OpMap<ValueType> pm;
@@ -353,15 +445,18 @@ inline Op& Op::set_attr( // NOLINT(*)
353445 std::make_pair (ValueType (), 0 ));
354446 }
355447 std::pair<ValueType, int >& p = vec[index_];
356- CHECK (p.second == 0 )
448+ CHECK (p.second != plevel )
357449 << " Attribute " << attr_name
358450 << " of operator " << this ->name
359- << " is already registered." ;
360- vec[index_] = std::make_pair (value, 1 );
451+ << " is already registered with same plevel=" << plevel;
452+ if (p.second < plevel) {
453+ vec[index_] = std::make_pair (value, plevel);
454+ }
361455 });
362456 return *this ;
363457}
364458
459+
365460inline Op& Op::describe (const std::string& descr) { // NOLINT(*)
366461 this ->description = descr;
367462 return *this ;
@@ -409,7 +504,7 @@ template<typename ValueType>
409504inline int OpMap<ValueType>::count(const Op* op) const {
410505 if (op == nullptr ) return 0 ;
411506 const uint32_t idx = op->index_ ;
412- return idx < data_.size () ? data_[idx].second : 0 ;
507+ return idx < data_.size () ? ( data_[idx].second != 0 ) : 0 ;
413508}
414509
415510template <typename ValueType>
@@ -433,6 +528,17 @@ inline const ValueType& OpMap<ValueType>::get(const Op* op, const ValueType& def
433528 }
434529}
435530
531+ template <typename ValueType>
532+ inline OpGroup& OpGroup::set_attr (const std::string& attr_name,
533+ const ValueType& value,
534+ int plevel) {
535+ auto trigger = [attr_name, value, plevel](Op* op) {
536+ op->set_attr <ValueType>(attr_name, value, plevel);
537+ };
538+ Op::AddGroupTrigger (group_name, trigger);
539+ return *this ;
540+ }
541+
436542} // namespace nnvm
437543
438544#endif // NNVM_OP_H_
0 commit comments