Skip to content

Commit eea05d7

Browse files
committed
[SYMBOL] Change list_input->list_input_names, add list_input_variables (apache#59)
* [SYMBOL] Change list_input->list_input_names, add list_input_variables * fix
1 parent e8fee6d commit eea05d7

File tree

7 files changed

+98
-28
lines changed

7 files changed

+98
-28
lines changed

nnvm/include/nnvm/c_api.h

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,25 @@ NNVM_DLL int NNSymbolListAttrs(SymbolHandle symbol,
205205
int recursive_option,
206206
nn_uint *out_size,
207207
const char*** out);
208+
209+
/*!
210+
* \brief List inputs variables in the symbol.
211+
* \param symbol the symbol
212+
* \param option The option to list the inputs
213+
* option=0 means list all arguments.
214+
* option=1 means list arguments that are readed only by the graph.
215+
* option=2 means list arguments that are mutated by the graph.
216+
* \param out_size output size
217+
* \param out_sym_array the output array.
218+
* \return 0 when success, -1 when failure happens
219+
*/
220+
NNVM_DLL int NNSymbolListInputVariables(SymbolHandle symbol,
221+
int option,
222+
nn_uint *out_size,
223+
SymbolHandle** out_sym_array);
224+
208225
/*!
209-
* \brief List inputs in the symbol.
226+
* \brief List input names in the symbol.
210227
* \param symbol the symbol
211228
* \param option The option to list the inputs
212229
* option=0 means list all arguments.

nnvm/python/nnvm/symbol.py

Lines changed: 38 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def __deepcopy__(self, _=None):
108108
def __getitem__(self, index):
109109
if isinstance(index, _base.string_types):
110110
idx = None
111-
for i, name in enumerate(self.list_outputs()):
111+
for i, name in enumerate(self.list_output_names()):
112112
if name == index:
113113
if idx is not None:
114114
raise ValueError('There are multiple outputs with name \"%s\"' % index)
@@ -177,7 +177,40 @@ def get_internals(self):
177177
self.handle, _ctypes.byref(handle)))
178178
return Symbol(handle=handle)
179179

180-
def list_inputs(self, option='all'):
180+
def _get_list_copt(self, option):
181+
"""internal function to get list option"""
182+
if option == 'all':
183+
return _ctypes.c_int(0)
184+
elif option == 'read_only':
185+
return _ctypes.c_int(1)
186+
elif option == 'aux_state':
187+
return _ctypes.c_int(2)
188+
else:
189+
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
190+
191+
def list_input_variables(self, option='all'):
192+
"""List all the input variables in the symbol.
193+
194+
Parameters
195+
----------
196+
option : {'all', 'read_only', 'aux_state'}, optional
197+
The listing option
198+
- 'all' will list all the arguments.
199+
- 'read_only' lists arguments that are readed by the graph.
200+
- 'aux_state' lists arguments that are mutated by the graph as state.
201+
Returns
202+
-------
203+
vars : list of symbol
204+
List of all the variables
205+
"""
206+
size = _ctypes.c_uint()
207+
sarr = _ctypes.POINTER(_base.SymbolHandle)()
208+
_check_call(_LIB.NNSymbolListInputVariables(
209+
self.handle, self._get_list_copt(option),
210+
_ctypes.byref(size), _ctypes.byref(sarr)))
211+
return [Symbol(_base.SymbolHandle(sarr[i])) for i in range(size.value)]
212+
213+
def list_input_names(self, option='all'):
181214
"""List all the inputs in the symbol.
182215
183216
Parameters
@@ -194,19 +227,12 @@ def list_inputs(self, option='all'):
194227
"""
195228
size = _ctypes.c_uint()
196229
sarr = _ctypes.POINTER(_ctypes.c_char_p)()
197-
if option == 'all':
198-
copt = _ctypes.c_int(0)
199-
elif option == 'read_only':
200-
copt = _ctypes.c_int(1)
201-
elif option == 'aux_state':
202-
copt = _ctypes.c_int(2)
203-
else:
204-
raise ValueError("option need to be in {'all', 'read_only, 'aux_state'}")
205230
_check_call(_LIB.NNSymbolListInputNames(
206-
self.handle, copt, _ctypes.byref(size), _ctypes.byref(sarr)))
231+
self.handle, self._get_list_copt(option),
232+
_ctypes.byref(size), _ctypes.byref(sarr)))
207233
return [_base.py_str(sarr[i]) for i in range(size.value)]
208234

209-
def list_outputs(self):
235+
def list_output_names(self):
210236
"""List all outputs in the symbol.
211237
212238
Returns

nnvm/src/c_api/c_api_symbolic.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,25 @@ int NNSymbolListAttrs(SymbolHandle symbol,
221221
API_END();
222222
}
223223

224+
int NNSymbolListInputVariables(SymbolHandle symbol,
225+
int option,
226+
nn_uint *out_size,
227+
SymbolHandle** out_sym_array) {
228+
Symbol *s = static_cast<Symbol*>(symbol);
229+
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
230+
API_BEGIN();
231+
std::vector<NodePtr> vs = s->ListInputs(Symbol::ListInputOption(option));
232+
ret->ret_handles.clear();
233+
for (size_t i = 0; i < vs.size(); ++i) {
234+
nnvm::Symbol* rs = new nnvm::Symbol();
235+
rs->outputs.push_back(NodeEntry{vs[i], 0, 0});
236+
ret->ret_handles.push_back(rs);
237+
}
238+
*out_size = static_cast<nn_uint>(vs.size());
239+
*out_sym_array = dmlc::BeginPtr(ret->ret_handles);
240+
API_END();
241+
}
242+
224243
int NNSymbolListInputNames(SymbolHandle symbol,
225244
int option,
226245
nn_uint *out_size,

nnvm/src/core/symbolic.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ inline std::vector<std::string> GetKeys(
8787

8888
// whether the symbol is atomic functor
8989
inline bool IsAtomic(const std::vector<NodeEntry>& outputs) {
90-
return outputs[0].node->inputs.size() == 0;
90+
return outputs[0].node->inputs.size() == 0 &&
91+
outputs[0].node->control_deps.size() == 0;
9192
}
9293

9394
// public functions
@@ -118,7 +119,9 @@ Symbol Symbol::Copy() const {
118119
}
119120

120121
void Symbol::Print(std::ostream &os) const {
121-
if (outputs.size() == 1 && outputs[0].node->inputs.size() == 0) {
122+
if (outputs.size() == 1 &&
123+
outputs[0].node->inputs.size() == 0 &&
124+
outputs[0].node->control_deps.size() == 0) {
122125
if (outputs[0].node->is_variable()) {
123126
os << "Variable:" << outputs[0].node->attrs.name << '\n';
124127
} else {

nnvm/src/pass/gradient.cc

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ Graph Gradient(Graph src) {
6969
// topo sort
7070
std::vector<NodePtr> topo_order;
7171
std::unordered_map<Node*, std::vector<GradEntry> > output_grads;
72+
7273
DFSVisit(ys, [&](const NodePtr& node) {
7374
if (output_grads.count(node.get()) == 0) {
7475
output_grads[node.get()].resize(node->num_outputs());
@@ -113,13 +114,15 @@ Graph Gradient(Graph src) {
113114
e.sum = agg_fun(std::move(e.grads));
114115
out_agg_grads.push_back(e.sum);
115116
}
116-
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
117-
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
118-
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
119-
<< "Gradient function not returning enough gradient";
120-
auto git = input_grads.begin();
121-
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
122-
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
117+
if ((*rit)->inputs.size() != 0) {
118+
std::vector<NodeEntry> input_grads = grad_fun_map[ptr->op()]
119+
(mirror_map.size() == 0 ? ptr : mirror_map.at(ptr.get()), out_agg_grads);
120+
CHECK_EQ((*rit)->inputs.size(), input_grads.size())
121+
<< "Gradient function not returning enough gradient";
122+
auto git = input_grads.begin();
123+
for (auto it = (*rit)->inputs.begin(); it != (*rit)->inputs.end(); ++it, ++git) {
124+
output_grads[it->node.get()][it->index].grads.emplace_back(std::move(*git));
125+
}
123126
}
124127
}
125128
// take out the xs' grads

nnvm/tests/python/test_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def test_list_args():
4242
y = sym.add(y, z, name='add1')
4343
# write after read
4444
z = sym.assign(x, y, name='assign')
45-
assert z.list_inputs('read_only') == ['conv_weight', 'z']
46-
assert z.list_inputs('aux_state') == ['x']
45+
assert z.list_input_names('read_only') == ['conv_weight', 'z']
46+
assert z.list_input_names('aux_state') == ['x']
4747

4848
def test_infer_shape():
4949
x = sym.Variable('x', shape=(4, 2))

nnvm/tests/python/test_symbol.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,17 +7,19 @@ def test_compose():
77
y = sym.exp(sym.add(x, x, name='add', gpu=2),
88
name='exp', gpu=1, attr={"kk": "1"})
99

10-
assert y.list_inputs() == ['x']
11-
assert y.list_outputs() == ["exp_output"]
10+
assert y.list_input_names() == ['x']
11+
assert y.list_output_names() == ["exp_output"]
1212
assert y.list_attr()['gpu'] == '1'
1313
z = y.get_internals()
14-
assert z['add_output'].list_outputs() == ['add_output']
14+
assert z['add_output'].list_output_names() == ['add_output']
1515
assert y.list_attr(recursive=True)['add_gpu'] == '2'
1616

1717
def test_default_input():
1818
x = sym.Variable('x')
1919
y = sym.conv2d(data=x, name='conv')
20-
assert y.list_inputs() == ['x', 'conv_weight']
20+
assert y.list_input_names() == ['x', 'conv_weight']
21+
tname = [z.list_output_names()[0] for z in y.list_input_variables()]
22+
assert tname == y.list_input_names()
2123
try:
2224
z = sym.add(x)
2325
assert False

0 commit comments

Comments
 (0)