Skip to content

Commit 5eb42ea

Browse files
committed
Add option to render a subgraph of a hugr.
1 parent 9dd4af8 commit 5eb42ea

File tree

2 files changed

+54
-17
lines changed

2 files changed

+54
-17
lines changed

hugr-py/src/hugr/hugr/base.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,21 +1152,32 @@ def load_json(cls, json_str: str) -> Hugr:
11521152
serial = SerialHugr.load_json(json_dict)
11531153
return cls._from_serial(serial)
11541154

1155-
def render_dot(self, config: RenderConfig | None = None) -> gv.Digraph:
1155+
def render_dot(
1156+
self, config: RenderConfig | None = None, root: Node | None = None
1157+
) -> gv.Digraph:
11561158
"""Render the HUGR to a graphviz Digraph.
11571159
11581160
Args:
11591161
config: Render configuration.
1162+
root: Root node defining the set of nodes to render. By default this is the
1163+
module root and all nodes are rendered. If this is a container node, all
1164+
nodes under it are rendered. Every incoming edge to the rendered set and
1165+
outgoing edge from it is also shown, with its other endpoint labelled
1166+
with its node index.
11601167
11611168
Returns:
11621169
The graphviz Digraph.
11631170
"""
11641171
from .render import DotRenderer
11651172

1166-
return DotRenderer(config).render(self)
1173+
return DotRenderer(config).render(self, root)
11671174

11681175
def store_dot(
1169-
self, filename: str, format: str = "svg", config: RenderConfig | None = None
1176+
self,
1177+
filename: str,
1178+
format: str = "svg",
1179+
config: RenderConfig | None = None,
1180+
root: Node | None = None,
11701181
) -> None:
11711182
"""Render the HUGR to a graphviz dot file.
11721183
@@ -1175,7 +1186,12 @@ def store_dot(
11751186
format: The format used for rendering ('pdf', 'png', etc.).
11761187
Defaults to SVG.
11771188
config: Render configuration.
1189+
root: Root node defining the set of nodes to render. By default this is the
1190+
module root and all nodes are rendered. If this is a container node, all
1191+
nodes under it are rendered. Every incoming edge to the rendered set and
1192+
outgoing edge from it is also shown, with its other endpoint labelled
1193+
with its node index.
11781194
"""
11791195
from .render import DotRenderer
11801196

1181-
DotRenderer(config).store(self, filename=filename, format=format)
1197+
DotRenderer(config).store(self, filename=filename, format=format, root=root)

hugr-py/src/hugr/hugr/render.py

Lines changed: 34 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,42 +92,60 @@ class DotRenderer:
9292

9393
def __init__(self, config: RenderConfig | None = None) -> None:
9494
self.config = config or RenderConfig()
95+
self.nodes = set()
9596

96-
def render(self, hugr: Hugr) -> Digraph:
97-
"""Render a HUGR to a graphviz dot object."""
97+
def render(self, hugr: Hugr, root: Node | None = None) -> Digraph:
98+
"""Render a HUGR to a graphviz dot object.
99+
100+
Args:
101+
hugr: The HUGR to render.
102+
root: Root node defining the set of nodes to render. By default this is the
103+
module root and all nodes are rendered. If this is a container node, all
104+
nodes under it are rendered. Every incoming edge to the rendered set and
105+
outgoing edge from it is also shown, with its other endpoint labelled
106+
with its node index.
107+
"""
108+
root = root or hugr.module_root
98109
graph_attr = {
99110
"rankdir": "",
100111
"ranksep": "0.1",
101112
"nodesep": "0.15",
102113
"margin": "0",
103114
"bgcolor": self.config.palette.background,
104115
}
105-
if name := hugr[hugr.module_root].metadata.get("name", None):
116+
if name := hugr[root].metadata.get("name", None):
106117
name = html.escape(str(name))
107118
else:
108119
name = ""
109120

110121
graph = gv.Digraph(name, strict=False)
111122
graph.attr(**graph_attr)
112123

113-
self._viz_node(hugr.module_root, hugr, graph)
124+
self._viz_node(root, hugr, graph)
114125

115126
for src_port, tgt_port in hugr.links():
116127
kind = hugr.port_kind(src_port)
117128
self._viz_link(src_port, tgt_port, kind, graph)
118129

119130
return graph
120131

121-
def store(self, hugr: Hugr, filename: str, format: str = "svg") -> None:
132+
def store(
133+
self, hugr: Hugr, filename: str, format: str = "svg", root: Node | None = None
134+
) -> None:
122135
"""Render a HUGR and save it to a file.
123136
124137
Args:
125138
hugr: The HUGR to render.
126139
filename: Filename for saving the rendered graph.
127140
format: The format used for rendering ('pdf', 'png', etc.).
128141
Defaults to SVG.
142+
root: Root node defining the set of nodes to render. By default this is the
143+
module root and all nodes are rendered. If this is a container node, all
144+
nodes under it are rendered. Every incoming edge to the rendered set and
145+
outgoing edge from it is also shown, with its other endpoint labelled
146+
with its node index.
129147
"""
130-
gv_graph = self.render(hugr)
148+
gv_graph = self.render(hugr, root=root)
131149
gv_graph.render(filename, format=format)
132150

133151
_FONTFACE = "monospace"
@@ -266,6 +284,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
266284
self._viz_node(child, hugr, sub)
267285
html_label = self._format_html_label(**label_config)
268286
sub.node(f"{node.idx}", shape="plain", label=f"<{html_label}>")
287+
self.nodes.add(node)
269288
sub.attr(
270289
label="",
271290
margin="10",
@@ -275,6 +294,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
275294
else:
276295
html_label = self._format_html_label(**label_config)
277296
graph.node(f"{node.idx}", label=f"<{html_label}>", shape="plain")
297+
self.nodes.add(node)
278298

279299
def _viz_link(
280300
self, src_port: OutPort, tgt_port: InPort, kind: Kind, graph: Digraph
@@ -302,10 +322,11 @@ def _viz_link(
302322
case _:
303323
assert_never(kind)
304324

305-
graph.edge(
306-
self._out_port_name(src_port),
307-
self._in_port_name(tgt_port),
308-
label=label,
309-
color=color,
310-
**edge_attr,
311-
)
325+
if src_port.node in self.nodes or tgt_port.node in self.nodes:
326+
graph.edge(
327+
self._out_port_name(src_port),
328+
self._in_port_name(tgt_port),
329+
label=label,
330+
color=color,
331+
**edge_attr,
332+
)

0 commit comments

Comments
 (0)