Skip to content

Commit ce5a763

Browse files
authored
test: Better python roundtrip tests, and lots of fixes (#2436)
The python validation tests did a roundtrip serialization check by encoding the HUGR, loading it, encoding it again, and checking for differences in the serialization. This ignored any information loss in the first encoding, and resulted in a lot of hidden bugs. This PR changes the test in `conftest.py` to instead compare the original HUGR and the loaded one directly, using a "node hash" that computes the main properties of each node and its children in an index-independent way. (we do not traverse graph edges to avoid having a graph isomorphism problem). The node hash is defined as follows, and should be easily extensible if needed. We compare hugrs by checking the hashes of their root modules. ```python @DataClass class _NodeHash: op: str entrypoint: bool input_neighbours: int output_neighbours: int input_ports: int output_ports: int input_order_edges: int output_order_edges: int is_region: bool node_depth: int children_hashes: list[_NodeHash] # sorted metadata: dict[str, str] ``` This revealed a bunch of bugs with the json serialization, hugr builders, node iterators, ...: - **Order edges were serialized incorrectly**. The json encoding uses `null` for order edges, but python emitted `#ports` instead (and this caused problems when combined with the out port inconsistencies below). - Standardize string/repr formatting for specialized types and values, so e.g. after roundtriping a `Some(TRUE)` it still shows as `Some` instead of `Sum(tag=1, tys=[[], [Bool]], val=[TRUE])`. - `Hugr.num_outgoing` and `num_incoming` were counting ports instead of edges... - Fixed many inconsistencies with the number of output ports in `FuncDefn`, `Case`, and the likes. - The output port count wasn't being set in a lot of cases (this worked fine in the tests because adding an edge auto-allocates the ports). - Order edges were inconsistently reported in the neighbours / links iterators. I fixed it and added new tests. - Probably more bugfixes I'm missing. The roundtrip checks also have options for checking all the format combinations (one variable for the encoding format and another for converting it with `hugr convert` before loading if necessary). This only does `json-json` for the moment, as `hugr-model` detects multiple errors.
1 parent 2e11266 commit ce5a763

File tree

11 files changed

+839
-95
lines changed

11 files changed

+839
-95
lines changed

hugr-py/src/hugr/build/dfg.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,15 @@ def new_nested(
132132
"""
133133
new = cls.__new__(cls)
134134

135+
try:
136+
num_outs = parent_op.num_out
137+
except ops.IncompleteOp:
138+
num_outs = None
139+
135140
new.hugr = hugr
136-
new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint)
141+
new.parent_node = hugr.add_node(
142+
parent_op, parent or hugr.entrypoint, num_outs=num_outs
143+
)
137144
new._init_io_nodes(parent_op)
138145
return new
139146

@@ -205,7 +212,14 @@ def add_op(
205212
>>> dfg.add_op(ops.Noop(), dfg.inputs()[0])
206213
Node(3)
207214
"""
208-
new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata)
215+
try:
216+
num_outs = op.num_out
217+
except ops.IncompleteOp:
218+
num_outs = None
219+
220+
new_n = self.hugr.add_node(
221+
op, self.parent_node, metadata=metadata, num_outs=num_outs
222+
)
209223
self._wire_up(new_n, args)
210224
new_n._num_out_ports = op.num_out
211225
return new_n
@@ -732,7 +746,6 @@ def declare_outputs(self, output_types: TypeRow) -> None:
732746
defined yet. The wires passed to :meth:`set_outputs` must match the
733747
declared output types.
734748
"""
735-
self._set_parent_output_count(len(output_types))
736749
self.parent_op._set_out_types(output_types)
737750

738751
def set_outputs(self, *args: Wire) -> None:

hugr-py/src/hugr/hugr/base.py

Lines changed: 71 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,16 @@
3535
Conditional,
3636
Const,
3737
Custom,
38+
DataflowBlock,
3839
DataflowOp,
40+
ExitBlock,
3941
FuncDefn,
4042
IncompleteOp,
4143
Module,
4244
Op,
45+
is_dataflow_op,
4346
)
44-
from hugr.tys import Kind, Type, ValueKind
47+
from hugr.tys import Kind, OrderKind, Type, ValueKind
4548
from hugr.utils import BiMap
4649
from hugr.val import Value
4750

@@ -149,7 +152,9 @@ def __init__(self, entrypoint_op: OpVarCov | None = None) -> None:
149152
case None | Module():
150153
pass
151154
case ops.FuncDefn():
152-
self.entrypoint = self.add_node(entrypoint_op, self.module_root)
155+
self.entrypoint = self.add_node(
156+
entrypoint_op, self.module_root, num_outs=1
157+
)
153158
case _:
154159
from hugr.build import Function
155160

@@ -542,6 +547,12 @@ def add_order_link(self, src: ToNode, dst: ToNode) -> None:
542547
"""
543548
source = src.out(-1)
544549
target = dst.inp(-1)
550+
assert (
551+
self.port_kind(source) == OrderKind()
552+
), f"Operation {self[src].op.name()} does not support order edges"
553+
assert (
554+
self.port_kind(target) == OrderKind()
555+
), f"Operation {self[dst].op.name()} does not support order edges"
545556
if not self.has_link(source, target):
546557
self.add_link(source, target)
547558

@@ -587,15 +598,20 @@ def num_ports(self, node: ToNode, direction: Direction) -> int:
587598
Not necessarily the number of connected ports - if port `i` is
588599
connected, then all ports `0..i` are assumed to exist.
589600
601+
This value includes order ports.
602+
590603
Args:
591604
node: Node to query.
592605
direction: Direction of ports to count.
593606
594607
Examples:
608+
>>> from hugr.std.logic import Not
595609
>>> h = Hugr()
596-
>>> n1 = h.add_const(val.TRUE)
597-
>>> n2 = h.add_const(val.FALSE)
598-
>>> h.add_link(n1.out(0), n2.inp(2)) # not a valid link!
610+
>>> n1 = h.add_node(Not)
611+
>>> n2 = h.add_node(Not)
612+
>>> # Passing offset `2` here allocates new ports automatically
613+
>>> h.add_link(n1.out(0), n2.inp(2))
614+
>>> h.add_order_link(n1, n2)
599615
>>> h.num_ports(n1, Direction.OUTGOING)
600616
1
601617
>>> h.num_ports(n2, Direction.INCOMING)
@@ -608,11 +624,17 @@ def num_ports(self, node: ToNode, direction: Direction) -> int:
608624
)
609625

