Skip to content

Commit 111b2da

Browse files
authored
RelayViz Graphviz renderer (#10400)
Following #10085, this PR adds a graphviz backend. It requires python `graphviz` package and `dot` executable in the PATH, similar to `tedd.py`. This implementation is much like a porting of `visualize` function in https://tvm.apache.org/2020/07/14/bert-pytorch-tvm, except that `node_attr_dict` is replaced with a callback `get_node_attr`. `get_node_attr` can be somehow used to emphasize a set of nodes. It might be useful if we encounter problems in inferences and want to find nodes with certain types and attributes. An example is provided in https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/test_viz.py Its outputs are (conv2d with NCHW layout is green-colored): https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/mod_with_subgraph.pdf https://github.com/chiwwang/tvm/blob/graphviz_renderer_example/mod_wo_subgraph.pdf
1 parent 92a80e9 commit 111b2da

File tree

6 files changed

+245
-7
lines changed

6 files changed

+245
-7
lines changed

docs/reference/api/python/contrib.rst

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,12 @@ tvm.contrib.relay_viz
9797
~~~~~~~~~~~~~~~~~~~~~
9898
.. automodule:: tvm.contrib.relay_viz
9999
:members:
100-
.. automodule:: tvm.contrib.relay_viz.interface
100+
.. automodule:: tvm.contrib.relay_viz.dot
101101
:members:
102102
.. automodule:: tvm.contrib.relay_viz.terminal
103103
:members:
104+
.. automodule:: tvm.contrib.relay_viz.interface
105+
:members:
104106

105107

106108
tvm.contrib.rocblas

gallery/how_to/work_with_relay/using_relay_viz.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@
3232
Here we use a renderer rendering graph in the text-form.
3333
It is a lightweight, AST-like visualizer, inspired by `clang ast-dump <https://clang.llvm.org/docs/IntroductionToTheClangAST.html>`_.
3434
We will introduce how to implement customized parsers and renderers through interface classes.
35+
36+
For more details, please refer to :py:mod:`tvm.contrib.relay_viz`.
3537
"""
3638
from typing import (
3739
Dict,

python/tvm/contrib/relay_viz/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
TermPlotter,
2828
TermVizParser,
2929
)
30+
from .dot import (
31+
DotPlotter,
32+
DotVizParser,
33+
)
3034

3135

3236
class RelayVisualizer:
@@ -69,12 +73,16 @@ def __init__(
6973

7074
node_to_id = {}
7175
# callback to generate an unique string-ID for nodes.
76+
# node_count_offset ensure each node ID is still unique across subgraph.
77+
node_count_offset = 0
78+
7279
def traverse_expr(node):
7380
if node in node_to_id:
7481
return
75-
node_to_id[node] = str(len(node_to_id))
82+
node_to_id[node] = str(len(node_to_id) + node_count_offset)
7683

7784
for name in graph_names:
85+
node_count_offset += len(node_to_id)
7886
node_to_id.clear()
7987
relay.analysis.post_order_visit(relay_mod[name], traverse_expr)
8088
graph = self._plotter.create_graph(name)
Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,221 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""Visualize Relay IR by Graphviz DOT language."""
18+
19+
from typing import (
20+
Any,
21+
Callable,
22+
Dict,
23+
)
24+
from .interface import (
25+
DefaultVizParser,
26+
Plotter,
27+
VizEdge,
28+
VizGraph,
29+
VizNode,
30+
)
31+
32+
try:
33+
import graphviz
34+
except ImportError:
35+
# add "from None" to silence
36+
# "During handling of the above exception, another exception occurred"
37+
raise ImportError(
38+
"The graphviz package is required for DOT renderer. "
39+
"Please install it first. For example, pip3 install graphviz"
40+
) from None
41+
42+
DotVizParser = DefaultVizParser
43+
44+
45+
class DotGraph(VizGraph):
46+
"""DOT graph for relay IR.
47+
48+
See also :py:class:`tvm.contrib.relay_viz.dot.DotPlotter`
49+
50+
Parameters
51+
----------
52+
name: str
53+
name of this graph.
54+
graph_attr: Optional[Dict[str, str]]
55+
key-value pairs for the graph.
56+
node_attr: Optional[Dict[str, str]]
57+
key-value pairs for all nodes.
58+
edge_attr: Optional[Dict[str, str]]
59+
key-value pairs for all edges.
60+
get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]]
61+
A callable returning attributes for the node.
62+
"""
63+
64+
def __init__(
65+
self,
66+
name: str,
67+
graph_attr: Dict[str, str] = None,
68+
node_attr: Dict[str, str] = None,
69+
edge_attr: Dict[str, str] = None,
70+
get_node_attr: Callable[[VizNode], Dict[str, str]] = None,
71+
):
72+
self._name = name
73+
self._get_node_attr = self._default_get_node_attr
74+
if get_node_attr is not None:
75+
self._get_node_attr = get_node_attr
76+
77+
# graphviz recognizes the subgraph as a cluster subgraph
78+
# by the name starting with "cluster" (all lowercase)
79+
self._digraph = graphviz.Digraph(
80+
name=f"cluster_{self._name}",
81+
graph_attr=graph_attr,
82+
node_attr=node_attr,
83+
edge_attr=edge_attr,
84+
)
85+
self._digraph.attr(label=self._name)
86+
87+
def node(self, viz_node: VizNode) -> None:
88+
"""Add a node to the underlying graph.
89+
Nodes in a Relay IR Module are expected to be added in the post-order.
90+
91+
Parameters
92+
----------
93+
viz_node : VizNode
94+
A `VizNode` instance.
95+
"""
96+
self._digraph.node(
97+
viz_node.identity,
98+
f"{viz_node.type_name}\n{viz_node.detail}",
99+
**self._get_node_attr(viz_node),
100+
)
101+
102+
def edge(self, viz_edge: VizEdge) -> None:
103+
"""Add an edge to the underlying graph.
104+
105+
Parameters
106+
----------
107+
viz_edge : VizEdge
108+
A `VizEdge` instance.
109+
"""
110+
self._digraph.edge(viz_edge.start, viz_edge.end)
111+
112+
@property
113+
def digraph(self):
114+
return self._digraph
115+
116+
@staticmethod
117+
def _default_get_node_attr(node: VizNode):
118+
if "Var" in node.type_name:
119+
return {"shape": "ellipse"}
120+
return {"shape": "box"}
121+
122+
123+
class DotPlotter(Plotter):
124+
"""DOT language graph plotter
125+
126+
The plotter accepts various graphviz attributes for graphs, nodes, and edges.
127+
Please refer to https://graphviz.org/doc/info/attrs.html for available attributes.
128+
129+
Parameters
130+
----------
131+
graph_attr: Optional[Dict[str, str]]
132+
key-value pairs for all graphs.
133+
node_attr: Optional[Dict[str, str]]
134+
key-value pairs for all nodes.
135+
edge_attr: Optional[Dict[str, str]]
136+
key-value pairs for all edges.
137+
get_node_attr: Optional[Callable[[VizNode], Dict[str, str]]]
138+
A callable returning attributes for a specific node.
139+
render_kwargs: Optional[Dict[str, Any]]
140+
keyword arguments directly passed to `graphviz.Digraph.render()`.
141+
142+
Examples
143+
--------
144+
145+
.. code-block:: python
146+
147+
from tvm.contrib import relay_viz
148+
from tvm.relay.testing import resnet
149+
150+
mod, param = resnet.get_workload(num_layers=18)
151+
# graphviz attributes
152+
graph_attr = {"color": "red"}
153+
node_attr = {"color": "blue"}
154+
edge_attr = {"color": "black"}
155+
156+
# VizNode is passed to the callback.
157+
# We want to color NCHW conv2d nodes. Also give Var a different shape.
158+
def get_node_attr(node):
159+
if "nn.conv2d" in node.type_name and "NCHW" in node.detail:
160+
return {
161+
"fillcolor": "green",
162+
"style": "filled",
163+
"shape": "box",
164+
}
165+
if "Var" in node.type_name:
166+
return {"shape": "ellipse"}
167+
return {"shape": "box"}
168+
169+
# Create plotter and pass it to viz. Then render the graph.
170+
dot_plotter = relay_viz.DotPlotter(
171+
graph_attr=graph_attr,
172+
node_attr=node_attr,
173+
edge_attr=edge_attr,
174+
get_node_attr=get_node_attr)
175+
176+
viz = relay_viz.RelayVisualizer(
177+
mod,
178+
relay_param=param,
179+
plotter=dot_plotter,
180+
parser=relay_viz.DotVizParser())
181+
viz.render("hello")
182+
"""
183+
184+
def __init__(
185+
self,
186+
graph_attr: Dict[str, str] = None,
187+
node_attr: Dict[str, str] = None,
188+
edge_attr: Dict[str, str] = None,
189+
get_node_attr: Callable[[VizNode], Dict[str, str]] = None,
190+
render_kwargs: Dict[str, Any] = None,
191+
):
192+
self._name_to_graph = {}
193+
self._graph_attr = graph_attr
194+
self._node_attr = node_attr
195+
self._edge_attr = edge_attr
196+
self._get_node_attr = get_node_attr
197+
198+
self._render_kwargs = {} if render_kwargs is None else render_kwargs
199+
200+
def create_graph(self, name):
201+
self._name_to_graph[name] = DotGraph(
202+
name, self._graph_attr, self._node_attr, self._edge_attr, self._get_node_attr
203+
)
204+
return self._name_to_graph[name]
205+
206+
def render(self, filename: str = None):
207+
"""render the graph generated from the Relay IR module.
208+
209+
This function is a thin wrapper of `graphviz.Digraph.render()`.
210+
"""
211+
# Create or update the filename
212+
if filename is not None:
213+
self._render_kwargs["filename"] = filename
214+
# default cleanup
215+
if "cleanup" not in self._render_kwargs:
216+
self._render_kwargs["cleanup"] = True
217+
218+
root_graph = graphviz.Digraph()
219+
for graph in self._name_to_graph.values():
220+
root_graph.subgraph(graph.digraph)
221+
root_graph.render(**self._render_kwargs)

python/tvm/contrib/relay_viz/interface.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, node_id: str, node_type: str, node_detail: str):
4848
self._detail = node_detail
4949

