Skip to content

Commit fd89386

Browse files
committed
Enable use json for graph attr exchange (#5)
1 parent 877d0c2 commit fd89386

File tree

6 files changed

+71
-29
lines changed

6 files changed

+71
-29
lines changed

nnvm/include/nnvm/c_api.h

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -248,27 +248,34 @@ NNVM_DLL int NNGraphFree(GraphHandle handle);
248248
*/
249249
NNVM_DLL int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol);
250250
/*!
251-
* \brief Get Set a std::string typed attribute to graph.
251+
* \brief Get Set a attribute in json format.
252+
* This feature allows pass graph attributes back and forth in reasonable speed.
253+
*
252254
* \param handle The graph handle.
253255
* \param key The key to the attribute.
254-
* \param value The value to be exposed.
256+
* \param json_value The value need to be in format [type_name, value],
257+
* Where type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
255258
* \return 0 when success, -1 when failure happens
256259
*/
257-
NNVM_DLL int NNGraphSetStrAttr(GraphHandle handle,
258-
const char* key,
259-
const char* value);
260+
NNVM_DLL int NNGraphSetJSONAttr(GraphHandle handle,
261+
const char* key,
262+
const char* json_value);
260263
/*!
261-
* \brief Get Set a std::string typed attribute from graph attribute.
264+
* \brief Get a serialized attrirbute from graph.
265+
* This feature allows pass graph attributes back and forth in reasonable speed.
266+
*
262267
* \param handle The graph handle.
263268
* \param key The key to the attribute.
264-
* \param out The result attribute, can be NULL if the attribute do not exist.
269+
* \param json_out The result attribute, can be NULL if the attribute do not exist.
270+
* The json_out is an array of [type_name, value].
271+
* Where the type_name is a registered type string in C++ side via DMLC_JSON_ENABLE_ANY.
265272
* \param success Whether the result is contained in out.
266273
* \return 0 when success, -1 when failure happens
267274
*/
268-
NNVM_DLL int NNGraphGetStrAttr(SymbolHandle handle,
269-
const char* key,
270-
const char** out,
271-
int *success);
275+
NNVM_DLL int NNGraphGetJSONAttr(SymbolHandle handle,
276+
const char* key,
277+
const char** json_out,
278+
int *success);
272279
/*!
273280
* \brief Apply pass on the src graph.
274281
* \param src The source graph handle.

nnvm/python/nnvm/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def _load_lib():
4747
SymbolHandle = ctypes.c_void_p
4848
GraphHandle = ctypes.c_void_p
4949

50+
5051
#----------------------------
5152
# helper function definition
5253
#----------------------------

nnvm/python/nnvm/graph.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55

66
import ctypes
77
import sys
8+
import json
89
from .base import _LIB
910
from .base import c_array, c_str, nn_uint, py_str, string_types
1011
from .base import GraphHandle, SymbolHandle
1112
from .base import check_call
1213
from .symbol import Symbol
1314

15+
1416
class Graph(object):
1517
"""Graph is the graph object that can be used to apply optimization pass.
1618
It contains additional graphwise attribute besides the internal symbol.
@@ -31,7 +33,7 @@ def __init__(self, handle):
3133
def __del__(self):
3234
check_call(_LIB.NNGraphFree(self.handle))
3335

34-
def attr(self, key):
36+
def json_attr(self, key):
3537
"""Get attribute string from the graph.
3638
3739
Parameters
@@ -46,24 +48,33 @@ def attr(self, key):
4648
"""
4749
ret = ctypes.c_char_p()
4850
success = ctypes.c_int()
49-
check_call(_LIB.NNGraphGetStrAttr(
51+
check_call(_LIB.NNGraphGetJSONAttr(
5052
self.handle, c_str(key), ctypes.byref(ret), ctypes.byref(success)))
5153
if success.value != 0:
52-
return py_str(ret.value)
54+
json_str = py_str(ret.value)
55+
return json.loads(json_str)[1]
5356
else:
5457
return None
5558

56-
def _set_attr(self, **kwargs):
59+
def _set_json_attr(self, key, value, type_name=None):
5760
"""Set the attribute of the symbol.
5861
5962
Parameters
6063
----------
61-
**kwargs
62-
The attributes to set
64+
key : string
65+
The key of the attribute
66+
value : value
67+
The any type that can be dumped to json
68+
type_name : string
69+
The typename registered on c++ side.
6370
"""
64-
for k, v in kwargs.items():
65-
check_call(_LIB.NNGraphSetStrAttr(
66-
self.handle, c_str(k), c_str(v)))
71+
if isinstance(value, string_types):
72+
type_name = 'str'
73+
elif type_name is None:
74+
raise ValueError("Need to specify type_name")
75+
json_value = json.dumps([type_name, value])
76+
check_call(_LIB.NNGraphSetJSONAttr(
77+
self.handle, c_str(key), c_str(json_value)))
6778

6879
@property
6980
def symbol(self):

nnvm/src/c_api/c_api_graph.cc

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include <nnvm/symbolic.h>
99
#include <nnvm/graph.h>
1010
#include <nnvm/pass.h>
11+
#include <dmlc/json.h>
1112
#include "./c_api_common.h"
1213

1314
using namespace nnvm;
@@ -34,26 +35,35 @@ int NNGraphGetSymbol(GraphHandle graph, SymbolHandle *symbol) {
3435
API_END_HANDLE_ERROR(delete s);
3536
}
3637

37-
int NNGraphSetStrAttr(GraphHandle handle,
38-
const char* key,
39-
const char* value) {
38+
int NNGraphSetJSONAttr(GraphHandle handle,
39+
const char* key,
40+
const char* json_value) {
4041
API_BEGIN();
4142
Graph* g = static_cast<Graph*>(handle);
42-
g->attrs[std::string(key)] = std::make_shared<any>(std::string(value));
43+
std::string temp(json_value);
44+
std::istringstream is(temp);
45+
dmlc::JSONReader reader(&is);
46+
nnvm::any value;
47+
reader.Read(&value);
48+
g->attrs[std::string(key)] = std::make_shared<any>(std::move(value));
4349
API_END();
4450
}
4551

46-
int NNGraphGetStrAttr(GraphHandle handle,
52+
int NNGraphGetJSONAttr(GraphHandle handle,
4753
const char* key,
48-
const char** out,
54+
const char** json_out,
4955
int *success) {
56+
NNAPIThreadLocalEntry *ret = NNAPIThreadLocalStore::Get();
5057
API_BEGIN();
5158
Graph* g = static_cast<Graph*>(handle);
5259
std::string skey(key);
5360
auto it = g->attrs.find(skey);
5461
if (it != g->attrs.end()) {
55-
const std::string& str = nnvm::get<std::string>(*it->second.get());
56-
*out = str.c_str();
62+
std::ostringstream os;
63+
dmlc::JSONWriter writer(&os);
64+
writer.Write(*it->second.get());
65+
ret->ret_str = os.str();
66+
*json_out = (ret->ret_str).c_str();
5767
*success = 1;
5868
} else {
5969
*success = 0;

nnvm/src/pass/saveload_json.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,5 +203,9 @@ NNVM_REGISTER_PASS(SaveJSON)
203203
.set_change_graph(true)
204204
.provide_graph_attr("json");
205205

206+
207+
DMLC_JSON_ENABLE_ANY(std::string, str);
208+
DMLC_JSON_ENABLE_ANY(std::vector<int>, list_int);
209+
206210
} // namespace pass
207211
} // namespace nnvm

nnvm/tests/python/test_graph.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,18 @@ def test_json_pass():
66
y = sym.conv2d(data=x, name='conv', stride=(2,2))
77
g = graph.create(y)
88
ret = g.apply('SaveJSON')
9+
ret._set_json_attr('json', ret.json_attr('json'))
910
g2 = ret.apply('LoadJSON')
10-
assert g2.apply('SaveJSON').attr('json') == ret.attr('json')
11+
assert g2.apply('SaveJSON').json_attr('json') == ret.json_attr('json')
12+
13+
def test_graph_json_attr():
14+
x = sym.Variable('x')
15+
y = sym.conv2d(data=x, name='conv', stride=(2,2))
16+
g = graph.create(y)
17+
g._set_json_attr('ilist', [1,2,3], 'list_int')
18+
assert g.json_attr('ilist') == [1,2,3]
1119

1220

1321
if __name__ == "__main__":
22+
test_graph_json_attr()
1423
test_json_pass()

0 commit comments

Comments
 (0)