Skip to content

Commit 2c2745e

Browse files
committed
Update composable passes
1 parent 10e9a75 commit 2c2745e

File tree

7 files changed

+40
-20
lines changed

7 files changed

+40
-20
lines changed

hugr-passes/src/composable.rs

Lines changed: 29 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -2,43 +2,48 @@
22
33
use std::{error::Error, marker::PhantomData};
44

5+
use hugr_core::core::HugrNode;
56
use hugr_core::hugr::{hugrmut::HugrMut, ValidationError};
67
use hugr_core::HugrView;
78
use itertools::Either;
89

910
/// An optimization pass that can be sequenced with another and/or wrapped
1011
/// e.g. by [ValidatingPass]
1112
pub trait ComposablePass: Sized {
13+
type Node: HugrNode;
1214
type Error: Error;
1315
type Result; // Would like to default to () but currently unstable
1416

15-
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error>;
17+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<Self::Result, Self::Error>;
1618

1719
fn map_err<E2: Error>(
1820
self,
1921
f: impl Fn(Self::Error) -> E2,
20-
) -> impl ComposablePass<Result = Self::Result, Error = E2> {
22+
) -> impl ComposablePass<Result = Self::Result, Error = E2, Node = Self::Node> {
2123
ErrMapper::new(self, f)
2224
}
2325

2426
/// Returns a [ComposablePass] that does "`self` then `other`", so long as
2527
/// `other::Err` can be combined with ours.
26-
fn then<P: ComposablePass, E: ErrorCombiner<Self::Error, P::Error>>(
28+
fn then<P: ComposablePass<Node = Self::Node>, E: ErrorCombiner<Self::Error, P::Error>>(
2729
self,
2830
other: P,
29-
) -> impl ComposablePass<Result = (Self::Result, P::Result), Error = E> {
31+
) -> impl ComposablePass<Result = (Self::Result, P::Result), Error = E, Node = Self::Node> {
3032
struct Sequence<E, P1, P2>(P1, P2, PhantomData<E>);
3133
impl<E, P1, P2> ComposablePass for Sequence<E, P1, P2>
3234
where
3335
P1: ComposablePass,
34-
P2: ComposablePass,
36+
P2: ComposablePass<Node = P1::Node>,
3537
E: ErrorCombiner<P1::Error, P2::Error>,
3638
{
39+
type Node = P1::Node;
3740
type Error = E;
38-
3941
type Result = (P1::Result, P2::Result);
4042

41-
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
43+
fn run(
44+
&self,
45+
hugr: &mut impl HugrMut<Node = Self::Node>,
46+
) -> Result<Self::Result, Self::Error> {
4247
let res1 = self.0.run(hugr).map_err(E::from_first)?;
4348
let res2 = self.1.run(hugr).map_err(E::from_second)?;
4449
Ok((res1, res2))
@@ -95,10 +100,11 @@ impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ErrMapper<P, E, F> {
95100
}
96101

97102
impl<P: ComposablePass, E: Error, F: Fn(P::Error) -> E> ComposablePass for ErrMapper<P, E, F> {
103+
type Node = P::Node;
98104
type Error = E;
99105
type Result = P::Result;
100106

101-
fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
107+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<P::Result, Self::Error> {
102108
self.0.run(hugr).map_err(&self.1)
103109
}
104110
}
@@ -157,10 +163,11 @@ impl<P: ComposablePass> ValidatingPass<P> {
157163
}
158164

159165
impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
166+
type Node = P::Node;
160167
type Error = ValidatePassError<P::Error>;
161168
type Result = P::Result;
162169

163-
fn run(&self, hugr: &mut impl HugrMut) -> Result<P::Result, Self::Error> {
170+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<P::Result, Self::Error> {
164171
self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input {
165172
err,
166173
pretty_hugr,
@@ -180,8 +187,11 @@ impl<P: ComposablePass> ComposablePass for ValidatingPass<P> {
180187
/// executes a second pass
181188
pub struct IfThen<E, A, B>(A, B, PhantomData<E>);
182189

183-
impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
184-
IfThen<E, A, B>
190+
impl<
191+
A: ComposablePass<Result = bool>,
192+
B: ComposablePass<Node = A::Node>,
193+
E: ErrorCombiner<A::Error, B::Error>,
194+
> IfThen<E, A, B>
185195
{
186196
/// Make a new instance given the [ComposablePass] to run first
187197
/// and (maybe) second
@@ -190,14 +200,17 @@ impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Er
190200
}
191201
}
192202

193-
impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Error, B::Error>>
194-
ComposablePass for IfThen<E, A, B>
203+
impl<
204+
A: ComposablePass<Result = bool>,
205+
B: ComposablePass<Node = A::Node>,
206+
E: ErrorCombiner<A::Error, B::Error>,
207+
> ComposablePass for IfThen<E, A, B>
195208
{
209+
type Node = A::Node;
196210
type Error = E;
197-
198211
type Result = Option<B::Result>;
199212

200-
fn run(&self, hugr: &mut impl HugrMut) -> Result<Self::Result, Self::Error> {
213+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<Self::Result, Self::Error> {
201214
let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?;
202215
res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second))
203216
.transpose()
@@ -206,7 +219,7 @@ impl<A: ComposablePass<Result = bool>, B: ComposablePass, E: ErrorCombiner<A::Er
206219

207220
pub(crate) fn validate_if_test<P: ComposablePass>(
208221
pass: P,
209-
hugr: &mut impl HugrMut,
222+
hugr: &mut impl HugrMut<Node = P::Node>,
210223
) -> Result<P::Result, ValidatePassError<P::Error>> {
211224
if cfg!(test) {
212225
ValidatingPass::new_default(pass).run(hugr)

hugr-passes/src/const_fold.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ impl ConstantFoldPass {
7979
}
8080

8181
impl ComposablePass for ConstantFoldPass {
82+
type Node = Node;
8283
type Error = ConstFoldError;
8384
type Result = ();
8485

@@ -88,7 +89,7 @@ impl ComposablePass for ConstantFoldPass {
8889
///
8990
/// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs]
9091
/// was of an invalid [OpType]
91-
fn run(&self, hugr: &mut impl HugrMut<Node = Node>) -> Result<(), ConstFoldError> {
92+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<(), ConstFoldError> {
9293
let fresh_node = Node::from(portgraph::NodeIndex::new(
9394
hugr.nodes().max().map_or(0, |n| n.index() + 1),
9495
));

hugr-passes/src/dead_code.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ impl DeadCodeElimPass {
158158
}
159159

160160
impl ComposablePass for DeadCodeElimPass {
161+
type Node = Node;
161162
type Error = Infallible;
162163
type Result = ();
163164

hugr-passes/src/dead_funcs.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ impl RemoveDeadFuncsPass {
8383
}
8484

8585
impl ComposablePass for RemoveDeadFuncsPass {
86+
type Node = Node;
8687
type Error = RemoveDeadFuncsError;
8788
type Result = ();
8889
fn run(&self, hugr: &mut impl HugrMut<Node = Node>) -> Result<(), RemoveDeadFuncsError> {

hugr-passes/src/monomorphize.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ use crate::ComposablePass;
3333
/// children of the root node. We make best effort to ensure that names (derived
3434
/// from parent function names and concrete type args) of new functions are unique
3535
/// whenever the names of their parents are unique, but this is not guaranteed.
36-
pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError<Infallible>> {
36+
pub fn monomorphize(
37+
hugr: &mut impl HugrMut<Node = Node>,
38+
) -> Result<(), ValidatePassError<Infallible>> {
3739
validate_if_test(MonomorphizePass, hugr)
3840
}
3941

@@ -258,6 +260,7 @@ fn instantiate(
258260
pub struct MonomorphizePass;
259261

260262
impl ComposablePass for MonomorphizePass {
263+
type Node = Node;
261264
type Error = Infallible;
262265
type Result = ();
263266

hugr-passes/src/replace_types.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -513,10 +513,11 @@ impl ReplaceTypes {
513513
}
514514

515515
impl ComposablePass for ReplaceTypes {
516+
type Node = Node;
516517
type Error = ReplaceTypesError;
517518
type Result = bool;
518519

519-
fn run(&self, hugr: &mut impl HugrMut) -> Result<bool, ReplaceTypesError> {
520+
fn run(&self, hugr: &mut impl HugrMut<Node = Self::Node>) -> Result<bool, ReplaceTypesError> {
520521
let mut changed = false;
521522
for n in hugr.nodes().collect::<Vec<_>>() {
522523
changed |= self.change_node(hugr, n)?;

hugr-passes/src/untuple.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,8 +122,8 @@ impl UntuplePass {
122122
}
123123

124124
impl ComposablePass for UntuplePass {
125+
type Node = Node;
125126
type Error = UntupleError;
126-
127127
type Result = UntupleResult;
128128

129129
fn run(&self, hugr: &mut impl HugrMut<Node = Node>) -> Result<Self::Result, Self::Error> {

0 commit comments

Comments
 (0)