Skip to content

Commit 877d0c2

Browse files
committed
[PYTHON] Check in a symbolic construction interface in python, start … (#4)
* [PYTHON] Check in a symbolic construction interface in python, start add graph API * Graph API
1 parent 1800e56 commit 877d0c2

File tree

17 files changed

+1143
-16
lines changed

17 files changed

+1143
-16
lines changed

nnvm/include/nnvm/c_api.h

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,8 +30,8 @@ typedef unsigned int nn_uint;
3030
typedef void *AtomicSymbolCreator;
3131
/*! \brief handle to a symbol that can be bind as operator */
3232
typedef void *SymbolHandle;
33-
/*! \brief handle to a AtomicSymbol */
34-
typedef void *AtomicSymbolHandle;
33+
/*! \brief handle to Graph */
34+
typedef void *GraphHandle;
3535

3636
/*!
3737
* \brief return str message of the last error
@@ -71,7 +71,7 @@ NNVM_DLL int NNSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator,
7171
const char ***arg_names,
7272
const char ***arg_type_infos,
7373
const char ***arg_descriptions,
74-
const char **return_type = NULL);
74+
const char **return_type);
7575
/*!
7676
* \brief Create an AtomicSymbol functor.
7777
* \param creator the AtomicSymbolCreator
@@ -123,7 +123,18 @@ NNVM_DLL int NNSymbolCopy(SymbolHandle symbol, SymbolHandle *out);
123123
* \return 0 when success, -1 when failure happens
124124
*/
125125
NNVM_DLL int NNSymbolPrint(SymbolHandle symbol, const char **out_str);
126-
126+
/*!
127+
* \brief Get string attribute from symbol
128+
* \param symbol the source symbol
129+
* \param key The key of the symbol.
130+
* \param out The result attribute, can be NULL if the attribute do not exist.
131+
* \param success Whether the result is contained in out.
132+
* \return 0 when success, -1 when failure happens
133+
*/
134+
NNVM_DLL int NNSymbolGetAttr(SymbolHandle symbol,
135+
const char* key,
136+
const char** out,
137+
int *success);
127138
/*!
128139
* \brief Set string attribute from symbol.
129140
* NOTE: Setting attribute to a symbol can affect the semantics(mutable/immutable) of symbolic graph.
@@ -216,4 +227,59 @@ NNVM_DLL int NNSymbolCompose(SymbolHandle sym,
216227
const char** keys,
217228
SymbolHandle* args);
218229

230+
// Graph IR API
231+
/*!
232+
* \brief create a graph handle from symbol
233+
* \param symbol The symbol representing the graph.
234+
* \param graph The graph handle created.
235+
* \return 0 when success, -1 when failure happens
236+
*/
237+
NNVM_DLL int NNGraphCreate(SymbolHandle symbol, GraphHandle *graph);
238+
/*!
239+
* \brief free the graph handle
240+
* \param handle The handle to be freed.
241+
*/
242+
NNVM_DLL int NNGraphFree(GraphHandle handle);
243+
/*!
244+
* \brief Get a new symbol from the graph.
245+
* \param graph The graph handle.
246+
* \param symbol The corresponding symbol
247+
* \return 0 when success, -1 when failure happens
248+
*/
249+
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
250+
/*!
251+
* \brief Get Set a std::string typed attribute to graph.
252+
* \param handle The graph handle.
253+
* \param key The key to the attribute.
254+
* \param value The value to be exposed.
255+
* \return 0 when success, -1 when failure happens
256+
*/
257+
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
258+
const char* key,
259+
const char* value);
260+
/*!
261+
* \brief Get Set a std::string typed attribute from graph attribute.
262+
* \param handle The graph handle.
263+
* \param key The key to the attribute.
264+
* \param out The result attribute, can be NULL if the attribute do not exist.
265+
* \param success Whether the result is contained in out.
266+
* \return 0 when success, -1 when failure happens
267+
*/
268+
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
269+
const char* key,
270+
const char** out,
271+
int *success);
272+
/*!
273+
* \brief Apply pass on the src graph.
274+
* \param src The source graph handle.
275+
* \param num_pass The number of pass to be applied.
276+
* \param pass_names The names of the pass.
277+
* \param dst The result graph.
278+
* \return 0 when success, -1 when failure happens
279+
*/
280+
NNVM_DLL int NNGraphApplyPass(GraphHandle src,
281+
nn_uint num_pass,
282+
const char** pass_names,
283+
GraphHandle *dst);
284+
219285
#endif // NNVM_C_API_H_

