diff --git a/include/nnvm/op_attr_types.h b/include/nnvm/op_attr_types.h index 7d3af45e8736..8e15c394be82 100644 --- a/include/nnvm/op_attr_types.h +++ b/include/nnvm/op_attr_types.h @@ -142,6 +142,18 @@ using FGradient = std::function( const NodePtr& nodeptr, const std::vector& out_grads)>; +/*! + * \brief Set the attributes of input variable. + * Usually used for setting initialization or weight decay. + * \param attrs The attributes of this node. + * \param var the input variable + * \param index index of var in all inputs + */ +using FSetInputVarAttrOnCompose = std::function; + } // namespace nnvm #endif // NNVM_OP_ATTR_TYPES_H_ diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index b88c00d8a89f..a101deed7919 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -259,6 +259,7 @@ void Symbol::Compose(const array_view& args, const std::unordered_map& kwargs, const std::string& name) { static auto& flist_inputs = Op::GetAttr("FListInputNames"); + static auto& fset_attrs = Op::GetAttr("FSetInputVarAttrOnCompose"); CHECK(!outputs[0].node->is_variable()) << "Variable cannot be composed"; // parameter check. @@ -323,6 +324,15 @@ void Symbol::Compose(const array_view& args, } } UpdateNodeVersion(n); + + FSetInputVarAttrOnCompose fn = fset_attrs.get(n->op(), nullptr); + if (fn != nullptr) { + for (size_t i = 0; i < n->inputs.size(); ++i) { + if (n->inputs[i].node->is_variable()) { + fn(n->attrs, n->inputs[i].node, i); + } + } + } } else { // general composition CHECK_EQ(args.size(), 0U)