Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
c09a884
Add peel_loop.rs
acl-cqc Jun 3, 2025
49de094
Test, fix Dfg -> Cond wiring
acl-cqc Jun 3, 2025
1839cc1
Test ext edge, fix handling of order input
acl-cqc Jun 4, 2025
535148d
Add TailLoop::control_variants
acl-cqc Jun 4, 2025
abf7bcf
Test order output, avoid unnecessary wire
acl-cqc Jun 4, 2025
6d4f4e2
Extend incoming-order-edges test to handle !=1 input-only types
acl-cqc Jun 4, 2025
15c153e
Replace TailLoop with DFG 'in-place'...easier, can leave edges intact
acl-cqc Jun 4, 2025
81c15ef
Make NotTailLoop into a struct with named fields
acl-cqc Jun 11, 2025
45c9e0b
fix comments, use optype_mut
acl-cqc Jun 11, 2025
9e97421
new -> try_new so can fix invalidation_set
acl-cqc Jun 11, 2025
8fbc8fa
Return proper error from constructor, Infallible later; rm hugr_modif…
acl-cqc Jun 11, 2025
508d8f3
imports
acl-cqc Jun 16, 2025
e08a8ce
invalidated_nodes
acl-cqc Jun 18, 2025
0dc9d98
Merge branch 'acl/invalidated_nodes' into acl/peel_tail_loop
acl-cqc Jun 18, 2025
a6712f7
Go back to single member, new(), fallible apply
acl-cqc Jun 18, 2025
30aebf1
Add SimpleReplacement-specific fn invalidation_set w/out HugrView
acl-cqc Jun 18, 2025
16ce795
Revert whitespace, oops
acl-cqc Jun 18, 2025
d952906
hugr-persistent doc
acl-cqc Jun 18, 2025
048f7e0
Merge remote-tracking branch 'origin/main' into acl/invalidated_nodes
acl-cqc Jun 18, 2025
c553122
Merge branch 'acl/invalidated_nodes' into acl/peel_tail_loop
acl-cqc Jun 18, 2025
a2a3c01
fmt
acl-cqc Jun 18, 2025
82b4221
std::mem::swap -> replace, much neater
acl-cqc Jun 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions hugr-core/src/hugr/patch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ pub mod inline_dfg;
pub mod insert_cut;
pub mod insert_identity;
pub mod outline_cfg;
pub mod peel_loop;
mod port_types;
pub mod replace;
pub mod simple_replace;
Expand Down
287 changes: 287 additions & 0 deletions hugr-core/src/hugr/patch/peel_loop.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
//! Rewrite to peel one iteration of a [TailLoop], creating a [DFG] containing a copy of
//! the loop body, and a [Conditional] containing the original `TailLoop` node.
use derive_more::{Display, Error};

use crate::core::HugrNode;
use crate::ops::{
Case, Conditional, DFG, DataflowOpTrait, Input, OpTrait, OpType, Output, TailLoop,
};
use crate::types::Signature;
use crate::{Direction, HugrView, Node};

use super::{HugrMut, PatchHugrMut, PatchVerification};

/// Rewrite peel one iteration of a [TailLoop] to a known [`FuncDefn`](OpType::FuncDefn)
pub struct PeelTailLoop<N = Node>(N);

