Skip to content

Commit 343c19a

Browse files
committed
[COMPILER] GraphHash based cache system, allow dump and query duplicated functions. (apache#30)
1 parent 300ae30 commit 343c19a

File tree

19 files changed

+856
-149
lines changed

19 files changed

+856
-149
lines changed

nnvm/include/nnvm/graph.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,11 +63,11 @@ class Graph {
6363
* \return The indexed graph.
6464
* \sa IndexedGraph
6565
*/
66-
const IndexedGraph& indexed_graph();
66+
const IndexedGraph& indexed_graph() const;
6767

6868
private:
6969
// internal structure of indexed graph
70-
std::shared_ptr<const IndexedGraph> indexed_graph_;
70+
mutable std::shared_ptr<const IndexedGraph> indexed_graph_;
7171
};
7272

7373
/*!

nnvm/include/nnvm/pass_functions.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,17 @@ inline std::string SaveJSON(Graph graph) {
4141
return ret.GetAttr<std::string>("json");
4242
}
4343

44+
45+
/*!
46+
* \brief Print graph ir
47+
* \param graph The graph to be printed
48+
* \return The graph ir string.
49+
*/
50+
inline std::string PrintGraphIR(Graph graph) {
51+
Graph ret = ApplyPass(std::move(graph), "PrintGraphIR");
52+
return ret.GetAttr<std::string>("graphir");
53+
}
54+
4455
/*!
4556
* \brief Add control flow dependencies between nodes.
4657
*

nnvm/python/nnvm/compiler/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from . import build_module
77
from . build_module import build, optimize, build_config
8+
from . compile_engine import engine, graph_key
89

910
from .. import symbol as _symbol
1011
from .. import graph as _graph
@@ -14,5 +15,6 @@
1415

1516
from .. import top as _top
1617

18+
1719
tvm.register_extension(_symbol.Symbol, _symbol.Symbol)
1820
tvm.register_extension(_graph.Graph, _graph.Graph)

nnvm/python/nnvm/compiler/build_module.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ def build(graph, target, shape, dtype="float32", params=None):
184184
graph._set_json_attr("target", target, "str")
185185
graph._set_json_attr("opt_level", cfg.opt_level, "int")
186186
graph = graph.apply("InferShape").apply("InferType")
187-
graph = graph.apply("GraphFusePartition").apply("GraphFuse")
187+
graph = graph.apply("GraphFusePartition").apply("GraphFuseCompile")
188188
libmod = graph_attr._move_out_module(graph, "module")
189189
return graph, libmod, params
190190

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
# pylint: disable=invalid-name
2+
"""Compiler engine interface to internal engine"""
3+
import tvm
4+
5+
_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems")
6+
_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache")
7+
_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem")
8+
_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem")
9+
_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph")
10+
_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey")
11+
12+
@tvm.register_node
13+
class GraphKey(tvm.node.NodeBase):
14+
"""Key of a graph compilation context"""
15+
@property
16+
def graph(self):
17+
return _graph_key_get_graph(self)
18+
19+
20+
@tvm.register_node
21+
class GraphCacheEntry(tvm.node.NodeBase):
22+
"""CacheEntry of compilation into a TVM Function"""
23+
pass
24+
25+
26+
@tvm.register_node
27+
class GraphFunc(tvm.node.NodeBase):
28+
"""Compiled result of a graph into a TVM Function"""
29+
pass
30+
31+
32+
class Engine(object):
33+
"""Global singleton compilation engine."""
34+
def items(self):
35+
"""List the available cache key value pairs.
36+
37+
Returns
38+
-------
39+
item_list : list of (GraphKey, GraphCacheEntry)
40+
The existing cache items
41+
"""
42+
res = _list_cache_items()
43+
assert len(res) % 2 == 0
44+
return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)]
45+
46+
def clear_cache(self):
47+
"""Clear the existing cached functions."""
48+
_clear_cache()
49+
50+
def __setitem__(self, key, value):
51+
"""Clear the existing cached functions."""
52+
if isinstance(value, GraphCacheEntry):
53+
_set_cache_item(key, value.graph_func)
54+
else:
55+
_set_cache_item(key, value)
56+
57+
def __getitem__(self, key):
58+
"""Clear the existing cached functions."""
59+
return _get_cache_item(key)
60+
61+
def dump(self):
62+
"""Return a string representation of engine dump
63+
64+
Returns
65+
-------
66+
dump : str
67+
The dumped string representation
68+
"""
69+
items = self.items()
70+
res = "====================================\n"
71+
res += "CompilerEngine dump, %d items cached\n" % len(items)
72+
for key, value in items:
73+
res += "------------------------------------\n"
74+
res += "target={}\n".format(key.target)
75+
res += "inputs={}\n".format(key.inputs)
76+
res += "use_count={}\n".format(value.use_count)
77+
res += "func_name={}\n".format(value.graph_func.func_name)
78+
res += key.graph.ir() + "\n"
79+
res += "===================================\n"
80+
return res
81+
82+
engine = Engine()
83+
84+
85+
def graph_key(graph, inputs, target):
86+
"""Construct a new graph key.
87+
88+
Parameters
89+
----------
90+
graph : Graph
91+
The computation graph structure
92+
93+
inputs : list of Tensor(placeholder)
94+
The input requirement to the graph.
95+
96+
target : str
97+
The target of compilation.
98+
"""
99+
return _make_graph_key(graph, inputs, target)
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,3 @@
11
"""Utilities for testcase"""
2+
3+
from .config import ctx_list

nnvm/python/nnvm/testing/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import os
33
import tvm
44

5-
def test_ctx_list():
5+
def ctx_list():
66
"""Get context list for testcases"""
77
device_list = os.environ.get("NNVM_TEST_TARGETS", "")
88
device_list = (device_list.split(",") if device_list

0 commit comments

Comments
 (0)