Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions hugr-py/src/hugr/build/dfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,15 @@ def new_nested(
"""
new = cls.__new__(cls)

try:
num_outs = parent_op.num_out
except ops.IncompleteOp:
num_outs = None

new.hugr = hugr
new.parent_node = hugr.add_node(parent_op, parent or hugr.entrypoint)
new.parent_node = hugr.add_node(
parent_op, parent or hugr.entrypoint, num_outs=num_outs
)
new._init_io_nodes(parent_op)
return new

Expand Down Expand Up @@ -205,7 +212,14 @@ def add_op(
>>> dfg.add_op(ops.Noop(), dfg.inputs()[0])
Node(3)
"""
new_n = self.hugr.add_node(op, self.parent_node, metadata=metadata)
try:
num_outs = op.num_out
except ops.IncompleteOp:
num_outs = None

new_n = self.hugr.add_node(
op, self.parent_node, metadata=metadata, num_outs=num_outs
)
self._wire_up(new_n, args)
new_n._num_out_ports = op.num_out
return new_n
Expand Down Expand Up @@ -732,7 +746,6 @@ def declare_outputs(self, output_types: TypeRow) -> None:
defined yet. The wires passed to :meth:`set_outputs` must match the
declared output types.
"""
self._set_parent_output_count(len(output_types))
self.parent_op._set_out_types(output_types)

def set_outputs(self, *args: Wire) -> None:
Expand Down
98 changes: 71 additions & 27 deletions hugr-py/src/hugr/hugr/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,16 @@
Conditional,
Const,
Custom,
DataflowBlock,
DataflowOp,
ExitBlock,
FuncDefn,
IncompleteOp,
Module,
Op,
is_dataflow_op,
)
from hugr.tys import Kind, Type, ValueKind
from hugr.tys import Kind, OrderKind, Type, ValueKind
from hugr.utils import BiMap
from hugr.val import Value

Expand Down Expand Up @@ -149,7 +152,9 @@ def __init__(self, entrypoint_op: OpVarCov | None = None) -> None:
case None | Module():
pass
case ops.FuncDefn():
self.entrypoint = self.add_node(entrypoint_op, self.module_root)
self.entrypoint = self.add_node(
entrypoint_op, self.module_root, num_outs=1
)
case _:
from hugr.build import Function

Expand Down Expand Up @@ -542,6 +547,12 @@ def add_order_link(self, src: ToNode, dst: ToNode) -> None:
"""
source = src.out(-1)
target = dst.inp(-1)
assert (
self.port_kind(source) == OrderKind()
), f"Operation {self[src].op.name()} does not support order edges"
assert (
self.port_kind(target) == OrderKind()
), f"Operation {self[dst].op.name()} does not support order edges"
if not self.has_link(source, target):
self.add_link(source, target)

Expand Down Expand Up @@ -587,15 +598,20 @@ def num_ports(self, node: ToNode, direction: Direction) -> int:
Not necessarily the number of connected ports - if port `i` is
connected, then all ports `0..i` are assumed to exist.

This value includes order ports.

Args:
node: Node to query.
direction: Direction of ports to count.

Examples:
>>> from hugr.std.logic import Not
>>> h = Hugr()
>>> n1 = h.add_const(val.TRUE)
>>> n2 = h.add_const(val.FALSE)
>>> h.add_link(n1.out(0), n2.inp(2)) # not a valid link!
>>> n1 = h.add_node(Not)
>>> n2 = h.add_node(Not)
>>> # Passing offset `2` here allocates new ports automatically
>>> h.add_link(n1.out(0), n2.inp(2))
>>> h.add_order_link(n1, n2)
>>> h.num_ports(n1, Direction.OUTGOING)
1
>>> h.num_ports(n2, Direction.INCOMING)
Expand All @@ -608,11 +624,17 @@ def num_ports(self, node: ToNode, direction: Direction) -> int:
)

def num_in_ports(self, node: ToNode) -> int:
"""The number of incoming ports of a node. See :meth:`num_ports`."""
"""The number of incoming ports of a node. See :meth:`num_ports`.

This value does not include order ports.
"""
return self[node]._num_inps

def num_out_ports(self, node: ToNode) -> int:
"""The number of outgoing ports of a node. See :meth:`num_ports`."""
"""The number of outgoing ports of a node. See :meth:`num_ports`.

This value cound does not include order ports.
"""
return self[node]._num_outs

