diff --git a/hugr-core/src/export.rs b/hugr-core/src/export.rs index ecd6506933..b2595c7a69 100644 --- a/hugr-core/src/export.rs +++ b/hugr-core/src/export.rs @@ -627,41 +627,47 @@ impl<'a> Context<'a> { let children = self.hugr.children(node); let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump); - let mut output_node = None; - for child in children { match self.hugr.get_optype(child) { OpType::Input(input) => { sources = self.make_ports(child, Direction::Outgoing, input.types.len()); input_types = Some(&input.types); + + if has_order_edges(self.hugr, child) { + let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_INPUT_KEY, &[key])); + } } OpType::Output(output) => { targets = self.make_ports(child, Direction::Incoming, output.types.len()); output_types = Some(&output.types); - output_node = Some(child); + + if has_order_edges(self.hugr, child) { + let key = self.make_term(model::Literal::Nat(child.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_OUTPUT_KEY, &[key])); + } } - child_optype => { + _ => { if let Some(child_id) = self.export_node_shallow(child) { region_children.push(child_id); - - // Record all order edges that originate from this node in metadata. - let successors = child_optype - .other_output_port() - .into_iter() - .flat_map(|port| self.hugr.linked_inputs(child, port)) - .map(|(successor, _)| successor) - .filter(|successor| Some(*successor) != output_node); - - for successor in successors { - let a = - self.make_term(model::Literal::Nat(child.index() as u64).into()); - let b = self - .make_term(model::Literal::Nat(successor.index() as u64).into()); - meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); - } } } } + + // Record all order edges that originate from this node in metadata. + let successors = self + .hugr + .get_optype(child) + .other_output_port() + .into_iter() + .flat_map(|port| self.hugr.linked_inputs(child, port)) + .map(|(successor, _)| successor); + + for successor in successors { + let a = self.make_term(model::Literal::Nat(child.index() as u64).into()); + let b = self.make_term(model::Literal::Nat(successor.index() as u64).into()); + meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b])); + } } for child_id in ®ion_children { @@ -1100,21 +1106,7 @@ impl<'a> Context<'a> { } fn export_node_order_metadata(&mut self, node: Node, meta: &mut Vec) { - fn is_relevant_node(hugr: &Hugr, node: Node) -> bool { - let optype = hugr.get_optype(node); - !optype.is_input() && !optype.is_output() - } - - let optype = self.hugr.get_optype(node); - - let has_order_edges = Direction::BOTH - .iter() - .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) - .filter_map(|dir| optype.other_port(*dir)) - .flat_map(|port| self.hugr.linked_ports(node, port)) - .any(|(other, _)| is_relevant_node(self.hugr, other)); - - if has_order_edges { + if has_order_edges(self.hugr, node) { let key = self.make_term(model::Literal::Nat(node.index() as u64).into()); meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key])); } @@ -1229,6 +1221,18 @@ impl Links { } } +/// Returns `true` if a node has any incident order edges. +fn has_order_edges(hugr: &Hugr, node: Node) -> bool { + let optype = hugr.get_optype(node); + Direction::BOTH + .iter() + .filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder)) + .filter_map(|dir| optype.other_port(*dir)) + .flat_map(|port| hugr.linked_ports(node, port)) + .next() + .is_some() +} + #[cfg(test)] mod test { use rstest::{fixture, rstest}; diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 3815dcc82a..9594e68690 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -114,7 +114,7 @@ impl From for ImportError { enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] - DuplicateKey(table::NodeId, u64), + DuplicateKey(table::RegionId, u64), /// Order hint including a key not defined in the region. #[error("order hint with unknown key {0}")] UnknownKey(u64), @@ -608,7 +608,7 @@ impl<'a> Context<'a> { self.import_node(*child, node)?; } - self.create_order_edges(region)?; + self.create_order_edges(region, input, output)?; for meta_item in region_data.meta { self.import_node_metadata(node, *meta_item)?; @@ -622,13 +622,18 @@ impl<'a> Context<'a> { /// Create order edges between nodes of a dataflow region based on order hint metadata. /// /// This method assumes that the nodes for the children of the region have already been imported. - fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> { + fn create_order_edges( + &mut self, + region_id: table::RegionId, + input: Node, + output: Node, + ) -> Result<(), ImportError> { let region_data = self.get_region(region_id)?; debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow); // Collect order hint keys // PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations. - let mut order_keys = FxHashMap::::default(); + let mut order_keys = FxHashMap::::default(); for child_id in region_data.children { let child_data = self.get_node(*child_id)?; @@ -642,8 +647,42 @@ impl<'a> Context<'a> { continue; }; - if order_keys.insert(*key, *child_id).is_some() { - return Err(OrderHintError::DuplicateKey(*child_id, *key).into()); + // NOTE: The lookups here are expected to succeed since we only + // process the order metadata after we have imported the nodes. + let child_node = self.nodes[child_id]; + let child_optype = self.hugr.get_optype(child_node); + + // Check that the node has order ports. + // NOTE: This assumes that a node has an input order port iff it has an output one. + if child_optype.other_output_port().is_none() { + return Err(OrderHintError::NoOrderPort(*child_id).into()); + } + + if order_keys.insert(*key, child_node).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); + } + } + } + + // Collect the order hint keys for the input and output nodes + for meta_id in region_data.meta { + if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_INPUT_KEY)? { + let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { + continue; + }; + + if order_keys.insert(*key, input).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); + } + } + + if let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_OUTPUT_KEY)? { + let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else { + continue; + }; + + if order_keys.insert(*key, output).is_some() { + return Err(OrderHintError::DuplicateKey(region_id, *key).into()); } } } @@ -665,24 +704,13 @@ impl<'a> Context<'a> { let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?; let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?; - // NOTE: The lookups here are expected to succeed since we only - // process the order metadata after we have imported the nodes. - let a_node = self.nodes[a]; - let b_node = self.nodes[b]; - - let a_port = self - .hugr - .get_optype(a_node) - .other_output_port() - .ok_or(OrderHintError::NoOrderPort(*a))?; - - let b_port = self - .hugr - .get_optype(b_node) - .other_input_port() - .ok_or(OrderHintError::NoOrderPort(*b))?; + // NOTE: The unwrap here must succeed: + // - For all ordinary nodes we checked that they have an order port. + // - Input and output nodes always have an order port. + let a_port = self.hugr.get_optype(*a).other_output_port().unwrap(); + let b_port = self.hugr.get_optype(*b).other_input_port().unwrap(); - self.hugr.connect(a_node, a_port, b_node, b_port); + self.hugr.connect(*a, a_port, *b, b_port); } Ok(()) diff --git a/hugr-core/tests/snapshots/model__roundtrip_order.snap b/hugr-core/tests/snapshots/model__roundtrip_order.snap index ae92aa3ab2..a0b613fd81 100644 --- a/hugr-core/tests/snapshots/model__roundtrip_order.snap +++ b/hugr-core/tests/snapshots/model__roundtrip_order.snap @@ -8,15 +8,19 @@ expression: ast (import core.meta.description) +(import core.order_hint.input_key) + +(import core.order_hint.order) + +(import arithmetic.int.types.int) + (import core.nat) (import core.order_hint.key) -(import core.fn) - -(import core.order_hint.order) +(import core.order_hint.output_key) -(import arithmetic.int.types.int) +(import core.fn) (declare-operation arithmetic.int.ineg @@ -48,9 +52,14 @@ expression: ast (arithmetic.int.types.int 6) (arithmetic.int.types.int 6) (arithmetic.int.types.int 6)])) + (meta (core.order_hint.input_key 2)) + (meta (core.order_hint.order 2 4)) + (meta (core.order_hint.order 2 3)) + (meta (core.order_hint.output_key 3)) (meta (core.order_hint.order 4 7)) (meta (core.order_hint.order 5 6)) (meta (core.order_hint.order 5 4)) + (meta (core.order_hint.order 5 3)) (meta (core.order_hint.order 6 7)) ((arithmetic.int.ineg 6) [%0] [%4] (signature diff --git a/hugr-model/src/v0/mod.rs b/hugr-model/src/v0/mod.rs index 27e6605ae7..12a64ea0ef 100644 --- a/hugr-model/src/v0/mod.rs +++ b/hugr-model/src/v0/mod.rs @@ -278,6 +278,26 @@ pub const COMPAT_CONST_JSON: &str = "compat.const_json"; /// - **Result:** `core.meta` pub const ORDER_HINT_KEY: &str = "core.order_hint.key"; +/// Metadata constructor for order hint keys on input nodes. +/// +/// When the sources of a dataflow region are represented by an input operation +/// within the region, this metadata can be attached the region to give the +/// input node an order hint key. +/// +/// - **Parameter:** `?key : core.nat` +/// - **Result:** `core.meta` +pub const ORDER_HINT_INPUT_KEY: &str = "core.order_hint.input_key"; + +/// Metadata constructor for order hint keys on output nodes. +/// +/// When the targets of a dataflow region are represented by an output operation +/// within the region, this metadata can be attached the region to give the +/// output node an order hint key. +/// +/// - **Parameter:** `?key : core.nat` +/// - **Result:** `core.meta` +pub const ORDER_HINT_OUTPUT_KEY: &str = "core.order_hint.output_key"; + /// Metadata constructor for order hints. /// /// When this metadata is attached to a dataflow region, it can indicate a diff --git a/hugr-model/tests/fixtures/model-order.edn b/hugr-model/tests/fixtures/model-order.edn index 57cae40b86..354007f7c4 100644 --- a/hugr-model/tests/fixtures/model-order.edn +++ b/hugr-model/tests/fixtures/model-order.edn @@ -28,6 +28,11 @@ (meta (core.order_hint.order 1 0)) (meta (core.order_hint.order 2 3)) (meta (core.order_hint.order 0 3)) + (meta (core.order_hint.input_key 4)) + (meta (core.order_hint.order 4 0)) + (meta (core.order_hint.order 4 5)) + (meta (core.order_hint.order 1 5)) + (meta (core.order_hint.output_key 5)) ((arithmetic.int.ineg 6) [%0] [%4] diff --git a/hugr-py/src/hugr/model/export.py b/hugr-py/src/hugr/model/export.py index a1b3fd532a..a652683ba3 100644 --- a/hugr-py/src/hugr/model/export.py +++ b/hugr-py/src/hugr/model/export.py @@ -71,7 +71,7 @@ def export_node( meta = self.export_json_meta(node) # Add an order hint key to the node if necessary - if _needs_order_key(self.hugr, node): + if _has_order_links(self.hugr, node): meta.append(model.Apply("core.order_hint.key", [model.Literal(node.idx)])) match node_data.op: @@ -411,6 +411,13 @@ def export_region_dfg(self, node: Node) -> model.Region: for i in range(child_data._num_outs) ] + if _has_order_links(self.hugr, child): + meta.append( + model.Apply( + "core.order_hint.input_key", [model.Literal(child.idx)] + ) + ) + case Output() as op: target_types = model.List([type.to_model() for type in op.types]) targets = [ @@ -418,6 +425,13 @@ def export_region_dfg(self, node: Node) -> model.Region: for i in range(child_data._num_inps) ] + if _has_order_links(self.hugr, child): + meta.append( + model.Apply( + "core.order_hint.output_key", [model.Literal(child.idx)] + ) + ) + case _: child_node = self.export_node(child) @@ -426,14 +440,13 @@ def export_region_dfg(self, node: Node) -> model.Region: children.append(child_node) - meta += [ - model.Apply( - "core.order_hint.order", - [model.Literal(child.idx), model.Literal(successor.idx)], - ) - for successor in self.hugr.outgoing_order_links(child) - if not isinstance(self.hugr[successor].op, Output) - ] + meta += [ + model.Apply( + "core.order_hint.order", + [model.Literal(child.idx), model.Literal(successor.idx)], + ) + for successor in self.hugr.outgoing_order_links(child) + ] signature = model.Apply("core.fn", [source_types, target_types]) @@ -618,19 +631,12 @@ def union(self, a: T, b: T): self.sizes[a] += self.sizes[b] -def _needs_order_key(hugr: Hugr, node: Node) -> bool: - """Checks whether the node has any order links for the purposes of - exporting order hint metadata. Order links to `Input` or `Output` - operations are ignored, since they are not present in the model format. - """ - for succ in hugr.outgoing_order_links(node): - succ_op = hugr[succ].op - if not isinstance(succ_op, Output): - return True - - for pred in hugr.incoming_order_links(node): - pred_op = hugr[pred].op - if not isinstance(pred_op, Input): - return True +def _has_order_links(hugr: Hugr, node: Node) -> bool: + """Checks whether the node has any order links.""" + for _succ in hugr.outgoing_order_links(node): + return True + + for _pred in hugr.incoming_order_links(node): + return True return False