610626
def num_in_ports(self, node: ToNode) -> int:
611-
"""The number of incoming ports of a node. See :meth:`num_ports`."""
627+
"""The number of incoming ports of a node. See :meth:`num_ports`.
628+
629+
This value does not include order ports.
630+
"""
612631
return self[node]._num_inps
613632

614633
def num_out_ports(self, node: ToNode) -> int:
615-
"""The number of outgoing ports of a node. See :meth:`num_ports`."""
634+
"""The number of outgoing ports of a node. See :meth:`num_ports`.
635+
636+
This value cound does not include order ports.
637+
"""
616638
return self[node]._num_outs
617639

618640
def _linked_ports(
@@ -694,9 +716,16 @@ def _node_links(
694716
port = cast("P", node.port(offset, direction))
695717
yield port, list(self._linked_ports(port, links))
696718

719+
order_port = cast("P", node.port(-1, direction))
720+
linked_order = list(self._linked_ports(order_port, links))
721+
if linked_order:
722+
yield order_port, linked_order
723+
697724
def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]:
698725
"""Iterator over outgoing links from a given node.
699726
727+
This number includes order ports.
728+
700729
Args:
701730
node: Node to query.
702731
@@ -708,14 +737,17 @@ def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]
708737
>>> df = dfg.Dfg()
709738
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
710739
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
740+
>>> df.hugr.add_order_link(df.input_node, df.output_node)
711741
>>> list(df.hugr.outgoing_links(df.input_node))
712-
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)])]
713-
"""
742+
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)]), (OutPort(Node(5), -1), [InPort(Node(6), -1)])]
743+
""" # noqa: E501
714744
return self._node_links(node, self._links.fwd)
715745

716746
def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]:
717747
"""Iterator over incoming links to a given node.
718748
749+
This number includes order ports.
750+
719751
Args:
720752
node: Node to query.
721753
@@ -727,8 +759,9 @@ def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]
727759
>>> df = dfg.Dfg()
728760
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
729761
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
762+
>>> df.hugr.add_order_link(df.input_node, df.output_node)
730763
>>> list(df.hugr.incoming_links(df.output_node))
731-
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)])]
764+
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)]), (InPort(Node(6), -1), [OutPort(Node(5), -1)])]
732765
""" # noqa: E501
733766
return self._node_links(node, self._links.bck)
734767

@@ -810,7 +843,7 @@ def num_incoming(self, node: Node) -> int:
810843
>>> df.hugr.num_incoming(df.output_node)
811844
1
812845
"""
813-
return sum(1 for _ in self.incoming_links(node))
846+
return sum(len(links) for (_, links) in self.incoming_links(node))
814847

