@@ -6,6 +6,7 @@ from .._base import NNVMError
66from ..name import NameManager
77from ..attribute import AttrScope
88from libcpp.vector cimport vector
9+ from libcpp.string cimport string
910from cpython.version cimport PY_MAJOR_VERSION
1011
1112include " ./base.pyi"
@@ -110,7 +111,7 @@ cdef class Symbol:
110111 CALL(NNSymbolGetOutput(self .handle, c_index, & handle))
111112 return NewSymbol(handle)
112113
113- def attr (self , const char* key ):
114+ def attr (self , key ):
114115 """ Get attribute string from the symbol, this function only works for non-grouped symbol.
115116
116117 Parameters
@@ -125,6 +126,8 @@ cdef class Symbol:
125126 """
126127 cdef const char * ret
127128 cdef int success
129+ key = c_str(key)
130+
128131 CALL(NNSymbolGetAttr(
129132 self .handle, key, & ret, & success))
130133 if success != 0 :
@@ -203,16 +206,19 @@ cdef class Symbol:
203206 def debug_str (self ):
204207 cdef const char * out_str
205208 CALL(NNSymbolPrint(self .handle, & out_str))
206- return str (out_str)
209+ return py_str (out_str)
207210
208211
209212cdef SymbolSetAttr(SymbolHandle handle, dict kwargs):
210- cdef vector[const char * ] param_keys
211- cdef vector[const char * ] param_vals
213+ cdef vector[string] sparam_keys
214+ cdef vector[string] sparam_vals
212215 cdef nn_uint num_args
213216 for k, v in kwargs.items():
214- param_keys.push_back(k)
215- param_vals.push_back(str (v))
217+ sparam_keys.push_back(c_str(k))
218+ sparam_vals.push_back(c_str(str (v)))
219+ # keep strings in vector
220+ cdef vector[const char * ] param_keys = SVec2Ptr(sparam_keys)
221+ cdef vector[const char * ] param_vals = SVec2Ptr(sparam_vals)
216222 num_args = param_keys.size()
217223 CALL(NNSymbolSetAttrs(
218224 handle, num_args, CBeginPtr(param_keys), CBeginPtr(param_vals)))
@@ -225,7 +231,7 @@ cdef NewSymbol(SymbolHandle handle):
225231 return sym
226232
227233
228- def Variable (const char* name , **kwargs ):
234+ def Variable (name , **kwargs ):
229235 """ Create a symbolic variable with specified name.
230236
231237 Parameters
@@ -241,6 +247,7 @@ def Variable(const char* name, **kwargs):
241247 The created variable symbol.
242248 """
243249 cdef SymbolHandle handle
250+ name = c_str(name)
244251 CALL(NNSymbolCreateVariable(name, & handle))
245252 return NewSymbol(handle)
246253
@@ -274,10 +281,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
274281 func_hint = func_name.lower()
275282
276283 def creator (*args , **kwargs ):
277- cdef vector[const char * ] param_keys
278- cdef vector[const char * ] param_vals
284+ cdef vector[string] sparam_keys
285+ cdef vector[string] sparam_vals
279286 cdef vector[SymbolHandle] symbol_args
280- cdef vector[const char * ] symbol_keys
287+ cdef vector[string] ssymbol_keys
281288 cdef SymbolHandle ret_handle
282289
283290 name = kwargs.pop(" name" , None )
@@ -286,11 +293,11 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
286293 if len (kwargs) != 0 :
287294 for k, v in kwargs.items():
288295 if isinstance (v, Symbol):
289- symbol_keys .push_back(k )
296+ ssymbol_keys .push_back(c_str(k) )
290297 symbol_args.push_back((< Symbol> v).handle)
291298 else :
292- param_keys .push_back(k )
293- param_vals .push_back(str (v))
299+ sparam_keys .push_back(c_str(k) )
300+ sparam_vals .push_back(c_str( str (v) ))
294301
295302 if len (args) != 0 :
296303 if symbol_args.size() != 0 :
@@ -301,6 +308,10 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
301308 raise TypeError (' Compose expect `Symbol` as arguments' )
302309 symbol_args.push_back((< Symbol> v).handle)
303310
311+ cdef vector[const char * ] param_keys = SVec2Ptr(sparam_keys)
312+ cdef vector[const char * ] param_vals = SVec2Ptr(sparam_vals)
313+ cdef vector[const char * ] symbol_keys = SVec2Ptr(ssymbol_keys)
314+
304315 CALL(NNSymbolCreateAtomicSymbol(
305316 handle,
306317 < nn_uint> param_keys.size(),
@@ -315,7 +326,9 @@ cdef _make_atomic_symbol_function(AtomicSymbolCreator handle):
315326 name = NameManager.current.get(name, func_hint)
316327
317328 cdef const char * c_name = NULL
329+
318330 if name:
331+ name = c_str(name)
319332 c_name = name
320333
321334 CALL(NNSymbolCompose(
0 commit comments