Skip to content

Commit 939c082

Browse files
acl-cqcss2165
authored andcommitted
feat: Rewrite for peeling a TailLoop (#2290)
Replaces the TailLoop with a DFG containing the body plus a Conditional, inside one Case of which there is a copy of the original TailLoop. (Uses `HugrMut::copy_descendants` to copy incoming edges without copying their sources.) Letting the new DFG take the Node of the old TailLoop allows to preserve all existing wiring (including e.g. order edges for nonlocals). closes #2107
1 parent d195340 commit 939c082

File tree

3 files changed

+305
-2
lines changed

3 files changed

+305
-2
lines changed

hugr-core/src/hugr/patch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod inline_dfg;
66
pub mod insert_cut;
77
pub mod insert_identity;
88
pub mod outline_cfg;
9+
pub mod peel_loop;
910
mod port_types;
1011
pub mod replace;
1112
pub mod simple_replace;
Lines changed: 298 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,298 @@
1+
//! Rewrite to peel one iteration of a [TailLoop], creating a [DFG] containing a copy of
2+
//! the loop body, and a [Conditional] containing the original `TailLoop` node.
3+
use derive_more::{Display, Error};
4+
5+
use crate::core::HugrNode;
6+
use crate::ops::{
7+
Case, Conditional, DFG, DataflowOpTrait, Input, OpTrait, OpType, Output, TailLoop,
8+
};
9+
use crate::types::Signature;
10+
use crate::{Direction, HugrView, Node};
11+
12+
use super::{HugrMut, PatchHugrMut, PatchVerification};
13+
14+
/// Rewrite that peels one iteration of a [TailLoop] by turning the
15+
/// iteration test into a [Conditional].
16+
#[derive(Clone, Debug, PartialEq)]
17+
pub struct PeelTailLoop<N = Node>(N);
18+
19+
/// Error in performing [`PeelTailLoop`] rewrite.
20+
#[derive(Clone, Debug, Display, Error, PartialEq)]
21+
#[non_exhaustive]
22+
pub enum PeelTailLoopError<N = Node> {
23+
/// The specified Node was not a [`TailLoop`]
24+
#[display("Node to peel {node} expected to be a TailLoop but actually {op}")]
25+
NotTailLoop {
26+
/// The node requested to peel
27+
node: N,
28+
/// The actual (non-tail-loop) operation
29+
op: OpType,
30+
},
31+
}
32+
33+
impl<N> PeelTailLoop<N> {
34+
/// Create a new instance that will peel the specified [TailLoop] node
35+
pub fn new(node: N) -> Self {
36+
Self(node)
37+
}
38+
}
39+
40+
impl<N: HugrNode> PatchVerification for PeelTailLoop<N> {
41+
type Error = PeelTailLoopError<N>;
42+
type Node = N;
43+
fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
44+
let opty = h.get_optype(self.0);
45+
if !opty.is_tail_loop() {
46+
return Err(PeelTailLoopError::NotTailLoop {
47+
node: self.0,
48+
op: opty.clone(),
49+
});
50+
}
51+
Ok(())
52+
}
53+
54+
fn invalidated_nodes(&self, h: &impl HugrView<Node = N>) -> impl Iterator<Item = N> {
55+
h.get_io(self.0)
56+
.into_iter()
57+
.flat_map(|[_, output]| [self.0, output].into_iter())
58+
}
59+
}
60+
61+
impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
62+
type Outcome = ();
63+
fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
64+
self.verify(h)?; // Now we know we have a TailLoop!
65+
let loop_ty = h.optype_mut(self.0);
66+
let signature = loop_ty.dataflow_signature().unwrap().into_owned();
67+
// Replace the TailLoop with a DFG - this maintains all external connections
68+
let OpType::TailLoop(tl) = std::mem::replace(loop_ty, DFG { signature }.into()) else {
69+
panic!("Wasn't a TailLoop ?!")
70+
};
71+
let sum_rows = Vec::from(tl.control_variants());
72+
let rest = tl.rest.clone();
73+
let Signature {
74+
input: loop_in,
75+
output: loop_out,
76+
} = tl.signature().into_owned();
77+
78+
// Copy the DFG (ex-TailLoop) children into a new TailLoop *before* we add any more
79+
let new_loop = h.add_node_after(self.0, tl); // Temporary parent
80+
h.copy_descendants(self.0, new_loop, None);
81+
82+
// Add conditional inside DFG.
83+
let [_, dfg_out] = h.get_io(self.0).unwrap();
84+
let cond = Conditional {
85+
sum_rows,
86+
other_inputs: rest,
87+
outputs: loop_out.clone(),
88+
};
89+
let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
90+
// This preserves all edges from the end of the loop body to the conditional:
91+
h.replace_op(dfg_out, cond);
92+
let cond_n = dfg_out;
93+
h.add_ports(cond_n, Direction::Outgoing, loop_out.len() as isize + 1);
94+
let dfg_out = h.add_node_before(
95+
cond_n,
96+
Output {
97+
types: loop_out.clone(),
98+
},
99+
);
100+
for p in 0..loop_out.len() {
101+
h.connect(cond_n, p, dfg_out, p)
102+
}
103+
104+
// Now wire up the internals of the Conditional
105+
let cases = case_in_rows.map(|in_row| {
106+
let signature = Signature::new(in_row.clone(), loop_out.clone());
107+
let n = h.add_node_with_parent(cond_n, Case { signature });
108+
h.add_node_with_parent(n, Input { types: in_row });
109+
let types = loop_out.clone();
110+
h.add_node_with_parent(n, Output { types });
111+
n
112+
});
113+
114+
h.set_parent(new_loop, cases[TailLoop::CONTINUE_TAG]);
115+
let [ctn_in, ctn_out] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap();
116+
let [brk_in, brk_out] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap();
117+
for p in 0..loop_out.len() {
118+
h.connect(brk_in, p, brk_out, p);
119+
h.connect(new_loop, p, ctn_out, p)
120+
}
121+
for p in 0..loop_in.len() {
122+
h.connect(ctn_in, p, new_loop, p);
123+
}
124+
Ok(())
125+
}
126+
127+
/// Failure only occurs if the node is not a [TailLoop].
128+
/// (Any later failure means an invalid Hugr and `panic`.)
129+
const UNCHANGED_ON_FAILURE: bool = true;
130+
}
131+
132+
#[cfg(test)]
133+
mod test {
134+
use itertools::Itertools;
135+
136+
use crate::builder::test::simple_dfg_hugr;
137+
use crate::builder::{
138+
Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder,
139+
};
140+
use crate::extension::prelude::{bool_t, usize_t};
141+
use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle};
142+
use crate::std_extensions::arithmetic::int_types::INT_TYPES;
143+
use crate::types::{Signature, Type, TypeRow};
144+
use crate::{HugrView, hugr::HugrMut};
145+
146+
use super::{PeelTailLoop, PeelTailLoopError};
147+
148+
#[test]
149+
fn bad_peel() {
150+
let backup = simple_dfg_hugr();
151+
let op = backup.entrypoint_optype().clone();
152+
assert!(!op.is_tail_loop());
153+
let mut h = backup.clone();
154+
let r = h.apply_patch(PeelTailLoop::new(h.entrypoint()));
155+
assert_eq!(
156+
r,
157+
Err(PeelTailLoopError::NotTailLoop {
158+
node: backup.entrypoint(),
159+
op
160+
})
161+
);
162+
assert_eq!(h, backup);
163+
}
164+
165+
#[test]
166+
fn peel_loop_incoming_edges() {
167+
let i32_t = || INT_TYPES[5].clone();
168+
let mut fb = FunctionBuilder::new(
169+
"main",
170+
Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()),
171+
)
172+
.unwrap();
173+
let helper = fb
174+
.module_root_builder()
175+
.declare(
176+
"helper",
177+
Signature::new(
178+
vec![bool_t(), usize_t(), i32_t()],
179+
vec![Type::new_sum([vec![bool_t(); 2], vec![]]), usize_t()],
180+
)
181+
.into(),
182+
)
183+
.unwrap();
184+
let [b, u, i] = fb.input_wires_arr();
185+
let (tl, call) = {
186+
let mut tlb = fb
187+
.tail_loop_builder(
188+
[(bool_t(), b), (bool_t(), b)],
189+
[(usize_t(), u)],
190+
TypeRow::new(),
191+
)
192+
.unwrap();
193+
let [b, _, u] = tlb.input_wires_arr();
194+
// Static edge from FuncDecl, and 'ext' edge from function Input:
195+
let c = tlb.call(&helper, &[], [b, u, i]).unwrap();
196+
let [pred, other] = c.outputs_arr();
197+
(tlb.finish_with_outputs(pred, [other]).unwrap(), c.node())
198+
};
199+
let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap();
200+
201+
h.apply_patch(PeelTailLoop::new(tl.node())).unwrap();
202+
h.validate().unwrap();
203+
204+
assert_eq!(
205+
h.nodes()
206+
.filter(|n| h.get_optype(*n).is_tail_loop())
207+
.count(),
208+
1
209+
);
210+
use OpTag::*;
211+
assert_eq!(tags(&h, call), [FnCall, Dfg, FuncDefn, ModuleRoot]);
212+
let [c1, c2] = h
213+
.all_linked_inputs(helper.node())
214+
.map(|(n, _p)| n)
215+
.collect_array()
216+
.unwrap();
217+
assert!([c1, c2].contains(&call));
218+
let other = if call == c1 { c2 } else { c1 };
219+
assert_eq!(
220+
tags(&h, other),
221+
[
222+
FnCall,
223+
TailLoop,
224+
Case,
225+
Conditional,
226+
Dfg,
227+
FuncDefn,
228+
ModuleRoot
229+
]
230+
);
231+
}
232+
233+
fn tags<H: HugrView>(h: &H, n: H::Node) -> Vec<OpTag> {
234+
let mut v = Vec::new();
235+
let mut o = Some(n);
236+
while let Some(n) = o {
237+
v.push(h.get_optype(n).tag());
238+
o = h.get_parent(n);
239+
}
240+
v
241+
}
242+
243+
#[test]
244+
fn peel_loop_order_output() {
245+
let i16_t = || INT_TYPES[4].clone();
246+
let mut fb =
247+
FunctionBuilder::new("main", Signature::new(vec![i16_t(), bool_t()], i16_t())).unwrap();
248+
249+
let [i, b] = fb.input_wires_arr();
250+
let tl = {
251+
let mut tlb = fb
252+
.tail_loop_builder([(i16_t(), i), (bool_t(), b)], [], i16_t().into())
253+
.unwrap();
254+
let [i, _b] = tlb.input_wires_arr();
255+
// This loop only goes round once. However, we do not expect this to affect
256+
// peeling: *dataflow analysis* can tell us that the conditional will always
257+
// take one Case (that does not contain the TailLoop), we do not do that here.
258+
let [cont] = tlb
259+
.add_dataflow_op(
260+
Tag::new(
261+
TailLoop::BREAK_TAG,
262+
tlb.loop_signature().unwrap().control_variants().into(),
263+
),
264+
[i],
265+
)
266+
.unwrap()
267+
.outputs_arr();
268+
tlb.finish_with_outputs(cont, []).unwrap()
269+
};
270+
let [i2] = tl.outputs_arr();
271+
// Create a DFG (no inputs, one output) that reads the result of the TailLoop via an 'ext` edge
272+
let dfg = fb
273+
.dfg_builder(Signature::new(vec![], i16_t()), [])
274+
.unwrap()
275+
.finish_with_outputs([i2])
276+
.unwrap();
277+
let mut h = fb.finish_hugr_with_outputs(dfg.outputs()).unwrap();
278+
let tl = tl.node();
279+
280+
h.apply_patch(PeelTailLoop::new(tl)).unwrap();
281+
h.validate().unwrap();
282+
let [tl] = h
283+
.nodes()
284+
.filter(|n| h.get_optype(*n).is_tail_loop())
285+
.collect_array()
286+
.unwrap();
287+
{
288+
use OpTag::*;
289+
assert_eq!(
290+
tags(&h, tl),
291+
[TailLoop, Case, Conditional, Dfg, FuncDefn, ModuleRoot]
292+
);
293+
}
294+
let [out_n] = h.output_neighbours(tl).collect_array().unwrap();
295+
assert!(h.get_optype(out_n).is_output());
296+
assert_eq!(h.get_parent(tl), h.get_parent(out_n));
297+
}
298+
}

hugr-core/src/ops/controlflow.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,12 +59,16 @@ impl TailLoop {
5959

6060
/// Build the output `TypeRow` of the child graph of a `TailLoop` node.
6161
pub(crate) fn body_output_row(&self) -> TypeRow {
62-
let sum_type = Type::new_sum([self.just_inputs.clone(), self.just_outputs.clone()]);
63-
let mut outputs = vec![sum_type];
62+
let mut outputs = vec![Type::new_sum(self.control_variants())];
6463
outputs.extend_from_slice(&self.rest);
6564
outputs.into()
6665
}
6766

67+
/// The variants (continue / break) of the first output from the child graph
68+
pub(crate) fn control_variants(&self) -> [TypeRow; 2] {
69+
[self.just_inputs.clone(), self.just_outputs.clone()]
70+
}
71+
6872
/// Build the input `TypeRow` of the child graph of a `TailLoop` node.
6973
pub(crate) fn body_input_row(&self) -> TypeRow {
7074
self.just_inputs.extend(self.rest.iter())

0 commit comments

Comments
 (0)