5050
@property
51-
def identity(self) -> Union[int, str]:
51+
def identity(self) -> str:
5252
return self._id
5353

5454
@property
@@ -59,6 +59,10 @@ def type_name(self) -> str:
5959
def detail(self) -> str:
6060
return self._detail
6161

62+
def __repr__(self) -> str:
63+
detail = self._detail.replace("\n", ", ")
64+
return f"VizNode(identity: {self._id}, type_name: {self._type}, detail: {detail}"
65+
6266

6367
class VizEdge:
6468
"""VizEdge connect two `VizNode`.
@@ -139,7 +143,7 @@ def edge(self, viz_edge: VizEdge) -> None:
139143
140144
Parameters
141145
----------
142-
id_start : VizEdge
146+
viz_edge : VizEdge
143147
A `VizEdge` instance.
144148
"""
145149

@@ -277,7 +281,7 @@ def _tuple_get_item(
277281
node_id = node_to_id[node]
278282

279283
# Tuple -> TupleGetItemNode
280-
viz_node = VizNode(node_id, f"TupleGetItem", "idx: {node.index}")
284+
viz_node = VizNode(node_id, f"TupleGetItem", f"idx: {node.index}")
281285
viz_edges = [VizEdge(node_to_id[node.tuple_value], node_id)]
282286
return viz_node, viz_edges
283287

python/tvm/contrib/relay_viz/terminal.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,11 @@
3131
VizEdge,
3232
VizGraph,
3333
VizNode,
34+
VizParser,
3435
)
3536

3637

37-
class TermVizParser(DefaultVizParser):
38+
class TermVizParser(VizParser):
3839
"""`TermVizParser` parse nodes and edges for `TermPlotter`."""
3940

4041
def __init__(self):
@@ -166,7 +167,7 @@ def edge(self, viz_edge: VizEdge) -> None:
166167
167168
Parameters
168169
----------
169-
id_start : VizEdge
170+
viz_edge : VizEdge
170171
A `VizEdge` instance.
171172
"""
172173
# Take CallNode as an example, instead of "arguments point to CallNode",

0 commit comments

Comments
 (0)