@@ -89,7 +89,7 @@ impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
8989 ) ;
9090
9191 for i in 0 ..num_iter_outputs {
92- h. connect ( cond_n , i, dfg , i) ;
92+ h. connect ( dfg , i, cond_n , i) ;
9393 }
9494 let cond = h. get_optype ( cond_n) . as_conditional ( ) . unwrap ( ) ;
9595 let case_in_rows = [ 0 , 1 ] . map ( |i| cond. case_input_row ( i) . unwrap ( ) ) ;
@@ -143,3 +143,109 @@ impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
143143 /// (Any later failure means an invalid Hugr and `panic`.)
144144 const UNCHANGED_ON_FAILURE : bool = true ;
145145}
146+
147+ #[ cfg( test) ]
148+ mod test {
149+ use itertools:: Itertools ;
150+
151+ use super :: { PeelTailLoop , PeelTailLoopError } ;
152+ use crate :: builder:: { Container , Dataflow , DataflowSubContainer , HugrBuilder } ;
153+ use crate :: builder:: { ModuleBuilder , test:: simple_dfg_hugr} ;
154+ use crate :: extension:: prelude:: { bool_t, usize_t} ;
155+ use crate :: ops:: handle:: NodeHandle ;
156+ use crate :: ops:: { OpTag , OpTrait } ;
157+ use crate :: types:: { Signature , Type , TypeRow } ;
158+ use crate :: { HugrView , hugr:: HugrMut } ;
159+
160+ #[ test]
161+ fn bad_peel ( ) {
162+ let backup = simple_dfg_hugr ( ) ;
163+ let opty = backup. entrypoint_optype ( ) . clone ( ) ;
164+ assert ! ( !opty. is_tail_loop( ) ) ;
165+ let mut h = backup. clone ( ) ;
166+ let r = h. apply_patch ( PeelTailLoop :: new ( h. entrypoint ( ) ) ) ;
167+ assert_eq ! (
168+ r,
169+ Err ( PeelTailLoopError :: NotTailLoop ( backup. entrypoint( ) , opty) )
170+ ) ;
171+ assert_eq ! ( h, backup) ;
172+ }
173+
174+ #[ test]
175+ fn peel_loop ( ) {
176+ // Note: after PR#2256, this can be updated to use `module_root_builder`
177+ let mut mb = ModuleBuilder :: new ( ) ;
178+ let helper = mb
179+ . declare (
180+ "helper" ,
181+ Signature :: new (
182+ vec ! [ bool_t( ) , usize_t( ) ] ,
183+ vec ! [ Type :: new_sum( [ vec![ bool_t( ) ] , vec![ ] ] ) , usize_t( ) ] ,
184+ )
185+ . into ( ) ,
186+ )
187+ . unwrap ( ) ;
188+ let mut fb = mb
189+ . define_function ( "main" , Signature :: new ( vec ! [ bool_t( ) , usize_t( ) ] , usize_t ( ) ) )
190+ . unwrap ( ) ;
191+ let [ b, u] = fb. input_wires_arr ( ) ;
192+ let ( tl, call) = {
193+ let mut tlb = fb
194+ . tail_loop_builder ( [ ( bool_t ( ) , b) ] , [ ( usize_t ( ) , u) ] , TypeRow :: new ( ) )
195+ . unwrap ( ) ;
196+ let [ b, u] = tlb. input_wires_arr ( ) ;
197+ let c = tlb. call ( & helper, & [ ] , [ b, u] ) . unwrap ( ) ;
198+ let [ pred, other] = c. outputs_arr ( ) ;
199+ ( tlb. finish_with_outputs ( pred, [ other] ) . unwrap ( ) , c. node ( ) )
200+ } ;
201+ fb. finish_with_outputs ( tl. outputs ( ) ) . unwrap ( ) ;
202+ let mut h = mb. finish_hugr ( ) . unwrap ( ) ;
203+
204+ h. apply_patch ( PeelTailLoop :: new ( tl. node ( ) ) ) . unwrap ( ) ;
205+ eprintln ! ( "ALAN {}" , h. mermaid_string( ) ) ;
206+ h. validate ( ) . unwrap ( ) ;
207+ let tags = |n| {
208+ let mut v = Vec :: new ( ) ;
209+ let mut o = Some ( n) ;
210+ while let Some ( n) = o {
211+ v. push ( h. get_optype ( n) . tag ( ) ) ;
212+ o = h. get_parent ( n) ;
213+ }
214+ v
215+ } ;
216+
217+ assert_eq ! (
218+ h. nodes( )
219+ . filter( |n| h. get_optype( * n) . is_tail_loop( ) )
220+ . collect_vec( ) ,
221+ [ tl. node( ) ]
222+ ) ;
223+ assert_eq ! (
224+ tags( call) ,
225+ [
226+ OpTag :: FnCall ,
227+ OpTag :: TailLoop ,
228+ OpTag :: Case ,
229+ OpTag :: Conditional ,
230+ OpTag :: FuncDefn ,
231+ OpTag :: ModuleRoot
232+ ]
233+ ) ;
234+ let [ c1, c2] = h
235+ . all_linked_inputs ( helper. node ( ) )
236+ . map ( |( n, _p) | n)
237+ . collect_array ( )
238+ . unwrap ( ) ;
239+ assert ! ( [ c1, c2] . contains( & call) ) ;
240+ let other = if call == c1 { c2 } else { c1 } ;
241+ assert_eq ! (
242+ tags( other) ,
243+ [
244+ OpTag :: FnCall ,
245+ OpTag :: Dfg ,
246+ OpTag :: FuncDefn ,
247+ OpTag :: ModuleRoot
248+ ]
249+ ) ;
250+ }
251+ }
0 commit comments