/// Error in performing [`PeelTailLoop`] rewrite.
#[derive(Clone, Debug, Display, Error, PartialEq)]
#[non_exhaustive]
pub enum PeelTailLoopError<N = Node> {
/// The specified Node was not a [TailLoop]
#[display("Node to inline {_0} expected to be a TailLoop but actually {_1}")]
NotTailLoop(N, OpType),
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please use named fields for variants unless that aren't direct wrappers over another error.
derive_more's Error derive assumes things about where's the source / backtrace in the positional fields of tuple structs. Naming things instead is more readable.

Suggested change
NotTailLoop(N, OpType),
#[display("Node to inline {node} expected to be a TailLoop but actually got a {op}")]
NotTailLoop {
/// The node being inlined.
node: N,
/// The actual non-tail loop operations.
op: OpType,
},

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, this is the kind of thing that makes me wonder whether we should just use thiserror instead of derive_more - IIUC the argument against is just dependencies? This was +10 lines for, IMHO, no gain in readability :-(

However, you are right to flag my display message - "inline" a TailLoop, oops ;-)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mind you, tuple enums do work fine with derive_more (the special case is only for single element tuples) but named fields give a bit of info and specific docs to their fields (remember they are all public).

It just makes things easier to read from the docs, and avoids confusing error.0 lines (what's 0?).

}

impl<N> PeelTailLoop<N> {
/// Create a new instance that will inline the specified node
/// (i.e. that should be a [TailLoop])
pub fn new(node: N) -> Self {
Self(node)
}
}

impl<N: HugrNode> PatchVerification for PeelTailLoop<N> {
type Error = PeelTailLoopError<N>;
type Node = N;
fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
let opty = h.get_optype(self.0);
if !opty.is_tail_loop() {
return Err(PeelTailLoopError::NotTailLoop(self.0, opty.clone()));
}
Ok(())
}

fn invalidation_set(&self) -> impl Iterator<Item = N> {
Some(self.0).into_iter()
}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to include other nodes here too?

From the trait definition:

A set of nodes referenced by the rewrite. Modifying any of these nodes will invalidate it.
Two impl Rewrites can be composed if their invalidation sets are disjoint.

Changing a descendant does not invalidate this rewrite, but applying the patch would invalidate any rewrite that references the tailLoop output node. I think all other nodes are safe to modify.

Copy link
Contributor Author

@acl-cqc acl-cqc Jun 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, excellent spot, thanks @aborgna-q !

However, fixing this turned out to be pretty painful. Neither invalidation_set nor new previously took the Hugr (the output node had not been available until verify/apply). So I added that to new. Now it seems that new needs to fail if the op isn't a Tailloop (or at least, not a dataflow container with an Output node - ok to panic if invalid, but not if we've pointed the rewrite at say a leaf op). So try_new.

Then apply really has no error to report (unless the Hugr has changed since the rewrite was constructed - I mean, that could be a panic). For the time being I have kept the error struct, but this is a bit of a PITA; to ease use (clarify errors that should be handled), I am tempted to remove the PeelTailLoopError::NotTailLoop variant. This leaves PeelTailLoop empty - well it is non-exhaustive, but it being non-exhaustive means you can't match it away, so I'm then further tempted to change the error type to just Infallible so we can use something like unwrap_infallible...

What do you think?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ugh, that's quite annoying.

I think try_new should return the error rather than a random optype, as it gives context about why it failed.

Setting the patch error to Infallible seems right. Rather than adding a new dependency you can just do an infallible pattern match.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yah....so I've done that. ("Never" gets closer? cool.) Seems a bit off for try_new to return a non_exhaustive error but there.

I think this is still not great tho. The right thing to do is probably to make invalidation_set take the Hugr as argument. Or for this rewrite/PR, we could require the output node to be passed into new (and then error in verify if it isn't the output node)...I mean obviously not the great user experience we'd like but it might do until we can get through deprecation/change to invalidation_set ?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Adding the hugr parameter to invalidation_set seems reasonable.

We should do a non-breaking bootstrap though, something like

    #[deprecated(since = "0.20.2", note = "Implement `invalidation_nodes` instead")]
    fn invalidation_set(&self) -> impl Iterator<Item = Self::Node> { }

    /// ...
    fn invalidation_nodes(
        &self,
        h: &impl HugrView<Node = Self::Node>,
    ) -> impl Iterator<Item = Self::Node> {
        let _ = h;
        #[expect(deprecated)]
        self.invalidation_set()
    }

add to the docs of invalidation_set that it will no longer be called directly so new traits should implement the new method instead and ignore the former.

}

impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
type Outcome = ();
fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
self.verify(h)?; // Now we know we have a TailLoop!
let signature = h
.get_optype(self.0)
.dataflow_signature()
.unwrap()
.into_owned();
// Replace the TailLoop with a DFG - this maintains all external connections
let OpType::TailLoop(tl) = h.replace_op(self.0, DFG { signature }) else {
panic!("Wasn't a TailLoop ?!")
};
let sum_rows = Vec::from(tl.control_variants());
let rest = tl.rest.clone();
let Signature {
input: loop_in,
output: loop_out,
} = tl.signature().into_owned();

// Copy the DFG (ex-TailLoop) children into a new TailLoop *before* we add any more
let new_loop = h.add_node_after(self.0, tl); // Temporary parent
h.copy_descendants(self.0, new_loop, None);

// Add conditional inside DFG.
let [_, dfg_out] = h.get_io(self.0).unwrap();
let cond = Conditional {
sum_rows,
other_inputs: rest,
outputs: loop_out.clone(),
};
let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
// This preserves all edges from the end of the loop body to the conditional:
h.replace_op(dfg_out, cond);
let cond_n = dfg_out;
h.add_ports(cond_n, Direction::Outgoing, loop_out.len() as isize + 1);
let dfg_out = h.add_node_before(
cond_n,
Output {
types: loop_out.clone(),
},
);
for p in 0..loop_out.len() {
h.connect(cond_n, p, dfg_out, p)
}

// Now wire up the internals of the Conditional
let cases = case_in_rows.map(|in_row| {
let signature = Signature::new(in_row.clone(), loop_out.clone());
let n = h.add_node_with_parent(cond_n, Case { signature });
h.add_node_with_parent(n, Input { types: in_row });
let types = loop_out.clone();
h.add_node_with_parent(n, Output { types });
n
});

h.set_parent(new_loop, cases[TailLoop::CONTINUE_TAG]);
let [ctn_in, ctn_out] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap();
let [brk_in, brk_out] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap();
for p in 0..loop_out.len() {
h.connect(brk_in, p, brk_out, p);
h.connect(new_loop, p, ctn_out, p)
}
for p in 0..loop_in.len() {
h.connect(ctn_in, p, new_loop, p);
}
Ok(())
}

/// Failure only occurs if the node is not a [TailLoop].
/// (Any later failure means an invalid Hugr and `panic`.)
const UNCHANGED_ON_FAILURE: bool = true;
}

#[cfg(test)]
mod test {
use itertools::Itertools;

use crate::builder::test::simple_dfg_hugr;
use crate::builder::{
Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
};
use crate::extension::prelude::{bool_t, usize_t};
use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle};
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
use crate::types::{Signature, Type, TypeRow};
use crate::{HugrView, hugr::HugrMut};

use super::{PeelTailLoop, PeelTailLoopError};

#[test]
fn bad_peel() {
let backup = simple_dfg_hugr();
let opty = backup.entrypoint_optype().clone();
assert!(!opty.is_tail_loop());
let mut h = backup.clone();
let r = h.apply_patch(PeelTailLoop::new(h.entrypoint()));
assert_eq!(
r,
Err(PeelTailLoopError::NotTailLoop(backup.entrypoint(), opty))
);
assert_eq!(h, backup);
}

#[test]
fn peel_loop_incoming_edges() {
let i32_t = || INT_TYPES[5].clone();
let mut fb = FunctionBuilder::new(
"main",
Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()),
)
.unwrap();
let helper = fb
.module_root_builder()
.declare(
"helper",
Signature::new(
vec![bool_t(), usize_t(), i32_t()],
vec![Type::new_sum([vec![bool_t(); 2], vec![]]), usize_t()],
)
.into(),
)
.unwrap();
let [b, u, i] = fb.input_wires_arr();
let (tl, call) = {
let mut tlb = fb
.tail_loop_builder(
[(bool_t(), b), (bool_t(), b)],
[(usize_t(), u)],
TypeRow::new(),
)
.unwrap();
let [b, _, u] = tlb.input_wires_arr();
// Static edge from FuncDecl, and 'ext' edge from function Input:
let c = tlb.call(&helper, &[], [b, u, i]).unwrap();
let [pred, other] = c.outputs_arr();
(tlb.finish_with_outputs(pred, [other]).unwrap(), c.node())
};
let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap();

