Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 141 additions & 52 deletions hugr-passes/src/replace_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,9 @@ impl NodeTemplate {
}
}

fn replace(&self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
fn replace(self, hugr: &mut impl HugrMut<Node = Node>, n: Node) -> Result<(), BuildError> {
assert_eq!(hugr.children(n).count(), 0);
let new_optype = match self.clone() {
let new_optype = match self {
NodeTemplate::SingleOp(op_type) => op_type,
NodeTemplate::CompoundOp(new_h) => {
let new_entrypoint = hugr.insert_hugr(n, *new_h).inserted_entrypoint;
Expand Down Expand Up @@ -171,6 +171,37 @@ fn call<H: HugrView<Node = Node>>(
Ok(Call::try_new(func_sig, type_args)?)
}

/// Options for how the replacement for an op is processed. May be specified by
/// [ReplaceTypes::replace_op_with] and [ReplaceTypes::replace_parametrized_op_with].
/// Otherwise (the default), replacements are inserted as is (without further processing).
// TODO would be good to migrate to default being process_recursive: true
#[derive(Clone, Default, PartialEq, Eq)] // More derives might inhibit future extension
pub struct ReplacementOptions {
process_recursive: bool,
linearize: bool,
}

impl ReplacementOptions {
/// Specifies that the replacement should be processed by the same [ReplaceTypes].
/// This increases compositionality (in that replacements for different ops do not
/// need to account for each other), but would lead to an infinite loop if e.g.
/// changing an op for a DFG containing an instance of the same op.
pub fn with_recursive_replacement(mut self, rec: bool) -> Self {
self.process_recursive = rec;
self
}

/// Specifies that the replacement should be linearized.
/// If [Self::with_recursive_replacement] has been set, this applies linearization
/// even to ops (within the original replacement) that are not altered by the
/// recursive processing. Otherwise, can be used to apply linearization without
/// changing any other ops.
pub fn with_linearization(mut self, lin: bool) -> Self {
self.linearize = lin;
self
}
}

/// A configuration of what types, ops, and constants should be replaced with what.
/// May be applied to a Hugr via [`Self::run`].
///
Expand Down Expand Up @@ -203,8 +234,14 @@ pub struct ReplaceTypes {
type_map: HashMap<CustomType, Type>,
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
linearize: DelegatingLinearizer,
op_map: HashMap<OpHashWrapper, NodeTemplate>,
param_ops: HashMap<ParametricOp, Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>>,
op_map: HashMap<OpHashWrapper, (NodeTemplate, ReplacementOptions)>,
param_ops: HashMap<
ParametricOp,
(
Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>,
ReplacementOptions,
),
>,
consts: HashMap<
CustomType,
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError>>,
Expand Down Expand Up @@ -337,13 +374,39 @@ impl ReplaceTypes {
}

/// Configures this instance to change occurrences of `src` to `dest`.
/// Equivalent to [Self::replace_op_with] with default [ReplacementOptions].
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
self.replace_op_with(src, dest, ReplacementOptions::default())
}

/// Configures this instance to change occurrences of `src` to `dest`.
///
/// Note that if `src` is an instance of a *parametrized* [`OpDef`], this takes
/// precedence over [`Self::replace_parametrized_op`] where the `src`s overlap. Thus,
/// this should only be used on already-*[monomorphize](super::monomorphize())d*
/// Hugrs, as substitution (parametric polymorphism) happening later will not respect
/// this replacement.
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
self.op_map.insert(OpHashWrapper::from(src), dest);
pub fn replace_op_with(
&mut self,
src: &ExtensionOp,
dest: NodeTemplate,
opts: ReplacementOptions,
) {
self.op_map.insert(OpHashWrapper::from(src), (dest, opts));
}

/// Configures this instance to change occurrences of a parametrized op `src`
/// via a callback that builds the replacement type given the [`TypeArg`]s.
/// Equivalent to [Self::replace_parametrized_op_with] with default [ReplacementOptions].
pub fn replace_parametrized_op(
&mut self,
src: &OpDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
) {
self.param_ops.insert(
src.into(),
(Arc::new(dest_fn), ReplacementOptions::default()),
);
}

/// Configures this instance to change occurrences of a parametrized op `src`
Expand All @@ -352,12 +415,13 @@ impl ReplaceTypes {
/// fit the bounds of the original op).
///
/// If the Callback returns None, the new typeargs will be applied to the original op.
pub fn replace_parametrized_op(
pub fn replace_parametrized_op_with(
&mut self,
src: &OpDef,
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
opts: ReplacementOptions,
) {
self.param_ops.insert(src.into(), Arc::new(dest_fn));
self.param_ops.insert(src.into(), (Arc::new(dest_fn), opts));
}

/// Configures this instance to change [Const]s of type `src_ty`, using
Expand Down Expand Up @@ -447,34 +511,42 @@ impl ReplaceTypes {
| rest.transform(self)?),

OpType::Const(Const { value, .. }) => self.change_value(value),
OpType::ExtensionOp(ext_op) => Ok(
// Copy/discard insertion done by caller
if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
OpType::ExtensionOp(ext_op) => Ok({
let def = ext_op.def_arc();
let mut changed = false;
let replacement = match self.op_map.get(&OpHashWrapper::from(&*ext_op)) {
r @ Some(_) => r.cloned(),
None => {
let mut args = ext_op.args().to_vec();
changed = args.transform(self)?;
let r2 = self
.param_ops
.get(&def.as_ref().into())
.and_then(|(rep_fn, opts)| rep_fn(&args).map(|nt| (nt, opts.clone())));
if r2.is_none() && changed {
*ext_op = ExtensionOp::new(def.clone(), args)?;
}
r2
}
};
if let Some((replacement, opts)) = replacement {
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
true
} else {
let def = ext_op.def_arc();
let mut args = ext_op.args().to_vec();
let ch = args.transform(self)?;
if let Some(replacement) = self
.param_ops
.get(&def.as_ref().into())
.and_then(|rep_fn| rep_fn(&args))
{
replacement
.replace(hugr, n)
.map_err(|e| ReplaceTypesError::AddTemplateError(n, Box::new(e)))?;
true
} else {
if ch {
*ext_op = ExtensionOp::new(def.clone(), args)?;
if opts.process_recursive {
self.change_subtree(hugr, n, opts.linearize)?;
} else if opts.linearize {
for d in hugr.descendants(n).collect::<Vec<_>>() {
if d != n {
self.linearize_outputs(hugr, d)?;
}
}
ch
}
},
),
true
} else {
changed
}
}),

OpType::OpaqueOp(_) => panic!("OpaqueOp should not be in a Hugr"),

Expand Down Expand Up @@ -518,34 +590,51 @@ impl ReplaceTypes {
Value::Function { hugr } => self.run(&mut **hugr),
}
}
}

impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
type Error = ReplaceTypesError;
type Result = bool;

fn run(&self, hugr: &mut H) -> Result<bool, ReplaceTypesError> {
fn change_subtree<H: HugrMut<Node = Node>>(
&self,
hugr: &mut H,
root: H::Node,
linearize_if_no_change: bool,
) -> Result<bool, ReplaceTypesError> {
let mut changed = false;
for n in hugr.entry_descendants().collect::<Vec<_>>() {
for n in hugr.descendants(root).collect::<Vec<_>>() {
changed |= self.change_node(hugr, n)?;
let new_dfsig = hugr.get_optype(n).dataflow_signature();
if let Some(new_sig) = new_dfsig
.filter(|_| changed && n != hugr.entrypoint())
.map(Cow::into_owned)
{
for outp in new_sig.output_ports() {
if !new_sig.out_port_type(outp).unwrap().copyable() {
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
if targets.len() != 1 {
hugr.disconnect(n, outp);
let src = Wire::new(n, outp);
self.linearize.insert_copy_discard(hugr, src, &targets)?;
}
if n != root && (changed || linearize_if_no_change) {
self.linearize_outputs(hugr, n)?;
}
}
Ok(changed)
}

fn linearize_outputs<H: HugrMut<Node = Node>>(
&self,
hugr: &mut H,
n: H::Node,
) -> Result<(), LinearizeError> {
if let Some(new_sig) = hugr.get_optype(n).dataflow_signature() {
let new_sig = new_sig.into_owned();
for outp in new_sig.output_ports() {
if !new_sig.out_port_type(outp).unwrap().copyable() {
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
if targets.len() != 1 {
hugr.disconnect(n, outp);
let src = Wire::new(n, outp);
self.linearize.insert_copy_discard(hugr, src, &targets)?;
}
}
}
}
Ok(changed)
Ok(())
}
}

impl<H: HugrMut<Node = Node>> ComposablePass<H> for ReplaceTypes {
type Error = ReplaceTypesError;
type Result = bool;

fn run(&self, hugr: &mut H) -> Result<bool, ReplaceTypesError> {
self.change_subtree(hugr, hugr.entrypoint(), false)
}
}

Expand Down
82 changes: 72 additions & 10 deletions hugr-passes/src/replace_types/linearize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,10 @@ pub trait Linearizer {
src: Wire,
targets: &[(Node, IncomingPort)],
) -> Result<(), LinearizeError> {
let sig = hugr.signature(src.node()).unwrap();
let typ = sig.port_type(src.source()).unwrap();
let (tgt_node, tgt_inport) = if targets.len() == 1 {
*targets.first().unwrap()
} else {
// Fail fast if the edges are nonlocal. (TODO transform to local edges!)
// Fail fast if the edges are nonlocal.
let src_parent = hugr
.get_parent(src.node())
.expect("Root node cannot have out edges");
Expand All @@ -74,7 +72,8 @@ pub trait Linearizer {
tgt_parent,
});
}
let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it
let sig = hugr.signature(src.node()).unwrap();
let typ = sig.port_type(src.source()).unwrap().clone();
let copy_discard_op = self
.copy_discard_op(&typ, targets.len())?
.add_hugr(hugr, src_parent)
Expand Down Expand Up @@ -148,7 +147,8 @@ pub enum LinearizeError {
sig: Option<Box<Signature>>,
},
#[error(
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent})"
"Cannot add nonlocal edge for linear type from {src} (with parent {src_parent}) to {tgt} (with parent {tgt_parent}).
Try using LocalizeEdges pass first."
)]
NoLinearNonLocalEdges {
src: Node,
Expand Down Expand Up @@ -367,11 +367,11 @@ mod test {
use std::sync::Arc;

use hugr_core::builder::{
BuildError, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
inout_sig,
BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer,
HugrBuilder, inout_sig,
};

use hugr_core::extension::prelude::{option_type, usize_t};
use hugr_core::extension::prelude::{option_type, qb_t, usize_t};
use hugr_core::extension::simple_op::MakeExtensionOp;
use hugr_core::extension::{
CustomSignatureFunc, OpDef, SignatureError, SignatureFunc, TypeDefBound, Version,
Expand All @@ -385,14 +385,16 @@ mod test {
};
use hugr_core::types::type_param::TypeParam;
use hugr_core::types::{
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeEnum, TypeRow,
FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeEnum, TypeRow,
};
use hugr_core::{Extension, Hugr, HugrView, Node, hugr::IdentList, type_row};
use itertools::Itertools;
use rstest::rstest;

use crate::replace_types::handlers::linearize_value_array;
use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError};
use crate::replace_types::{
LinearizeError, NodeTemplate, ReplaceTypesError, ReplacementOptions,
};
use crate::{ComposablePass, ReplaceTypes};

const LIN_T: &str = "Lin";
Expand Down Expand Up @@ -855,4 +857,64 @@ mod test {
panic!("Expected error");
}
}

#[test]
fn use_in_op_callback() {
let (e, mut lowerer) = ext_lowerer();
let drop_ext = Extension::new_arc(
IdentList::new_unchecked("DropExt"),
Version::new(0, 0, 0),
|e, w| {
e.add_op(
"drop".into(),
String::new(),
PolyFuncTypeRV::new(
[TypeBound::Linear.into()], // It won't *lower* for any type tho!
Signature::new(Type::new_var_use(0, TypeBound::Linear), vec![]),
),
w,
)
.unwrap();
},
);
let drop_op = drop_ext.get_op("drop").unwrap();
lowerer.replace_parametrized_op_with(
drop_op,
|args| {
let [TypeArg::Runtime(ty)] = args else {
panic!("Expected just one type")
};
// The Hugr here is invalid, so we have to pull it out manually
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
let h = std::mem::take(dfb.hugr_mut());
Comment on lines +887 to +889
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this feels a bit odd, however, I see your point about the explicit version becoming quite cumbersome if we want to compose stuff. Overall, the approach in this PR feels like the better way to go 👍

Some(NodeTemplate::CompoundOp(Box::new(h)))
},
ReplacementOptions::default().with_linearization(true),
);

let build_hugr = |ty: Type| {
let mut dfb = DFGBuilder::new(Signature::new(ty.clone(), vec![])).unwrap();
let [inp] = dfb.input_wires_arr();
let drop_op = drop_ext
.instantiate_extension_op("drop", [ty.into()])
.unwrap();
dfb.add_dataflow_op(drop_op, [inp]).unwrap();
dfb.finish_hugr().unwrap()
};
// We can drop a tuple of 2* lin_t
let lin_t = Type::from(e.get_type(LIN_T).unwrap().instantiate([]).unwrap());
let mut h = build_hugr(Type::new_tuple(vec![lin_t; 2]));
lowerer.run(&mut h).unwrap();
h.validate().unwrap();
let mut exts = h.nodes().filter_map(|n| h.get_optype(n).as_extension_op());
assert_eq!(exts.clone().count(), 2);
assert!(exts.all(|eo| eo.qualified_id() == "TestExt.discard"));

// We cannot drop a qubit
let mut h = build_hugr(qb_t());
assert_eq!(
lowerer.run(&mut h).unwrap_err(),
ReplaceTypesError::LinearizeError(LinearizeError::NeedCopyDiscard(Box::new(qb_t())))
);
}
}
Loading