Skip to content

Commit 5fced92

Browse files
tqchenicemelon
authored andcommitted
[LANG] Enable json load/save and pickle (#10)
1 parent 7250005 commit 5fced92

File tree

15 files changed

+521
-37
lines changed

15 files changed

+521
-37
lines changed

include/tvm/base.h

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,41 @@ using ::tvm::Node;
2121
using ::tvm::NodeRef;
2222
using ::tvm::AttrVisitor;
2323

24+
/*!
25+
* \brief save the node as well as all the node it depends on as json.
26+
* This can be used to serialize any TVM object
27+
*
28+
* \return the string representation of the node.
29+
*/
30+
std::string SaveJSON(const NodeRef& node);
31+
32+
/*!
33+
* \brief Internal implementation of LoadJSON
34+
* Load tvm Node object from json and return a shared_ptr of Node.
35+
* \param json_str The json string to load from.
36+
*
37+
* \return The shared_ptr of the Node.
38+
*/
39+
std::shared_ptr<Node> LoadJSON_(std::string json_str);
40+
41+
/*!
42+
* \brief Load the node from json string.
43+
* This can be used to deserialize any TVM object.
44+
*
45+
* \param json_str The json string to load from.
46+
*
47+
* \tparam NodeType the nodetype
48+
*
49+
* \code
50+
* Expr e = LoadJSON<Expr>(json_str);
51+
* \endcode
52+
*/
53+
template<typename NodeType,
54+
typename = typename std::enable_if<std::is_base_of<NodeRef, NodeType>::value>::type >
55+
inline NodeType LoadJSON(const std::string& json_str) {
56+
return NodeType(LoadJSON_(json_str));
57+
}
58+
2459
/*! \brief typedef the factory function of data iterator */
2560
using NodeFactory = std::function<std::shared_ptr<Node> ()>;
2661
/*!
@@ -32,8 +67,9 @@ struct NodeFactoryReg
3267
};
3368

3469
#define TVM_REGISTER_NODE_TYPE(TypeName) \
35-
DMLC_REGISTRY_REGISTER(::tvm::NodeFactoryReg, NodeFactoryReg, TypeName) \
36-
.set_body([]() { return std::make_shared<TypeName>(); })
70+
static DMLC_ATTRIBUTE_UNUSED ::tvm::NodeFactoryReg & __make_Node ## _ ## TypeName ## __ = \
71+
::dmlc::Registry<::tvm::NodeFactoryReg>::Get()->__REGISTER__(TypeName::_type_key) \
72+
.set_body([]() { return std::make_shared<TypeName>(); })
3773

3874
} // namespace tvm
3975
#endif // TVM_BASE_H_

include/tvm/c_api.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
/*! \brief TVM_DLL prefix for windows */
1616
#ifdef _WIN32
1717
#ifdef TVM_EXPORTS
18-
#define TVM_DLL TVM_EXTERN_C __declspec(dllexport)
18+
#define TVM_DLL __declspec(dllexport)
1919
#else
20-
#define TVM_DLL TVM_EXTERN_C __declspec(dllimport)
20+
#define TVM_DLL __declspec(dllimport)
2121
#endif
2222
#else
23-
#define TVM_DLL TVM_EXTERN_C
23+
#define TVM_DLL
2424
#endif
2525

26+
TVM_EXTERN_C {
2627
/*! \brief handle to functions */
2728
typedef void* FunctionHandle;
2829
/*! \brief handle to node */
@@ -147,5 +148,5 @@ TVM_DLL int TVMNodeGetAttr(NodeHandle handle,
147148
TVM_DLL int TVMNodeListAttrNames(NodeHandle handle,
148149
int *out_size,
149150
const char*** out_array);
150-
151+
} // TVM_EXTERN_C
151152
#endif // TVM_C_API_H_

python/tvm/_ctypes/_api.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ def __getattr__(self, name):
8989
"'%s' object has no attribute '%s'" % (str(type(self)), name))
9090
return value
9191

92-
9392
def __hash__(self):
9493
return _function_internal._raw_ptr(self)
9594

@@ -111,6 +110,29 @@ def __dir__(self):
111110
names.append(py_str(plist[i]))
112111
return names
113112

113+
def __reduce__(self):
114+
return (type(self), (None,), self.__getstate__())
115+
116+
def __getstate__(self):
117+
handle = self.handle
118+
if handle is not None:
119+
return {'handle': _function_internal._save_json(self)}
120+
else:
121+
return {'handle': None}
122+
123+
def __setstate__(self, state):
124+
# pylint: disable=assigning-non-slot
125+
handle = state['handle']
126+
if handle is not None:
127+
json_str = handle
128+
_push_arg(json_str)
129+
other = _function_internal._load_json(json_str)
130+
self.handle = other.handle
131+
other.handle = None
132+
else:
133+
self.handle = None
134+
135+
114136
def const(value, dtype=None):
115137
"""construct a constant"""
116138
if dtype is None:

python/tvm/function.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,38 @@ def const(value, dtype=None):
1919
return _function_internal._const(value, dtype)
2020

2121

22+
def load_json(json_str):
23+
"""Load tvm object from json_str.
24+
25+
Parameters
26+
----------
27+
json_str : str
28+
The json string
29+
30+
Returns
31+
-------
32+
node : Node
33+
The loaded tvm node.
34+
"""
35+
return _function_internal._load_json(json_str)
36+
37+
38+
def save_json(node):
39+
"""Load tvm object as json string.
40+
41+
Parameters
42+
----------
43+
node : Node
44+
A TVM Node object to be saved.
45+
46+
Returns
47+
-------
48+
json_str : str
49+
Saved json string.
50+
"""
51+
return _function_internal._save_json(node)
52+
53+
2254
def Var(name="tindex", dtype=int32):
2355
"""Create a new variable with specified name and dtype
2456

src/base/common.h

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
/*!
2+
* Copyright (c) 2016 by Contributors
3+
* \file common.h
4+
* \brief Common utilities
5+
*/
6+
#ifndef TVM_BASE_COMMON_H_
7+
#define TVM_BASE_COMMON_H_
8+
9+
#include <tvm/base.h>
10+
#include <string>
11+
12+
namespace tvm {
13+
14+
inline std::string Type2String(const Type& t) {
15+
std::ostringstream os;
16+
os << t;
17+
return os.str();
18+
}
19+
20+
inline Type String2Type(std::string s) {
21+
std::istringstream is(s);
22+
halide_type_code_t code = Type::Int;
23+
if (s.substr(0, 3) == "int") {
24+
code = Type::Int; s = s.substr(3);
25+
} else if (s.substr(0, 4) == "uint") {
26+
code = Type::UInt; s = s.substr(4);
27+
} else if (s.substr(0, 5) == "float") {
28+
code = Type::Float; s = s.substr(5);
29+
} else if (s.substr(0, 5) == "float") {
30+
code = Type::Float; s = s.substr(5);
31+
} else {
32+
LOG(FATAL) << "unknown type " << s;
33+
}
34+
int bits = 32, lanes = 1;
35+
if (sscanf(s.c_str(), "%dx%d", &bits, &lanes) == 0) {
36+
LOG(FATAL) << "unknown type " << s;
37+
}
38+
return Type(code, bits, lanes);
39+
}
40+
41+
} // namespace tvm
42+
#endif // TVM_BASE_COMMON_H_

0 commit comments

Comments
 (0)