nnvm/include/nnvm/op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,10 +323,10 @@ inline Op& Op::attr( // NOLINT(*)
323323
vec.resize(index_ + 1,
324324
std::make_pair(ValueType(), 0));
325325
std::pair<ValueType, int>& p = vec[index_];
326-
CHECK(p.second == 0 || p.first == value)
326+
CHECK(p.second == 0)
327327
<< "Attribute " << attr_name
328328
<< " of operator " << this->name
329-
<< " is already registered to a different value";
329+
<< " is already registered.";
330330
vec[index_] = std::make_pair(value, 1);
331331
});
332332
return *this;

nnvm/include/nnvm/symbolic.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,15 @@ class Symbol {
111111
* \param attrs The attributes to set.
112112
*/
113113
void SetAttrs(const std::vector<std::pair<std::string, std::string> >& attrs);
114+
/*!
115+
* \brief Get attributes from the symbol.
116+
* This only works for symbol with outputs from single operators.
117+
* For grouped sybmbol, an error will be raised.
118+
* \param key Key of the attribute. When key == "name", it returns the name attirbute.
119+
* \param out the output value of the attribute.
120+
* \return true if the attribute exists, false if the attribute do not exist.
121+
*/
122+
bool GetAttr(const std::string& key, std::string* out) const;
114123
/*!
115124
* \brief Get attribute dictionary from the symbol.
116125
* For grouped sybmbol, an error will be raised.

nnvm/python/nnvm/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#!/usr/bin/env python
2+
# coding: utf-8
3+
"""NNVM python API for ease of use and help new framework establish python API. """
4+
from __future__ import absolute_import
5+
6+
from . import base
7+
from . import symbol as sym
8+
from . import symbol
9+
10+
__version__ = base.__version__

nnvm/python/nnvm/attribute.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# coding: utf-8
2+
"""Attribute scoping support for symbolic API."""
3+
from __future__ import absolute_import
4+
5+
from .base import string_types
6+
7+
class AttrScope(object):
8+
"""Attribute manager for scoping.
9+
10+
User can also inherit this object to change naming behavior.
11+
12+
Parameters
13+
----------
14+
kwargs
15+
The attributes to set for all symbol creations in the scope.
16+
"""
17+
current = None
18+
19+
def __init__(self, **kwargs):
20+
self._old_scope = None
21+
for value in kwargs.values():
22+
if not isinstance(value, string_types):
23+
raise ValueError("Attributes need to be string")
24+
self._attr = kwargs
25+
26+
def get(self, attr):
27+
"""
28+
Get the attribute dict given the attribute set by the symbol.
29+
30+
Parameters
31+
----------
32+
attr : dict of string to string
33+
The attribute passed in by user during symbol creation.
34+
35+
Returns
36+
-------
37+
attr : dict of string to string
38+
Updated attributes to add other scope related attributes.
39+
"""
40+
if self._attr:
41+
ret = self._attr.copy()
42+
if attr:
43+
ret.update(attr)
44+
return ret
45+
else:
46+
return attr
47+
48+
def __enter__(self):
49+
# pylint: disable=protected-access
50+
self._old_scope = AttrScope.current
51+
attr = AttrScope.current._attr.copy()
52+
attr.update(self._attr)
53+
self._attr = attr
54+
AttrScope.current = self
55+
return self
56+
57+
def __exit__(self, ptype, value, trace):
58+
assert self._old_scope
59+
AttrScope.current = self._old_scope
60+
61+
AttrScope.current = AttrScope()
62+

nnvm/python/nnvm/base.py

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,189 @@
1+
# coding: utf-8
2+
# pylint: disable=invalid-name
3+
""" ctypes library of nnvm and helper functions """
4+
from __future__ import absolute_import
5+
6+
import sys
7+
import ctypes
8+
import numpy as np
9+
from . import libinfo
10+
11+
__all__ = ['NNNetError']
12+
#----------------------------
13+
# library loading
14+
#----------------------------
15+
if sys.version_info[0] == 3:
16+
string_types = str,
17+
numeric_types = (float, int, np.float32, np.int32)
18+
# this function is needed for python3
19+
# to convert ctypes.char_p .value back to python str
20+
py_str = lambda x: x.decode('utf-8')
21+
else:
22+
string_types = basestring,
23+
numeric_types = (float, int, long, np.float32, np.int32)
24+
py_str = lambda x: x
25+
26+
27+
class NNVMError(Exception):
28+
"""Error that will be throwed by all nnvm functions"""
29+
pass
30+
31+
def _load_lib():
32+
"""Load libary by searching possible path."""
33+
lib_path = libinfo.find_lib_path()
34+
lib = ctypes.cdll.LoadLibrary(lib_path[0])
35+
# DMatrix functions
36+
lib.NNGetLastError.restype = ctypes.c_char_p
37+
return lib
38+
39+
# version number
40+
__version__ = libinfo.__version__
41+
# library instance of nnvm
42+
_LIB = _load_lib()
43+
44+
# type definitions
45+
nn_uint = ctypes.c_uint
46+
SymbolCreatorHandle = ctypes.c_void_p
47+
SymbolHandle = ctypes.c_void_p
48+
GraphHandle = ctypes.c_void_p
49+
50+
#----------------------------
51+
# helper function definition
52+
#----------------------------
53+
def check_call(ret):
54+
"""Check the return value of C API call
55+
56+
This function will raise exception when error occurs.
57+
Wrap every API call with this function
58+
59+
Parameters
60+
----------
61+
ret : int
62+
return value from API calls
63+
"""
64+
if ret != 0:
65+
raise NNVMError(py_str(_LIB.NNGetLastError()))
66+
67+
def c_str(string):
68+
"""Create ctypes char * from a python string
69+
Parameters
70+
----------
71+
string : string type
72+
python string
73+
74+
Returns
75+
-------
76+
str : c_char_p
77+
A char pointer that can be passed to C API
78+
"""
79+
return ctypes.c_char_p(string.encode('utf-8'))
80+
81+
82+
def c_array(ctype, values):
83+
"""Create ctypes array from a python array
84+
85+
Parameters
86+
----------
87+
ctype : ctypes data type
88+
data type of the array we want to convert to
89+
90+
values : tuple or list
91+
data content
92+
93+
Returns
94+
-------
95+
out : ctypes array
96+
Created ctypes array
97+
"""
98+
return (ctype * len(values))(*values)
99+
100+
def ctypes2buffer(cptr, length):
101+
"""Convert ctypes pointer to buffer type.
102+
103+
Parameters
104+
----------
105+
cptr : ctypes.POINTER(ctypes.c_char)
106+
pointer to the raw memory region
107+
length : int
108+
the length of the buffer
109+
110+
Returns
111+
-------
112+
buffer : bytearray
113+
The raw byte memory buffer
114+
"""
115+
if not isinstance(cptr, ctypes.POINTER(ctypes.c_char)):
116+
raise TypeError('expected char pointer')
117+
res = bytearray(length)
118+
rptr = (ctypes.c_char * length).from_buffer(res)
119+
if not ctypes.memmove(rptr, cptr, length):
120+
raise RuntimeError('memmove failed')
121+
return res
122+
123+
def ctypes2numpy_shared(cptr, shape):
124+
"""Convert a ctypes pointer to a numpy array
125+
126+
The result numpy array shares the memory with the pointer
127+
128+
Parameters
129+
----------
130+
cptr : ctypes.POINTER(mx_float)
131+
pointer to the memory region
132+
133+
shape : tuple
134+
shape of target ndarray
135+
136+
Returns
137+
-------
138+
out : numpy_array
139+
A numpy array : numpy array
140+
"""
141+
if not isinstance(cptr, ctypes.POINTER(mx_float)):
142+
raise RuntimeError('expected float pointer')
143+
size = 1
144+
for s in shape:
145+
size *= s
146+
dbuffer = (mx_float * size).from_address(ctypes.addressof(cptr.contents))
147+
return np.frombuffer(dbuffer, dtype=np.float32).reshape(shape)
148+
149+
150+
def ctypes2docstring(num_args, arg_names, arg_types, arg_descs, remove_dup=True):
151+
"""Convert ctypes returned doc string information into parameters docstring.
152+
153+
num_args : nn_uint
154+
Number of arguments.
155+
156+
arg_names : ctypes.POINTER(ctypes.c_char_p)
157+
Argument names.
158+
159+
arg_types : ctypes.POINTER(ctypes.c_char_p)
160+
Argument type information.
161+
162+
arg_descs : ctypes.POINTER(ctypes.c_char_p)
163+
Argument description information.
164+
165+
remove_dup : boolean, optional
166+
Whether remove duplication or not.
167+
168+
Returns
169+
-------
170+
docstr : str
171+
Python docstring of parameter sections.
172+
"""
173+
param_keys = set()
174+
param_str = []
175+
for i in range(num_args.value):
176+
key = py_str(arg_names[i])
177+
if key in param_keys and remove_dup:
178+
continue
179+
param_keys.add(key)
180+
type_info = py_str(arg_types[i])
181+
ret = '%s : %s' % (key, type_info)
182+
if len(arg_descs[i]) != 0:
183+
ret += '\n ' + py_str(arg_descs[i])
184+
param_str.append(ret)
185+
doc_str = ('Parameters\n' +
186+
'----------\n' +
187+
'%s\n')
188+
doc_str = doc_str % ('\n'.join(param_str))
189+
return doc_str

0 commit comments

Comments
 (0)