Skip to content

Commit 660ebde

Browse files
committed
Add peel_loop.rs
1 parent f8b6d25 commit 660ebde

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

hugr-core/src/hugr/patch.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ pub mod inline_dfg;
66
pub mod insert_cut;
77
pub mod insert_identity;
88
pub mod outline_cfg;
9+
pub mod peel_loop;
910
mod port_types;
1011
pub mod replace;
1112
pub mod simple_replace;
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
//! Rewrite to peel one iteration of a [TailLoop], creating a [DFG] containing a copy of
2+
//! the loop body, and a [Conditional] containing the original `TailLoop` node.
3+
use derive_more::{Display, Error};
4+
5+
use crate::core::HugrNode;
6+
use crate::ops::{Case, Conditional, DFG, DataflowOpTrait, Input, OpType, Output, TailLoop};
7+
use crate::types::Signature;
8+
use crate::{HugrView, Node, PortIndex};
9+
10+
use super::{HugrMut, PatchHugrMut, PatchVerification};
11+
12+
/// Rewrite peel one iteration of a [TailLoop] to a known [`FuncDefn`](OpType::FuncDefn)
13+
pub struct PeelTailLoop<N = Node>(N);
14+
15+
/// Error in performing [`PeelTailLoop`] rewrite.
16+
#[derive(Clone, Debug, Display, Error, PartialEq)]
17+
#[non_exhaustive]
18+
pub enum PeelTailLoopError<N = Node> {
19+
/// The specified Node was not a [TailLoop]
20+
#[display("Node to inline {_0} expected to be a TailLoop but actually {_1}")]
21+
NotTailLoop(N, OpType),
22+
}
23+
24+
impl<N> PeelTailLoop<N> {
25+
/// Create a new instance that will inline the specified node
26+
/// (i.e. that should be a [TailLoop])
27+
pub fn new(node: N) -> Self {
28+
Self(node)
29+
}
30+
}
31+
32+
impl<N: HugrNode> PatchVerification for PeelTailLoop<N> {
33+
type Error = PeelTailLoopError<N>;
34+
type Node = N;
35+
fn verify(&self, h: &impl HugrView<Node = N>) -> Result<(), Self::Error> {
36+
let opty = h.get_optype(self.0);
37+
if !opty.is_tail_loop() {
38+
return Err(PeelTailLoopError::NotTailLoop(self.0, opty.clone()));
39+
}
40+
Ok(())
41+
}
42+
43+
fn invalidation_set(&self) -> impl Iterator<Item = N> {
44+
Some(self.0).into_iter()
45+
}
46+
}
47+
48+
impl<N: HugrNode> PatchHugrMut for PeelTailLoop<N> {
49+
type Outcome = ();
50+
fn apply_hugr_mut(self, h: &mut impl HugrMut<Node = N>) -> Result<(), Self::Error> {
51+
self.verify(h)?; // Now we know we have a TailLoop.
52+
let tl = h.get_optype(self.0).as_tail_loop().unwrap();
53+
54+
let Signature {
55+
input: loop_in,
56+
output: loop_out,
57+
} = tl.signature().into_owned();
58+
let iter_outputs = tl.body_output_row().into_owned();
59+
let num_iter_outputs = iter_outputs.len();
60+
let dfg = h.add_node_before(
61+
self.0,
62+
DFG {
63+
signature: Signature::new(loop_in, iter_outputs.clone()),
64+
},
65+
);
66+
67+
h.copy_descendants(self.0, dfg, None);
68+
69+
let mut other_inputs = iter_outputs;
70+
let sum_rows = other_inputs
71+
.remove(0)
72+
.as_sum()
73+
.unwrap()
74+
.variants()
75+
.map(|r| r.clone().try_into().unwrap())
76+
.collect();
77+
78+
let cond_n = h.add_node_after(
79+
dfg,
80+
Conditional {
81+
sum_rows,
82+
other_inputs: other_inputs.into(),
83+
outputs: loop_out.clone(),
84+
},
85+
);
86+
debug_assert_eq!(
87+
h.signature(dfg).unwrap().output_types(),
88+
h.signature(cond_n).unwrap().input_types()
89+
);
90+
91+
for i in 0..num_iter_outputs {
92+
h.connect(cond_n, i, dfg, i);
93+
}
94+
let cond = h.get_optype(cond_n).as_conditional().unwrap();
95+
let case_in_rows = [0, 1].map(|i| cond.case_input_row(i).unwrap());
96+
// Stop borrowing `cond` as it borrows `h`
97+
let cases = case_in_rows.map(|in_row| {
98+
let n = h.add_node_with_parent(
99+
cond_n,
100+
Case {
101+
signature: Signature::new(in_row.clone(), loop_out.clone()),
102+
},
103+
);
104+
h.add_node_with_parent(n, Input { types: in_row });
105+
h.add_node_with_parent(
106+
n,
107+
Output {
108+
types: loop_out.clone(),
109+
},
110+
);
111+
n
112+
});
113+
114+
let [i, o] = h.get_io(cases[TailLoop::BREAK_TAG]).unwrap();
115+
for p in 0..loop_out.len() {
116+
h.connect(i, p, o, p);
117+
}
118+
119+
h.set_parent(self.0, cases[TailLoop::CONTINUE_TAG]);
120+
let [i, o] = h.get_io(cases[TailLoop::CONTINUE_TAG]).unwrap();
121+
// Inputs to original TailLoop are fed to DFG; TailLoop now takes inputs from Case(.Input)
122+
for inport in h.node_inputs(self.0).collect::<Vec<_>>() {
123+
for (src_n, src_p) in h.linked_outputs(self.0, inport).collect::<Vec<_>>() {
124+
h.connect(src_n, src_p, dfg, inport);
125+
}
126+
h.disconnect(self.0, inport);
127+
// Note this also creates an Order edge from Case.Input -> TailLoop if the loop had any Order predecessors
128+
h.connect(i, inport.index(), self.0, inport);
129+
}
130+
// Outputs from original TailLoop come from Conditional; TailLoop outputs go to Case(.Output)
131+
for outport in h.node_outputs(self.0).collect::<Vec<_>>() {
132+
for (tgt_n, tgt_p) in h.linked_inputs(self.0, outport).collect::<Vec<_>>() {
133+
h.connect(cond_n, outport, tgt_n, tgt_p);
134+
}
135+
h.disconnect(self.0, outport);
136+
// Note this also creates an Order edge from TailLoop -> Case.Output if the loop had any Order successors
137+
h.connect(self.0, outport, o, outport.index());
138+
}
139+
Ok(())
140+
}
141+
142+
/// Failure only occurs if the node is not a [TailLoop].
143+
/// (Any later failure means an invalid Hugr and `panic`.)
144+
const UNCHANGED_ON_FAILURE: bool = true;
145+
}

0 commit comments

Comments
 (0)