@@ -30,18 +30,34 @@ class Graph {
3030 std::vector<NodeEntry> outputs;
3131 /* !
3232 * \brief attributes of a graph
33- * Each attribute is immutable,
34- * and can be shared across multiple Instance of graph
33+ * Note that attribute is shared pointer and can be shared across graphs.
34+ *
35+ * It is highly recommended to keep each attribute immutable.
36+ * It is also safe to implement an copy-on-write semnatics.
37+ *
38+ * Copy when shared_ptr.unique is not true, while reuse original space
39+ * when shared_ptr.unique is true.
3540 */
36- std::unordered_map<std::string, std::shared_ptr<const any> > attrs;
41+ std::unordered_map<std::string, std::shared_ptr<any> > attrs;
3742 /* !
38- * \brief Get the attribute from attrs.
43+ * \brief Get the immutable attribute from attrs.
3944 * \param attr_name the name of the attribute
4045 * \return the reference to corresponding attribute
4146 * \tparam T the type of the attribute.
4247 */
4348 template <typename T>
4449 inline const T& GetAttr (const std::string& attr_name);
50+ /* !
51+ * \brief Get a move copy of the attribute, implement copy on write semantics.
52+ * The content is moved if the reference counter of shared_ptr is 1.
53+ * The attribute is erased from attrs after the call.
54+ *
55+ * \param attr_name the name of the attribute
56+ * \return a new copy of the corresponding attribute.
57+ * \tparam T the type of the attribute.
58+ */
59+ template <typename T>
60+ inline T MoveCopyAttr (const std::string& attr_name);
4561 /* !
4662 * \brief get a indexed graph of current graph, if not exist, create it on demand
4763 * \return The indexed graph.
@@ -200,6 +216,20 @@ inline const T& Graph::GetAttr(const std::string& attr_name) {
200216 return nnvm::get<T>(*it->second );
201217}
202218
219+ template <typename T>
220+ inline T Graph::MoveCopyAttr (const std::string& attr_name) {
221+ auto it = attrs.find (attr_name);
222+ CHECK (it != attrs.end ())
223+ << " Cannot find attribute " << attr_name << " in the graph" ;
224+ std::shared_ptr<any> sptr = it->second ;
225+ attrs.erase (it);
226+ if (sptr.unique ()) {
227+ return std::move (nnvm::get<T>(*sptr));
228+ } else {
229+ return nnvm::get<T>(*sptr);
230+ }
231+ }
232+
203233template <typename GNode, typename HashType,
204234 typename FVisit, typename HashFunc,
205235 typename InDegree, typename GetInput>
0 commit comments