Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 39 additions & 35 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,41 +627,47 @@ impl<'a> Context<'a> {
let children = self.hugr.children(node);
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);

let mut output_node = None;

for child in children {
match self.hugr.get_optype(child) {
OpType::Input(input) => {
sources = self.make_ports(child, Direction::Outgoing, input.types.len());
input_types = Some(&input.types);

if has_order_edges(self.hugr, child) {
let key = self.make_term(model::Literal::Nat(child.index() as u64).into());
meta.push(self.make_term_apply(model::ORDER_HINT_INPUT_KEY, &[key]));
}
}
OpType::Output(output) => {
targets = self.make_ports(child, Direction::Incoming, output.types.len());
output_types = Some(&output.types);
output_node = Some(child);

if has_order_edges(self.hugr, child) {
let key = self.make_term(model::Literal::Nat(child.index() as u64).into());
meta.push(self.make_term_apply(model::ORDER_HINT_OUTPUT_KEY, &[key]));
}
}
child_optype => {
_ => {
if let Some(child_id) = self.export_node_shallow(child) {
region_children.push(child_id);

// Record all order edges that originate from this node in metadata.
let successors = child_optype
.other_output_port()
.into_iter()
.flat_map(|port| self.hugr.linked_inputs(child, port))
.map(|(successor, _)| successor)
.filter(|successor| Some(*successor) != output_node);

for successor in successors {
let a =
self.make_term(model::Literal::Nat(child.index() as u64).into());
let b = self
.make_term(model::Literal::Nat(successor.index() as u64).into());
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
}
}
}
}

// Record all order edges that originate from this node in metadata.
let successors = self
.hugr
.get_optype(child)
.other_output_port()
.into_iter()
.flat_map(|port| self.hugr.linked_inputs(child, port))
.map(|(successor, _)| successor);

for successor in successors {
let a = self.make_term(model::Literal::Nat(child.index() as u64).into());
let b = self.make_term(model::Literal::Nat(successor.index() as u64).into());
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
}
}

for child_id in &region_children {
Expand Down Expand Up @@ -1100,21 +1106,7 @@ impl<'a> Context<'a> {
}

fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec<table::TermId>) {
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
let optype = hugr.get_optype(node);
!optype.is_input() && !optype.is_output()
}

let optype = self.hugr.get_optype(node);

let has_order_edges = Direction::BOTH
.iter()
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
.filter_map(|dir| optype.other_port(*dir))
.flat_map(|port| self.hugr.linked_ports(node, port))
.any(|(other, _)| is_relevant_node(self.hugr, other));

if has_order_edges {
if has_order_edges(self.hugr, node) {
let key = self.make_term(model::Literal::Nat(node.index() as u64).into());
meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key]));
}
Expand Down Expand Up @@ -1229,6 +1221,18 @@ impl Links {
}
}

/// Returns `true` if a node has any incident order edges.
fn has_order_edges(hugr: &Hugr, node: Node) -> bool {
let optype = hugr.get_optype(node);
Direction::BOTH
.iter()
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
.filter_map(|dir| optype.other_port(*dir))
.flat_map(|port| hugr.linked_ports(node, port))
.next()
.is_some()
}

