Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ def new_nested(
"""
new = cls.__new__(cls)

try:
num_outs = parent_op.num_out
except ops.IncompleteOp:
num_outs = None

new.hugr = hugr
new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint)
new.parent_node = hugr.add_node(
parent_op, parent or hugr.entrypoint, num_outs=num_outs
)
new._init_io_nodes(parent_op)
return new

Expand Down Expand Up @@ -205,7 +212,14 @@ def add_op(
>>> dfg.add_op(ops.Noop(), dfg.inputs()[0])
Node(3)
"""
new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata)
try:
num_outs = op.num_out
except ops.IncompleteOp:
num_outs = None

new_n = self.hugr.add_node(
op, self.parent_node, metadata=metadata, num_outs=num_outs
)
self._wire_up(new_n, args)
new_n._num_out_ports = op.num_out
return new_n
Expand Down
90 changes: 66 additions & 24 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@
Conditional,
Const,
Custom,
DataflowBlock,
DataflowOp,
ExitBlock,
FuncDefn,
IncompleteOp,
Module,
Op,
is_dataflow_op,
)
from hugr.tys import Kind, Type, ValueKind
from hugr.tys import Kind, OrderKind, Type, ValueKind
from hugr.utils import BiMap
from hugr.val import Value

Expand Down Expand Up @@ -149,7 +152,9 @@
case None | Module():
pass
case ops.FuncDefn():
self.entrypoint = self.add_node(entrypoint_op, self.module_root)
self.entrypoint = self.add_node(
entrypoint_op, self.module_root, num_outs=1
)
case _:
from hugr.build import Function

Expand Down Expand Up @@ -542,6 +547,12 @@
"""
source = src.out(-1)
target = dst.inp(-1)
assert (
self.port_kind(source) == OrderKind()
), f"Operation {self[src].op.name()} does not support order edges"
assert (
self.port_kind(target) == OrderKind()
), f"Operation {self[dst].op.name()} does not support order edges"
if not self.has_link(source, target):
self.add_link(source, target)

Expand Down Expand Up @@ -587,6 +598,8 @@
Not necessarily the number of connected ports - if port `i` is
connected, then all ports `0..i` are assumed to exist.

This value includes order ports.

Args:
node: Node to query.
direction: Direction of ports to count.
Expand All @@ -596,6 +609,7 @@
>>> n1 = h.add_const(val.TRUE)
>>> n2 = h.add_const(val.FALSE)
>>> h.add_link(n1.out(0), n2.inp(2)) # not a valid link!
>>> h.add_order_link(n1, n2)
>>> h.num_ports(n1, Direction.OUTGOING)
1
>>> h.num_ports(n2, Direction.INCOMING)
Expand All @@ -608,11 +622,17 @@
)

def num_in_ports(self, node: ToNode) -> int:
"""The number of incoming ports of a node. See :meth:`num_ports`."""
"""The number of incoming ports of a node. See :meth:`num_ports`.

This value does not include order ports.
"""
return self[node]._num_inps

def num_out_ports(self, node: ToNode) -> int:
"""The number of outgoing ports of a node. See :meth:`num_ports`."""
"""The number of outgoing ports of a node. See :meth:`num_ports`.

This value cound does not include order ports.
"""
return self[node]._num_outs

def _linked_ports(
Expand Down Expand Up @@ -694,9 +714,16 @@
port = cast("P", node.port(offset, direction))
yield port, list(self._linked_ports(port, links))

order_port = cast("P", node.port(-1, direction))
linked_order = list(self._linked_ports(order_port, links))
if linked_order:
yield order_port, linked_order

def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]:
"""Iterator over outgoing links from a given node.

This number includes order ports.

Args:
node: Node to query.

Expand All @@ -708,14 +735,17 @@
>>> df = dfg.Dfg()
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
>>> df.hugr.add_order_link(df.input_node, df.output_node)
>>> list(df.hugr.outgoing_links(df.input_node))
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)])]
"""
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)]), (OutPort(Node(5), -1), [InPort(Node(6), -1)])]
""" # noqa: E501
return self._node_links(node, self._links.fwd)

