-
Notifications
You must be signed in to change notification settings - Fork 13
feat: ReplaceTypes: recursively replace on types too #2442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
acl-cqc
wants to merge
6
commits into
main
Choose a base branch
from
acl/replace_recursive2
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from 3 commits
Commits
Show all changes
6 commits
Select commit
Hold shift + click to select a range
1c9857f
ReplaceTypes: option to recurse on replacement, including for types
acl-cqc 52e31f2
formatting
acl-cqc 082eddb
clippy/reword comment
acl-cqc 9392eca
apply_custom: use mutation
acl-cqc 6799dc0
with_recursive also sets linearize, TODO update docs
acl-cqc 9a0bec2
WIP: try to make linearization behaviour consistent, but breaks test
acl-cqc File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -171,18 +171,37 @@ fn call<H: HugrView<Node = Node>>( | |
| Ok(Call::try_new(func_sig, type_args)?) | ||
| } | ||
|
|
||
| /// Options for how the replacement for an op is processed. | ||
| // TODO also need to apply to replacement consts: | ||
| /// Options for how a replacement (op or type) is processed. | ||
| /// | ||
| /// May be specified by [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with]. | ||
| /// Otherwise (the default), replacements are inserted as is (without further processing). | ||
| /// May be specified by | ||
| /// [ReplaceTypes::replace_op_with], [ReplaceTypes::replace_parametrized_op_with], | ||
| /// [ReplaceTypes::replace_type_opts] or [ReplaceTypes::replace_parametrized_type_opts]. | ||
| /// Otherwise (the default), replacements are inserted as given (without further processing). | ||
| // TODO would be good to migrate to default being process_recursive: true | ||
| #[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension | ||
| pub struct ReplacementOptions { | ||
| process_recursive: bool, | ||
| linearize: bool, | ||
| } | ||
|
|
||
| impl ReplacementOptions { | ||
| /// Specifies that all operations within the replacement should have their | ||
| /// Specifies that the replacement (op or type) should be processed by the same | ||
| /// [ReplaceTypes]. This increases compositionality (in that replacements for | ||
| /// other types/ops do not need to have already been applied to the RHS), but | ||
| /// would lead to an infinite loop if e.g. changing an op for a DFG containing | ||
| /// an instance of the same op. | ||
| pub fn with_recursive_replacement(mut self, rec: bool) -> Self { | ||
| self.process_recursive = rec; | ||
| self | ||
| } | ||
|
|
||
| /// Specifies that all nodes within the replacement should have their | ||
| /// output ports linearized. | ||
| /// | ||
| /// * If [Self::with_recursive_replacement] has been set, this applies linearization | ||
| /// after recursive processing. | ||
| /// * Otherwise, just applies linearization (to all nodes) without changing any ops. | ||
| pub fn with_linearization(mut self, lin: bool) -> Self { | ||
| self.linearize = lin; | ||
| self | ||
|
|
@@ -218,8 +237,9 @@ impl ReplacementOptions { | |
| /// [monomorphization]: super::monomorphize() | ||
| #[derive(Clone)] | ||
| pub struct ReplaceTypes { | ||
| type_map: HashMap<CustomType, Type>, | ||
| param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>, | ||
| type_map: HashMap<CustomType, (Type, ReplacementOptions)>, | ||
| param_types: | ||
| HashMap<ParametricType, (Arc<dyn Fn(&[TypeArg]) -> Option<Type>>, ReplacementOptions)>, | ||
| linearize: DelegatingLinearizer, | ||
| op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>, | ||
| param_ops: HashMap< | ||
|
|
@@ -255,19 +275,26 @@ impl TypeTransformer for ReplaceTypes { | |
| type Err = ReplaceTypesError; | ||
|
|
||
| fn apply_custom(&self, ct: &CustomType) -> Result<Option<Type>, Self::Err> { | ||
| Ok(if let Some(res) = self.type_map.get(ct) { | ||
| let next = if let Some(res) = self.type_map.get(ct) { | ||
| Some(res.clone()) | ||
| } else if let Some(dest_fn) = self.param_types.get(&ct.into()) { | ||
| } else if let Some((dest_fn, opts)) = self.param_types.get(&ct.into()) { | ||
| // `ct` has not had args transformed | ||
| let mut nargs = ct.args().to_vec(); | ||
| // We don't care if `nargs` are changed, we're just calling `dest_fn` | ||
| nargs | ||
| .iter_mut() | ||
| .try_for_each(|ta| ta.transform(self).map(|_ch| ()))?; | ||
| dest_fn(&nargs) | ||
| dest_fn(&nargs).map(|ty| (ty, opts.clone())) | ||
| } else { | ||
| None | ||
| }) | ||
| }; | ||
| let Some((mut ty, opts)) = next else { | ||
| return Ok(None); | ||
| }; | ||
| if opts.process_recursive { | ||
| ty.transform(self)?; | ||
| } | ||
| Ok(Some(ty)) | ||
| } | ||
| } | ||
|
|
||
|
|
@@ -304,6 +331,14 @@ impl ReplaceTypes { | |
| } | ||
|
|
||
| /// Configures this instance to replace occurrences of type `src` with `dest`. | ||
| /// Equivalent to [Self::replace_type_opts] with [ReplacementOptions::default()] | ||
| pub fn replace_type(&mut self, src: CustomType, dest: Type) { | ||
| self.replace_type_opts(src, dest, ReplacementOptions::default()) | ||
| } | ||
|
|
||
| /// Configures this instance to replace occurrences of type `src` with `dest`, | ||
| /// according to the given `ReplacementOptions`. | ||
| /// | ||
| /// Note that if `src` is an instance of a *parametrized* [`TypeDef`], this takes | ||
| /// precedence over [`Self::replace_parametrized_type`] where the `src`s overlap. Thus, this | ||
| /// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as | ||
|
|
@@ -316,14 +351,28 @@ impl ReplaceTypes { | |
| /// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations) | ||
| /// [`SignatureError`] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src). | ||
| /// (This can be overridden by an additional [`Self::replace_type`].) | ||
| pub fn replace_type(&mut self, src: CustomType, dest: Type) { | ||
| // We could check that 'dest' is copyable or 'src' is linear, but since we can't | ||
| // check that for parametrized types, we'll be consistent and not check here either. | ||
| self.type_map.insert(src, dest); | ||
| pub fn replace_type_opts(&mut self, src: CustomType, dest: Type, opts: ReplacementOptions) { | ||
| // We could check that 'dest' is copyable, 'src' is linear, or relevant copy and | ||
| // discard functions are registered with the linearizer; but since we can't check | ||
| // that for parametrized types, we'll be consistent and not check here either. | ||
| self.type_map.insert(src, (dest, opts)); | ||
| } | ||
|
|
||
| /// Configures this instance to change occurrences of a parametrized type `src` | ||
| /// via a callback that builds the replacement type given the [`TypeArg`]s. | ||
| /// Equivalent to [Self::replace_parametrized_type_opts] with [ReplacementOptions::default]. | ||
| pub fn replace_parametrized_type( | ||
| &mut self, | ||
| src: &TypeDef, | ||
| dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static, | ||
| ) { | ||
| self.replace_parametrized_type_opts(src, dest_fn, ReplacementOptions::default()) | ||
| } | ||
|
|
||
| /// Configures this instance to change occurrences of a parametrized type `src` | ||
| /// via a callback that builds the replacement type given the [`TypeArg`]s, | ||
| /// and using the given [ReplacementOptions]. | ||
| /// | ||
| /// Note that the `TypeArgs` will already have been updated (e.g. they may not | ||
| /// fit the bounds of the original type). The callback may return `None` to indicate | ||
| /// no change (in which case the supplied `TypeArgs` will be given to `src`). | ||
|
|
@@ -332,10 +381,11 @@ impl ReplaceTypes { | |
| /// [`Self::replace_consts_parametrized`] (or [`Self::replace_consts`]) as the | ||
| /// [`LoadConstant`]s will be reparametrized (and this will break the edge from [Const] to | ||
| /// [`LoadConstant`]). | ||
| pub fn replace_parametrized_type( | ||
| pub fn replace_parametrized_type_opts( | ||
| &mut self, | ||
| src: &TypeDef, | ||
| dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static, | ||
| opts: ReplacementOptions, | ||
| ) { | ||
| // No way to check that dest_fn never produces a linear type. | ||
| // We could require copy/discard-generators if src is Copyable, or *might be* | ||
|
|
@@ -346,7 +396,8 @@ impl ReplaceTypes { | |
| // dest_fn: impl Fn(&TypeArg) -> (Type, | ||
| // Fn(&Linearizer) -> NodeTemplate, // copy | ||
| // Fn(&Linearizer) -> NodeTemplate)` // discard | ||
| self.param_types.insert(src.into(), Arc::new(dest_fn)); | ||
| self.param_types | ||
| .insert(src.into(), (Arc::new(dest_fn), opts)); | ||
| } | ||
|
|
||
| /// Allows to configure how to deal with types/wires that were [Copyable] | ||
|
|
@@ -445,6 +496,21 @@ impl ReplaceTypes { | |
| self.regions = Some(regions.into_iter().collect()); | ||
| } | ||
|
|
||
| fn change_subtree( | ||
| &self, | ||
| hugr: &mut impl HugrMut<Node = Node>, | ||
| root: Node, | ||
| ) -> Result<bool, ReplaceTypesError> { | ||
| let mut changed = false; | ||
| for n in hugr.descendants(root).collect::<Vec<_>>() { | ||
| changed |= self.change_node(hugr, n)?; | ||
| if n != hugr.entrypoint() { | ||
| self.linearize_outputs(hugr, n)?; | ||
| } | ||
| } | ||
| Ok(changed) | ||
| } | ||
|
|
||
| fn change_node( | ||
| &self, | ||
| hugr: &mut impl HugrMut<Node = Node>, | ||
|
|
@@ -527,6 +593,9 @@ impl ReplaceTypes { | |
| replacement | ||
| .replace(hugr, n) | ||
| .map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?; | ||
| if opts.process_recursive { | ||
| self.change_subtree(hugr, n)?; | ||
| } | ||
| if opts.linearize { | ||
| for d in hugr.descendants(n).collect::<Vec<_>>() { | ||
| if d != n { | ||
|
|
@@ -620,12 +689,7 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes { | |
| }; | ||
| let mut changed = false; | ||
| for region_root in regions { | ||
| for n in hugr.descendants(*region_root).collect::<Vec<_>>() { | ||
| changed |= self.change_node(hugr, n)?; | ||
| if n != hugr.entrypoint() && changed { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Note this is clearly wrong - it means linearization applies to the first node that changes and all nodes coming after that (!) |
||
| self.linearize_outputs(hugr, n)?; | ||
| } | ||
| } | ||
| changed |= self.change_subtree(hugr, *region_root)?; | ||
| } | ||
| Ok(changed) | ||
| } | ||
|
|
@@ -677,6 +741,7 @@ impl From<&OpDef> for ParametricOp { | |
| mod test { | ||
| use std::sync::Arc; | ||
|
|
||
| use crate::replace_types::ReplacementOptions; | ||
| use crate::replace_types::handlers::generic_array_const; | ||
| use hugr_core::builder::{ | ||
| BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, | ||
|
|
@@ -686,15 +751,17 @@ mod test { | |
| use hugr_core::extension::prelude::{ | ||
| ConstUsize, UnwrapBuilder, bool_t, option_type, qb_t, usize_t, | ||
| }; | ||
| use hugr_core::extension::simple_op::MakeOpDef; | ||
| use hugr_core::extension::{TypeDefBound, Version, simple_op::MakeExtensionOp}; | ||
| use hugr_core::hugr::hugrmut::HugrMut; | ||
| use hugr_core::hugr::{IdentList, ValidationError}; | ||
| use hugr_core::ops::constant::CustomConst; | ||
| use hugr_core::ops::constant::OpaqueValue; | ||
| use hugr_core::ops::{ExtensionOp, OpTrait, OpType, Tag, Value}; | ||
| use hugr_core::ops::constant::{CustomConst, OpaqueValue}; | ||
| use hugr_core::ops::{self, ExtensionOp, OpTrait, OpType, Tag, Value}; | ||
| use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; | ||
| use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; | ||
| use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue}; | ||
| use hugr_core::std_extensions::collections::array::{ | ||
| self, Array, ArrayKind, ArrayOpDef, GenericArrayValue, array_type, array_type_def, | ||
| }; | ||
| use hugr_core::std_extensions::collections::list::{ | ||
| ListOp, ListValue, list_type, list_type_def, | ||
| }; | ||
|
|
@@ -703,7 +770,7 @@ mod test { | |
| }; | ||
|
|
||
| use hugr_core::types::{ | ||
| EdgeKind, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow, | ||
| EdgeKind, PolyFuncType, Signature, SumType, Term, Type, TypeArg, TypeBound, TypeRow, | ||
| }; | ||
| use hugr_core::{Direction, Extension, HugrView, Port, type_row}; | ||
| use itertools::Itertools; | ||
|
|
@@ -1255,4 +1322,101 @@ mod test { | |
| }) | ||
| ); | ||
| } | ||
|
|
||
| #[test] | ||
| fn compositionality() { | ||
| let ext = ext(); | ||
| let mut lowerer = lowerer(&ext); | ||
| // Replace std Array's with 64 elements with PackedVec's | ||
| let ext2 = ext.clone(); | ||
| lowerer.replace_parametrized_type_opts( | ||
| array_type_def(), | ||
| move |args| { | ||
| let [sz, ty] = args else { | ||
| panic!("Expected two args to array") | ||
| }; | ||
| (sz == &Term::BoundedNat(64)).then_some( | ||
| ext2.get_type(PACKED_VEC) | ||
| .unwrap() | ||
| .instantiate([ty.clone()]) | ||
| .unwrap() | ||
| .into(), | ||
| ) | ||
| }, | ||
| ReplacementOptions::default().with_recursive_replacement(true), | ||
| ); | ||
|
|
||
| // Replacement of `get` is complex because we need to wrap result of read into a Some | ||
| let ext = ext.clone(); | ||
| lowerer.replace_parametrized_op_with( | ||
| array::EXTENSION | ||
| .get_op(ArrayOpDef::get.opdef_id().as_str()) | ||
| .unwrap() | ||
| .as_ref(), | ||
| move |args| { | ||
| let [sz, Term::Runtime(ty)] = args else { | ||
| panic!("Expected two args to array-get") | ||
| }; | ||
| if sz != &Term::BoundedNat(64) { | ||
| return None; | ||
| } | ||
| let pv = ext | ||
| .get_type(PACKED_VEC) | ||
| .unwrap() | ||
| .instantiate([ty.clone().into()]) | ||
| .unwrap(); | ||
|
|
||
| let mut dfb = DFGBuilder::new(Signature::new( | ||
| vec![pv.clone().into(), usize_t()], | ||
| vec![option_type(ty.clone()).into(), pv.into()], | ||
| )) | ||
| .unwrap(); | ||
| let [pvec, idx] = dfb.input_wires_arr(); | ||
| let [idx] = dfb | ||
| .add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [elem] = dfb | ||
| .add_dataflow_op(read_op(&ext, ty.clone()), [pvec, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [wrapped_elem] = dfb | ||
| .add_dataflow_op( | ||
| ops::Tag::new(1, vec![type_row![], ty.clone().into()]), | ||
| [elem], | ||
| ) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| Some(NodeTemplate::CompoundOp(Box::new( | ||
| dfb.finish_hugr_with_outputs([wrapped_elem, pvec]).unwrap(), | ||
| ))) | ||
| }, | ||
| ReplacementOptions::default().with_recursive_replacement(true), | ||
| ); | ||
|
|
||
| // Arrays of 64 bools should thus be transformed into PackedVec<bool> and then to int64s | ||
| // Arrays of 64 non-bools should thus become PackedVec<T> and then back to ValueArray<64, T> | ||
| let a64 = |t| array_type(64, t); | ||
| let opt = |t| Type::from(option_type(t)); | ||
| let mut dfb = DFGBuilder::new(Signature::new( | ||
| vec![a64(bool_t()), a64(usize_t())], | ||
| vec![opt(bool_t()), a64(bool_t()), opt(usize_t()), a64(usize_t())], | ||
| )) | ||
| .unwrap(); | ||
| let [bools, usizes] = dfb.input_wires_arr(); | ||
| let idx = dfb.add_load_value(ConstUsize::new(5)); | ||
| let [b, bools] = dfb | ||
| .add_dataflow_op(ArrayOpDef::get.to_concrete(bool_t(), 64), [bools, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let [u, usizes] = dfb | ||
| .add_dataflow_op(ArrayOpDef::get.to_concrete(usize_t(), 64), [usizes, idx]) | ||
| .unwrap() | ||
| .outputs_arr(); | ||
| let mut h = dfb.finish_hugr_with_outputs([b, bools, u, usizes]).unwrap(); | ||
|
|
||
| lowerer.run(&mut h).unwrap(); | ||
|
|
||
| h.validate().unwrap(); | ||
| } | ||
| } | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Or do we, see description