@@ -15,6 +15,10 @@ using nnvm::FMutateInputs;
1515using nnvm::FInferShape;
1616using nnvm::FInferType;
1717using nnvm::FInplaceOption;
18+ using nnvm::Node;
19+ using nnvm::NodePtr;
20+ using nnvm::NodeEntry;
21+ using nnvm::FGradient;
1822using nnvm::NodeAttrs;
1923using nnvm::TShape;
2024using nnvm::array_view;
@@ -37,6 +41,17 @@ inline std::vector<std::pair<int, int> > InplaceIn0Out0(const NodeAttrs& attrs)
3741 return {{0 , 0 }};
3842}
3943
44+ // quick helper to make node
45+ inline NodeEntry MakeNode (const char * op_name,
46+ std::string node_name,
47+ std::vector<NodeEntry> inputs) {
48+ NodePtr p = Node::Create ();
49+ p->op = nnvm::Op::Get (op_name);
50+ p->attrs .name = std::move (node_name);
51+ p->inputs = std::move (inputs);
52+ return NodeEntry{p, 0 , 0 };
53+ }
54+
4055// simple demonstration of reshape.
4156NNVM_REGISTER_OP (reshape)
4257.describe(" reshape source to target shape" )
@@ -84,21 +99,67 @@ NNVM_REGISTER_OP(cast)
8499 return true ;
85100 });
86101
102+ NNVM_REGISTER_OP (exp)
103+ .describe(" take exponential" )
104+ .set_num_inputs(1 )
105+ .attr<FInferShape>(" FInferShape" , SameShape)
106+ .attr<FGradient>(
107+ " FGradient" , [](const NodePtr& n,
108+ const std::vector<NodeEntry>& ograds) {
109+ return std::vector<NodeEntry>{
110+ MakeNode (" mul" , n->attrs .name + " _grad" ,
111+ {ograds[0 ], NodeEntry{n, 0 , 0 }})
112+ };
113+ });
114+
115+ NNVM_REGISTER_OP (identity)
116+ .describe(" identity function" )
117+ .set_num_inputs(1 )
118+ .attr<FInferShape>(" FInferShape" , SameShape)
119+ .attr<FGradient>(
120+ " FGradient" , [](const NodePtr& n,
121+ const std::vector<NodeEntry>& ograds) {
122+ return std::vector<NodeEntry>{ograds[0 ]};
123+ });
87124
88125NNVM_REGISTER_OP (add)
89126.describe(" add two data together" )
90127.set_num_inputs(2 )
91128.attr<FInferShape>(" FInferShape" , SameShape)
92- .attr<FInplaceOption>(" FInplaceOption" , InplaceIn0Out0);
129+ .attr<FInplaceOption>(" FInplaceOption" , InplaceIn0Out0)
130+ .attr<FGradient>(
131+ " FGradient" , [](const NodePtr& n,
132+ const std::vector<NodeEntry>& ograds){
133+ return std::vector<NodeEntry>{ograds[0 ], ograds[0 ]};
134+ });
93135
94- NNVM_REGISTER_OP (__add_symbol__)
95- .describe(" Alias of add" )
96- .set_num_inputs(2 );
136+ NNVM_REGISTER_OP (mul)
137+ .describe(" multiply two data together" )
138+ .set_num_inputs(2 )
139+ .attr<FInferShape>(" FInferShape" , SameShape)
140+ .attr<FInplaceOption>(" FInplaceOption" , InplaceIn0Out0)
141+ .attr<FGradient>(
142+ " FGradient" , [](const NodePtr& n,
143+ const std::vector<NodeEntry>& ograds){
144+ return std::vector<NodeEntry>{
145+ MakeNode (" mul" , n->attrs .name + " _grad_0" ,
146+ {ograds[0 ], n->inputs [1 ]}),
147+ MakeNode (" mul" , n->attrs .name + " _grad_1" ,
148+ {ograds[0 ], n->inputs [0 ]})
149+ };
150+ });
97151
98- NNVM_REGISTER_OP (exp)
99- .describe(" take exponential" )
100- .set_num_inputs(1 )
101- .attr<FInferShape>(" FInferShape" , SameShape);
152+ NNVM_REGISTER_OP (__ewise_sum__)
153+ .describe(" elementwise sum" )
154+ .set_num_inputs(nnvm::kVarg );
155+
156+ NNVM_REGISTER_OP (__zero__)
157+ .describe(" set output to zero" )
158+ .set_num_inputs(0 );
159+
160+ NNVM_REGISTER_OP (__one__)
161+ .describe(" set output to one" )
162+ .set_num_inputs(0 );
102163
103164NNVM_REGISTER_OP (cross_device_copy)
104165.describe(" Copy data across device." )
0 commit comments