Skip to content

Commit a73bb36

Browse files
committed
Add node param to extract_hugr, and let it handle non-extractable roots
1 parent e3ae683 commit a73bb36

File tree

4 files changed

+76
-21
lines changed

4 files changed

+76
-21
lines changed

hugr-core/src/hugr/patch.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ impl<R: PatchHugrMut> PatchHugrMut for Transactional<R> {
165165
// TODO: This requires a full graph copy on each application.
166166
// Ideally we'd be able to just restore modified nodes, perhaps using a `HugrMut` wrapper
167167
// that keeps track of them.
168-
let (backup, backup_map) = h.with_entrypoint(h.module_root()).extract_hugr();
168+
let (backup, backup_map) = h.extract_hugr(h.module_root());
169169
let backup_root = backup_map.extracted_node(h.module_root());
170170
let backup_entrypoint = backup_map.extracted_node(h.entrypoint());
171171

hugr-core/src/hugr/views.rs

Lines changed: 35 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -456,23 +456,23 @@ pub trait HugrView: HugrInternals {
456456
self.base_hugr().validate()
457457
}
458458

459-
/// Extracts a HUGR containing the current entrypoint and all its
460-
/// descendants.
459+
/// Extracts a HUGR containing the parent node and all its descendants.
461460
///
462-
/// Returns a new HUGR and a map from the nodes in the source HUGR to the nodes
463-
/// in the extracted HUGR.
461+
/// Returns a new HUGR and a map from the nodes in the source HUGR to the
462+
/// nodes in the extracted HUGR. The new HUGR entrypoint corresponds to the
463+
/// extracted `parent` node.
464464
///
465-
/// Edges that connected to nodes outside the entrypoint are not be included
466-
/// in the new HUGR.
465+
/// Edges that connected to nodes outside the parent node are not be
466+
/// included in the new HUGR.
467467
///
468-
/// If the entrypoint is not a module, the returned HUGR will contain some
468+
/// If the parent is not a module, the returned HUGR will contain some
469469
/// additional nodes to contain the new entrypoint. E.g. if the optype must
470470
/// be contained in a dataflow region, a module with a function definition
471471
/// will be created to contain it.
472-
///
473-
/// If you need to extract the complete HUGR, move the entrypoint to the
474-
/// [`HugrView::module_root`] first.
475-
fn extract_hugr(&self) -> (Hugr, impl ExtractionResult<Self::Node> + 'static);
472+
fn extract_hugr(
473+
&self,
474+
parent: Self::Node,
475+
) -> (Hugr, impl ExtractionResult<Self::Node> + 'static);
476476
}
477477

478478
/// Records the result of extracting a Hugr via [HugrView::extract_hugr].
@@ -674,19 +674,36 @@ impl HugrView for Hugr {
674674
}
675675

676676
#[inline]
677-
fn extract_hugr(&self) -> (Hugr, impl ExtractionResult<Node> + 'static) {
677+
fn extract_hugr(&self, target: Node) -> (Hugr, impl ExtractionResult<Node> + 'static) {
678678
// Shortcircuit if the extracted HUGr is the same as the original
679-
if self.entrypoint() == self.module_root().node() {
679+
if target == self.module_root().node() {
680680
return (self.clone(), DefaultNodeMap(HashMap::new()));
681681
}
682682

683-
let new_entrypoint_op = self.entrypoint_optype().clone();
684-
let mut extracted = Hugr::new_with_entrypoint(new_entrypoint_op).unwrap(); // TODO: Handle error
683+
// Initialize a new HUGR with the desired entrypoint operation.
684+
// If we cannot create a new hugr with the parent's optype (e.g. if it's a `BasicBlock`),
685+
// find the first ancestor that can be extracted and use that instead.
686+
//
687+
// The final entrypoint will be set to the original `parent`.
688+
let mut parent = target;
689+
let mut extracted = loop {
690+
let parent_op = self.get_optype(parent).clone();
691+
if let Ok(hugr) = Hugr::new_with_entrypoint(parent_op) {
692+
break hugr;
693+
};
694+
// If the operation is not extractable, try the parent.
695+
// This loop always terminates, since at least the module root is extractable.
696+
parent = self
697+
.get_parent(parent)
698+
.expect("The module root is always extractable");
699+
};
685700

701+
// The entrypoint and its parent in the newly created HUGR.
702+
// These will be replaced with nodes from the original HUGR.
686703
let old_entrypoint = extracted.entrypoint();
687704
let old_parent = extracted.get_parent(old_entrypoint);
688705

689-
let inserted = extracted.insert_from_view(old_entrypoint, self);
706+
let inserted = extracted.insert_from_view(old_entrypoint, &self.with_entrypoint(parent));
690707
let new_entrypoint = inserted.inserted_entrypoint;
691708

692709
match old_parent {
@@ -710,7 +727,7 @@ impl HugrView for Hugr {
710727
})
711728
.collect_vec();
712729
// Replace the node
713-
extracted.set_entrypoint(new_entrypoint);
730+
extracted.set_entrypoint(inserted.node_map[&target]);
714731
extracted.remove_node(old_entrypoint);
715732
extracted.set_parent(new_entrypoint, old_parent);
716733
// Reconnect the inputs and outputs to the new entrypoint
@@ -723,7 +740,7 @@ impl HugrView for Hugr {
723740
}
724741
// The entrypoint a module op
725742
None => {
726-
extracted.set_entrypoint(new_entrypoint);
743+
extracted.set_entrypoint(inserted.node_map[&target]);
727744
extracted.set_module_root(new_entrypoint);
728745
extracted.remove_node(old_entrypoint);
729746
}

hugr-core/src/hugr/views/impls.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ macro_rules! hugr_view_methods {
7171
fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator<Item = (crate::Port, crate::types::Type)>;
7272
fn extensions(&self) -> &crate::extension::ExtensionRegistry;
7373
fn validate(&self) -> Result<(), crate::hugr::ValidationError>;
74-
fn extract_hugr(&self) -> (crate::Hugr, impl crate::hugr::views::ExtractionResult<Self::Node> + 'static);
74+
fn extract_hugr(&self, parent: Self::Node) -> (crate::Hugr, impl crate::hugr::views::ExtractionResult<Self::Node> + 'static);
7575
}
7676
}
7777
}

hugr-core/src/hugr/views/rerooted.rs

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ impl<H: HugrView> HugrView for Rerooted<H> {
107107
fn value_types(&self, node: Self::Node, dir: crate::Direction) -> impl Iterator<Item = (crate::Port, crate::types::Type)>;
108108
fn extensions(&self) -> &crate::extension::ExtensionRegistry;
109109
fn validate(&self) -> Result<(), crate::hugr::ValidationError>;
110-
fn extract_hugr(&self) -> (crate::Hugr, impl crate::hugr::views::ExtractionResult<Self::Node> + 'static);
110+
fn extract_hugr(&self, parent: Self::Node) -> (crate::Hugr, impl crate::hugr::views::ExtractionResult<Self::Node> + 'static);
111111
}
112112
}
113113
}
@@ -147,8 +147,10 @@ impl<H: HugrMut> HugrMut for Rerooted<H> {
147147

148148
#[cfg(test)]
149149
mod test {
150+
use crate::builder::test::simple_cfg_hugr;
150151
use crate::builder::{Dataflow, FunctionBuilder, HugrBuilder, SubContainer};
151152
use crate::hugr::internal::HugrMutInternals;
153+
use crate::hugr::views::ExtractionResult;
152154
use crate::hugr::HugrMut;
153155
use crate::ops::handle::NodeHandle;
154156
use crate::ops::{DataflowBlock, OpType};
@@ -194,4 +196,40 @@ mod test {
194196
assert!(h.entrypoint_optype().is_func_defn());
195197
assert!(h.get_optype(h.module_root().node()).is_module());
196198
}
199+
200+
#[test]
201+
fn extract_rerooted() {
202+
let mut hugr = simple_cfg_hugr();
203+
let cfg = hugr.entrypoint();
204+
let basic_block = hugr.first_child(cfg).unwrap();
205+
hugr.set_entrypoint(basic_block);
206+
assert!(hugr.get_optype(hugr.entrypoint()).is_dataflow_block());
207+
208+
let rerooted = hugr.with_entrypoint(cfg);
209+
assert!(rerooted.get_optype(rerooted.entrypoint()).is_cfg());
210+
211+
// Extract the basic block
212+
let (extracted_hugr, map) = rerooted.extract_hugr(basic_block);
213+
let extracted_cfg = map.extracted_node(cfg);
214+
let extracted_bb = map.extracted_node(basic_block);
215+
assert_eq!(extracted_hugr.entrypoint(), extracted_bb);
216+
assert!(extracted_hugr.get_optype(extracted_cfg).is_cfg());
217+
assert_eq!(
218+
extracted_hugr.first_child(extracted_cfg),
219+
Some(extracted_bb)
220+
);
221+
assert!(extracted_hugr.get_optype(extracted_bb).is_dataflow_block());
222+
223+
// Extract the cfg (and current entrypoint)
224+
let (extracted_hugr, map) = rerooted.extract_hugr(cfg);
225+
let extracted_cfg = map.extracted_node(cfg);
226+
let extracted_bb = map.extracted_node(basic_block);
227+
assert_eq!(extracted_hugr.entrypoint(), extracted_cfg);
228+
assert!(extracted_hugr.get_optype(extracted_cfg).is_cfg());
229+
assert_eq!(
230+
extracted_hugr.first_child(extracted_cfg),
231+
Some(extracted_bb)
232+
);
233+
assert!(extracted_hugr.get_optype(extracted_bb).is_dataflow_block());
234+
}
197235
}

0 commit comments

Comments
 (0)