Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
238 changes: 200 additions & 38 deletions hugr-passes/src/replace_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,39 @@ fn call<H: HugrView<Node = Node>>(
Ok(Call::try_new(func_sig, type_args)?)
}

/// Options for how the replacement for an op is processed.
// TODO also need to apply to replacement consts:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we, see description

/// Options for how a replacement (op or type) is processed.
///
/// May be specified by [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with].
/// Otherwise (the default), replacements are inserted as is (without further processing).
/// May be specified by
/// [ReplaceTypes::replace_op_with], [ReplaceTypes::replace_parametrized_op_with],
/// [ReplaceTypes::replace_type_opts] or [ReplaceTypes::replace_parametrized_type_opts].
/// Otherwise (the default), replacements are inserted as given (without further processing).
// TODO would be good to migrate to default being process_recursive: true
#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension
pub struct ReplacementOptions {
linearize: bool,
process_recursive: bool,
linearize_unchanged: bool,
}

impl ReplacementOptions {
/// Specifies that all operations within the replacement should have their
/// Specifies whether the replacement (op or type) should be processed by the same
/// [ReplaceTypes], including linearization of any changed ops. This increases
/// compositionality (in that replacements for other types/ops do not need to have
/// already been applied to the RHS), but can lead to infinite looping if e.g. the
/// replacement for an op is a DFG containing an instance of the same op.
pub fn with_recursive_replacement(mut self, rec: bool) -> Self {
self.process_recursive = rec;
self
}

/// Specifies whether *all* nodes within the replacement should have their
/// output ports linearized.
///
/// * If [Self::with_recursive_replacement] has been set, this causes linearization
/// to apply even to unchanged ops.
/// * Otherwise, just applies linearization (to all nodes) without changing any ops.
pub fn with_linearization(mut self, lin: bool) -> Self {
self.linearize = lin;
self.linearize_unchanged = lin;
self
}
}
Expand Down Expand Up @@ -218,8 +237,9 @@ impl ReplacementOptions {
/// [monomorphization]: super::monomorphize()
#[derive(Clone)]
pub struct ReplaceTypes {
type_map: HashMap<CustomType, Type>,
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
type_map: HashMap<CustomType, (Type, ReplacementOptions)>,
param_types:
HashMap<ParametricType, (Arc<dyn Fn(&[TypeArg]) -> Option<Type>>, ReplacementOptions)>,
linearize: DelegatingLinearizer,
op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>,
param_ops: HashMap<
Expand Down Expand Up @@ -255,19 +275,25 @@ impl TypeTransformer for ReplaceTypes {
type Err = ReplaceTypesError;

fn apply_custom(&self, ct: &CustomType) -> Result<Option<Type>, Self::Err> {
Ok(if let Some(res) = self.type_map.get(ct) {
Some(res.clone())
} else if let Some(dest_fn) = self.param_types.get(&ct.into()) {
let mut ty = None;
if let Some(res) = self.type_map.get(ct) {
ty = Some(res.clone())
} else if let Some((dest_fn, opts)) = self.param_types.get(&ct.into()) {
// `ct` has not had args transformed
let mut nargs = ct.args().to_vec();
// We don't care if `nargs` are changed, we're just calling `dest_fn`
nargs
.iter_mut()
.try_for_each(|ta| ta.transform(self).map(|_ch| ()))?;
dest_fn(&nargs)
} else {
None
})
ty = dest_fn(&nargs).map(|ty| (ty, opts.clone()))
};
let Some((mut ty, opts)) = ty else {
return Ok(None);
};
if opts.process_recursive {
ty.transform(self)?;
}
Ok(Some(ty))
}
}

Expand Down Expand Up @@ -304,6 +330,14 @@ impl ReplaceTypes {
}

/// Configures this instance to replace occurrences of type `src` with `dest`.
/// Equivalent to [Self::replace_type_opts] with [ReplacementOptions::default()]
pub fn replace_type(&mut self, src: CustomType, dest: Type) {
self.replace_type_opts(src, dest, ReplacementOptions::default())
}

/// Configures this instance to replace occurrences of type `src` with `dest`,
/// according to the given `ReplacementOptions`.
///
/// Note that if `src` is an instance of a *parametrized* [`TypeDef`], this takes
/// precedence over [`Self::replace_parametrized_type`] where the `src`s overlap. Thus, this
/// should only be used on already-*[monomorphize](super::monomorphize())d* Hugrs, as
Expand All @@ -316,14 +350,28 @@ impl ReplaceTypes {
/// Note that if `src` is Copyable and `dest` is Linear, then (besides linearity violations)
/// [`SignatureError`] will be raised if this leads to an impossible type e.g. ArrayOfCopyables(src).
/// (This can be overridden by an additional [`Self::replace_type`].)
pub fn replace_type(&mut self, src: CustomType, dest: Type) {
// We could check that 'dest' is copyable or 'src' is linear, but since we can't
// check that for parametrized types, we'll be consistent and not check here either.
self.type_map.insert(src, dest);
pub fn replace_type_opts(&mut self, src: CustomType, dest: Type, opts: ReplacementOptions) {
// We could check that 'dest' is copyable, 'src' is linear, or relevant copy and
// discard functions are registered with the linearizer; but since we can't check
// that for parametrized types, we'll be consistent and not check here either.
self.type_map.insert(src, (dest, opts));
}

/// Configures this instance to change occurrences of a parametrized type `src`
/// via a callback that builds the replacement type given the [`TypeArg`]s.
/// Equivalent to [Self::replace_parametrized_type_opts] with [ReplacementOptions::default].
pub fn replace_parametrized_type(
&mut self,
src: &TypeDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static,
) {
self.replace_parametrized_type_opts(src, dest_fn, ReplacementOptions::default())
}

/// Configures this instance to change occurrences of a parametrized type `src`
/// via a callback that builds the replacement type given the [`TypeArg`]s,
/// and using the given [ReplacementOptions].
///
/// Note that the `TypeArgs` will already have been updated (e.g. they may not
/// fit the bounds of the original type). The callback may return `None` to indicate
/// no change (in which case the supplied `TypeArgs` will be given to `src`).
Expand All @@ -332,10 +380,11 @@ impl ReplaceTypes {
/// [`Self::replace_consts_parametrized`] (or [`Self::replace_consts`]) as the
/// [`LoadConstant`]s will be reparametrized (and this will break the edge from [Const] to
/// [`LoadConstant`]).
pub fn replace_parametrized_type(
pub fn replace_parametrized_type_opts(
&mut self,
src: &TypeDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<Type> + 'static,
opts: ReplacementOptions,
) {
// No way to check that dest_fn never produces a linear type.
// We could require copy/discard-generators if src is Copyable, or *might be*
Expand All @@ -346,7 +395,8 @@ impl ReplaceTypes {
// dest_fn: impl Fn(&TypeArg) -> (Type,
// Fn(&Linearizer) -> NodeTemplate, // copy
// Fn(&Linearizer) -> NodeTemplate)` // discard
self.param_types.insert(src.into(), Arc::new(dest_fn));
self.param_types
.insert(src.into(), (Arc::new(dest_fn), opts));
}

/// Allows to configure how to deal with types/wires that were [Copyable]
Expand Down Expand Up @@ -445,6 +495,26 @@ impl ReplaceTypes {
self.regions = Some(regions.into_iter().collect());
}

fn change_subtree(
&self,
hugr: &mut impl HugrMut<Node = Node>,
root: Node,
linearize_unchanged_ops: bool,
) -> Result<bool, ReplaceTypesError> {
let mut changed = false;
for n in hugr.descendants(root).collect::<Vec<_>>() {
if self.change_node(hugr, n)? {
changed = true;
} else if !linearize_unchanged_ops {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix the "bug" in the other comment but tests are now breaking!

continue;
}
if n != hugr.entrypoint() {
self.linearize_outputs(hugr, n)?;
}
}
Ok(changed)
}

fn change_node(
&self,
hugr: &mut impl HugrMut<Node = Node>,
Expand Down Expand Up @@ -527,12 +597,9 @@ impl ReplaceTypes {
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
if opts.linearize {
for d in hugr.descendants(n).collect::<Vec<_>>() {
if d != n {
self.linearize_outputs(hugr, d)?;
}
}
if opts.process_recursive {
self.change_subtree(hugr, n, opts.linearize_unchanged)?;
// change_subtree does not linearize it's root, but that's done by our caller
}
true
} else {
Expand Down Expand Up @@ -620,12 +687,7 @@ impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
};
let mut changed = false;
for region_root in regions {
for n in hugr.descendants(*region_root).collect::<Vec<_>>() {
changed |= self.change_node(hugr, n)?;
if n != hugr.entrypoint() && changed {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note this is clearly wrong - it means linearization applies to the first node that changes and all nodes coming after that (!)

self.linearize_outputs(hugr, n)?;
}
}
changed |= self.change_subtree(hugr, *region_root, false)?;
}
Ok(changed)
}
Expand Down Expand Up @@ -677,6 +739,7 @@ impl From<&OpDef> for ParametricOp {
mod test {
use std::sync::Arc;

use crate::replace_types::ReplacementOptions;
use crate::replace_types::handlers::generic_array_const;
use hugr_core::builder::{
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
Expand All @@ -686,15 +749,17 @@ mod test {
use hugr_core::extension::prelude::{
ConstUsize, UnwrapBuilder, bool_t, option_type, qb_t, usize_t,
};
use hugr_core::extension::simple_op::MakeOpDef;
use hugr_core::extension::{TypeDefBound, Version, simple_op::MakeExtensionOp};
use hugr_core::hugr::hugrmut::HugrMut;
use hugr_core::hugr::{IdentList, ValidationError};
use hugr_core::ops::constant::CustomConst;
use hugr_core::ops::constant::OpaqueValue;
use hugr_core::ops::{ExtensionOp, OpTrait, OpType, Tag, Value};
use hugr_core::ops::constant::{CustomConst, OpaqueValue};
use hugr_core::ops::{self, ExtensionOp, OpTrait, OpType, Tag, Value};
use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef;
use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES};
use hugr_core::std_extensions::collections::array::{Array, ArrayKind, GenericArrayValue};
use hugr_core::std_extensions::collections::array::{
self, Array, ArrayKind, ArrayOpDef, GenericArrayValue, array_type, array_type_def,
};
use hugr_core::std_extensions::collections::list::{
ListOp, ListValue, list_type, list_type_def,
};
Expand All @@ -703,7 +768,7 @@ mod test {
};

use hugr_core::types::{
EdgeKind, PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow,
EdgeKind, PolyFuncType, Signature, SumType, Term, Type, TypeArg, TypeBound, TypeRow,
};
use hugr_core::{Direction, Extension, HugrView, Port, type_row};
use itertools::Itertools;
Expand Down Expand Up @@ -1255,4 +1320,101 @@ mod test {
})
);
}

#[test]
fn compositionality() {
let ext = ext();
let mut lowerer = lowerer(&ext);
// Replace std Array's with 64 elements with PackedVec's
let ext2 = ext.clone();
lowerer.replace_parametrized_type_opts(
array_type_def(),
move |args| {
let [sz, ty] = args else {
panic!("Expected two args to array")
};
(sz == &Term::BoundedNat(64)).then_some(
ext2.get_type(PACKED_VEC)
.unwrap()
.instantiate([ty.clone()])
.unwrap()
.into(),
)
},
ReplacementOptions::default().with_recursive_replacement(true),
);

// Replacement of `get` is complex because we need to wrap result of read into a Some
let ext = ext.clone();
lowerer.replace_parametrized_op_with(
array::EXTENSION
.get_op(ArrayOpDef::get.opdef_id().as_str())
.unwrap()
.as_ref(),
move |args| {
let [sz, Term::Runtime(ty)] = args else {
panic!("Expected two args to array-get")
};
if sz != &Term::BoundedNat(64) {
return None;
}
let pv = ext
.get_type(PACKED_VEC)
.unwrap()
.instantiate([ty.clone().into()])
.unwrap();

let mut dfb = DFGBuilder::new(Signature::new(
vec![pv.clone().into(), usize_t()],
vec![option_type(ty.clone()).into(), pv.into()],
))
.unwrap();
let [pvec, idx] = dfb.input_wires_arr();
let [idx] = dfb
.add_dataflow_op(ConvertOpDef::ifromusize.without_log_width(), [idx])
.unwrap()
.outputs_arr();
let [elem] = dfb
.add_dataflow_op(read_op(&ext, ty.clone()), [pvec, idx])
.unwrap()
.outputs_arr();
let [wrapped_elem] = dfb
.add_dataflow_op(
ops::Tag::new(1, vec![type_row![], ty.clone().into()]),
[elem],
)
.unwrap()
.outputs_arr();
Some(NodeTemplate::CompoundOp(Box::new(
dfb.finish_hugr_with_outputs([wrapped_elem, pvec]).unwrap(),
)))
},
ReplacementOptions::default().with_recursive_replacement(true),
);

// Arrays of 64 bools should thus be transformed into PackedVec<bool> and then to int64s
// Arrays of 64 non-bools should thus become PackedVec<T> and then back to ValueArray<64, T>
let a64 = |t| array_type(64, t);
let opt = |t| Type::from(option_type(t));
let mut dfb = DFGBuilder::new(Signature::new(
vec![a64(bool_t()), a64(usize_t())],
vec![opt(bool_t()), a64(bool_t()), opt(usize_t()), a64(usize_t())],
))
.unwrap();
let [bools, usizes] = dfb.input_wires_arr();
let idx = dfb.add_load_value(ConstUsize::new(5));
let [b, bools] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(bool_t(), 64), [bools, idx])
.unwrap()
.outputs_arr();
let [u, usizes] = dfb
.add_dataflow_op(ArrayOpDef::get.to_concrete(usize_t(), 64), [usizes, idx])
.unwrap()
.outputs_arr();
let mut h = dfb.finish_hugr_with_outputs([b, bools, u, usizes]).unwrap();

lowerer.run(&mut h).unwrap();

h.validate().unwrap();
}
}
Loading