@@ -108,7 +108,7 @@ def __deepcopy__(self, _=None):
108108 def __getitem__ (self , index ):
109109 if isinstance (index , _base .string_types ):
110110 idx = None
111- for i , name in enumerate (self .list_outputs ()):
111+ for i , name in enumerate (self .list_output_names ()):
112112 if name == index :
113113 if idx is not None :
114114 raise ValueError ('There are multiple outputs with name \" %s\" ' % index )
@@ -177,7 +177,40 @@ def get_internals(self):
177177 self .handle , _ctypes .byref (handle )))
178178 return Symbol (handle = handle )
179179
180- def list_inputs (self , option = 'all' ):
180+ def _get_list_copt (self , option ):
181+ """internal function to get list option"""
182+ if option == 'all' :
183+ return _ctypes .c_int (0 )
184+ elif option == 'read_only' :
185+ return _ctypes .c_int (1 )
186+ elif option == 'aux_state' :
187+ return _ctypes .c_int (2 )
188+ else :
189+ raise ValueError ("option need to be in {'all', 'read_only, 'aux_state'}" )
190+
191+ def list_input_variables (self , option = 'all' ):
192+ """List all the input variables in the symbol.
193+
194+ Parameters
195+ ----------
196+ option : {'all', 'read_only', 'aux_state'}, optional
197+ The listing option
198+ - 'all' will list all the arguments.
199+ - 'read_only' lists arguments that are readed by the graph.
200+ - 'aux_state' lists arguments that are mutated by the graph as state.
201+ Returns
202+ -------
203+ vars : list of symbol
204+ List of all the variables
205+ """
206+ size = _ctypes .c_uint ()
207+ sarr = _ctypes .POINTER (_base .SymbolHandle )()
208+ _check_call (_LIB .NNSymbolListInputVariables (
209+ self .handle , self ._get_list_copt (option ),
210+ _ctypes .byref (size ), _ctypes .byref (sarr )))
211+ return [Symbol (_base .SymbolHandle (sarr [i ])) for i in range (size .value )]
212+
213+ def list_input_names (self , option = 'all' ):
181214 """List all the inputs in the symbol.
182215
183216 Parameters
@@ -194,19 +227,12 @@ def list_inputs(self, option='all'):
194227 """
195228 size = _ctypes .c_uint ()
196229 sarr = _ctypes .POINTER (_ctypes .c_char_p )()
197- if option == 'all' :
198- copt = _ctypes .c_int (0 )
199- elif option == 'read_only' :
200- copt = _ctypes .c_int (1 )
201- elif option == 'aux_state' :
202- copt = _ctypes .c_int (2 )
203- else :
204- raise ValueError ("option need to be in {'all', 'read_only, 'aux_state'}" )
205230 _check_call (_LIB .NNSymbolListInputNames (
206- self .handle , copt , _ctypes .byref (size ), _ctypes .byref (sarr )))
231+ self .handle , self ._get_list_copt (option ),
232+ _ctypes .byref (size ), _ctypes .byref (sarr )))
207233 return [_base .py_str (sarr [i ]) for i in range (size .value )]
208234
209- def list_outputs (self ):
235+ def list_output_names (self ):
210236 """List all outputs in the symbol.
211237
212238 Returns
0 commit comments