h.apply_patch(PeelTailLoop::new(tl.node())).unwrap();
h.validate().unwrap();

assert_eq!(
h.nodes()
.filter(|n| h.get_optype(*n).is_tail_loop())
.count(),
1
);
use OpTag::*;
assert_eq!(tags(&h, call), [FnCall, Dfg, FuncDefn, ModuleRoot]);
let [c1, c2] = h
.all_linked_inputs(helper.node())
.map(|(n, _p)| n)
.collect_array()
.unwrap();
assert!([c1, c2].contains(&call));
let other = if call == c1 { c2 } else { c1 };
assert_eq!(
tags(&h, other),
[
FnCall,
TailLoop,
Case,
Conditional,
Dfg,
FuncDefn,
ModuleRoot
]
);
}

fn tags<H: HugrView>(h: &H, n: H::Node) -> Vec<OpTag> {
let mut v = Vec::new();
let mut o = Some(n);
while let Some(n) = o {
v.push(h.get_optype(n).tag());
o = h.get_parent(n);
}
v
}

#[test]
fn peel_loop_order_output() {
let i16_t = || INT_TYPES[4].clone();
let mut fb =
FunctionBuilder::new("main", Signature::new(vec![i16_t(), bool_t()], i16_t())).unwrap();

let [i, b] = fb.input_wires_arr();
let tl = {
let mut tlb = fb
.tail_loop_builder([(i16_t(), i), (bool_t(), b)], [], i16_t().into())
.unwrap();
let [i, _b] = tlb.input_wires_arr();
// This loop only goes round once. However, we do not expect this to affect
// peeling: *dataflow analysis* can tell us that the conditional will always
// take one Case (that does not contain the TailLoop), we do not do that here.
let [cont] = tlb
.add_dataflow_op(
Tag::new(
TailLoop::BREAK_TAG,
tlb.loop_signature().unwrap().control_variants().into(),
),
[i],
)
.unwrap()
.outputs_arr();
tlb.finish_with_outputs(cont, []).unwrap()
};
let [i2] = tl.outputs_arr();
// Create a DFG (no inputs, one output) that reads the result of the TailLoop via an 'ext` edge
let dfg = fb
.dfg_builder(Signature::new(vec![], i16_t()), [])
.unwrap()
.finish_with_outputs([i2])
.unwrap();
let mut h = fb.finish_hugr_with_outputs(dfg.outputs()).unwrap();
let tl = tl.node();

h.apply_patch(PeelTailLoop::new(tl)).unwrap();
h.validate().unwrap();
let [tl] = h
.nodes()
.filter(|n| h.get_optype(*n).is_tail_loop())
.collect_array()
.unwrap();
{
use OpTag::*;
assert_eq!(
tags(&h, tl),
[TailLoop, Case, Conditional, Dfg, FuncDefn, ModuleRoot]
);
}
let [out_n] = h.output_neighbours(tl).collect_array().unwrap();
assert!(h.get_optype(out_n).is_output());
assert_eq!(h.get_parent(tl), h.get_parent(out_n));
}
}
8 changes: 6 additions & 2 deletions hugr-core/src/ops/controlflow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,12 +59,16 @@ impl TailLoop {

/// Build the output `TypeRow` of the child graph of a `TailLoop` node.
pub(crate) fn body_output_row(&self) -> TypeRow {
let sum_type = Type::new_sum([self.just_inputs.clone(), self.just_outputs.clone()]);
let mut outputs = vec![sum_type];
let mut outputs = vec![Type::new_sum(self.control_variants())];
outputs.extend_from_slice(&self.rest);
outputs.into()
}

/// The variants (continue / break) of the first output from the child graph
pub(crate) fn control_variants(&self) -> [TypeRow; 2] {
[self.just_inputs.clone(), self.just_outputs.clone()]
}

/// Build the input `TypeRow` of the child graph of a `TailLoop` node.
pub(crate) fn body_input_row(&self) -> TypeRow {
self.just_inputs.extend(self.rest.iter())
Expand Down
Loading