#[cfg(test)]
mod test {
use rstest::{fixture, rstest};
Expand Down
74 changes: 51 additions & 23 deletions hugr-core/src/import.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ impl From<ExtensionError> for ImportError {
enum OrderHintError {
/// Duplicate order hint key in the same region.
#[error("duplicate order hint key {0}")]
DuplicateKey(table::NodeId, u64),
DuplicateKey(table::RegionId, u64),
/// Order hint including a key not defined in the region.
#[error("order hint with unknown key {0}")]
UnknownKey(u64),
Expand Down Expand Up @@ -608,7 +608,7 @@ impl<'a> Context<'a> {
self.import_node(*child, node)?;
}

self.create_order_edges(region)?;
self.create_order_edges(region, input, output)?;

for meta_item in region_data.meta {
self.import_node_metadata(node, *meta_item)?;
Expand All @@ -622,13 +622,18 @@ impl<'a> Context<'a> {
/// Create order edges between nodes of a dataflow region based on order hint metadata.
///
/// This method assumes that the nodes for the children of the region have already been imported.
fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> {
fn create_order_edges(
&mut self,
region_id: table::RegionId,
input: Node,
output: Node,
) -> Result<(), ImportError> {
let region_data = self.get_region(region_id)?;
debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow);

// Collect order hint keys
// PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations.
let mut order_keys = FxHashMap::<u64, table::NodeId>::default();
let mut order_keys = FxHashMap::<u64, Node>::default();

for child_id in region_data.children {
let child_data = self.get_node(*child_id)?;
Expand All @@ -642,8 +647,42 @@ impl<'a> Context<'a> {
continue;
};

if order_keys.insert(*key, *child_id).is_some() {
return Err(OrderHintError::DuplicateKey(*child_id, *key).into());
// NOTE: The lookups here are expected to succeed since we only
// process the order metadata after we have imported the nodes.
let child_node = self.nodes[child_id];
let child_optype = self.hugr.get_optype(child_node);

// Check that the node has order ports.
// NOTE: This assumes that a node has an input order port iff it has an output one.
if child_optype.other_output_port().is_none() {
return Err(OrderHintError::NoOrderPort(*child_id).into());
}

if order_keys.insert(*key, child_node).is_some() {
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
}
}
}

// Collect the order hint keys for the input and output nodes
for meta_id in region_data.meta {
if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? {
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
continue;
};

if order_keys.insert(*key, input).is_some() {
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
}
}

if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? {
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
continue;
};

if order_keys.insert(*key, output).is_some() {
return Err(OrderHintError::DuplicateKey(region_id, *key).into());
}
}
}
Expand All @@ -665,24 +704,13 @@ impl<'a> Context<'a> {
let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?;
let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?;

// NOTE: The lookups here are expected to succeed since we only
// process the order metadata after we have imported the nodes.
let a_node = self.nodes[a];
let b_node = self.nodes[b];

let a_port = self
.hugr
.get_optype(a_node)
.other_output_port()
.ok_or(OrderHintError::NoOrderPort(*a))?;

let b_port = self
.hugr
.get_optype(b_node)
.other_input_port()
.ok_or(OrderHintError::NoOrderPort(*b))?;
// NOTE: The unwrap here must succeed:
// - For all ordinary nodes we checked that they have an order port.
// - Input and output nodes always have an order port.
let a_port = self.hugr.get_optype(*a).other_output_port().unwrap();
let b_port = self.hugr.get_optype(*b).other_input_port().unwrap();

self.hugr.connect(a_node, a_port, b_node, b_port);
self.hugr.connect(*a, a_port, *b, b_port);
}

Ok(())
Expand Down
16 changes: 12 additions & 4 deletions hugr-core/tests/snapshots/model__roundtrip_order.snap
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,19 @@ expression: ast

(import core.meta.description)

(import core.order_hint.input_key)

(import core.order_hint.order)

(import arithmetic.int.types.int)

(import core.nat)

(import core.order_hint.key)

(import core.fn)

(import core.order_hint.order)
(import core.order_hint.output_key)

(import arithmetic.int.types.int)
(import core.fn)

(declare-operation
arithmetic.int.ineg
Expand Down Expand Up @@ -48,9 +52,13 @@ expression: ast
(arithmetic.int.types.int 6)
(arithmetic.int.types.int 6)
(arithmetic.int.types.int 6)]))
(meta (core.order_hint.input_key 2))
(meta (core.order_hint.order 2 4))
(meta (core.order_hint.output_key 3))
(meta (core.order_hint.order 4 7))
(meta (core.order_hint.order 5 6))
(meta (core.order_hint.order 5 4))
(meta (core.order_hint.order 5 3))
(meta (core.order_hint.order 6 7))
((arithmetic.int.ineg 6) [%0] [%4]
(signature
Expand Down
20 changes: 20 additions & 0 deletions hugr-model/src/v0/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,26 @@ pub const COMPAT_CONST_JSON: &str = "compat.const_json";
/// - **Result:** `core.meta`
pub const ORDER_HINT_KEY: &str = "core.order_hint.key";

/// Metadata constructor for order hint keys on input nodes.
///
/// When the sources of a dataflow region are represented by an input operation
/// within the region, this metadata can be attached the region to give the
/// input node an order hint key.
///
/// - **Parameter:** `?key : core.nat`
/// - **Result:** `core.meta`
pub const ORDER_HINT_INPUT_KEY: &str = "core.order_hint.input_key";

/// Metadata constructor for order hint keys on output nodes.
///
/// When the targets of a dataflow region are represented by an output operation
/// within the region, this metadata can be attached the region to give the
/// output node an order hint key.
///
/// - **Parameter:** `?key : core.nat`
/// - **Result:** `core.meta`
pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key";

/// Metadata constructor for order hints.
///
/// When this metadata is attached to a dataflow region, it can indicate a
Expand Down
4 changes: 4 additions & 0 deletions hugr-model/tests/fixtures/model-order.edn
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,10 @@
(meta (core.order_hint.order 1 0))
(meta (core.order_hint.order 2 3))
(meta (core.order_hint.order 0 3))
(meta (core.order_hint.input_key 4))
(meta (core.order_hint.order 4 0))
(meta (core.order_hint.order 1 5))
(meta (core.order_hint.output_key 5))

((arithmetic.int.ineg 6)
[%0] [%4]
Expand Down
52 changes: 29 additions & 23 deletions hugr-py/src/hugr/model/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@
meta = self.export_json_meta(node)

# Add an order hint key to the node if necessary
if _needs_order_key(self.hugr, node):
if _has_order_links(self.hugr, node):

Check warning on line 74 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L74

Added line #L74 was not covered by tests
meta.append(model.Apply("core.order_hint.key", [model.Literal(node.idx)]))

match node_data.op:
Expand Down Expand Up @@ -411,13 +411,27 @@
for i in range(child_data._num_outs)
]

if _has_order_links(self.hugr, child):
meta.append(

Check warning on line 415 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L414-L415

Added lines #L414 - L415 were not covered by tests
model.Apply(
"core.order_hint.input_key", [model.Literal(child.idx)]
)
)

case Output() as op:
target_types = model.List([type.to_model() for type in op.types])
targets = [
self.link_name(InPort(child, i))
for i in range(child_data._num_inps)
]

if _has_order_links(self.hugr, child):
meta.append(

Check warning on line 429 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L428-L429

Added lines #L428 - L429 were not covered by tests
model.Apply(
"core.order_hint.output_key", [model.Literal(child.idx)]
)
)

case _:
child_node = self.export_node(child)

Expand All @@ -426,14 +440,13 @@

children.append(child_node)

meta += [
model.Apply(
"core.order_hint.order",
[model.Literal(child.idx), model.Literal(successor.idx)],
)
for successor in self.hugr.outgoing_order_links(child)
if not isinstance(self.hugr[successor].op, Output)
]
meta += [

Check warning on line 443 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L443

Added line #L443 was not covered by tests
model.Apply(
"core.order_hint.order",
[model.Literal(child.idx), model.Literal(successor.idx)],
)
for successor in self.hugr.outgoing_order_links(child)
]

signature = model.Apply("core.fn", [source_types, target_types])

Expand Down Expand Up @@ -618,19 +631,12 @@
self.sizes[a] += self.sizes[b]


def _needs_order_key(hugr: Hugr, node: Node) -> bool:
"""Checks whether the node has any order links for the purposes of
exporting order hint metadata. Order links to `Input` or `Output`
operations are ignored, since they are not present in the model format.
"""
for succ in hugr.outgoing_order_links(node):
succ_op = hugr[succ].op
if not isinstance(succ_op, Output):
return True

for pred in hugr.incoming_order_links(node):
pred_op = hugr[pred].op
if not isinstance(pred_op, Input):
return True
def _has_order_links(hugr: Hugr, node: Node) -> bool:
"""Checks whether the node has any order links."""
for _succ in hugr.outgoing_order_links(node):
return True

Check warning on line 637 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L636-L637

Added lines #L636 - L637 were not covered by tests

for _pred in hugr.incoming_order_links(node):
return True

Check warning on line 640 in hugr-py/src/hugr/model/export.py

View check run for this annotation

Codecov / codecov/patch

hugr-py/src/hugr/model/export.py#L639-L640

Added lines #L639 - L640 were not covered by tests

return False
Loading