diff --git a/hugr-passes/src/linearize_array.rs b/hugr-passes/src/linearize_array.rs index dbec6b8c03..07fbc6e958 100644 --- a/hugr-passes/src/linearize_array.rs +++ b/hugr-passes/src/linearize_array.rs @@ -141,7 +141,7 @@ impl LinearizeArrayPass { #[cfg(test)] mod test { - use hugr_core::builder::{FunctionBuilder, ModuleBuilder}; + use hugr_core::builder::ModuleBuilder; use hugr_core::extension::prelude::{ConstUsize, Noop}; use hugr_core::ops::handle::NodeHandle; use hugr_core::ops::{Const, OpType}; @@ -287,7 +287,7 @@ mod test { ), }; let sig = Signature::new(src, tgt); - let mut builder = FunctionBuilder::new("main", sig).unwrap(); + let mut builder = DFGBuilder::new(sig).unwrap(); let [arr] = builder.input_wires_arr(); let op: OpType = match dir { INTO => VArrayToArray::new(elem_ty.clone(), size).into(), @@ -313,7 +313,7 @@ mod test { #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] fn implicit_clone(#[case] array_ty: Type) { let sig = Signature::new(array_ty.clone(), vec![array_ty; 2]); - let mut builder = FunctionBuilder::new("main", sig).unwrap(); + let mut builder = DFGBuilder::new(sig).unwrap(); let [arr] = builder.input_wires_arr(); builder.set_outputs(vec![arr, arr]).unwrap(); @@ -329,7 +329,7 @@ mod test { #[case(value_array_type(2, Type::new_tuple(vec![usize_t(), value_array_type(4, usize_t())])))] fn implicit_discard(#[case] array_ty: Type) { let sig = Signature::new(array_ty, Type::EMPTY_TYPEROW); - let mut builder = FunctionBuilder::new("main", sig).unwrap(); + let mut builder = DFGBuilder::new(sig).unwrap(); builder.set_outputs(vec![]).unwrap(); let mut hugr = builder.finish_hugr().unwrap(); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e6c8d63a5e..6e26e628b5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -237,6 +237,7 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, + regions: Option>, } impl Default for ReplaceTypes { @@ -298,6 +299,7 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), + regions: None, } } @@ -435,6 +437,14 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } + /// Set the regions of the Hugr to which this pass should be applied. + /// + /// If not set, the pass is applied to the whole Hugr. + /// Each call to overwrites any previous calls to `set_regions`. + pub fn set_regions(&mut self, regions: impl IntoIterator) { + self.regions = Some(regions.into_iter().collect()); + } + fn change_node( &self, hugr: &mut impl HugrMut, @@ -600,11 +610,21 @@ impl> ComposablePass for ReplaceTypes { type Result = bool; fn run(&self, hugr: &mut H) -> Result { + let temp: Vec; // keep alive + let regions = match self.regions { + Some(ref regs) => regs, + None => { + temp = vec![hugr.module_root()]; + &temp + } + }; let mut changed = false; - for n in hugr.entry_descendants().collect::>() { - changed |= self.change_node(hugr, n)?; - if n != hugr.entrypoint() && changed { - self.linearize_outputs(hugr, n)?; + for region_root in regions { + for n in hugr.descendants(*region_root).collect::>() { + changed |= self.change_node(hugr, n)?; + if n != hugr.entrypoint() && changed { + self.linearize_outputs(hugr, n)?; + } } } Ok(changed) @@ -660,7 +680,8 @@ mod test { use crate::replace_types::handlers::generic_array_const; use hugr_core::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, inout_sig, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, endo_sig, + inout_sig, }; use hugr_core::extension::prelude::{ ConstUsize, UnwrapBuilder, bool_t, option_type, qb_t, usize_t, @@ -681,8 +702,10 @@ mod test { VArrayOp, VArrayOpDef, VArrayValue, ValueArray, value_array_type, }; - use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{Extension, HugrView, type_row}; + use hugr_core::types::{ + EdgeKind, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow, + }; + use hugr_core::{Direction, Extension, HugrView, Port, type_row}; use itertools::Itertools; use rstest::rstest; @@ -1117,8 +1140,8 @@ mod test { where GenericArrayValue: CustomConst, { - let sig = inout_sig(type_row![], AK::ty(vals.len() as _, usize_t())); - let mut dfb = FunctionBuilder::new("main", sig).unwrap(); + let mut dfb = + DFGBuilder::new(inout_sig(type_row![], AK::ty(vals.len() as _, usize_t()))).unwrap(); let c = dfb.add_load_value(GenericArrayValue::::new( usize_t(), vals.iter().map(|u| ConstUsize::new(*u).into()), @@ -1192,4 +1215,44 @@ mod test { .collect_vec(); assert_eq!(ext_op_names, ["get", "itousize", "panic",]); } + + #[test] + fn regions() { + let ext = ext(); + let coln = ext.get_type(PACKED_VEC).unwrap(); + let c_u = Type::new_extension(coln.instantiate(&[usize_t().into()]).unwrap()); + let mut h = { + let db = DFGBuilder::new(endo_sig(c_u.clone())).unwrap(); + let inps = db.input_wires(); + db.finish_hugr_with_outputs(inps) + } + .unwrap(); + let mut lowerer = lowerer(&ext); + + { + let backup = h.clone(); + lowerer.set_regions(vec![]); + assert!(!lowerer.run(&mut h).unwrap()); + assert_eq!(h, backup); + } + + let ep = h.entrypoint(); + lowerer.set_regions(vec![h.entrypoint()]); + assert!(lowerer.run(&mut h).unwrap()); + let v_u = value_array_type(64, usize_t()); + assert_eq!(h.signature(ep).unwrap().as_ref(), &endo_sig(v_u.clone())); + assert_eq!(h.num_nodes(), h.num_nodes()); + let [f_in, _] = h.get_io(h.get_parent(ep).unwrap()).unwrap(); + assert_eq!( + h.validate(), + Err(ValidationError::IncompatiblePorts { + from: f_in, + from_port: Port::new(Direction::Outgoing, 0), + to: ep, + to_port: Port::new(Direction::Incoming, 0), + from_kind: Box::new(EdgeKind::Value(c_u)), + to_kind: Box::new(EdgeKind::Value(v_u)) + }) + ); + } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 6bf78f0383..edc6128813 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -375,7 +375,7 @@ mod test { use hugr_core::builder::{ BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - FunctionBuilder, HugrBuilder, inout_sig, + HugrBuilder, inout_sig, }; use hugr_core::extension::prelude::{option_type, qb_t, usize_t}; @@ -912,7 +912,7 @@ mod test { ); let build_hugr = |ty: Type| { - let mut dfb = FunctionBuilder::new("main", Signature::new(ty.clone(), vec![])).unwrap(); + let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap(); let [inp] = dfb.input_wires_arr(); let drop_op = drop_ext .instantiate_extension_op("drop", [ty.into()])