|
| 1 | +# pylint: disable=invalid-name |
| 2 | +"""Compiler engine interface to internal engine""" |
| 3 | +import tvm |
| 4 | + |
| 5 | +_list_cache_items = tvm.get_global_func("nnvm.compiler.ListCacheItems") |
| 6 | +_clear_cache = tvm.get_global_func("nnvm.compiler.ClearCache") |
| 7 | +_get_cache_item = tvm.get_global_func("nnvm.compiler.GetCacheItem") |
| 8 | +_set_cache_item = tvm.get_global_func("nnvm.compiler.SetCacheItem") |
| 9 | +_graph_key_get_graph = tvm.get_global_func("nnvm.compiler.GraphKeyGetGraph") |
| 10 | +_make_graph_key = tvm.get_global_func("nnvm.compiler.MakeGraphKey") |
| 11 | + |
| 12 | +@tvm.register_node |
| 13 | +class GraphKey(tvm.node.NodeBase): |
| 14 | + """Key of a graph compilation context""" |
| 15 | + @property |
| 16 | + def graph(self): |
| 17 | + return _graph_key_get_graph(self) |
| 18 | + |
| 19 | + |
| 20 | +@tvm.register_node |
| 21 | +class GraphCacheEntry(tvm.node.NodeBase): |
| 22 | + """CacheEntry of compilation into a TVM Function""" |
| 23 | + pass |
| 24 | + |
| 25 | + |
| 26 | +@tvm.register_node |
| 27 | +class GraphFunc(tvm.node.NodeBase): |
| 28 | + """Compiled result of a graph into a TVM Function""" |
| 29 | + pass |
| 30 | + |
| 31 | + |
| 32 | +class Engine(object): |
| 33 | + """Global singleton compilation engine.""" |
| 34 | + def items(self): |
| 35 | + """List the available cache key value pairs. |
| 36 | +
|
| 37 | + Returns |
| 38 | + ------- |
| 39 | + item_list : list of (GraphKey, GraphCacheEntry) |
| 40 | + The existing cache items |
| 41 | + """ |
| 42 | + res = _list_cache_items() |
| 43 | + assert len(res) % 2 == 0 |
| 44 | + return [(res[2*i], res[2*i+1]) for i in range(len(res)/2)] |
| 45 | + |
| 46 | + def clear_cache(self): |
| 47 | + """Clear the existing cached functions.""" |
| 48 | + _clear_cache() |
| 49 | + |
| 50 | + def __setitem__(self, key, value): |
| 51 | + """Clear the existing cached functions.""" |
| 52 | + if isinstance(value, GraphCacheEntry): |
| 53 | + _set_cache_item(key, value.graph_func) |
| 54 | + else: |
| 55 | + _set_cache_item(key, value) |
| 56 | + |
| 57 | + def __getitem__(self, key): |
| 58 | + """Clear the existing cached functions.""" |
| 59 | + return _get_cache_item(key) |
| 60 | + |
| 61 | + def dump(self): |
| 62 | + """Return a string representation of engine dump |
| 63 | +
|
| 64 | + Returns |
| 65 | + ------- |
| 66 | + dump : str |
| 67 | + The dumped string representation |
| 68 | + """ |
| 69 | + items = self.items() |
| 70 | + res = "====================================\n" |
| 71 | + res += "CompilerEngine dump, %d items cached\n" % len(items) |
| 72 | + for key, value in items: |
| 73 | + res += "------------------------------------\n" |
| 74 | + res += "target={}\n".format(key.target) |
| 75 | + res += "inputs={}\n".format(key.inputs) |
| 76 | + res += "use_count={}\n".format(value.use_count) |
| 77 | + res += "func_name={}\n".format(value.graph_func.func_name) |
| 78 | + res += key.graph.ir() + "\n" |
| 79 | + res += "===================================\n" |
| 80 | + return res |
| 81 | + |
| 82 | +engine = Engine() |
| 83 | + |
| 84 | + |
| 85 | +def graph_key(graph, inputs, target): |
| 86 | + """Construct a new graph key. |
| 87 | +
|
| 88 | + Parameters |
| 89 | + ---------- |
| 90 | + graph : Graph |
| 91 | + The computation graph structure |
| 92 | +
|
| 93 | + inputs : list of Tensor(placeholder) |
| 94 | + The input requirement to the graph. |
| 95 | +
|
| 96 | + target : str |
| 97 | + The target of compilation. |
| 98 | + """ |
| 99 | + return _make_graph_key(graph, inputs, target) |
0 commit comments