Skip to content

Commit c3370fe

Browse files
authored
feat: Add option to render a subgraph of a hugr. (#2497)
For moderately sized hugrs, rendering the whole hugr produces sprawling images that are difficult to analyse. This feature allows one to select a region to render. Example output: [h_10.pdf](https://github.com/user-attachments/files/21509633/h_10.pdf)
1 parent 9dd4af8 commit c3370fe

File tree

4 files changed

+341
-17
lines changed

4 files changed

+341
-17
lines changed

hugr-py/src/hugr/hugr/base.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,21 +1152,32 @@ def load_json(cls, json_str: str) -> Hugr:
11521152
serial = SerialHugr.load_json(json_dict)
11531153
return cls._from_serial(serial)
11541154

1155-
def render_dot(self, config: RenderConfig | None = None) -> gv.Digraph:
1155+
def render_dot(
1156+
self, config: RenderConfig | None = None, root: Node | None = None
1157+
) -> gv.Digraph:
11561158
"""Render the HUGR to a graphviz Digraph.
11571159
11581160
Args:
11591161
config: Render configuration.
1162+
root: Root node defining the set of nodes to render. By default this is the
1163+
module root and all nodes are rendered. If this is a container node, all
1164+
nodes under it are rendered. Every incoming edge to the rendered set and
1165+
outgoing edge from it is also shown, with its other endpoint labelled
1166+
with its node index.
11601167
11611168
Returns:
11621169
The graphviz Digraph.
11631170
"""
11641171
from .render import DotRenderer
11651172

1166-
return DotRenderer(config).render(self)
1173+
return DotRenderer(config).render(self, root)
11671174

11681175
def store_dot(
1169-
self, filename: str, format: str = "svg", config: RenderConfig | None = None
1176+
self,
1177+
filename: str,
1178+
format: str = "svg",
1179+
config: RenderConfig | None = None,
1180+
root: Node | None = None,
11701181
) -> None:
11711182
"""Render the HUGR to a graphviz dot file.
11721183
@@ -1175,7 +1186,12 @@ def store_dot(
11751186
format: The format used for rendering ('pdf', 'png', etc.).
11761187
Defaults to SVG.
11771188
config: Render configuration.
1189+
root: Root node defining the set of nodes to render. By default this is the
1190+
module root and all nodes are rendered. If this is a container node, all
1191+
nodes under it are rendered. Every incoming edge to the rendered set and
1192+
outgoing edge from it is also shown, with its other endpoint labelled
1193+
with its node index.
11781194
"""
11791195
from .render import DotRenderer
11801196

1181-
DotRenderer(config).store(self, filename=filename, format=format)
1197+
DotRenderer(config).store(self, filename=filename, format=format, root=root)

hugr-py/src/hugr/hugr/render.py

Lines changed: 54 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,42 +92,60 @@ class DotRenderer:
9292

9393
def __init__(self, config: RenderConfig | None = None) -> None:
9494
self.config = config or RenderConfig()
95+
self.nodes: set[Node] = set()
9596

96-
def render(self, hugr: Hugr) -> Digraph:
97-
"""Render a HUGR to a graphviz dot object."""
97+
def render(self, hugr: Hugr, root: Node | None = None) -> Digraph:
98+
"""Render a HUGR to a graphviz dot object.
99+
100+
Args:
101+
hugr: The HUGR to render.
102+
root: Root node defining the set of nodes to render. By default this is the
103+
module root and all nodes are rendered. If this is a container node, all
104+
nodes under it are rendered. Every incoming edge to the rendered set and
105+
outgoing edge from it is also shown, with its other endpoint labelled
106+
with its node index.
107+
"""
108+
root = root or hugr.module_root
98109
graph_attr = {
99110
"rankdir": "",
100111
"ranksep": "0.1",
101112
"nodesep": "0.15",
102113
"margin": "0",
103114
"bgcolor": self.config.palette.background,
104115
}
105-
if name := hugr[hugr.module_root].metadata.get("name", None):
116+
if name := hugr[root].metadata.get("name", None):
106117
name = html.escape(str(name))
107118
else:
108119
name = ""
109120

110121
graph = gv.Digraph(name, strict=False)
111122
graph.attr(**graph_attr)
112123

113-
self._viz_node(hugr.module_root, hugr, graph)
124+
self._viz_node(root, hugr, graph)
114125

115126
for src_port, tgt_port in hugr.links():
116127
kind = hugr.port_kind(src_port)
117128
self._viz_link(src_port, tgt_port, kind, graph)
118129

119130
return graph
120131

121-
def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None:
132+
def store(
133+
self, hugr: Hugr, filename: str, format: str = "svg", root: Node | None = None
134+
) -> None:
122135
"""Render a HUGR and save it to a file.
123136
124137
Args:
125138
hugr: The HUGR to render.
126139
filename: Filename for saving the rendered graph.
127140
format: The format used for rendering ('pdf', 'png', etc.).
128141
Defaults to SVG.
142+
root: Root node defining the set of nodes to render. By default this is the
143+
module root and all nodes are rendered. If this is a container node, all
144+
nodes under it are rendered. Every incoming edge to the rendered set and
145+
outgoing edge from it is also shown, with its other endpoint labelled
146+
with its node index.
129147
"""
130-
gv_graph = self.render(hugr)
148+
gv_graph = self.render(hugr, root=root)
131149
gv_graph.render(filename, format=format)
132150

133151
_FONTFACE = "monospace"
@@ -275,6 +293,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
275293
else:
276294
html_label = self._format_html_label(**label_config)
277295
graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain")
296+
self.nodes.add(node)
278297

279298
def _viz_link(
280299
self, src_port: OutPort, tgt_port: InPort, kind: Kind, graph: Digraph
@@ -302,10 +321,32 @@ def _viz_link(
302321
case _:
303322
assert_never(kind)
304323

305-
graph.edge(
306-
self._out_port_name(src_port),
307-
self._in_port_name(tgt_port),
308-
label=label,
309-
color=color,
310-
**edge_attr,
311-
)
324+
src = self._out_port_name(src_port)
325+
tgt = self._in_port_name(tgt_port)
326+
327+
unknown_src = src_port.node not in self.nodes
328+
unknown_tgt = tgt_port.node not in self.nodes
329+
if unknown_src and unknown_tgt:
330+
return
331+
if unknown_src:
332+
src = f"{src_port.node.idx}"
333+
html_label = self._format_html_label(
334+
node_back_color=self.config.palette.node,
335+
node_label=f"{src_port.node}",
336+
node_data="",
337+
border_colour=self.config.palette.background,
338+
border_width="1",
339+
)
340+
graph.node(src, label=f"<{html_label}>", shape="plain")
341+
if unknown_tgt:
342+
tgt = f"{tgt_port.node.idx}"
343+
html_label = self._format_html_label(
344+
node_back_color=self.config.palette.node,
345+
node_label=f"{tgt_port.node}",
346+
node_data="",
347+
border_colour=self.config.palette.background,
348+
border_width="1",
349+
)
350+
graph.node(tgt, label=f"<{html_label}>", shape="plain")
351+
352+
graph.edge(src, tgt, label=label, color=color, **edge_attr)

0 commit comments

Comments
 (0)