def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]:
"""Iterator over incoming links to a given node.

This number includes order ports.

Args:
node: Node to query.

Expand All @@ -727,8 +757,9 @@
>>> df = dfg.Dfg()
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
>>> df.hugr.add_order_link(df.input_node, df.output_node)
>>> list(df.hugr.incoming_links(df.output_node))
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)])]
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)]), (InPort(Node(6), -1), [OutPort(Node(5), -1)])]
""" # noqa: E501
return self._node_links(node, self._links.bck)

Expand Down Expand Up @@ -810,7 +841,7 @@
>>> df.hugr.num_incoming(df.output_node)
1
"""
return sum(1 for _ in self.incoming_links(node))
return sum(len(links) for (_, links) in self.incoming_links(node))

def num_outgoing(self, node: ToNode) -> int:
"""The number of outgoing links from a `node`.
Expand All @@ -821,7 +852,7 @@
>>> df.hugr.num_outgoing(df.input_node)
1
"""
return sum(1 for _ in self.outgoing_links(node))
return sum(len(links) for (_, links) in self.outgoing_links(node))

# TODO: num_links and _linked_ports

Expand Down Expand Up @@ -906,7 +937,9 @@

def _serialize_link(
link: tuple[_SO, _SI],
) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]:
) -> tuple[
tuple[NodeIdx, PortOffset | None], tuple[NodeIdx, PortOffset | None]
]:
src, dst = link
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
return (src.port.node.idx, s), (dst.port.node.idx, d)
Expand All @@ -933,16 +966,16 @@
entrypoint=entrypoint,
)

def _constrain_offset(self, p: P) -> PortOffset:
# An offset of -1 is a special case, indicating an order edge,
# not counted in the number of ports.
def _constrain_offset(self, p: P) -> PortOffset | None:
"""Constrain an offset to be a valid encoded port offset.

Order edges and control flow edges should be encoded without an offset.
"""
if p.offset < 0:
assert p.offset == -1, "Only order edges are allowed with offset < 0"
offset = self.num_ports(p.node, p.direction)
return None
else:
offset = p.offset

return offset
return p.offset

def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
"""Resolve extension types and operations in the HUGR by matching them to
Expand All @@ -962,7 +995,7 @@
"""
from hugr.build import Function

if not isinstance(self.entrypoint_op(), DataflowOp):
if not is_dataflow_op(self.entrypoint_op()):
return

func_node = self[self.entrypoint].parent
Expand Down Expand Up @@ -1007,9 +1040,8 @@
parent: Node | None = Node(serial_node.root.parent)

serial_node.root.parent = -1
n = hugr._add_node(
serial_node.root.deserialize(), parent, metadata=node_meta
)
op = serial_node.root.deserialize()
n = hugr._add_node(op, parent, metadata=node_meta, num_outs=op.num_out)
assert (
n.idx == idx + boilerplate_nodes
), "Nodes should be added contiguously"
Expand All @@ -1018,11 +1050,21 @@
hugr.entrypoint = n

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
src = Node(src_node, _metadata=get_meta(src_node))
dst = Node(dst_node, _metadata=get_meta(dst_node))
if src_offset is None or dst_offset is None:
continue
src_op = hugr[src].op
if isinstance(src_op, DataflowBlock | ExitBlock):
# Control flow edge
src_offset = 0
dst_offset = 0

Check warning on line 1060 in hugr-py/src/hugr/hugr/base.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/hugr/base.py#L1059-L1060

Added lines #L1059 - L1060 were not covered by tests
else:
# Order edge
hugr.add_order_link(src, dst)
continue
hugr.add_link(
Node(src_node, _metadata=get_meta(src_node)).out(src_offset),
Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset),
src.out(src_offset),
dst.inp(dst_offset),
)

return hugr
Expand Down
Loading
Loading