def _linked_ports(
Expand Down Expand Up @@ -694,9 +716,16 @@ def _node_links(
port = cast("P", node.port(offset, direction))
yield port, list(self._linked_ports(port, links))

order_port = cast("P", node.port(-1, direction))
linked_order = list(self._linked_ports(order_port, links))
if linked_order:
yield order_port, linked_order

def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]:
"""Iterator over outgoing links from a given node.

This number includes order ports.

Args:
node: Node to query.

Expand All @@ -708,14 +737,17 @@ def outgoing_links(self, node: ToNode) -> Iterable[tuple[OutPort, list[InPort]]]
>>> df = dfg.Dfg()
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
>>> df.hugr.add_order_link(df.input_node, df.output_node)
>>> list(df.hugr.outgoing_links(df.input_node))
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)])]
"""
[(OutPort(Node(5), 0), [InPort(Node(6), 0), InPort(Node(6), 1)]), (OutPort(Node(5), -1), [InPort(Node(6), -1)])]
""" # noqa: E501
return self._node_links(node, self._links.fwd)

def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]:
"""Iterator over incoming links to a given node.

This number includes order ports.

Args:
node: Node to query.

Expand All @@ -727,8 +759,9 @@ def incoming_links(self, node: ToNode) -> Iterable[tuple[InPort, list[OutPort]]]
>>> df = dfg.Dfg()
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(0))
>>> df.hugr.add_link(df.input_node.out(0), df.output_node.inp(1))
>>> df.hugr.add_order_link(df.input_node, df.output_node)
>>> list(df.hugr.incoming_links(df.output_node))
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)])]
[(InPort(Node(6), 0), [OutPort(Node(5), 0)]), (InPort(Node(6), 1), [OutPort(Node(5), 0)]), (InPort(Node(6), -1), [OutPort(Node(5), -1)])]
""" # noqa: E501
return self._node_links(node, self._links.bck)

Expand Down Expand Up @@ -810,7 +843,7 @@ def num_incoming(self, node: Node) -> int:
>>> df.hugr.num_incoming(df.output_node)
1
"""
return sum(1 for _ in self.incoming_links(node))
return sum(len(links) for (_, links) in self.incoming_links(node))

def num_outgoing(self, node: ToNode) -> int:
"""The number of outgoing links from a `node`.
Expand All @@ -821,7 +854,7 @@ def num_outgoing(self, node: ToNode) -> int:
>>> df.hugr.num_outgoing(df.input_node)
1
"""
return sum(1 for _ in self.outgoing_links(node))
return sum(len(links) for (_, links) in self.outgoing_links(node))

# TODO: num_links and _linked_ports

Expand Down Expand Up @@ -906,7 +939,9 @@ def _to_serial(self) -> SerialHugr:

def _serialize_link(
link: tuple[_SO, _SI],
) -> tuple[tuple[NodeIdx, PortOffset], tuple[NodeIdx, PortOffset]]:
) -> tuple[
tuple[NodeIdx, PortOffset | None], tuple[NodeIdx, PortOffset | None]
]:
src, dst = link
s, d = self._constrain_offset(src.port), self._constrain_offset(dst.port)
return (src.port.node.idx, s), (dst.port.node.idx, d)
Expand All @@ -933,16 +968,16 @@ def _serialize_link(
entrypoint=entrypoint,
)

def _constrain_offset(self, p: P) -> PortOffset:
# An offset of -1 is a special case, indicating an order edge,
# not counted in the number of ports.
def _constrain_offset(self, p: P) -> PortOffset | None:
"""Constrain an offset to be a valid encoded port offset.

Order edges and control flow edges should be encoded without an offset.
"""
if p.offset < 0:
assert p.offset == -1, "Only order edges are allowed with offset < 0"
offset = self.num_ports(p.node, p.direction)
return None
else:
offset = p.offset

return offset
return p.offset

def resolve_extensions(self, registry: ext.ExtensionRegistry) -> Hugr:
"""Resolve extension types and operations in the HUGR by matching them to
Expand All @@ -962,7 +997,7 @@ def _connect_df_entrypoint_outputs(self) -> None:
"""
from hugr.build import Function

if not isinstance(self.entrypoint_op(), DataflowOp):
if not is_dataflow_op(self.entrypoint_op()):
return

func_node = self[self.entrypoint].parent
Expand Down Expand Up @@ -1007,9 +1042,8 @@ def get_meta(idx: int) -> dict[str, Any]:
parent: Node | None = Node(serial_node.root.parent)

serial_node.root.parent = -1
n = hugr._add_node(
serial_node.root.deserialize(), parent, metadata=node_meta
)
op = serial_node.root.deserialize()
n = hugr._add_node(op, parent, metadata=node_meta, num_outs=op.num_out)
assert (
n.idx == idx + boilerplate_nodes
), "Nodes should be added contiguously"
Expand All @@ -1018,11 +1052,21 @@ def get_meta(idx: int) -> dict[str, Any]:
hugr.entrypoint = n

for (src_node, src_offset), (dst_node, dst_offset) in serial.edges:
src = Node(src_node, _metadata=get_meta(src_node))
dst = Node(dst_node, _metadata=get_meta(dst_node))
if src_offset is None or dst_offset is None:
continue
src_op = hugr[src].op
if isinstance(src_op, DataflowBlock | ExitBlock):
# Control flow edge
src_offset = 0
dst_offset = 0
else:
# Order edge
hugr.add_order_link(src, dst)
continue
hugr.add_link(
Node(src_node, _metadata=get_meta(src_node)).out(src_offset),
Node(dst_node, _metadata=get_meta(dst_node)).inp(dst_offset),
src.out(src_offset),
dst.inp(dst_offset),
)

return hugr
Expand Down
Loading
Loading