Skip to content

Commit 09de9e3

Browse files
authored
feat: Represent order edges in hugr-model as metadata. (#2027)
This PR introduces metadata to encode order edges in `hugr-model`. The children of a dataflow region can be assigned a key via `order_hint.key` metadata. Then `order_hint.order` metadata on the dataflow region encodes order edges between keys. The PR includes the necessary import and export code in `hugr-core`.
1 parent a16389f commit 09de9e3

File tree

8 files changed

+377
-54
lines changed

8 files changed

+377
-54
lines changed

hugr-core/src/export.rs

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Exporting HUGR graphs to their `hugr-model` representation.
22
use crate::{
33
extension::{ExtensionId, OpDef, SignatureFunc},
4-
hugr::{IdentList, NodeMetadataMap},
4+
hugr::IdentList,
55
ops::{
66
constant::CustomSerialized, DataflowBlock, DataflowOpTrait, OpName, OpTrait, OpType, Value,
77
},
@@ -13,10 +13,10 @@ use crate::{
1313
types::{
1414
type_param::{TypeArgVariable, TypeParam},
1515
type_row::TypeRowBase,
16-
CustomType, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType, TypeArg,
17-
TypeBase, TypeBound, TypeEnum, TypeRow,
16+
CustomType, EdgeKind, FuncTypeBase, MaybeRV, PolyFuncTypeBase, RowVariable, SumType,
17+
TypeArg, TypeBase, TypeBound, TypeEnum, TypeRow,
1818
},
19-
Direction, Hugr, HugrView, IncomingPort, Node, Port,
19+
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
2020
};
2121

2222
use fxhash::{FxBuildHasher, FxHashMap};
@@ -478,10 +478,7 @@ impl<'a> Context<'a> {
478478
let inputs = self.make_ports(node, Direction::Incoming, num_inputs);
479479
let outputs = self.make_ports(node, Direction::Outgoing, num_outputs);
480480

481-
let meta = match self.hugr.get_node_metadata(node) {
482-
Some(metadata_map) => self.export_node_metadata(metadata_map),
483-
None => &[],
484-
};
481+
let meta = self.export_node_metadata(node);
485482

486483
self.module.nodes[node_id.index()] = table::Node {
487484
operation,
@@ -588,10 +585,13 @@ impl<'a> Context<'a> {
588585
let mut targets: &[_] = &[];
589586
let mut input_types = None;
590587
let mut output_types = None;
588+
let mut meta = Vec::new();
591589

592590
let children = self.hugr.children(node);
593591
let mut region_children = BumpVec::with_capacity_in(children.size_hint().0 - 2, self.bump);
594592

593+
let mut output_node = None;
594+
595595
for child in children {
596596
match self.hugr.get_optype(child) {
597597
OpType::Input(input) => {
@@ -601,10 +601,27 @@ impl<'a> Context<'a> {
601601
OpType::Output(output) => {
602602
targets = self.make_ports(child, Direction::Incoming, output.types.len());
603603
output_types = Some(&output.types);
604+
output_node = Some(child);
604605
}
605-
_ => {
606+
child_optype => {
606607
if let Some(child_id) = self.export_node_shallow(child) {
607608
region_children.push(child_id);
609+
610+
// Record all order edges that originate from this node in metadata.
611+
let successors = child_optype
612+
.other_output_port()
613+
.into_iter()
614+
.flat_map(|port| self.hugr.linked_inputs(child, port))
615+
.map(|(successor, _)| successor)
616+
.filter(|successor| Some(*successor) != output_node);
617+
618+
for successor in successors {
619+
let a =
620+
self.make_term(model::Literal::Nat(child.index() as u64).into());
621+
let b = self
622+
.make_term(model::Literal::Nat(successor.index() as u64).into());
623+
meta.push(self.make_term_apply(model::ORDER_HINT_ORDER, &[a, b]));
624+
}
608625
}
609626
}
610627
}
@@ -634,7 +651,7 @@ impl<'a> Context<'a> {
634651
sources,
635652
targets,
636653
children: region_children.into_bump_slice(),
637-
meta: &[], // TODO: Export metadata
654+
meta: self.bump.alloc_slice_copy(&meta),
638655
signature,
639656
scope,
640657
};
@@ -1013,11 +1030,37 @@ impl<'a> Context<'a> {
10131030
}
10141031
}
10151032

1016-
pub fn export_node_metadata(&mut self, metadata_map: &NodeMetadataMap) -> &'a [table::TermId] {
1017-
let mut meta = BumpVec::with_capacity_in(metadata_map.len(), self.bump);
1033+
pub fn export_node_metadata(&mut self, node: Node) -> &'a [table::TermId] {
1034+
let metadata_map = self.hugr.get_node_metadata(node);
1035+
1036+
let has_order_edges = {
1037+
fn is_relevant_node(hugr: &Hugr, node: Node) -> bool {
1038+
let optype = hugr.get_optype(node);
1039+
!optype.is_input() && !optype.is_output()
1040+
}
1041+
1042+
let optype = self.hugr.get_optype(node);
1043+
1044+
Direction::BOTH
1045+
.iter()
1046+
.filter(|dir| optype.other_port_kind(**dir) == Some(EdgeKind::StateOrder))
1047+
.filter_map(|dir| optype.other_port(*dir))
1048+
.flat_map(|port| self.hugr.linked_ports(node, port))
1049+
.any(|(other, _)| is_relevant_node(self.hugr, other))
1050+
};
1051+
1052+
let meta_capacity = metadata_map.map_or(0, |map| map.len()) + has_order_edges as usize;
1053+
let mut meta = BumpVec::with_capacity_in(meta_capacity, self.bump);
1054+
1055+
if let Some(metadata_map) = metadata_map {
1056+
for (name, value) in metadata_map {
1057+
meta.push(self.export_json_meta(name, value));
1058+
}
1059+
}
10181060

1019-
for (name, value) in metadata_map {
1020-
meta.push(self.export_json_meta(name, value));
1061+
if has_order_edges {
1062+
let key = self.make_term(model::Literal::Nat(node.index() as u64).into());
1063+
meta.push(self.make_term_apply(model::ORDER_HINT_KEY, &[key]));
10211064
}
10221065

10231066
meta.into_bump_slice()

hugr-core/src/import.rs

Lines changed: 109 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use std::sync::Arc;
77

88
use crate::{
99
extension::{ExtensionId, ExtensionRegistry, ExtensionSet, SignatureError},
10-
hugr::HugrMut,
10+
hugr::{HugrMut, NodeMetadata},
1111
ops::{
1212
constant::{CustomConst, CustomSerialized, OpaqueValue},
1313
AliasDecl, AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock,
@@ -68,6 +68,23 @@ pub enum ImportError {
6868
/// The model is not well-formed.
6969
#[error("validate error: {0}")]
7070
Model(#[from] table::ModelError),
71+
/// Incorrect order hints.
72+
#[error("incorrect order hint: {0}")]
73+
OrderHint(#[from] OrderHintError),
74+
}
75+
76+
/// Import error caused by incorrect order hints.
77+
#[derive(Debug, Clone, Error)]
78+
pub enum OrderHintError {
79+
/// Duplicate order hint key in the same region.
80+
#[error("duplicate order hint key {0}")]
81+
DuplicateKey(table::NodeId, u64),
82+
/// Order hint including a key not defined in the region.
83+
#[error("order hint with unknown key {0}")]
84+
UnknownKey(u64),
85+
/// Order hint involving a node with no order port.
86+
#[error("order hint on node with no order port: {0}")]
87+
NoOrderPort(table::NodeId),
7188
}
7289

7390
/// Helper macro to create an `ImportError::Unsupported` error with a formatted message.
@@ -197,11 +214,27 @@ impl<'a> Context<'a> {
197214
self.record_links(node, Direction::Incoming, node_data.inputs);
198215
self.record_links(node, Direction::Outgoing, node_data.outputs);
199216

217+
// Import the JSON metadata
200218
for meta_item in node_data.meta {
201-
// TODO: For now we expect all metadata to be JSON since this is how
202-
// it is handled in `hugr-core`.
203-
let (name, value) = self.import_json_meta(*meta_item)?;
204-
self.hugr.set_metadata(node, name, value);
219+
let Some([name_arg, json_arg]) =
220+
self.match_symbol(*meta_item, model::COMPAT_META_JSON)?
221+
else {
222+
continue;
223+
};
224+
225+
let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else {
226+
return Err(table::ModelError::TypeError(*meta_item).into());
227+
};
228+
229+
let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)?
230+
else {
231+
return Err(table::ModelError::TypeError(*meta_item).into());
232+
};
233+
234+
let json_value: NodeMetadata = serde_json::from_str(json_str)
235+
.map_err(|_| table::ModelError::TypeError(*meta_item))?;
236+
237+
self.hugr.set_metadata(node, name, json_value);
205238
}
206239

207240
Ok(node)
@@ -617,11 +650,82 @@ impl<'a> Context<'a> {
617650
self.import_node(*child, node)?;
618651
}
619652

653+
self.create_order_edges(region)?;
654+
620655
self.region_scope = prev_region;
621656

622657
Ok(())
623658
}
624659

660+
/// Create order edges between nodes of a dataflow region based on order hint metadata.
661+
///
662+
/// This method assumes that the nodes for the children of the region have already been imported.
663+
fn create_order_edges(&mut self, region_id: table::RegionId) -> Result<(), ImportError> {
664+
let region_data = self.get_region(region_id)?;
665+
debug_assert_eq!(region_data.kind, model::RegionKind::DataFlow);
666+
667+
// Collect order hint keys
668+
// PERFORMANCE: It might be worthwhile to reuse the map to avoid allocations.
669+
let mut order_keys = FxHashMap::<u64, table::NodeId>::default();
670+
671+
for child_id in region_data.children {
672+
let child_data = self.get_node(*child_id)?;
673+
674+
for meta_id in child_data.meta {
675+
let Some([key]) = self.match_symbol(*meta_id, model::ORDER_HINT_KEY)? else {
676+
continue;
677+
};
678+
679+
let table::Term::Literal(model::Literal::Nat(key)) = self.get_term(key)? else {
680+
continue;
681+
};
682+
683+
if order_keys.insert(*key, *child_id).is_some() {
684+
return Err(OrderHintError::DuplicateKey(*child_id, *key).into());
685+
}
686+
}
687+
}
688+
689+
// Insert order edges
690+
for meta_id in region_data.meta {
691+
let Some([a, b]) = self.match_symbol(*meta_id, model::ORDER_HINT_ORDER)? else {
692+
continue;
693+
};
694+
695+
let table::Term::Literal(model::Literal::Nat(a)) = self.get_term(a)? else {
696+
continue;
697+
};
698+
699+
let table::Term::Literal(model::Literal::Nat(b)) = self.get_term(b)? else {
700+
continue;
701+
};
702+
703+
let a = order_keys.get(a).ok_or(OrderHintError::UnknownKey(*a))?;
704+
let b = order_keys.get(b).ok_or(OrderHintError::UnknownKey(*b))?;
705+
706+
// NOTE: The lookups here are expected to succeed since we only
707+
// process the order metadata after we have imported the nodes.
708+
let a_node = self.nodes[a];
709+
let b_node = self.nodes[b];
710+
711+
let a_port = self
712+
.hugr
713+
.get_optype(a_node)
714+
.other_output_port()
715+
.ok_or(OrderHintError::NoOrderPort(*a))?;
716+
717+
let b_port = self
718+
.hugr
719+
.get_optype(b_node)
720+
.other_input_port()
721+
.ok_or(OrderHintError::NoOrderPort(*b))?;
722+
723+
self.hugr.connect(a_node, a_port, b_node, b_port);
724+
}
725+
726+
Ok(())
727+
}
728+
625729
fn import_adt_and_rest(
626730
&mut self,
627731
node_id: table::NodeId,
@@ -1358,28 +1462,6 @@ impl<'a> Context<'a> {
13581462
}
13591463
}
13601464

1361-
fn import_json_meta(
1362-
&mut self,
1363-
term_id: table::TermId,
1364-
) -> Result<(&'a str, serde_json::Value), ImportError> {
1365-
let [name_arg, json_arg] = self
1366-
.match_symbol(term_id, model::COMPAT_META_JSON)?
1367-
.ok_or(table::ModelError::TypeError(term_id))?;
1368-
1369-
let table::Term::Literal(model::Literal::Str(name)) = self.get_term(name_arg)? else {
1370-
return Err(table::ModelError::TypeError(term_id).into());
1371-
};
1372-
1373-
let table::Term::Literal(model::Literal::Str(json_str)) = self.get_term(json_arg)? else {
1374-
return Err(table::ModelError::TypeError(term_id).into());
1375-
};
1376-
1377-
let json_value =
1378-
serde_json::from_str(json_str).map_err(|_| table::ModelError::TypeError(term_id))?;
1379-
1380-
Ok((name, json_value))
1381-
}
1382-
13831465
fn import_value(
13841466
&mut self,
13851467
term_id: table::TermId,

hugr-core/tests/model.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,3 +78,10 @@ pub fn test_roundtrip_const() {
7878
"../../hugr-model/tests/fixtures/model-const.edn"
7979
)));
8080
}
81+
82+
#[test]
83+
pub fn test_roundtrip_order() {
84+
insta::assert_snapshot!(roundtrip(include_str!(
85+
"../../hugr-model/tests/fixtures/model-order.edn"
86+
)));
87+
}
Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
---
2+
source: hugr-core/tests/model.rs
3+
expression: "roundtrip(include_str!(\"../../hugr-model/tests/fixtures/model-order.edn\"))"
4+
---
5+
(hugr 0)
6+
7+
(mod)
8+
9+
(import core.order_hint.key)
10+
11+
(import core.fn)
12+
13+
(import core.order_hint.order)
14+
15+
(import arithmetic.int.types.int)
16+
17+
(import arithmetic.int.ineg)
18+
19+
(define-func
20+
main
21+
(core.fn
22+
[arithmetic.int.types.int
23+
arithmetic.int.types.int
24+
arithmetic.int.types.int
25+
arithmetic.int.types.int]
26+
[arithmetic.int.types.int
27+
arithmetic.int.types.int
28+
arithmetic.int.types.int
29+
arithmetic.int.types.int])
30+
(dfg [%0 %1 %2 %3] [%4 %5 %6 %7]
31+
(signature
32+
(core.fn
33+
[arithmetic.int.types.int
34+
arithmetic.int.types.int
35+
arithmetic.int.types.int
36+
arithmetic.int.types.int]
37+
[arithmetic.int.types.int
38+
arithmetic.int.types.int
39+
arithmetic.int.types.int
40+
arithmetic.int.types.int]))
41+
(meta (core.order_hint.order 4 7))
42+
(meta (core.order_hint.order 5 6))
43+
(meta (core.order_hint.order 5 4))
44+
(meta (core.order_hint.order 6 7))
45+
(arithmetic.int.ineg [%0] [%4]
46+
(signature
47+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))
48+
(meta (core.order_hint.key 4)))
49+
(arithmetic.int.ineg [%1] [%5]
50+
(signature
51+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))
52+
(meta (core.order_hint.key 5)))
53+
(arithmetic.int.ineg [%2] [%6]
54+
(signature
55+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))
56+
(meta (core.order_hint.key 6)))
57+
(arithmetic.int.ineg [%3] [%7]
58+
(signature
59+
(core.fn [arithmetic.int.types.int] [arithmetic.int.types.int]))
60+
(meta (core.order_hint.key 7)))))

0 commit comments

Comments
 (0)