@@ -456,23 +456,23 @@ pub trait HugrView: HugrInternals {
456456 self . base_hugr ( ) . validate ( )
457457 }
458458
459- /// Extracts a HUGR containing the current entrypoint and all its
460- /// descendants.
459+ /// Extracts a HUGR containing the parent node and all its descendants.
461460 ///
462- /// Returns a new HUGR and a map from the nodes in the source HUGR to the nodes
463- /// in the extracted HUGR.
461+ /// Returns a new HUGR and a map from the nodes in the source HUGR to the
462+ /// nodes in the extracted HUGR. The new HUGR entrypoint corresponds to the
463+ /// extracted `parent` node.
464464 ///
465- /// Edges that connected to nodes outside the entrypoint are not be included
466- /// in the new HUGR.
465+ /// Edges that connected to nodes outside the parent node are not be
466+ /// included in the new HUGR.
467467 ///
468- /// If the entrypoint is not a module, the returned HUGR will contain some
468+ /// If the parent is not a module, the returned HUGR will contain some
469469 /// additional nodes to contain the new entrypoint. E.g. if the optype must
470470 /// be contained in a dataflow region, a module with a function definition
471471 /// will be created to contain it.
472- ///
473- /// If you need to extract the complete HUGR, move the entrypoint to the
474- /// [`HugrView::module_root`] first.
475- fn extract_hugr ( & self ) -> ( Hugr , impl ExtractionResult < Self :: Node > + ' static ) ;
472+ fn extract_hugr (
473+ & self ,
474+ parent : Self :: Node ,
475+ ) -> ( Hugr , impl ExtractionResult < Self :: Node > + ' static ) ;
476476}
477477
478478/// Records the result of extracting a Hugr via [HugrView::extract_hugr].
@@ -674,19 +674,36 @@ impl HugrView for Hugr {
674674 }
675675
676676 #[ inline]
677- fn extract_hugr ( & self ) -> ( Hugr , impl ExtractionResult < Node > + ' static ) {
677+ fn extract_hugr ( & self , target : Node ) -> ( Hugr , impl ExtractionResult < Node > + ' static ) {
678678 // Shortcircuit if the extracted HUGr is the same as the original
679- if self . entrypoint ( ) == self . module_root ( ) . node ( ) {
679+ if target == self . module_root ( ) . node ( ) {
680680 return ( self . clone ( ) , DefaultNodeMap ( HashMap :: new ( ) ) ) ;
681681 }
682682
683- let new_entrypoint_op = self . entrypoint_optype ( ) . clone ( ) ;
684- let mut extracted = Hugr :: new_with_entrypoint ( new_entrypoint_op) . unwrap ( ) ; // TODO: Handle error
683+ // Initialize a new HUGR with the desired entrypoint operation.
684+ // If we cannot create a new hugr with the parent's optype (e.g. if it's a `BasicBlock`),
685+ // find the first ancestor that can be extracted and use that instead.
686+ //
687+ // The final entrypoint will be set to the original `parent`.
688+ let mut parent = target;
689+ let mut extracted = loop {
690+ let parent_op = self . get_optype ( parent) . clone ( ) ;
691+ if let Ok ( hugr) = Hugr :: new_with_entrypoint ( parent_op) {
692+ break hugr;
693+ } ;
694+ // If the operation is not extractable, try the parent.
695+ // This loop always terminates, since at least the module root is extractable.
696+ parent = self
697+ . get_parent ( parent)
698+ . expect ( "The module root is always extractable" ) ;
699+ } ;
685700
701+ // The entrypoint and its parent in the newly created HUGR.
702+ // These will be replaced with nodes from the original HUGR.
686703 let old_entrypoint = extracted. entrypoint ( ) ;
687704 let old_parent = extracted. get_parent ( old_entrypoint) ;
688705
689- let inserted = extracted. insert_from_view ( old_entrypoint, self ) ;
706+ let inserted = extracted. insert_from_view ( old_entrypoint, & self . with_entrypoint ( parent ) ) ;
690707 let new_entrypoint = inserted. inserted_entrypoint ;
691708
692709 match old_parent {
@@ -710,7 +727,7 @@ impl HugrView for Hugr {
710727 } )
711728 . collect_vec ( ) ;
712729 // Replace the node
713- extracted. set_entrypoint ( new_entrypoint ) ;
730+ extracted. set_entrypoint ( inserted . node_map [ & target ] ) ;
714731 extracted. remove_node ( old_entrypoint) ;
715732 extracted. set_parent ( new_entrypoint, old_parent) ;
716733 // Reconnect the inputs and outputs to the new entrypoint
@@ -723,7 +740,7 @@ impl HugrView for Hugr {
723740 }
724741 // The entrypoint a module op
725742 None => {
726- extracted. set_entrypoint ( new_entrypoint ) ;
743+ extracted. set_entrypoint ( inserted . node_map [ & target ] ) ;
727744 extracted. set_module_root ( new_entrypoint) ;
728745 extracted. remove_node ( old_entrypoint) ;
729746 }
0 commit comments