55
66import ctypes
77import sys
8+ import json
89from .base import _LIB
910from .base import c_array , c_str , nn_uint , py_str , string_types
1011from .base import GraphHandle , SymbolHandle
1112from .base import check_call
1213from .symbol import Symbol
1314
15+
1416class Graph (object ):
1517 """Graph is the graph object that can be used to apply optimization pass.
1618 It contains additional graphwise attribute besides the internal symbol.
@@ -31,7 +33,7 @@ def __init__(self, handle):
3133 def __del__ (self ):
3234 check_call (_LIB .NNGraphFree (self .handle ))
3335
34- def attr (self , key ):
36+ def json_attr (self , key ):
3537 """Get attribute string from the graph.
3638
3739 Parameters
@@ -46,24 +48,33 @@ def attr(self, key):
4648 """
4749 ret = ctypes .c_char_p ()
4850 success = ctypes .c_int ()
49- check_call (_LIB .NNGraphGetStrAttr (
51+ check_call (_LIB .NNGraphGetJSONAttr (
5052 self .handle , c_str (key ), ctypes .byref (ret ), ctypes .byref (success )))
5153 if success .value != 0 :
52- return py_str (ret .value )
54+ json_str = py_str (ret .value )
55+ return json .loads (json_str )[1 ]
5356 else :
5457 return None
5558
56- def _set_attr (self , ** kwargs ):
59+ def _set_json_attr (self , key , value , type_name = None ):
5760 """Set the attribute of the symbol.
5861
5962 Parameters
6063 ----------
61- **kwargs
62- The attributes to set
64+ key : string
65+ The key of the attribute
66+ value : value
67+ The any type that can be dumped to json
68+ type_name : string
69+ The typename registered on c++ side.
6370 """
64- for k , v in kwargs .items ():
65- check_call (_LIB .NNGraphSetStrAttr (
66- self .handle , c_str (k ), c_str (v )))
71+ if isinstance (value , string_types ):
72+ type_name = 'str'
73+ elif type_name is None :
74+ raise ValueError ("Need to specify type_name" )
75+ json_value = json .dumps ([type_name , value ])
76+ check_call (_LIB .NNGraphSetJSONAttr (
77+ self .handle , c_str (key ), c_str (json_value )))
6778
6879 @property
6980 def symbol (self ):
0 commit comments