diff --git a/include/nnvm/symbolic.h b/include/nnvm/symbolic.h index 3d7f94ce7b0e..4d26947ca646 100644 --- a/include/nnvm/symbolic.h +++ b/include/nnvm/symbolic.h @@ -12,6 +12,7 @@ #include #include +#include #include #include "./base.h" @@ -168,6 +169,15 @@ class Symbol { * \return The created attribute. */ std::unordered_map ListAttrs(ListAttrOption option) const; + /*! + * \brief Get attribute dictionary from the symbol and all children. + * + * For grouped symbol, an error will be raised. + * + * \return The created attribute in format . + */ + std::vector > + ListAttrsRecursive() const; /*! * \brief Create symbolic functor(AtomicSymbol) by given operator and attributes. * \param op The operator. diff --git a/src/core/symbolic.cc b/src/core/symbolic.cc index a35218839296..2b89bf2a446b 100644 --- a/src/core/symbolic.cc +++ b/src/core/symbolic.cc @@ -472,6 +472,17 @@ std::unordered_map Symbol::ListAttrs(ListAttrOption op } } +std::vector > + Symbol::ListAttrsRecursive() const { + std::vector > ret; + DFSVisit(this->outputs, [&ret](const NodePtr& n) { + for (const auto& it : n->attrs.dict) { + ret.emplace_back(std::make_tuple(n->attrs.name, it.first, it.second)); + } + }); + return ret; +} + Symbol Symbol::CreateFunctor(const Op* op, std::unordered_map attrs) { static auto& fnum_vis_output = Op::GetAttr("FNumVisibleOutputs");