-
Notifications
You must be signed in to change notification settings - Fork 13
feat: ReplaceTypes: recursively replace on types too #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1c9857f
52e31f2
082eddb
9392eca
6799dc0
9a0bec2
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,20 +171,39 @@ fn call<H: HugrView<Node = Node>>( | |
| Ok(Call::try_new(func_sig, type_args)?) | ||
| } | ||
|
|
||
| /// Options for how the replacement for an op is processed. | ||
| // TODO also need to apply to replacement consts: | ||
| /// Options for how a replacement (op or type) is processed. | ||
| /// | ||
| /// May be specified by [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. | ||
| /// Otherwise (the default), replacements are inserted as is (without further processing). | ||
| /// May be specified by | ||
| /// [ReplaceTypes::replace_op_with], [ReplaceTypes::replace_parametrized_op_with], | ||
| /// [ReplaceTypes::replace_type_opts] or [ReplaceTypes::replace_parametrized_type_opts]. | ||
| /// Otherwise (the default), replacements are inserted as given (without further processing). | ||
| // TODO would be good to migrate to default being process_recursive: true | ||
| #[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension | ||
| pub struct ReplacementOptions { | ||
| linearize: bool, | ||
| process_recursive: bool, | ||
| linearize_unchanged: bool, | ||
| } | ||
|
|
||
| impl ReplacementOptions { | ||
| /// Specifies that all operations within the replacement should have their | ||
| /// Specifies whether the replacement (op or type) should be processed by the same | ||
| /// [ReplaceTypes], including linearization of any changed ops. This increases | ||
| /// compositionality (in that replacements for other types/ops do not need to have | ||
| /// already been applied to the RHS), but can lead to infinite looping if e.g. the | ||
| /// replacement for an op is a DFG containing an instance of the same op. | ||
| pub fn with_recursive_replacement(mut self, rec: bool) -> Self { | ||
| self.process_recursive = rec; | ||
| self | ||
| } | ||
|
|
||
| /// Specifies whether *all* nodes within the replacement should have their | ||
| /// output ports linearized. | ||
| /// | ||
| /// * If [Self::with_recursive_replacement] has been set, this causes linearization | ||
| /// to apply even to unchanged ops. | ||
| /// * Otherwise, just applies linearization (to all nodes) without changing any ops. | ||
| pub fn with_linearization(mut self, lin: bool) -> Self { | ||
| self.linearize = lin; | ||
| self.linearize_unchanged = lin; | ||
| self | ||
| } | ||
| } | ||
|
|
@@ -218,8 +237,9 @@ impl ReplacementOptions { | |
| /// [monomorphization]: super::monomorphize() | ||
| #[derive(Clone)] | ||
| pub struct ReplaceTypes { | ||
| type_map: HashMap<CustomType, Type>, | ||
| param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>, | ||
| type_map: HashMap<CustomType, (Type, ReplacementOptions)>, | ||
| param_types: | ||
| HashMap<ParametricType, (Arc<dyn Fn(&[TypeArg]) -> Option<Type>>, ReplacementOptions)>, | ||
| linearize: DelegatingLinearizer, | ||
| op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>, | ||
| param_ops: HashMap< | ||
|
|
@@ -255,19 +275,25 @@ impl TypeTransformer for ReplaceTypes { | |
| type Err = ReplaceTypesError; | ||
|
|
||
| fn apply_custom(&self, ct: &CustomType) -> Result<Option<Type>, Self::Err> { | ||
| Ok(if let Some(res) = self.type_map.get(ct) { | ||
| Some(res.clone()) | ||
| } else if let Some(dest_fn) = self.param_types.get(&ct.into()) { | ||
| let mut ty = None; | ||
| if let Some(res) = self.type_map.get(ct) { | ||
| ty = Some(res.clone()) | ||
| } else if let Some((dest_fn, opts)) = self.param_types.get(&ct.into()) { | ||
| // `ct` has not had args transformed | ||
| let mut nargs = ct.args().to_vec(); | ||
| // We don't care if `nargs` are changed, we're just calling `dest_fn` | ||
| nargs | ||
| .iter_mut() | ||
| .try_for_each(|ta| ta.transform(self).map(|_ch| ()))?; | ||
| dest_fn(&nargs) | ||
| } else { | ||
| None | ||
| }) | ||
| ty = dest_fn(&nargs).map(|ty| (ty, opts.clone())) | ||
| }; | ||
| let Some((mut ty, opts)) = ty else { | ||
| return Ok(None); | ||
| }; | ||
| if opts.process_recursive { | ||
| ty.transform(self)?; | ||
| } | ||
| Ok(Some(ty)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -304,6 +330,14 @@ impl ReplaceTypes { | |
| } | ||
|
|
||
| /// Configures this instance to replace occurrences of type `src` with `dest`. | ||
| /// Equivalent to [Self::replace_type_opts] with [ReplacementOptions::default()] | ||
| pub fn replace_type(&mut self, src: CustomType, dest: Type) { | ||
| self.replace_type_opts(src, dest, ReplacementOptions::default()) | ||
| } | ||
|
|
||
| /// Configures this instance to replace occurrences of type `src` with `dest`, | ||
| /// according to the given `ReplacementOptions`. | ||
| /// | ||
| /// Note that if `src` is an instance of a *parametrized* [`TypeDef`], this takes | ||
| /// precedence over [`Self::replace_parametrized_type`] where the `src`s overlap. Thus, this | ||
| /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as | ||
|
|
@@ -316,14 +350,28 @@ impl ReplaceTypes { | |
| /// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations) | ||
| /// [`SignatureError`] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src). | ||
| /// (This can be overridden by an additional [`Self::replace_type`].) | ||
| pub fn replace_type(&mut self, src: CustomType, dest: Type) { | ||
| // We could check that 'dest' is copyable or 'src' is linear, but since we can't | ||
| // check that for parametrized types, we'll be consistent and not check here either. | ||
| self.type_map.insert(src, dest); | ||
| pub fn replace_type_opts(&mut self, src: CustomType, dest: Type, opts: ReplacementOptions) { | ||
| // We could check that 'dest' is copyable, 'src' is linear, or relevant copy and | ||
| // discard functions are registered with the linearizer; but since we can't check | ||
| // that for parametrized types, we'll be consistent and not check here either. | ||
| self.type_map.insert(src, (dest, opts)); | ||
| } | ||
|
|
||
| /// Configures this instance to change occurrences of a parametrized type `src` | ||
| /// via a callback that builds the replacement type given the [`TypeArg`]s. | ||
| /// Equivalent to [Self::replace_parametrized_type_opts] with [ReplacementOptions::default]. | ||
| pub fn replace_parametrized_type( | ||
| &mut self, | ||
| src: &TypeDef, | ||
| dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static, | ||
| ) { | ||
| self.replace_parametrized_type_opts(src, dest_fn, ReplacementOptions::default()) | ||
| } | ||
|
|
||
| /// Configures this instance to change occurrences of a parametrized type `src` | ||
| /// via a callback that builds the replacement type given the [`TypeArg`]s, | ||
| /// and using the given [ReplacementOptions]. | ||
| /// | ||
| /// Note that the `TypeArgs` will already have been updated (e.g. they may not | ||
| /// fit the bounds of the original type). The callback may return `None` to indicate | ||
| /// no change (in which case the supplied `TypeArgs` will be given to `src`). | ||
|
|
@@ -332,10 +380,11 @@ impl ReplaceTypes { | |
| /// [`Self::replace_consts_parametrized`] (or [`Self::replace_consts`]) as the | ||
| /// [`LoadConstant`]s will be reparametrized (and this will break the edge from [Const] to | ||
| /// [`LoadConstant`]). | ||
| pub fn replace_parametrized_type( | ||
| pub fn replace_parametrized_type_opts( | ||
| &mut self, | ||
| src: &TypeDef, | ||
| dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static, | ||
| opts: ReplacementOptions, | ||
| ) { | ||
| // No way to check that dest_fn never produces a linear type. | ||
| // We could require copy/discard-generators if src is Copyable, or *might be* | ||
|
|
@@ -346,7 +395,8 @@ impl ReplaceTypes { | |
| // dest_fn: impl Fn(&TypeArg) -> (Type, | ||
| // Fn(&Linearizer) -> NodeTemplate, // copy | ||
| // Fn(&Linearizer) -> NodeTemplate)` // discard | ||
| self.param_types.insert(src.into(), Arc::new(dest_fn)); | ||
| self.param_types | ||
| .insert(src.into(), (Arc::new(dest_fn), opts)); | ||
| } | ||
|
|
||
| /// Allows to configure how to deal with types/wires that were [Copyable] | ||
|
|
@@ -445,6 +495,26 @@ impl ReplaceTypes { | |
| self.regions = Some(regions.into_iter().collect()); | ||
| } | ||
|
|
||
| fn change_subtree( | ||
| &self, | ||
| hugr: &mut impl HugrMut<Node = Node>, | ||
| root: Node, | ||
| linearize_unchanged_ops: bool, | ||
| ) -> Result<bool, ReplaceTypesError> { | ||
| let mut changed = false; | ||
| for n in hugr.descendants(root).collect::<Vec<_>>() { | ||
| if self.change_node(hugr, n)? { | ||
| changed = true; | ||
| } else if !linearize_unchanged_ops { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should fix the "bug" in the other comment but tests are now breaking! |
||
| continue; | ||
| } | ||
| if n != hugr.entrypoint() { | ||
| self.linearize_outputs(hugr, n)?; | ||
| } | ||
| } | ||
| Ok(changed) | ||
| } | ||
|
|
||
| fn change_node( | ||
| &self, | ||
| hugr: &mut impl HugrMut<Node = Node>, | ||
|
|
@@ -527,12 +597,9 @@ impl ReplaceTypes { | |
| replacement | ||
| .replace(hugr, n) | ||
| .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; | ||
| if opts.linearize { | ||
| for d in hugr.descendants(n).collect::<Vec<_>>() { | ||
| if d != n { | ||
| self.linearize_outputs(hugr, d)?; | ||
| } | ||
| } | ||
| if opts.process_recursive { | ||
| self.change_subtree(hugr, n, opts.linearize_unchanged)?; | ||
| // change_subtree does not linearize it's root, but that's done by our caller | ||
| } | ||
| true | ||
| } else { | ||
|
|
@@ -620,12 +687,7 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes { | |
| }; | ||
| let mut changed = false; | ||
| for region_root in regions { | ||
| for n in hugr.descendants(*region_root).collect::<Vec<_>>() { | ||
| changed |= self.change_node(hugr, n)?; | ||
| if n != hugr.entrypoint() && changed { | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this is clearly wrong - it means linearization applies to the first node that changes and all nodes coming after that (!) |
||
| self.linearize_outputs(hugr, n)?; | ||
| } | ||
| } | ||
| changed |= self.change_subtree(hugr, *region_root, false)?; | ||
| } | ||
| Ok(changed) | ||
| } | ||
|
|
@@ -677,6 +739,7 @@ impl From<&OpDef> for ParametricOp { | |
| mod test { | ||
| use std::sync::Arc; | ||
|
|
||
| use crate::replace_types::ReplacementOptions; | ||
| use crate::replace_types::handlers::generic_array_const; | ||
| use hugr_core::builder::{ | ||
| BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, | ||
|
|
@@ -686,15 +749,17 @@ mod test { | |
| use hugr_core::extension::prelude::{ | ||
| ConstUsize, UnwrapBuilder, bool_t, option_type, qb_t, usize_t, | ||
| }; | ||
| use hugr_core::extension::simple_op::MakeOpDef; | ||
| use hugr_core::extension::{TypeDefBound, Version, simple_op::MakeExtensionOp}; | ||
| use hugr_core::hugr::hugrmut::HugrMut; | ||
| use hugr_core::hugr::{IdentList, ValidationError}; | ||
| use hugr_core::ops::constant::CustomConst; | ||
| use hugr_core::ops::constant::OpaqueValue; | ||
| use hugr_core::ops::{ExtensionOp, OpTrait, OpType, Tag, Value}; | ||
| use hugr_core::ops::constant::{CustomConst, OpaqueValue}; | ||
| use hugr_core::ops::{self, ExtensionOp, OpTrait, OpType, Tag, Value}; | ||
| use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; | ||
| use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; | ||
| use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue}; | ||
| use hugr_core::std_extensions::collections::array::{ | ||
| self, Array, ArrayKind, ArrayOpDef, GenericArrayValue, array_type, array_type_def, | ||
| }; | ||
| use hugr_core::std_extensions::collections::list::{ | ||
| ListOp, ListValue, list_type, list_type_def, | ||
| }; | ||
|
|
@@ -703,7 +768,7 @@ mod test { | |
| }; | ||
|
|
||
| use hugr_core::types::{ | ||
| EdgeKind, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow, | ||
| EdgeKind, PolyFuncType, Signature, SumType, Term, Type, TypeArg, TypeBound, TypeRow, | ||
| }; | ||
| use hugr_core::{Direction, Extension, HugrView, Port, type_row}; | ||
| use itertools::Itertools; | ||
|
|
@@ -1255,4 +1320,101 @@ mod test { | |
| }) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn compositionality() { | ||
| let ext = ext(); | ||
| let mut lowerer = lowerer(&ext); | ||
| // Replace std Array's with 64 elements with PackedVec's | ||
| let ext2 = ext.clone(); | ||
| lowerer.replace_parametrized_type_opts( | ||
| array_type_def(), | ||
| move |args| { | ||
| let [sz, ty] = args else { | ||
| panic!("Expected two args to array") | ||
| }; | ||
| (sz == &Term::BoundedNat(64)).then_some( | ||
| ext2.get_type(PACKED_VEC) | ||
| .unwrap() | ||
| .instantiate([ty.clone()]) | ||
| .unwrap() | ||
| .into(), | ||
| ) | ||
| }, | ||
| ReplacementOptions::default().with_recursive_replacement(true), | ||
| ); | ||
|
|
||
| // Replacement of `get` is complex because we need to wrap result of read into a Some | ||
| let ext = ext.clone(); | ||
| lowerer.replace_parametrized_op_with( | ||
| array::EXTENSION | ||
| .get_op(ArrayOpDef::get.opdef_id().as_str()) | ||
| .unwrap() | ||
| .as_ref(), | ||
| move |args| { | ||
| let [sz, Term::Runtime(ty)] = args else { | ||
| panic!("Expected two args to array-get") | ||
| }; | ||
| if sz != &Term::BoundedNat(64) { | ||
| return None; | ||
| } | ||
| let pv = ext | ||
| .get_type(PACKED_VEC) | ||
| .unwrap() | ||
| .instantiate([ty.clone().into()]) | ||
| .unwrap(); | ||
|
|
||
| let mut dfb = DFGBuilder::new(Signature::new( | ||
| vec![pv.clone().into(), usize_t()], | ||
| vec![option_type(ty.clone()).into(), pv.into()], | ||
| )) | ||
| .unwrap(); | ||
| let [pvec, idx] = dfb.input_wires_arr(); | ||
| let [idx] = dfb | ||
| .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [elem] = dfb | ||
| .add_dataflow_op(read_op(&ext, ty.clone()), [pvec, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [wrapped_elem] = dfb | ||
| .add_dataflow_op( | ||
| ops::Tag::new(1, vec![type_row![], ty.clone().into()]), | ||
| [elem], | ||
| ) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| Some(NodeTemplate::CompoundOp(Box::new( | ||
| dfb.finish_hugr_with_outputs([wrapped_elem, pvec]).unwrap(), | ||
| ))) | ||
| }, | ||
| ReplacementOptions::default().with_recursive_replacement(true), | ||
| ); | ||
|
|
||
| // Arrays of 64 bools should thus be transformed into PackedVec<bool> and then to int64s | ||
| // Arrays of 64 non-bools should thus become PackedVec<T> and then back to ValueArray<64, T> | ||
| let a64 = |t| array_type(64, t); | ||
| let opt = |t| Type::from(option_type(t)); | ||
| let mut dfb = DFGBuilder::new(Signature::new( | ||
| vec![a64(bool_t()), a64(usize_t())], | ||
| vec![opt(bool_t()), a64(bool_t()), opt(usize_t()), a64(usize_t())], | ||
| )) | ||
| .unwrap(); | ||
| let [bools, usizes] = dfb.input_wires_arr(); | ||
| let idx = dfb.add_load_value(ConstUsize::new(5)); | ||
| let [b, bools] = dfb | ||
| .add_dataflow_op(ArrayOpDef::get.to_concrete(bool_t(), 64), [bools, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [u, usizes] = dfb | ||
| .add_dataflow_op(ArrayOpDef::get.to_concrete(usize_t(), 64), [usizes, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let mut h = dfb.finish_hugr_with_outputs([b, bools, u, usizes]).unwrap(); | ||
|
|
||
| lowerer.run(&mut h).unwrap(); | ||
|
|
||
| h.validate().unwrap(); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or do we, see description