@@ -92,42 +92,60 @@ class DotRenderer:
9292
9393 def __init__ (self , config : RenderConfig | None = None ) -> None :
9494 self .config = config or RenderConfig ()
95+ self .nodes : set [Node ] = set ()
9596
96- def render (self , hugr : Hugr ) -> Digraph :
97- """Render a HUGR to a graphviz dot object."""
97+ def render (self , hugr : Hugr , root : Node | None = None ) -> Digraph :
98+ """Render a HUGR to a graphviz dot object.
99+
100+ Args:
101+ hugr: The HUGR to render.
102+ root: Root node defining the set of nodes to render. By default this is the
103+ module root and all nodes are rendered. If this is a container node, all
104+ nodes under it are rendered. Every incoming edge to the rendered set and
105+ outgoing edge from it is also shown, with its other endpoint labelled
106+ with its node index.
107+ """
108+ root = root or hugr .module_root
98109 graph_attr = {
99110 "rankdir" : "" ,
100111 "ranksep" : "0.1" ,
101112 "nodesep" : "0.15" ,
102113 "margin" : "0" ,
103114 "bgcolor" : self .config .palette .background ,
104115 }
105- if name := hugr [hugr . module_root ].metadata .get ("name" , None ):
116+ if name := hugr [root ].metadata .get ("name" , None ):
106117 name = html .escape (str (name ))
107118 else :
108119 name = ""
109120
110121 graph = gv .Digraph (name , strict = False )
111122 graph .attr (** graph_attr )
112123
113- self ._viz_node (hugr . module_root , hugr , graph )
124+ self ._viz_node (root , hugr , graph )
114125
115126 for src_port , tgt_port in hugr .links ():
116127 kind = hugr .port_kind (src_port )
117128 self ._viz_link (src_port , tgt_port , kind , graph )
118129
119130 return graph
120131
121- def store (self , hugr : Hugr , filename : str , format : str = "svg" ) -> None :
132+ def store (
133+ self , hugr : Hugr , filename : str , format : str = "svg" , root : Node | None = None
134+ ) -> None :
122135 """Render a HUGR and save it to a file.
123136
124137 Args:
125138 hugr: The HUGR to render.
126139 filename: Filename for saving the rendered graph.
127140 format: The format used for rendering ('pdf', 'png', etc.).
128141 Defaults to SVG.
142+ root: Root node defining the set of nodes to render. By default this is the
143+ module root and all nodes are rendered. If this is a container node, all
144+ nodes under it are rendered. Every incoming edge to the rendered set and
145+ outgoing edge from it is also shown, with its other endpoint labelled
146+ with its node index.
129147 """
130- gv_graph = self .render (hugr )
148+ gv_graph = self .render (hugr , root = root )
131149 gv_graph .render (filename , format = format )
132150
133151 _FONTFACE = "monospace"
@@ -275,6 +293,7 @@ def _viz_node(self, node: Node, hugr: Hugr, graph: Digraph) -> None:
275293 else :
276294 html_label = self ._format_html_label (** label_config )
277295 graph .node (f"{ node .idx } " , label = f"<{ html_label } >" , shape = "plain" )
296+ self .nodes .add (node )
278297
279298 def _viz_link (
280299 self , src_port : OutPort , tgt_port : InPort , kind : Kind , graph : Digraph
@@ -302,10 +321,32 @@ def _viz_link(
302321 case _:
303322 assert_never (kind )
304323
305- graph .edge (
306- self ._out_port_name (src_port ),
307- self ._in_port_name (tgt_port ),
308- label = label ,
309- color = color ,
310- ** edge_attr ,
311- )
324+ src = self ._out_port_name (src_port )
325+ tgt = self ._in_port_name (tgt_port )
326+
327+ unknown_src = src_port .node not in self .nodes
328+ unknown_tgt = tgt_port .node not in self .nodes
329+ if unknown_src and unknown_tgt :
330+ return
331+ if unknown_src :
332+ src = f"{ src_port .node .idx } "
333+ html_label = self ._format_html_label (
334+ node_back_color = self .config .palette .node ,
335+ node_label = f"{ src_port .node } " ,
336+ node_data = "" ,
337+ border_colour = self .config .palette .background ,
338+ border_width = "1" ,
339+ )
340+ graph .node (src , label = f"<{ html_label } >" , shape = "plain" )
341+ if unknown_tgt :
342+ tgt = f"{ tgt_port .node .idx } "
343+ html_label = self ._format_html_label (
344+ node_back_color = self .config .palette .node ,
345+ node_label = f"{ tgt_port .node } " ,
346+ node_data = "" ,
347+ border_colour = self .config .palette .background ,
348+ border_width = "1" ,
349+ )
350+ graph .node (tgt , label = f"<{ html_label } >" , shape = "plain" )
351+
352+ graph .edge (src , tgt , label = label , color = color , ** edge_attr )
0 commit comments