Skip to content

Commit 59f63aa

Browse files
committed
Test, fix Dfg -> Cond wiring
1 parent 660ebde commit 59f63aa

File tree

1 file changed

+107
-1
lines changed

1 file changed

+107
-1
lines changed

hugr-core/src/hugr/patch/peel_loop.rs

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
8989
);
9090

9191
for i in 0..num_iter_outputs {
92-
h.connect(cond_n, i, dfg, i);
92+
h.connect(dfg, i, cond_n, i);
9393
}
9494
let cond = h.get_optype(cond_n).as_conditional().unwrap();
9595
let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
@@ -143,3 +143,109 @@ impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
143143
/// (Any later failure means an invalid Hugr and `panic`.)
144144
const UNCHANGED_ON_FAILURE: bool = true;
145145
}
146+
147+
#[cfg(test)]
148+
mod test {
149+
use itertools::Itertools;
150+
151+
use super::{PeelTailLoop, PeelTailLoopError};
152+
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder};
153+
use crate::builder::{ModuleBuilder, test::simple_dfg_hugr};
154+
use crate::extension::prelude::{bool_t, usize_t};
155+
use crate::ops::handle::NodeHandle;
156+
use crate::ops::{OpTag, OpTrait};
157+
use crate::types::{Signature, Type, TypeRow};
158+
use crate::{HugrView, hugr::HugrMut};
159+
160+
#[test]
161+
fn bad_peel() {
162+
let backup = simple_dfg_hugr();
163+
let opty = backup.entrypoint_optype().clone();
164+
assert!(!opty.is_tail_loop());
165+
let mut h = backup.clone();
166+
let r = h.apply_patch(PeelTailLoop::new(h.entrypoint()));
167+
assert_eq!(
168+
r,
169+
Err(PeelTailLoopError::NotTailLoop(backup.entrypoint(), opty))
170+
);
171+
assert_eq!(h, backup);
172+
}
173+
174+
#[test]
175+
fn peel_loop() {
176+
// Note: after PR#2256, this can be updated to use `module_root_builder`
177+
let mut mb = ModuleBuilder::new();
178+
let helper = mb
179+
.declare(
180+
"helper",
181+
Signature::new(
182+
vec![bool_t(), usize_t()],
183+
vec![Type::new_sum([vec![bool_t()], vec![]]), usize_t()],
184+
)
185+
.into(),
186+
)
187+
.unwrap();
188+
let mut fb = mb
189+
.define_function("main", Signature::new(vec![bool_t(), usize_t()], usize_t()))
190+
.unwrap();
191+
let [b, u] = fb.input_wires_arr();
192+
let (tl, call) = {
193+
let mut tlb = fb
194+
.tail_loop_builder([(bool_t(), b)], [(usize_t(), u)], TypeRow::new())
195+
.unwrap();
196+
let [b, u] = tlb.input_wires_arr();
197+
let c = tlb.call(&helper, &[], [b, u]).unwrap();
198+
let [pred, other] = c.outputs_arr();
199+
(tlb.finish_with_outputs(pred, [other]).unwrap(), c.node())
200+
};
201+
fb.finish_with_outputs(tl.outputs()).unwrap();
202+
let mut h = mb.finish_hugr().unwrap();
203+
204+
h.apply_patch(PeelTailLoop::new(tl.node())).unwrap();
205+
eprintln!("ALAN {}", h.mermaid_string());
206+
h.validate().unwrap();
207+
let tags = |n| {
208+
let mut v = Vec::new();
209+
let mut o = Some(n);
210+
while let Some(n) = o {
211+
v.push(h.get_optype(n).tag());
212+
o = h.get_parent(n);
213+
}
214+
v
215+
};
216+
217+
assert_eq!(
218+
h.nodes()
219+
.filter(|n| h.get_optype(*n).is_tail_loop())
220+
.collect_vec(),
221+
[tl.node()]
222+
);
223+
assert_eq!(
224+
tags(call),
225+
[
226+
OpTag::FnCall,
227+
OpTag::TailLoop,
228+
OpTag::Case,
229+
OpTag::Conditional,
230+
OpTag::FuncDefn,
231+
OpTag::ModuleRoot
232+
]
233+
);
234+
let [c1, c2] = h
235+
.all_linked_inputs(helper.node())
236+
.map(|(n, _p)| n)
237+
.collect_array()
238+
.unwrap();
239+
assert!([c1, c2].contains(&call));
240+
let other = if call == c1 { c2 } else { c1 };
241+
assert_eq!(
242+
tags(other),
243+
[
244+
OpTag::FnCall,
245+
OpTag::Dfg,
246+
OpTag::FuncDefn,
247+
OpTag::ModuleRoot
248+
]
249+
);
250+
}
251+
}

0 commit comments

Comments
 (0)