815848
def num_outgoing(self, node: ToNode) -> int:
816849
"""The number of outgoing links from a `node`.
@@ -821,7 +854,7 @@ def num_outgoing(self, node: ToNode) -> int:
821854
>>> df.hugr.num_outgoing(df.input_node)
822855
1
823856
"""
824-
return sum(1 for _ in self.outgoing_links(node))
857+
return sum(len(links) for (_, links) in self.outgoing_links(node))
825858

826859
# TODO: num_links and _linked_ports
827860

@@ -906,7 +939,9 @@ def _to_serial(self) -> SerialHugr:
906939

907940
def _serialize_link(
908941
link: tuple[_SO, _SI],
909-
) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]:
942+
) -> tuple[
943+
tuple[NodeIdx, PortOffset | None], tuple[NodeIdx, PortOffset | None]
944+
]:
910945
src, dst = link
911946
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
912947
return (src.port.node.idx, s), (dst.port.node.idx, d)
@@ -933,16 +968,16 @@ def _serialize_link(
933968
entrypoint=entrypoint,
934969
)
935970

936-
def _constrain_offset(self, p: P) -> PortOffset:
937-
# An offset of -1 is a special case, indicating an order edge,
938-
# not counted in the number of ports.
971+
def _constrain_offset(self, p: P) -> PortOffset | None:
972+
"""Constrain an offset to be a valid encoded port offset.
973+
974+
Order edges and control flow edges should be encoded without an offset.
975+
"""
939976
if p.offset < 0:
940977
assert p.offset == -1, "Only order edges are allowed with offset < 0"
941-
offset = self.num_ports(p.node, p.direction)
978+
return None
942979
else:
943-
offset = p.offset
944-
945-
return offset
980+
return p.offset
946981

947982
def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
948983
"""Resolve extension types and operations in the HUGR by matching them to
@@ -962,7 +997,7 @@ def _connect_df_entrypoint_outputs(self) -> None:
962997
"""
963998
from hugr.build import Function
964999

965-
if not isinstance(self.entrypoint_op(), DataflowOp):
1000+
if not is_dataflow_op(self.entrypoint_op()):
9661001
return
9671002

9681003
func_node = self[self.entrypoint].parent
@@ -1007,9 +1042,8 @@ def get_meta(idx: int) -> dict[str, Any]:
10071042
parent: Node | None = Node(serial_node.root.parent)
10081043

10091044
serial_node.root.parent = -1
1010-
n = hugr._add_node(
1011-
serial_node.root.deserialize(), parent, metadata=node_meta
1012-
)
1045+
op = serial_node.root.deserialize()
1046+
n = hugr._add_node(op, parent, metadata=node_meta, num_outs=op.num_out)
10131047
assert (
10141048
n.idx == idx + boilerplate_nodes
10151049
), "Nodes should be added contiguously"
@@ -1018,11 +1052,21 @@ def get_meta(idx: int) -> dict[str, Any]:
10181052
hugr.entrypoint = n
10191053

10201054
for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
1055+
src = Node(src_node, _metadata=get_meta(src_node))
1056+
dst = Node(dst_node, _metadata=get_meta(dst_node))
10211057
if src_offset is None or dst_offset is None:
1022-
continue
1058+
src_op = hugr[src].op
1059+
if isinstance(src_op, DataflowBlock | ExitBlock):
1060+
# Control flow edge
1061+
src_offset = 0
1062+
dst_offset = 0
1063+
else:
1064+
# Order edge
1065+
hugr.add_order_link(src, dst)
1066+
continue
10231067
hugr.add_link(
1024-
Node(src_node, _metadata=get_meta(src_node)).out(src_offset),
1025-
Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset),
1068+
src.out(src_offset),
1069+
dst.inp(dst_offset),
10261070
)
10271071

10281072
return hugr

0 commit comments

Comments
 (0)