|
| 1 | +//! Rewrite to peel one iteration of a [TailLoop], creating a [DFG] containing a copy of |
| 2 | +//! the loop body, and a [Conditional] containing the original `TailLoop` node. |
| 3 | +use derive_more::{Display, Error}; |
| 4 | + |
| 5 | +use crate::core::HugrNode; |
| 6 | +use crate::ops::{ |
| 7 | + Case, Conditional, DFG, DataflowOpTrait, Input, OpTrait, OpType, Output, TailLoop, |
| 8 | +}; |
| 9 | +use crate::types::Signature; |
| 10 | +use crate::{Direction, HugrView, Node}; |
| 11 | + |
| 12 | +use super::{HugrMut, PatchHugrMut, PatchVerification}; |
| 13 | + |
| 14 | +/// Rewrite that peels one iteration of a [TailLoop] by turning the |
| 15 | +/// iteration test into a [Conditional]. |
| 16 | +#[derive(Clone, Debug, PartialEq)] |
| 17 | +pub struct PeelTailLoop<N = Node>(N); |
| 18 | + |
| 19 | +/// Error in performing [`PeelTailLoop`] rewrite. |
| 20 | +#[derive(Clone, Debug, Display, Error, PartialEq)] |
| 21 | +#[non_exhaustive] |
| 22 | +pub enum PeelTailLoopError<N = Node> { |
| 23 | + /// The specified Node was not a [`TailLoop`] |
| 24 | + #[display("Node to peel {node} expected to be a TailLoop but actually {op}")] |
| 25 | + NotTailLoop { |
| 26 | + /// The node requested to peel |
| 27 | + node: N, |
| 28 | + /// The actual (non-tail-loop) operation |
| 29 | + op: OpType, |
| 30 | + }, |
| 31 | +} |
| 32 | + |
| 33 | +impl<N> PeelTailLoop<N> { |
| 34 | + /// Create a new instance that will peel the specified [TailLoop] node |
| 35 | + pub fn new(node: N) -> Self { |
| 36 | + Self(node) |
| 37 | + } |
| 38 | +} |
| 39 | + |
| 40 | +impl<N: HugrNode> PatchVerification for PeelTailLoop<N> { |
| 41 | + type Error = PeelTailLoopError<N>; |
| 42 | + type Node = N; |
| 43 | + fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> { |
| 44 | + let opty = h.get_optype(self.0); |
| 45 | + if !opty.is_tail_loop() { |
| 46 | + return Err(PeelTailLoopError::NotTailLoop { |
| 47 | + node: self.0, |
| 48 | + op: opty.clone(), |
| 49 | + }); |
| 50 | + } |
| 51 | + Ok(()) |
| 52 | + } |
| 53 | + |
| 54 | + fn invalidated_nodes(&self, h: &impl HugrView<Node = N>) -> impl Iterator<Item = N> { |
| 55 | + h.get_io(self.0) |
| 56 | + .into_iter() |
| 57 | + .flat_map(|[_, output]| [self.0, output].into_iter()) |
| 58 | + } |
| 59 | +} |
| 60 | + |
| 61 | +impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> { |
| 62 | + type Outcome = (); |
| 63 | + fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> { |
| 64 | + self.verify(h)?; // Now we know we have a TailLoop! |
| 65 | + let loop_ty = h.optype_mut(self.0); |
| 66 | + let signature = loop_ty.dataflow_signature().unwrap().into_owned(); |
| 67 | + // Replace the TailLoop with a DFG - this maintains all external connections |
| 68 | + let OpType::TailLoop(tl) = std::mem::replace(loop_ty, DFG { signature }.into()) else { |
| 69 | + panic!("Wasn't a TailLoop ?!") |
| 70 | + }; |
| 71 | + let sum_rows = Vec::from(tl.control_variants()); |
| 72 | + let rest = tl.rest.clone(); |
| 73 | + let Signature { |
| 74 | + input: loop_in, |
| 75 | + output: loop_out, |
| 76 | + } = tl.signature().into_owned(); |
| 77 | + |
| 78 | + // Copy the DFG (ex-TailLoop) children into a new TailLoop *before* we add any more |
| 79 | + let new_loop = h.add_node_after(self.0, tl); // Temporary parent |
| 80 | + h.copy_descendants(self.0, new_loop, None); |
| 81 | + |
| 82 | + // Add conditional inside DFG. |
| 83 | + let [_, dfg_out] = h.get_io(self.0).unwrap(); |
| 84 | + let cond = Conditional { |
| 85 | + sum_rows, |
| 86 | + other_inputs: rest, |
| 87 | + outputs: loop_out.clone(), |
| 88 | + }; |
| 89 | + let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap()); |
| 90 | + // This preserves all edges from the end of the loop body to the conditional: |
| 91 | + h.replace_op(dfg_out, cond); |
| 92 | + let cond_n = dfg_out; |
| 93 | + h.add_ports(cond_n, Direction::Outgoing, loop_out.len() as isize + 1); |
| 94 | + let dfg_out = h.add_node_before( |
| 95 | + cond_n, |
| 96 | + Output { |
| 97 | + types: loop_out.clone(), |
| 98 | + }, |
| 99 | + ); |
| 100 | + for p in 0..loop_out.len() { |
| 101 | + h.connect(cond_n, p, dfg_out, p) |
| 102 | + } |
| 103 | + |
| 104 | + // Now wire up the internals of the Conditional |
| 105 | + let cases = case_in_rows.map(|in_row| { |
| 106 | + let signature = Signature::new(in_row.clone(), loop_out.clone()); |
| 107 | + let n = h.add_node_with_parent(cond_n, Case { signature }); |
| 108 | + h.add_node_with_parent(n, Input { types: in_row }); |
| 109 | + let types = loop_out.clone(); |
| 110 | + h.add_node_with_parent(n, Output { types }); |
| 111 | + n |
| 112 | + }); |
| 113 | + |
| 114 | + h.set_parent(new_loop, cases[TailLoop::CONTINUE_TAG]); |
| 115 | + let [ctn_in, ctn_out] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap(); |
| 116 | + let [brk_in, brk_out] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap(); |
| 117 | + for p in 0..loop_out.len() { |
| 118 | + h.connect(brk_in, p, brk_out, p); |
| 119 | + h.connect(new_loop, p, ctn_out, p) |
| 120 | + } |
| 121 | + for p in 0..loop_in.len() { |
| 122 | + h.connect(ctn_in, p, new_loop, p); |
| 123 | + } |
| 124 | + Ok(()) |
| 125 | + } |
| 126 | + |
| 127 | + /// Failure only occurs if the node is not a [TailLoop]. |
| 128 | + /// (Any later failure means an invalid Hugr and `panic`.) |
| 129 | + const UNCHANGED_ON_FAILURE: bool = true; |
| 130 | +} |
| 131 | + |
| 132 | +#[cfg(test)] |
| 133 | +mod test { |
| 134 | + use itertools::Itertools; |
| 135 | + |
| 136 | + use crate::builder::test::simple_dfg_hugr; |
| 137 | + use crate::builder::{ |
| 138 | + Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, |
| 139 | + }; |
| 140 | + use crate::extension::prelude::{bool_t, usize_t}; |
| 141 | + use crate::ops::{OpTag, OpTrait, Tag, TailLoop, handle::NodeHandle}; |
| 142 | + use crate::std_extensions::arithmetic::int_types::INT_TYPES; |
| 143 | + use crate::types::{Signature, Type, TypeRow}; |
| 144 | + use crate::{HugrView, hugr::HugrMut}; |
| 145 | + |
| 146 | + use super::{PeelTailLoop, PeelTailLoopError}; |
| 147 | + |
| 148 | + #[test] |
| 149 | + fn bad_peel() { |
| 150 | + let backup = simple_dfg_hugr(); |
| 151 | + let op = backup.entrypoint_optype().clone(); |
| 152 | + assert!(!op.is_tail_loop()); |
| 153 | + let mut h = backup.clone(); |
| 154 | + let r = h.apply_patch(PeelTailLoop::new(h.entrypoint())); |
| 155 | + assert_eq!( |
| 156 | + r, |
| 157 | + Err(PeelTailLoopError::NotTailLoop { |
| 158 | + node: backup.entrypoint(), |
| 159 | + op |
| 160 | + }) |
| 161 | + ); |
| 162 | + assert_eq!(h, backup); |
| 163 | + } |
| 164 | + |
| 165 | + #[test] |
| 166 | + fn peel_loop_incoming_edges() { |
| 167 | + let i32_t = || INT_TYPES[5].clone(); |
| 168 | + let mut fb = FunctionBuilder::new( |
| 169 | + "main", |
| 170 | + Signature::new(vec![bool_t(), usize_t(), i32_t()], usize_t()), |
| 171 | + ) |
| 172 | + .unwrap(); |
| 173 | + let helper = fb |
| 174 | + .module_root_builder() |
| 175 | + .declare( |
| 176 | + "helper", |
| 177 | + Signature::new( |
| 178 | + vec![bool_t(), usize_t(), i32_t()], |
| 179 | + vec![Type::new_sum([vec![bool_t(); 2], vec![]]), usize_t()], |
| 180 | + ) |
| 181 | + .into(), |
| 182 | + ) |
| 183 | + .unwrap(); |
| 184 | + let [b, u, i] = fb.input_wires_arr(); |
| 185 | + let (tl, call) = { |
| 186 | + let mut tlb = fb |
| 187 | + .tail_loop_builder( |
| 188 | + [(bool_t(), b), (bool_t(), b)], |
| 189 | + [(usize_t(), u)], |
| 190 | + TypeRow::new(), |
| 191 | + ) |
| 192 | + .unwrap(); |
| 193 | + let [b, _, u] = tlb.input_wires_arr(); |
| 194 | + // Static edge from FuncDecl, and 'ext' edge from function Input: |
| 195 | + let c = tlb.call(&helper, &[], [b, u, i]).unwrap(); |
| 196 | + let [pred, other] = c.outputs_arr(); |
| 197 | + (tlb.finish_with_outputs(pred, [other]).unwrap(), c.node()) |
| 198 | + }; |
| 199 | + let mut h = fb.finish_hugr_with_outputs(tl.outputs()).unwrap(); |
| 200 | + |
| 201 | + h.apply_patch(PeelTailLoop::new(tl.node())).unwrap(); |
| 202 | + h.validate().unwrap(); |
| 203 | + |
| 204 | + assert_eq!( |
| 205 | + h.nodes() |
| 206 | + .filter(|n| h.get_optype(*n).is_tail_loop()) |
| 207 | + .count(), |
| 208 | + 1 |
| 209 | + ); |
| 210 | + use OpTag::*; |
| 211 | + assert_eq!(tags(&h, call), [FnCall, Dfg, FuncDefn, ModuleRoot]); |
| 212 | + let [c1, c2] = h |
| 213 | + .all_linked_inputs(helper.node()) |
| 214 | + .map(|(n, _p)| n) |
| 215 | + .collect_array() |
| 216 | + .unwrap(); |
| 217 | + assert!([c1, c2].contains(&call)); |
| 218 | + let other = if call == c1 { c2 } else { c1 }; |
| 219 | + assert_eq!( |
| 220 | + tags(&h, other), |
| 221 | + [ |
| 222 | + FnCall, |
| 223 | + TailLoop, |
| 224 | + Case, |
| 225 | + Conditional, |
| 226 | + Dfg, |
| 227 | + FuncDefn, |
| 228 | + ModuleRoot |
| 229 | + ] |
| 230 | + ); |
| 231 | + } |
| 232 | + |
| 233 | + fn tags<H: HugrView>(h: &H, n: H::Node) -> Vec<OpTag> { |
| 234 | + let mut v = Vec::new(); |
| 235 | + let mut o = Some(n); |
| 236 | + while let Some(n) = o { |
| 237 | + v.push(h.get_optype(n).tag()); |
| 238 | + o = h.get_parent(n); |
| 239 | + } |
| 240 | + v |
| 241 | + } |
| 242 | + |
| 243 | + #[test] |
| 244 | + fn peel_loop_order_output() { |
| 245 | + let i16_t = || INT_TYPES[4].clone(); |
| 246 | + let mut fb = |
| 247 | + FunctionBuilder::new("main", Signature::new(vec![i16_t(), bool_t()], i16_t())).unwrap(); |
| 248 | + |
| 249 | + let [i, b] = fb.input_wires_arr(); |
| 250 | + let tl = { |
| 251 | + let mut tlb = fb |
| 252 | + .tail_loop_builder([(i16_t(), i), (bool_t(), b)], [], i16_t().into()) |
| 253 | + .unwrap(); |
| 254 | + let [i, _b] = tlb.input_wires_arr(); |
| 255 | + // This loop only goes round once. However, we do not expect this to affect |
| 256 | + // peeling: *dataflow analysis* can tell us that the conditional will always |
| 257 | + // take one Case (that does not contain the TailLoop), we do not do that here. |
| 258 | + let [cont] = tlb |
| 259 | + .add_dataflow_op( |
| 260 | + Tag::new( |
| 261 | + TailLoop::BREAK_TAG, |
| 262 | + tlb.loop_signature().unwrap().control_variants().into(), |
| 263 | + ), |
| 264 | + [i], |
| 265 | + ) |
| 266 | + .unwrap() |
| 267 | + .outputs_arr(); |
| 268 | + tlb.finish_with_outputs(cont, []).unwrap() |
| 269 | + }; |
| 270 | + let [i2] = tl.outputs_arr(); |
| 271 | + // Create a DFG (no inputs, one output) that reads the result of the TailLoop via an 'ext` edge |
| 272 | + let dfg = fb |
| 273 | + .dfg_builder(Signature::new(vec![], i16_t()), []) |
| 274 | + .unwrap() |
| 275 | + .finish_with_outputs([i2]) |
| 276 | + .unwrap(); |
| 277 | + let mut h = fb.finish_hugr_with_outputs(dfg.outputs()).unwrap(); |
| 278 | + let tl = tl.node(); |
| 279 | + |
| 280 | + h.apply_patch(PeelTailLoop::new(tl)).unwrap(); |
| 281 | + h.validate().unwrap(); |
| 282 | + let [tl] = h |
| 283 | + .nodes() |
| 284 | + .filter(|n| h.get_optype(*n).is_tail_loop()) |
| 285 | + .collect_array() |
| 286 | + .unwrap(); |
| 287 | + { |
| 288 | + use OpTag::*; |
| 289 | + assert_eq!( |
| 290 | + tags(&h, tl), |
| 291 | + [TailLoop, Case, Conditional, Dfg, FuncDefn, ModuleRoot] |
| 292 | + ); |
| 293 | + } |
| 294 | + let [out_n] = h.output_neighbours(tl).collect_array().unwrap(); |
| 295 | + assert!(h.get_optype(out_n).is_output()); |
| 296 | + assert_eq!(h.get_parent(tl), h.get_parent(out_n)); |
| 297 | + } |
| 298 | +} |
0 commit comments