@@ -13,6 +13,43 @@ namespace symbol_constants {
1313const char *kNamespaceSeparator = " _" ;
1414} // namespace symbol_constants
1515
16+ // auxililary version attribute in variable.
17+ struct VariableParam {
18+ uint32_t version{0 };
19+ };
20+
21+ std::shared_ptr<Node> CreateVariableNode (const std::string& name) {
22+ std::shared_ptr<Node> n = Node::Create ();
23+ n->op = nullptr ;
24+ n->attrs .name = name;
25+ n->attrs .parsed = VariableParam ();
26+ return n;
27+ }
28+
29+ // scan over a node's input, update the version to latest
30+ // If the node's op mutates a certain input variable,
31+ // The version of that varaible will increase
32+ // version is used to implicitly order the mutation sequences
33+ inline void UpdateNodeVersion (Node *n) {
34+ static auto & fmutate_inputs = Op::GetAttr<FMutateInput>(" FMutateInput" );
35+ for (NodeEntry& e : n->inputs ) {
36+ if (e.node ->is_variable ()) {
37+ e.version = nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
38+ }
39+ }
40+ if (fmutate_inputs.count (n->op ) != 0 ) {
41+ FMutateInput fmutate = fmutate_inputs[n->op ];
42+ for (uint32_t i = 0 ; i < n->inputs .size (); ++i) {
43+ if (fmutate (n->attrs , i)) {
44+ NodeEntry& e = n->inputs [i];
45+ CHECK (e.node ->is_variable ())
46+ << " Mutation target can only be Variable" ;
47+ // increase the version of the variable.
48+ ++nnvm::get<VariableParam>(e.node ->attrs .parsed ).version ;
49+ }
50+ }
51+ }
52+ }
1653
1754inline std::string DefaultVarName (const std::string &op_name,
1855 const std::string &arg_name) {
@@ -67,13 +104,13 @@ Symbol Symbol::Copy() const {
67104 for (const auto &kv : old_new) {
68105 for (const NodeEntry& e : kv.first ->inputs ) {
69106 Node *ptr = e.node .get ();
70- kv.second ->inputs .emplace_back (NodeEntry{old_new[ptr], e.index });
107+ kv.second ->inputs .emplace_back (NodeEntry{old_new[ptr], e.index , e. version });
71108 }
72109 }
73110 // set the head
74111 Symbol ret;
75112 for (const NodeEntry &e : outputs) {
76- ret.outputs .emplace_back (NodeEntry{old_new[e.node .get ()], e.index });
113+ ret.outputs .emplace_back (NodeEntry{old_new[e.node .get ()], e.index , e. version });
77114 }
78115 return ret;
79116}
@@ -95,8 +132,14 @@ void Symbol::Print(std::ostream &os) const {
95132 os << " Name: " << node->attrs .name << " Op:" << node->op ->name << ' \n '
96133 << " Inputs:\n " ;
97134 for (size_t i = 0 ; i < node->inputs .size (); ++i) {
98- os << " \t arg[" << i << " ]=" << node->inputs [i].node ->attrs .name
99- << ' (' << node->inputs [i].index << " )\n " ;
135+ const NodeEntry& e = node->inputs [i];
136+ os << " \t arg[" << i << " ]=" << e.node ->attrs .name
137+ << ' (' << e.index << " )" ;
138+ if (e.node ->is_variable ()) {
139+ os << " version=" << e.version << ' \n ' ;
140+ } else {
141+ os << ' \n ' ;
142+ }
100143 }
101144 os << " Attrs:\n " ;
102145 for (auto &kv : node->attrs .dict ) {
@@ -163,6 +206,8 @@ std::vector<std::string> Symbol::ListOutputs() const {
163206void Symbol::Compose (const std::vector<Symbol>& args,
164207 const std::unordered_map<std::string, Symbol>& kwargs,
165208 const std::string& name) {
209+ static auto & flist_inputs = Op::GetAttr<FListInputNames>(" FListInputNames" );
210+
166211 CHECK_EQ (outputs.size (), 1 )
167212 << " Only composition of value function is supported currently" ;
168213 CHECK (!outputs[0 ].node ->is_variable ()) << " Variable cannot be composed" ;
@@ -193,7 +238,6 @@ void Symbol::Compose(const std::vector<Symbol>& args,
193238 }
194239 // switch to keyword argument matching
195240 if (args.size () != n_req) {
196- static auto & flist_inputs = Op::GetAttr<FListInputNames>(" FListInputNames" );
197241 FListInputNames fn = flist_inputs.get (n->op , nullptr );
198242 auto arg_names = (fn == nullptr ) ? std::vector<std::string>{" data" } : fn (n->attrs );
199243 if (arg_names.size () != n_req) {
@@ -206,8 +250,8 @@ void Symbol::Compose(const std::vector<Symbol>& args,
206250 n->inputs [i] = it->second .outputs [0 ];
207251 ++nmatched;
208252 } else {
209- n->inputs [i] = NodeEntry{Node::Create (), 0 };
210- n-> inputs [i]. node -> attrs . name = DefaultVarName (name, arg_names[i]);
253+ n->inputs [i] = NodeEntry{
254+ CreateVariableNode ( DefaultVarName (name, arg_names[i])), 0 , 0 } ;
211255 }
212256 }
213257
@@ -226,6 +270,7 @@ void Symbol::Compose(const std::vector<Symbol>& args,
226270 n->inputs .push_back (s.outputs [0 ]);
227271 }
228272 }
273+ UpdateNodeVersion (n);
229274 } else {
230275 // general composition
231276 CHECK_EQ (args.size (), 0 )
@@ -253,25 +298,32 @@ void Symbol::Compose(const std::vector<Symbol>& args,
253298 DFSVisit (this ->outputs , find_replace_map);
254299
255300 if (nmatched == kwargs.size () && arg_counter < args.size ()) {
301+ std::vector<Node*> update_nodes;
256302 std::vector<std::pair<NodeEntry*, const NodeEntry*> > replace_plan;
257- auto find_replace_plan = [&replace_map, &replace_plan]
303+ auto find_replace_plan = [&replace_map, &replace_plan, &update_nodes ]
258304 (const std::shared_ptr<Node> &node) {
259305 // visit all the childs, find possible replacement
306+ bool repl = false ;
260307 for (size_t i = 0 ; i < node->inputs .size (); ++i) {
261308 NodeEntry *e = &(node->inputs [i]);
262309 if (e->node ->is_variable ()) {
263310 auto iter = replace_map.find (e->node .get ());
264311 if (iter != replace_map.end ()) {
265312 replace_plan.push_back (std::make_pair (e, iter->second ));
313+ repl = true ;
266314 }
267315 }
268316 }
317+ if (repl) update_nodes.push_back (node.get ());
269318 };
270319 DFSVisit (this ->outputs , find_replace_plan);
271320
272321 for (const auto & kv : replace_plan) {
273322 *(kv.first ) = *(kv.second );
274323 }
324+ for (Node* n : update_nodes) {
325+ UpdateNodeVersion (n);
326+ }
275327 } else {
276328 std::vector<std::string> keys = GetKeys (kwargs);
277329 std::vector<std::string> arg_names = ListArguments ();
@@ -303,9 +355,15 @@ Symbol Symbol::GetInternals() const {
303355 Symbol ret;
304356 DFSVisit (this ->outputs , [&ret](const std::shared_ptr<Node>& node) {
305357 Node* n = node.get ();
306- uint32_t nout = n->num_outputs ();
307- for (uint32_t i = 0 ; i < nout; ++i) {
308- ret.outputs .emplace_back (NodeEntry{node, i});
358+ if (n->is_variable ()) {
359+ // grab version from variable.
360+ VariableParam& param = nnvm::get<VariableParam>(n->attrs .parsed );
361+ ret.outputs .emplace_back (NodeEntry{node, 0 , param.version });
362+ } else {
363+ uint32_t nout = n->num_outputs ();
364+ for (uint32_t i = 0 ; i < nout; ++i) {
365+ ret.outputs .emplace_back (NodeEntry{node, i, 0 });
366+ }
309367 }
310368 });
311369 return ret;
@@ -325,7 +383,7 @@ void Symbol::SetAttrs(const std::vector<std::pair<std::string, std::string> >& a
325383 }
326384 }
327385 if (node->op != nullptr && node->op ->attr_parser != nullptr ) {
328- (* node->op ->attr_parser ) (&(node->attrs ));
386+ node->op ->attr_parser (&(node->attrs ));
329387 }
330388}
331389
@@ -366,9 +424,9 @@ Symbol Symbol::CreateFunctor(const Op* op,
366424 n->op = op;
367425 n->attrs .dict = std::move (attrs);
368426 if (n->op ->attr_parser != nullptr ) {
369- (* n->op ->attr_parser ) (&(n->attrs ));
427+ n->op ->attr_parser (&(n->attrs ));
370428 }
371- s.outputs .emplace_back (NodeEntry{std::move (n), 0 });
429+ s.outputs .emplace_back (NodeEntry{std::move (n), 0 , 0 });
372430 return s;
373431}
374432
@@ -382,10 +440,7 @@ Symbol Symbol::CreateGroup(const std::vector<Symbol> &symbols) {
382440
383441Symbol Symbol::CreateVariable (const std::string& name) {
384442 Symbol s;
385- std::shared_ptr<Node> n = Node::Create ();
386- n->op = nullptr ;
387- n->attrs .name = name;
388- s.outputs .emplace_back (NodeEntry{std::move (n), 0 });
443+ s.outputs .emplace_back (NodeEntry{CreateVariableNode (name), 0 , 0 });
389444 return s;
390445}
391446
0 commit comments