diff --git a/Cargo.lock b/Cargo.lock index 1385ae530..235701764 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2679,11 +2679,14 @@ dependencies = [ "bindgen", "cc", "conan2", + "hugr", + "itertools 0.14.0", "libc", "rstest 0.26.1", "serde", "serde_json", "thiserror 2.0.17", + "tket", "tket-json-rs", ] diff --git a/tket/src/passes/pytket.rs b/tket/src/passes/pytket.rs index f68933eb8..3eaf72e02 100644 --- a/tket/src/passes/pytket.rs +++ b/tket/src/passes/pytket.rs @@ -7,7 +7,7 @@ use derive_more::{Display, Error, From}; use hugr::{HugrView, Node}; use itertools::Itertools; -use crate::serialize::pytket::OpConvertError; +use crate::serialize::pytket::PytketEncodeOpError; use crate::Circuit; use super::find_tuple_unpack_rewrites; @@ -37,7 +37,7 @@ pub enum PytketLoweringError { /// An error occurred during the conversion of an operation. #[display("operation conversion error: {_0}")] #[from] - OpConversionError(OpConvertError), + OpConversionError(PytketEncodeOpError), /// The circuit is not fully-contained in a region. /// Function calls are not supported. #[display("Non-local operations found. Function calls are not supported.")] diff --git a/tket/src/serialize/pytket.rs b/tket/src/serialize/pytket.rs index 8d6ae77d0..b8f8f7ddf 100644 --- a/tket/src/serialize/pytket.rs +++ b/tket/src/serialize/pytket.rs @@ -15,7 +15,9 @@ pub use config::{ TypeTranslatorSet, }; pub use encoder::PytketEncoderContext; -pub use error::{OpConvertError, PytketDecodeError, PytketDecodeErrorInner, PytketEncodeError}; +pub use error::{ + PytketDecodeError, PytketDecodeErrorInner, PytketEncodeError, PytketEncodeOpError, +}; pub use extension::PytketEmitter; pub use options::{DecodeInsertionTarget, DecodeOptions, EncodeOptions}; @@ -115,7 +117,11 @@ impl TKETDecode for SerialCircuit { fn decode(&self, options: DecodeOptions) -> Result { let mut hugr = Hugr::new(); - let main_func = self.decode_inplace(&mut hugr, DecodeInsertionTarget::Function, options)?; + let main_func = self.decode_inplace( + &mut hugr, + DecodeInsertionTarget::Function { fn_name: None }, + options, + )?; hugr.set_entrypoint(main_func); Ok(hugr.into()) } @@ -126,25 +132,14 @@ impl TKETDecode for SerialCircuit { target: DecodeInsertionTarget, options: DecodeOptions, ) -> Result { - let config = options - .config - .unwrap_or_else(|| default_decoder_config().into()); - - let mut decoder = PytketDecoderContext::new( - self, - hugr, - target, - options.fn_name, - options.signature, - options.input_params, - config, - )?; + let mut decoder = PytketDecoderContext::new(self, hugr, target, options)?; decoder.run_decoder(&self.commands)?; Ok(decoder.finish()?.node()) } fn encode(circuit: &Circuit, options: EncodeOptions) -> Result { - EncodedCircuit::from_hugr(circuit, options)?.extract_standalone() + let mut encoded = EncodedCircuit::new_standalone(circuit, options)?; + Ok(std::mem::take(&mut encoded[circuit.parent()])) } } diff --git a/tket/src/serialize/pytket/circuit.rs b/tket/src/serialize/pytket/circuit.rs index ffbf3b40e..a0c0b1e95 100644 --- a/tket/src/serialize/pytket/circuit.rs +++ b/tket/src/serialize/pytket/circuit.rs @@ -2,15 +2,23 @@ use std::collections::{HashMap, VecDeque}; use std::ops::{Index, IndexMut}; +use std::sync::Arc; use hugr::core::HugrNode; -use hugr::ops::{OpTag, OpTrait}; -use hugr::{Hugr, HugrView}; -use rayon::iter::{IntoParallelIterator, IntoParallelRefIterator, IntoParallelRefMutIterator}; +use hugr::hugr::hugrmut::HugrMut; +use hugr::ops::handle::NodeHandle; +use hugr::ops::{OpParent, OpTag, OpTrait}; +use hugr::{Hugr, HugrView, Node}; +use hugr_core::hugr::internal::HugrMutInternals; +use itertools::Itertools; +use rayon::iter::{IntoParallelRefIterator, IntoParallelRefMutIterator, ParallelIterator}; use tket_json_rs::circuit_json::{Command as PytketCommand, SerialCircuit}; +use crate::serialize::pytket::decoder::PytketDecoderContext; +use crate::serialize::pytket::opaque::SubgraphId; use crate::serialize::pytket::{ - default_encoder_config, EncodeOptions, PytketEncodeError, PytketEncoderContext, + default_encoder_config, DecodeInsertionTarget, DecodeOptions, EncodeOptions, PytketDecodeError, + PytketDecodeErrorInner, PytketDecoderConfig, PytketEncodeError, PytketEncoderContext, }; use crate::Circuit; @@ -22,59 +30,179 @@ use super::opaque::OpaqueSubgraphs; /// circuit, so we can reconstruct the HUGR if needed. /// /// Serial circuits in this structure are intended to be transient, only alive -/// while this structure is in memory. -/// To obtain a fully standalone pytket circuit that can be used independently, -/// and stored permanently, use [`EncodedCircuit::extract_standalone`]. -pub struct EncodedCircuit<'a, H: HugrView = Hugr> { - /// Region in the HUGR that was encoded as the main circuit. - /// - /// If [`EncodeOptions::encode_subcircuits`] was set during the encoding - /// process, `circuits` will contain entries for some dataflow regions that - /// descendants of this node. - /// - /// If [`EncodeOptions::encode_subcircuits`] was not set, `circuits` will - /// only contain an entry for this region if it was a dataflow container, or - /// no entries if it was not. - head_region: H::Node, +/// while this structure is in memory. To obtain a fully standalone pytket +/// circuit that can be used independently, and stored permanently, use +/// [`EncodedCircuit::new_standalone`] or call +/// [`EncodedCircuit::ensure_standalone`]. +#[derive(Debug, Clone)] +pub struct EncodedCircuit { /// Circuits encoded from independent dataflow regions in the HUGR. /// /// These correspond to sections of the HUGR that can be optimized /// independently. - circuits: HashMap, + circuits: HashMap, /// Sets of subgraphs in the HUGR that have been encoded as opaque barriers /// in the pytket circuit. /// /// Subcircuits are identified in the barrier metadata by their ID in this /// vector. See [`SubgraphId`]. - opaque_subgraphs: OpaqueSubgraphs, - /// The HUGR from where the pytket circuits were encoded. - hugr: &'a H, + opaque_subgraphs: OpaqueSubgraphs, +} + +/// Information stored about a pytket circuit encoded from a HUGR region. +#[derive(Debug, Clone)] +pub(super) struct EncodedCircuitInfo { + /// The serial circuit encoded from the region. + pub serial_circuit: SerialCircuit, + /// A subgraph of the region that does not contain any operation encodable + /// as a pytket command, and hence was not encoded in [`serial_circuit`]. + #[expect(unused)] + pub extra_subgraph: Option, + /// List of parameters in the pytket circuit in the order they appear in the + /// hugr input. + /// + /// We require this to correctly reconstruct the input order in the reassembled hugr, + /// since parameters in pytket are unordered. + pub input_params: Vec, + /// List of output parameter expressions found at the end of the encoded region. + // + // TODO: The decoder does not currently connect these. + pub output_params: Vec, } -impl<'a, H: HugrView> EncodedCircuit<'a, H> { - /// Encode a Hugr into a [`EncodedCircuit`]. +impl EncodedCircuit { + /// Encode a HugrView into a [`EncodedCircuit`]. /// /// The HUGR's entrypoint must be a dataflow region that will be encoded as /// the main circuit. Additional circuits may be encoded if /// [`EncodeOptions::encode_subcircuits`] is set. /// /// The circuit may contain opaque barriers referencing subgraphs in the - /// original HUGR. To extract a fully standalone pytket circuit that can be - /// used independently, use [`EncodedCircuit::extract_standalone`]. + /// original HUGR. To obtain a fully standalone pytket circuit that can be + /// used independently, and stored permanently, use + /// [`EncodedCircuit::new_standalone`] or call + /// [`EncodedCircuit::ensure_standalone`]. + /// + /// See [`EncodeOptions`] for the options used by the encoder. + pub fn new + AsMut + HugrView>( + circuit: &Circuit, + options: EncodeOptions, + ) -> Result> { + let mut enc = Self { + circuits: HashMap::new(), + opaque_subgraphs: OpaqueSubgraphs::new(0), + }; + + enc.encode_circuits(circuit, options)?; + + Ok(enc) + } + + /// Reassemble the encoded circuits into the original [`Hugr`], replacing + /// the existing regions that were encoded as subcircuits. + /// + /// + /// + /// # Arguments + /// + /// - `hugr`: The [`Hugr`] to reassemble the circuits in. + /// - `config`: The set of extension decoders used to convert the pytket + /// commands into HUGR operations. + /// + /// # Returns + /// + /// A list of region parents whose contents were replaced by the updated + /// circuits. + /// + /// # Errors + /// + /// Returns a [`PytketDecodeErrorInner::IncompatibleTargetRegion`] error if + /// the source region of an encoded circuit does not match the circuit + /// signature. This is likely caused by the original hugr being modified + /// since the circuit was encoded. + /// + /// Returns an error if a circuit being decoded is invalid. See + /// [`PytketDecodeErrorInner`][super::error::PytketDecodeErrorInner] for + /// more details. + pub fn reassemble_inline( + &self, + hugr: &mut Hugr, + config: Option>, + ) -> Result, PytketDecodeError> { + let options = match &config { + Some(config) => DecodeOptions::new().with_config(config.clone()), + None => DecodeOptions::new().with_default_config(), + }; + + for (&original_region, encoded) in &self.circuits { + // Decode the circuit into a temporary function node. + let Some(signature) = hugr.get_optype(original_region).inner_function_type() else { + return Err(PytketDecodeErrorInner::IncompatibleTargetRegion { + region: original_region, + new_optype: hugr.get_optype(original_region).clone(), + } + .wrap()); + }; + let options = options + .clone() + .with_signature(signature.into_owned()) + .with_input_params(encoded.input_params.iter().cloned()); + + // Run the decoder, generating a new function with the extracted definition. + // + // Unsupported subgraphs of the original region will be transplanted here. + let mut decoder = PytketDecoderContext::new( + &encoded.serial_circuit, + hugr, + DecodeInsertionTarget::Function { fn_name: None }, + options, + )?; + decoder.register_opaque_subgraphs(&self.opaque_subgraphs); + decoder.run_decoder(&encoded.serial_circuit.commands)?; + let decoded_node = decoder.finish()?.node(); + + // Replace the region with the decoded function. + // + // All descendant nodes that were re-used by the decoded circuit got + // re-parented at this point, so we can just do a full clear here. + while let Some(child) = hugr.first_child(original_region) { + hugr.remove_subtree(child); + } + while let Some(child) = hugr.first_child(decoded_node) { + hugr.set_parent(child, original_region); + } + hugr.remove_node(decoded_node); + } + Ok(self.circuits.keys().copied().collect_vec()) + } +} + +impl EncodedCircuit { + /// Encode a HugrView into a [`EncodedCircuit`]. + /// + /// The HUGR's entrypoint must be a dataflow region that will be encoded as + /// the main circuit. Additional circuits may be encoded if + /// [`EncodeOptions::encode_subcircuits`] is set. + /// + /// The circuit may contain opaque barriers encoding opaque subgraphs in the + /// original HUGR. These are encoded completely as Hugr envelopes in the + /// barrier operations metadata. + /// + /// When encoding a `Hugr`, prefer using [`EncodedCircuit::new`] instead to + /// avoid unnecessary copying of the opaque subgraphs. /// /// See [`EncodeOptions`] for the options used by the encoder. - pub fn from_hugr( - circuit: &'a Circuit, + pub fn new_standalone>( + circuit: &Circuit, options: EncodeOptions, ) -> Result> { let mut enc = Self { - head_region: circuit.parent(), circuits: HashMap::new(), opaque_subgraphs: OpaqueSubgraphs::new(0), - hugr: circuit.hugr(), }; enc.encode_circuits(circuit, options)?; + enc.ensure_standalone(circuit.hugr())?; Ok(enc) } @@ -82,10 +210,10 @@ impl<'a, H: HugrView> EncodedCircuit<'a, H> { /// Encode the circuits for the entrypoint region to the hugr, and if [`EncodeOptions::encode_subcircuits`] is set, /// for the descendants of any unsupported node in the main circuit. /// - /// Auxiliary method for [`Self::from_hugr`]. + /// Auxiliary method for [`Self::new`] and [`Self::new_standalone`]. /// /// TODO: Add an option in [EncodeOptions] to run the subcircuit encoders in parallel. - fn encode_circuits( + fn encode_circuits>( &mut self, // This is already in [`self.hugr`], but we pass it since wrapping it // again results in a `Circuit<&H>`, which doesn't play well with @@ -98,7 +226,7 @@ impl<'a, H: HugrView> EncodedCircuit<'a, H> { // These may be either dataflow region parents that we can encode, or // any node with children that we should traverse recursively until we // find a dataflow region. - let mut candidate_nodes = VecDeque::from([self.head_region]); + let mut candidate_nodes = VecDeque::from([circuit.parent()]); let config = match options.config.take() { Some(config) => config, None => default_encoder_config().into(), @@ -136,46 +264,76 @@ impl<'a, H: HugrView> EncodedCircuit<'a, H> { let mut encoder: PytketEncoderContext = PytketEncoderContext::new(circuit, node, opaque_subgraphs, config.clone())?; encoder.run_encoder(circuit, node)?; - let (serial, _, opaque_subgraphs) = encoder.finish(circuit, node)?; + let (encoded, opaque_subgraphs) = encoder.finish(circuit, node)?; if options.encode_subcircuits { add_subgraph_candidates(&opaque_subgraphs, &mut candidate_nodes); } - self.circuits.insert(node, serial); + self.circuits.insert(node, encoded); self.opaque_subgraphs.merge(opaque_subgraphs); } Ok(()) } - /// Extract the top-level pytket circuit as a standalone definition - /// containing the whole original HUGR. + /// Reassemble the encoded circuits into a new [`Hugr`], containing a + /// function with the decoded circuit originally corresponding to `region`. /// - /// Traverses the commands in `head_circuit` and replaces - /// [`OpaqueSubgraphPayloadType::External`][super::opaque::OpaqueSubgraphPayloadType::External] - /// pointers in opaque barriers with inline payloads. + /// # Arguments /// - /// Discards any changes to the internal subcircuits, as they are not part - /// of the top-level circuit. + /// - `fn_name`: The name of the function to create. If `None`, we will use + /// the name of the circuit, or "main" if the circuit has no name. + /// - `options`: The options for the decoder. /// /// # Errors /// - /// Returns a [`PytketEncodeError::InvalidStandaloneHeadRegion`] error if - /// [`Self::head_region`] is not a dataflow container in the hugr. + /// Returns a [`PytketDecodeErrorInner::NotAnEncodedRegion`] error if + /// there is no encoded circuit for `region`. + pub fn reassemble( + &self, + region: Node, + fn_name: Option, + options: DecodeOptions, + ) -> Result { + if !self.contains_circuit(region) { + return Err(PytketDecodeErrorInner::NotAnEncodedRegion { + region: region.to_string(), + } + .wrap()); + } + let serial_circuit = &self[region]; + + if self.len() > 1 { + unimplemented!( + "Reassembling an `EncodedCircuit` with nested subcircuits is not yet implemented." + ); + }; + + let mut hugr = Hugr::new(); + let target = DecodeInsertionTarget::Function { fn_name }; + + let mut decoder = PytketDecoderContext::new(serial_circuit, &mut hugr, target, options)?; + decoder.run_decoder(&serial_circuit.commands)?; + decoder.finish()?; + Ok(hugr) + } + + /// Ensure that none of the encoded circuits contain references to opaque subgraphs in the original HUGR. + /// + /// Traverses the commands in the encoded circuits and replaces + /// [`OpaqueSubgraphPayload::External`][super::opaque::OpaqueSubgraphPayload::External] + /// payloads in opaque barriers with inline payloads. + /// + /// # Errors /// /// Returns an error if a barrier operation with the /// [`OPGROUP_OPAQUE_HUGR`][super::opaque::OPGROUP_OPAQUE_HUGR] /// opgroup has an invalid payload. - // - // TODO: We'll need to handle non-local edges and function definitions in this step. - pub fn extract_standalone(mut self) -> Result> { - if !self.check_dataflow_head_region() { - let head_op = self.hugr.get_optype(self.head_region).to_string(); - return Err(PytketEncodeError::InvalidStandaloneHeadRegion { head_op }); - }; - let mut serial_circuit = self.circuits.remove(&self.head_region).unwrap(); - + pub fn ensure_standalone( + &mut self, + hugr: &impl HugrView, + ) -> Result<(), PytketEncodeError> { /// Replace references to the `EncodedCircuit` context from the circuit commands. /// /// Replaces [`OpaqueSubgraphPayloadType::External`][super::opaque::OpaqueSubgraphPayloadType::External] @@ -196,49 +354,19 @@ impl<'a, H: HugrView> EncodedCircuit<'a, H> { } Ok(()) } - make_commands_standalone( - &mut serial_circuit.commands, - &self.opaque_subgraphs, - self.hugr, - )?; - - Ok(serial_circuit) - } - - /// Checks if [`Self::head_region`] was a dataflow container in the original hugr, - /// and therefore has an encoded circuit in this structure. - fn check_dataflow_head_region(&self) -> bool { - self.circuits.contains_key(&self.head_region) - } - - /// Returns the HUGR from where the circuit was encoded. - pub fn hugr(&self) -> &H { - self.hugr - } - /// Returns the region node from which the main circuit was encoded. - pub fn head_region(&self) -> H::Node { - self.head_region - } - - /// Returns an iterator over all the encoded pytket circuits. - pub fn circuits(&self) -> impl Iterator { - self.into_iter().map(|(&n, circ)| (n, circ)) - } - - /// Returns an iterator over all the encoded pytket circuits as mutable - /// references. - /// - /// The circuits may be modified arbitrarily, as long as - /// [`OpaqueSubgraphPayloadType::External`][super::opaque::OpaqueSubgraphPayloadType::External] - /// pointers to HUGR subgraphs in opaque barriers remain valid and - /// topologically consistent with the original circuit. - pub fn circuits_mut(&mut self) -> impl Iterator { - self.into_iter().map(|(&n, circ)| (n, circ)) + for encoded in self.circuits.values_mut() { + make_commands_standalone( + &mut encoded.serial_circuit.commands, + &self.opaque_subgraphs, + hugr, + )?; + } + Ok(()) } /// Returns `true` if there is an encoded pytket circuit for the given region. - pub fn contains_circuit(&self, region: H::Node) -> bool { + pub fn contains_circuit(&self, region: Node) -> bool { self.circuits.contains_key(®ion) } @@ -251,64 +379,52 @@ impl<'a, H: HugrView> EncodedCircuit<'a, H> { pub fn is_empty(&self) -> bool { self.circuits.is_empty() } -} - -impl<'a, H: HugrView> Index for EncodedCircuit<'a, H> { - type Output = SerialCircuit; - fn index(&self, index: H::Node) -> &Self::Output { - &self.circuits[&index] + /// Returns an iterator over the encoded pytket circuits. + pub fn iter(&self) -> impl Iterator { + self.circuits + .iter() + .map(|(&n, circ)| (n, &circ.serial_circuit)) } -} -impl<'a, H: HugrView> IndexMut for EncodedCircuit<'a, H> { - fn index_mut(&mut self, index: H::Node) -> &mut Self::Output { + /// Returns a mutable iterator over the encoded pytket circuits. + pub fn iter_mut(&mut self) -> impl Iterator { self.circuits - .get_mut(&index) - .unwrap_or_else(|| panic!("Indexing into a circuit that was not encoded: {index}")) + .iter_mut() + .map(|(&n, circ)| (n, &mut circ.serial_circuit)) } } -impl<'c, 'a, H: HugrView> IntoIterator for &'c EncodedCircuit<'a, H> { - type Item = (&'c H::Node, &'c SerialCircuit); - type IntoIter = <&'c HashMap as IntoIterator>::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.circuits.iter() +impl EncodedCircuit { + /// Returns a parallel iterator over the encoded pytket circuits. + pub fn par_iter(&self) -> impl ParallelIterator { + self.circuits + .par_iter() + .map(|(&n, circ)| (n, &circ.serial_circuit)) } -} - -impl<'c, 'a, H: HugrView> IntoIterator for &'c mut EncodedCircuit<'a, H> { - type Item = (&'c H::Node, &'c mut SerialCircuit); - type IntoIter = <&'c mut HashMap as IntoIterator>::IntoIter; - fn into_iter(self) -> Self::IntoIter { - self.circuits.iter_mut() + /// Returns a parallel mutable iterator over the encoded pytket circuits. + pub fn par_iter_mut(&mut self) -> impl ParallelIterator { + self.circuits + .par_iter_mut() + .map(|(&n, circ)| (n, &mut circ.serial_circuit)) } } -impl<'c, 'a, H> IntoParallelIterator for &'c EncodedCircuit<'a, H> -where - H: HugrView, - H::Node: Send + Sync, -{ - type Item = (&'c H::Node, &'c SerialCircuit); - type Iter = <&'c HashMap as IntoParallelIterator>::Iter; +impl Index for EncodedCircuit { + type Output = SerialCircuit; - fn into_par_iter(self) -> Self::Iter { - self.circuits.par_iter() + fn index(&self, index: Node) -> &Self::Output { + &self.circuits[&index].serial_circuit } } -impl<'c, 'a, H> IntoParallelIterator for &'c mut EncodedCircuit<'a, H> -where - H: HugrView, - H::Node: Send + Sync, -{ - type Item = (&'c H::Node, &'c mut SerialCircuit); - type Iter = <&'c mut HashMap as IntoParallelIterator>::Iter; - - fn into_par_iter(self) -> Self::Iter { - self.circuits.par_iter_mut() +impl IndexMut for EncodedCircuit { + fn index_mut(&mut self, index: Node) -> &mut Self::Output { + &mut self + .circuits + .get_mut(&index) + .unwrap_or_else(|| panic!("Indexing into a circuit that was not encoded: {index}")) + .serial_circuit } } diff --git a/tket/src/serialize/pytket/decoder.rs b/tket/src/serialize/pytket/decoder.rs index 2ba4042a7..3fac62315 100644 --- a/tket/src/serialize/pytket/decoder.rs +++ b/tket/src/serialize/pytket/decoder.rs @@ -1,6 +1,7 @@ //! Intermediate structure for decoding [`SerialCircuit`]s into [`Hugr`]s. mod param; +mod subgraph; mod tracked_elem; mod wires; @@ -9,6 +10,8 @@ pub use param::{LoadedParameter, ParameterType}; pub use tracked_elem::{TrackedBit, TrackedQubit}; pub use wires::TrackedWires; +pub(super) use wires::FoundWire; + use std::sync::Arc; use hugr::builder::{ @@ -34,23 +37,33 @@ use crate::extension::rotation::rotation_type; use crate::serialize::pytket::config::PytketDecoderConfig; use crate::serialize::pytket::decoder::wires::WireTracker; use crate::serialize::pytket::extension::{build_opaque_tket_op, RegisterCount}; -use crate::serialize::pytket::{DecodeInsertionTarget, PytketDecodeErrorInner}; +use crate::serialize::pytket::opaque::OpaqueSubgraphs; +use crate::serialize::pytket::{DecodeInsertionTarget, DecodeOptions, PytketDecodeErrorInner}; use crate::TketOp; /// State of the tket circuit being decoded. /// -/// The state of an in-progress [`FunctionBuilder`] being built from a [`SerialCircuit`]. +/// The state of an in-progress [`FunctionBuilder`] being built from a +/// [`SerialCircuit`]. +/// +/// The generic parameter `H` is the HugrView type of the Hugr that was encoded +/// into the circuit, if any. This is required when the encoded pytket circuit +/// contains opaque barriers that reference subgraphs in the original HUGR. See +/// [`OpaqueSubgraphPayload`][super::opaque::OpaqueSubgraphPayload] for more details. #[derive(Debug)] pub struct PytketDecoderContext<'h> { /// The Hugr being built. pub builder: DFGBuilder<&'h mut Hugr>, /// A tracker keeping track of the generated wires and their corresponding types. - pub(super) wire_tracker: WireTracker, - /// Configuration for decoding commands. - /// - /// Contains custom operation decoders, that define translation of legacy tket - /// commands into HUGR operations. - config: Arc, + pub(super) wire_tracker: Box, + /// Options used when decoding the circuit. + /// + /// This contains the decoding parameters specific to the current circuit + /// being decoded. + options: DecodeOptions, + /// A registry of opaque subgraphs from `original_hugr`, that are referenced by opaque barriers in the pytket circuit + /// via their [`SubgraphId`]. + pub(super) opaque_subgraphs: Option<&'h OpaqueSubgraphs>, } impl<'h> PytketDecoderContext<'h> { @@ -94,13 +107,16 @@ impl<'h> PytketDecoderContext<'h> { serialcirc: &SerialCircuit, hugr: &'h mut Hugr, target: DecodeInsertionTarget, - fn_name: Option, - signature: Option, - input_params: impl IntoIterator, - config: impl Into>, + mut options: DecodeOptions, ) -> Result { - let config: Arc = config.into(); - let signature = signature.unwrap_or_else(|| { + // Ensure that the set of decoders is present, use a default one if not. + if options.config.is_none() { + options = options.with_default_config(); + } + + // Compute the signature of the decoded region, if not provided, and + // initialize the DFG builder. + let signature = options.signature.clone().unwrap_or_else(|| { let num_qubits = serialcirc.qubits.len(); let num_bits = serialcirc.bits.len(); let types: TypeRow = [vec![qb_t(); num_qubits], vec![bool_t(); num_bits]] @@ -108,11 +124,11 @@ impl<'h> PytketDecoderContext<'h> { .into(); Signature::new(types.clone(), types) }); - let name = fn_name - .or_else(|| serialcirc.name.clone()) - .unwrap_or_default(); let mut dfg: DFGBuilder<&mut Hugr> = match target { - DecodeInsertionTarget::Function => { + DecodeInsertionTarget::Function { fn_name } => { + let name = fn_name + .or_else(|| serialcirc.name.clone()) + .unwrap_or_default(); FunctionBuilder::with_hugr(hugr, name, signature.clone()) .unwrap() .into_dfg_builder() @@ -131,8 +147,8 @@ impl<'h> PytketDecoderContext<'h> { serialcirc, &mut dfg, &signature.input, - input_params, - &config, + options.input_params.iter().cloned(), + options.get_config(), )?; if !serialcirc.phase.is_empty() { @@ -144,8 +160,9 @@ impl<'h> PytketDecoderContext<'h> { Ok(PytketDecoderContext { builder: dfg, - wire_tracker, - config, + wire_tracker: Box::new(wire_tracker), + options, + opaque_subgraphs: None, }) } @@ -314,8 +331,12 @@ impl<'h> PytketDecoderContext<'h> { }; e.hugr_op("Output") })?; + let output_wire_count = output_wires.register_count(); let output_wires = output_wires.wires(); + // Qubits not in the output need to be freed. + self.add_implicit_qfree_operations(&qubits[output_wire_count.qubits..]); + // Store the name for the input parameter wires let input_params = self.wire_tracker.finish(); if !input_params.is_empty() { @@ -332,12 +353,52 @@ impl<'h> PytketDecoderContext<'h> { .node()) } + /// Add the implicit QFree operations for a list of qubits that are not in the hugr output. + /// + /// We only do this if there's a wire with type `qb_t` containing the qubit. + fn add_implicit_qfree_operations(&mut self, qubits: &[TrackedQubit]) { + let qb_type = qb_t(); + let mut bit_args: &[TrackedBit] = &[]; + let mut params: &[LoadedParameter] = &[]; + for q in qubits.iter() { + let mut qubit_args: &[TrackedQubit] = std::slice::from_ref(q); + let Ok(FoundWire::Register(wire)) = self.wire_tracker.find_typed_wire( + self.config(), + &qb_type, + &mut qubit_args, + &mut bit_args, + &mut params, + None, + ) else { + continue; + }; + + self.builder + .add_dataflow_op(TketOp::QFree, [wire.wire()]) + .unwrap() + .out_wire(0); + } + } + + /// Register the set of opaque subgraphs that are present in the HUGR being decoded. + /// + /// # Arguments + /// - `opaque_subgraphs`: A registry of opaque subgraphs from + /// `self.builder.hugr()`, that are referenced by opaque barriers in the + /// pytket circuit via their [`SubgraphId`]. + pub(super) fn register_opaque_subgraphs( + &mut self, + opaque_subgraphs: &'h OpaqueSubgraphs, + ) { + self.opaque_subgraphs = Some(opaque_subgraphs); + } + /// Decode a list of pytket commands. pub(super) fn run_decoder( &mut self, commands: &[circuit_json::Command], ) -> Result<(), PytketDecodeError> { - let config = self.config.clone(); + let config = self.config().clone(); for com in commands { let op_type = com.op.op_type; self.process_command(com, config.as_ref()) @@ -378,11 +439,6 @@ impl<'h> PytketDecoderContext<'h> { } Ok(()) } - - /// Returns the configuration used by the decoder. - pub fn config(&self) -> &Arc { - &self.config - } } /// Public API, used by the [`PytketDecoder`][super::extension::PytketDecoder] implementers. @@ -420,7 +476,7 @@ impl<'h> PytketDecoderContext<'h> { params: &[LoadedParameter], ) -> Result { self.wire_tracker - .find_typed_wires(&self.config, types, qubit_args, bit_args, params) + .find_typed_wires(self.config(), types, qubit_args, bit_args, params) } /// Connects the input ports of a node using a list of input qubits, bits, @@ -486,12 +542,12 @@ impl<'h> PytketDecoderContext<'h> { let op_input_count: RegisterCount = sig .input_types() .iter() - .map(|ty| self.config.type_to_pytket(ty).unwrap_or_default()) + .map(|ty| self.config().type_to_pytket(ty).unwrap_or_default()) .sum(); let op_output_count: RegisterCount = sig .output_types() .iter() - .map(|ty| self.config.type_to_pytket(ty).unwrap_or_default()) + .map(|ty| self.config().type_to_pytket(ty).unwrap_or_default()) .sum(); // Validate input counts @@ -643,7 +699,7 @@ impl<'h> PytketDecoderContext<'h> { let mut port_types = sig.output_ports().zip(sig.output_types().iter()); while let Some((port, ty)) = port_types.next() { let wire = Wire::new(node, port); - let counts = self.config.type_to_pytket(ty).unwrap_or_default(); + let counts = self.config().type_to_pytket(ty).unwrap_or_default(); reg_count += counts; // Get the qubits and bits for this wire. @@ -653,7 +709,7 @@ impl<'h> PytketDecoderContext<'h> { let expected_qubits = reg_count.qubits - counts.qubits + wire_qubits.len(); let expected_bits = reg_count.bits - counts.bits + wire_bits.len(); return Err(make_unexpected_node_out_error( - &self.config, + self.config(), port_types, reg_count, expected_qubits, @@ -700,6 +756,16 @@ impl<'h> PytketDecoderContext<'h> { .load_half_turns_parameter(&mut self.builder, param, Some(typ)) .with_type(typ, &mut self.builder) } + + /// Returns the configuration used by the decoder. + pub fn config(&self) -> &Arc { + self.options.get_config() + } + + /// Returns the options used by the decoder. + pub fn options(&self) -> &DecodeOptions { + &self.options + } } /// Result of trying to decode pytket operation into a HUGR definition. @@ -721,7 +787,12 @@ pub enum DecodeStatus { Unsupported, } -/// Helper to continue exhausting the iterators in [`PytketDecoderContext::register_node_outputs`] until we have the total number of elements to report. +/// Helper to continue exhausting the iterators in +/// [`PytketDecoderContext::register_node_outputs`] until we have the total +/// number of elements to report. +/// +/// Processes remaining port types and adds them to the partial count of the +/// number of qubits and bits we expected to have available. fn make_unexpected_node_out_error<'ty>( config: &PytketDecoderConfig, port_types: impl IntoIterator, diff --git a/tket/src/serialize/pytket/decoder/subgraph.rs b/tket/src/serialize/pytket/decoder/subgraph.rs new file mode 100644 index 000000000..22aafe9f0 --- /dev/null +++ b/tket/src/serialize/pytket/decoder/subgraph.rs @@ -0,0 +1,407 @@ +//! Methods to decode opaque subgraphs from a pytket barrier operation. + +use std::collections::HashMap; +use std::ops::RangeTo; +use std::sync::Arc; + +use hugr::builder::Container; +use hugr::hugr::hugrmut::HugrMut; +use hugr::types::Type; +use hugr::{Hugr, HugrView, IncomingPort, Node, OutgoingPort, PortIndex, Wire}; +use hugr_core::hugr::internal::HugrMutInternals; +use itertools::Itertools; + +use crate::serialize::pytket::decoder::{ + DecodeStatus, FoundWire, LoadedParameter, PytketDecoderContext, TrackedBit, TrackedQubit, +}; +use crate::serialize::pytket::extension::RegisterCount; +use crate::serialize::pytket::opaque::{EncodedEdgeID, OpaqueSubgraphPayload, SubgraphId}; +use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig}; + +impl<'h> PytketDecoderContext<'h> { + /// Insert a subgraph encoded in the payload of a pytket barrier operation into + /// the Hugr being decoded. + pub(in crate::serialize::pytket) fn insert_subgraph_from_payload( + &mut self, + qubits: &[TrackedQubit], + bits: &[TrackedBit], + params: &[LoadedParameter], + payload: &OpaqueSubgraphPayload, + ) -> Result { + match payload { + OpaqueSubgraphPayload::External { id } => { + self.insert_external_subgraph(*id, qubits, bits, params) + } + OpaqueSubgraphPayload::Inline { + hugr_envelope, + inputs, + outputs, + } => self.insert_inline_subgraph(hugr_envelope, inputs, outputs, qubits, bits, params), + } + } + + /// Move the subgraph nodes referenced by an + /// [`OpaqueSubgraphPayload::External`] into the region being decoded. + fn insert_external_subgraph( + &mut self, + id: SubgraphId, + qubits: &[TrackedQubit], + bits: &[TrackedBit], + params: &[LoadedParameter], + ) -> Result { + fn mk_subgraph_error(id: SubgraphId, context: impl ToString) -> PytketDecodeError { + PytketDecodeErrorInner::InvalidExternalSubgraph { + id, + context: context.to_string(), + } + .wrap() + } + + let Some(subgraph) = self.opaque_subgraphs.and_then(|s| s.get(id)) else { + return Err(PytketDecodeErrorInner::OpaqueSubgraphNotFound { id }.wrap()); + }; + let signature = subgraph.signature(self.builder.hugr()); + + let old_parent = self + .builder + .hugr() + .get_parent(subgraph.nodes()[0]) + .ok_or_else(|| mk_subgraph_error(id, "Subgraph must contain dataflow nodes."))?; + let [old_input, old_output] = self + .builder + .hugr() + .get_io(old_parent) + .ok_or_else(|| mk_subgraph_error(id, "Subgraph must be in a dataflow region."))?; + let new_parent = self.builder.container_node(); + + // Re-parent the nodes in the subgraph. + for &node in subgraph.nodes() { + self.builder.hugr_mut().set_parent(node, new_parent); + } + + // Re-wire the input wires that should be connected to nodes in the new region. + let mut input_qubits = qubits; + let mut input_bits = bits; + let mut input_params = params; + for (ty, targets) in signature.input().iter().zip_eq(subgraph.incoming_ports()) { + let wire = match self.wire_tracker.find_typed_wire( + self.config(), + ty, + &mut input_qubits, + &mut input_bits, + &mut input_params, + None, + ) { + Ok(FoundWire::Register(wire_data)) => wire_data.wire(), + Ok(FoundWire::Parameter(param)) => param.wire(), + Ok(FoundWire::Unsupported { .. }) => { + unreachable!("`unsupported_wire` not passed to `find_typed_wire`."); + } + Err(PytketDecodeError { + inner: + PytketDecodeErrorInner::NoMatchingWire { .. } + | PytketDecodeErrorInner::NoMatchingParameter { .. }, + .. + }) => { + // Not a qubit or bit wire. + // If it was linked to the old circuit, we need to re-wire it. + // Otherwise we just leave it connected to the node it was linked to. + let Some((neigh, neigh_port)) = targets.first().and_then(|(tgt, port)| { + self.builder.hugr().single_linked_output(*tgt, *port) + }) else { + continue; + }; + if neigh != old_input { + continue; + } + Wire::new(neigh, neigh_port) + } + Err(e) => return Err(e), + }; + + for (tgt, port) in targets { + self.builder + .hugr_mut() + .connect(wire.node(), wire.source(), *tgt, *port); + } + } + + // Register the output wires that should be connected to nodes in the new region. + // + // Re-wire wires from the subgraph to the old region's outputs. + let mut output_qubits = qubits; + let mut output_bits = bits; + for (ty, (src, src_port)) in signature.output().iter().zip_eq(subgraph.outgoing_ports()) { + let wire = Wire::new(*src, *src_port); + match self.config().type_to_pytket(ty) { + Some(counts) => { + // Make sure to disconnect the old wire. + self.builder.hugr_mut().disconnect(*src, *src_port); + + let wire_qubits = split_off(&mut output_qubits, ..counts.qubits); + let wire_bits = split_off(&mut output_bits, ..counts.bits); + if wire_qubits.is_none() || wire_bits.is_none() { + return Err(make_unexpected_node_out_error( + self.config(), + signature.output().iter(), + qubits.len(), + bits.len(), + )); + } + self.wire_tracker.track_wire( + wire, + Arc::new(ty.clone()), + wire_qubits.unwrap().iter().cloned(), + wire_bits.unwrap().iter().cloned(), + )?; + } + None => { + // This is an unsupported wire. + // If it was connected to the old region output, rewire it. + // Otherwise leave it connected. + for (tgt, tgt_port) in self + .builder + .hugr() + .linked_inputs(*src, *src_port) + .collect_vec() + { + if tgt == old_output { + self.builder.hugr_mut().disconnect(tgt, tgt_port); + self.builder + .hugr_mut() + .connect(*src, *src_port, tgt, tgt_port); + } + } + } + } + } + + // Mark the used qubits and bits as outdated. + qubits.iter().for_each(|q| { + self.wire_tracker.mark_qubit_outdated(q.clone()); + }); + bits.iter().for_each(|b| { + self.wire_tracker.mark_bit_outdated(b.clone()); + }); + + Ok(DecodeStatus::Success) + } + + /// Insert an [`OpaqueSubgraphPayload::Inline`] into the Hugr being decoded. + fn insert_inline_subgraph( + &mut self, + hugr_envelope: &str, + payload_inputs: &[(Type, EncodedEdgeID)], + payload_outputs: &[(Type, EncodedEdgeID)], + qubits: &[TrackedQubit], + bits: &[TrackedBit], + params: &[LoadedParameter], + ) -> Result { + let to_insert_hugr = Hugr::load_str(hugr_envelope, Some(self.options.extension_registry())) + .map_err(|e| PytketDecodeErrorInner::UnsupportedSubgraphPayload { source: e })?; + let to_insert_signature = to_insert_hugr.inner_function_type().unwrap(); + + let module = self.builder.hugr().module_root(); + let region = self.builder.container_node(); + + // Collect the non-IO nodes in the hugr we plan to insert. + let Some([to_insert_input, to_insert_output]) = + to_insert_hugr.get_io(to_insert_hugr.entrypoint()) + else { + return Err(PytketDecodeError::custom( + "Opaque subgraph payload has a non-dataflow parent as entrypoint", + )); + }; + let entrypoint_children = to_insert_hugr + .children(to_insert_hugr.entrypoint()) + .filter(|c| *c != to_insert_input && *c != to_insert_output) + .map(|c| (c, region)); + + // Compute the inputs and output ports of the subgraph. + // Since we insert the nodes inside the region directly (without the I/O nodes), + // we need to do some bookkeeping to match the ports. + // + // `to_insert_inputs` is a vector of vectors, where each first-level entry corresponds to + // an input to the subgraph / element in `payload.inputs`, and the second-level + // list is all the target ports that connect to that input. + // + // `to_insert_outputs` is just a vector of node+outgoing ports. + let to_insert_inputs: hugr::hugr::views::sibling_subgraph::IncomingPorts = + to_insert_signature + .input_ports() + .map(|p| { + to_insert_hugr + .linked_inputs(to_insert_input, p.index()) + .collect_vec() + }) + .collect_vec(); + let to_insert_outputs: Vec<(Node, OutgoingPort)> = to_insert_signature + .output_ports() + .map(|p| { + to_insert_hugr + .single_linked_output(to_insert_output, p.index()) + .unwrap() + }) + .collect_vec(); + + // Collect any module child that does not contain the entrypoint function. + // + // These are global functions or constant definitions. + let entrypoint_function = + std::iter::successors(Some(to_insert_hugr.entrypoint()), |&node| { + let parent = to_insert_hugr.get_parent(node)?; + if parent == module { + None + } else { + Some(parent) + } + }) + .last() + .unwrap(); + let module_children = to_insert_hugr + .children(to_insert_hugr.module_root()) + .filter(|c| *c != entrypoint_function) + .map(|c| (c, module)); + + // Insert the hugr's entrypoint region directly into the region being built, + // and any other function in the HUGR module into the module being built. + let insertion_roots = entrypoint_children.chain(module_children).collect_vec(); + let insertion_result = self + .builder + .hugr_mut() + .insert_forest(to_insert_hugr, insertion_roots) + .unwrap_or_else(|e| panic!("Invalid `insertion_roots`. {e}")); + + // Gather and connect the wires between the previously decoded nodes and the + // inserted subgraph inputs, using the types and edge IDs from the payload. + let mut input_qubits = qubits; + let mut input_bits = bits; + let mut input_params = params; + // A list of incoming ports corresponding to [`EncodedEdgeID`]s that must be + // connected once the outgoing port is created. + // + // This handles the case where unsupported subgraphs in opaque barriers on + // the pytket circuit get reordered and input ports are seen before their + // outputs. + let mut pending_encoded_edge_connections: HashMap< + EncodedEdgeID, + Vec<(Node, IncomingPort)>, + > = HashMap::new(); + for ((ty, edge_id), targets) in payload_inputs.iter().zip_eq(to_insert_inputs) { + let found_wire = self.wire_tracker.find_typed_wire( + self.config(), + ty, + &mut input_qubits, + &mut input_bits, + &mut input_params, + Some(*edge_id), + )?; + + let wire = match found_wire { + FoundWire::Register(wire_data) => wire_data.wire(), + FoundWire::Parameter(param) => param.wire(), + FoundWire::Unsupported { id } => { + let Some(wire) = self.wire_tracker.get_unsupported_wire(id) else { + // The corresponding outgoing port has not been created yet, so we + // register the edge id and the targets to be connected later. + pending_encoded_edge_connections + .entry(id) + .or_default() + .extend(targets); + // TODO: We have to store this list somewhere so we + // connect ports from subgraph that get added later. + continue; + }; + *wire + } + }; + + for (to_insert_node, port) in targets { + let node = *insertion_result.node_map.get(&to_insert_node).unwrap(); + self.builder + .hugr_mut() + .connect(wire.node(), wire.source(), node, port); + } + } + + // Register the subgraph outputs in the wire tracker. + let mut output_qubits = qubits; + let mut output_bits = bits; + for ((ty, edge_id), (to_insert_node, port)) in + payload_outputs.iter().zip_eq(to_insert_outputs) + { + let node = *insertion_result.node_map.get(&to_insert_node).unwrap(); + let wire = Wire::new(node, port); + match self.config().type_to_pytket(ty) { + Some(counts) => { + let wire_qubits = split_off(&mut output_qubits, ..counts.qubits); + let wire_bits = split_off(&mut output_bits, ..counts.bits); + if wire_qubits.is_none() || wire_bits.is_none() { + return Err(make_unexpected_node_out_error( + self.config(), + payload_outputs.iter().map(|(ty, _)| ty), + qubits.len(), + bits.len(), + )); + } + self.wire_tracker.track_wire( + wire, + Arc::new(ty.clone()), + wire_qubits.unwrap().iter().cloned(), + wire_bits.unwrap().iter().cloned(), + )?; + } + None => { + // This is an unsupported wire, so we just register the edge id to the wire. + self.wire_tracker.register_unsupported_wire(*edge_id, wire); + } + } + } + + // Mark the used qubits and bits as outdated. + qubits.iter().for_each(|q| { + self.wire_tracker.mark_qubit_outdated(q.clone()); + }); + bits.iter().for_each(|b| { + self.wire_tracker.mark_bit_outdated(b.clone()); + }); + + Ok(DecodeStatus::Success) + } +} + +// TODO: Replace with array's `split_off` method once MSRV is ≥1.87 +fn split_off<'a, T>(slice: &mut &'a [T], range: RangeTo) -> Option<&'a [T]> { + let split_index = range.end; + if split_index > slice.len() { + return None; + } + let (front, back) = slice.split_at(split_index); + *slice = back; + Some(front) +} + +/// Helper to compute the expected register counts before generating a +/// [`PytketDecodeErrorInner::UnexpectedNodeOutput`] error when registering the +/// outputs of an unsupported subgraph. +/// +/// Processes all the output types to compute the number of qubits and bits we +/// required to have available. +fn make_unexpected_node_out_error<'ty>( + config: &PytketDecoderConfig, + output_types: impl IntoIterator, + available_qubits: usize, + available_bits: usize, +) -> PytketDecodeError { + let mut expected_count = RegisterCount::default(); + for ty in output_types { + expected_count += config.type_to_pytket(ty).unwrap_or_default(); + } + PytketDecodeErrorInner::UnexpectedNodeOutput { + expected_qubits: expected_count.qubits, + expected_bits: expected_count.bits, + circ_qubits: available_qubits, + circ_bits: available_bits, + } + .wrap() +} diff --git a/tket/src/serialize/pytket/decoder/wires.rs b/tket/src/serialize/pytket/decoder/wires.rs index a08afb287..440ff8d4f 100644 --- a/tket/src/serialize/pytket/decoder/wires.rs +++ b/tket/src/serialize/pytket/decoder/wires.rs @@ -20,6 +20,7 @@ use crate::serialize::pytket::decoder::{ TrackedQubitId, }; use crate::serialize::pytket::extension::RegisterCount; +use crate::serialize::pytket::opaque::EncodedEdgeID; use crate::serialize::pytket::{ PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig, RegisterHash, }; @@ -350,6 +351,10 @@ pub(crate) struct WireTracker { /// Registers outside the range of the array are not affected, and will /// appear in the same order as they were added to `latest_qubit_tracker`. output_qubit_permutation: Vec, + /// Wires with unsupported types, created from the input node or from decoded opaque barriers. + /// + /// See [`EncodedEdgeID`]. + unsupported_wires: IndexMap, } impl WireTracker { @@ -367,6 +372,7 @@ impl WireTracker { unused_parameter_inputs: VecDeque::new(), parameter_vars: IndexSet::new(), output_qubit_permutation: Vec::with_capacity(qubit_count), + unsupported_wires: IndexMap::new(), } } @@ -561,6 +567,151 @@ impl WireTracker { Ok((qubit_args, bit_args)) } + /// Returns a new set of [TrackedWires] for a list of [`TrackedQubit`]s, + /// [`TrackedBit`]s, and [`LoadedParameter`]s following the required types. + /// + /// Returns an error if a valid set of wires with the given types cannot be + /// found. + /// + /// The qubit and bit arguments are only consumed as required by the types. + /// Some registers may be left unused. + /// + /// # Arguments + /// + /// * `config` - The configuration for the decoder, used to count the qubits + /// and bits required by each type. + /// * `ty` - The type of the arguments we require in the wire. + /// * `qubit_args` - The list of tracked qubits we require in the wire. + /// Values are consumed from the front and removed from the slice. + /// * `bit_args` - The list of tracked bits we require in the wire. + /// * `params` - The list of parameters to load to wire. See + /// [`WireTracker::load_half_turns_parameter`] for more details. Values + /// are consumed from the front and removed from the slice. + /// * `unsupported_wire` - The id of an unsupported wire, if known. + /// + /// # Errors + /// + /// - [`PytketDecodeErrorInner::OutdatedQubit`] if a qubit in `qubit_args` + /// was marked as outdated. + /// - [`PytketDecodeErrorInner::OutdatedBit`] if a bit in `bit_args` was + /// marked as outdated. + /// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` + /// cannot be mapped to a [`RegisterCount`] and `unsupported_wire` was not + /// provided. + /// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with + /// the requested type for the given qubit/bit arguments. + pub(in crate::serialize::pytket) fn find_typed_wire( + &self, + config: &PytketDecoderConfig, + ty: &Type, + qubit_args: &mut &[TrackedQubit], + bit_args: &mut &[TrackedBit], + params: &mut &[LoadedParameter], + unsupported_wire: Option, + ) -> Result { + // TODO: Use the slice `split_off_first` method once MSRV is ≥1.87 + fn split_off_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> { + let (first, rem) = slice.split_first()?; + *slice = rem; + Some(first) + } + + // Return a parameter input if the type is a float or rotation. + if [float64_type(), rotation_type()].contains(ty) { + let Some(param) = split_off_first(params) else { + return Err( + PytketDecodeErrorInner::NoMatchingParameter { ty: ty.to_string() }.wrap(), + ); + }; + return Ok(FoundWire::Parameter(*param)); + } + + // Translate the wire type to a pytket register count. + let reg_count = match (config.type_to_pytket(ty), unsupported_wire) { + (Some(reg_count), _) => reg_count, + (None, Some(id)) => return Ok(FoundWire::Unsupported { id }), + (None, None) => { + let err = PytketDecodeErrorInner::UnexpectedInputType { + unknown_type: ty.to_string(), + all_types: vec![ty.to_string()], + }; + return Err(err.wrap()); + } + }; + + // List candidate wires that contain the qubits and bits we need. + let qubit_candidates = qubit_args + .first() + .into_iter() + .flat_map(|qb| self.qubit_wires(qb)); + let bit_candidates = bit_args + .first() + .into_iter() + .flat_map(|bit| self.bit_wires(bit)); + let mut candidate = qubit_candidates.chain(bit_candidates); + + // The bits and qubits we expect the wire to contain. + let wire_qubits = qubit_args + .iter() + .take(reg_count.qubits) + .map(|q| q.id()) + .collect_vec(); + let wire_bits = bit_args + .iter() + .take(reg_count.bits) + .map(|bit| bit.id()) + .collect_vec(); + + // Find a wire that contains the correct type.. + let check_wire = |w: &Wire| { + let wire_data = &self.wires[w]; + wire_data.ty() == ty && wire_data.qubits == wire_qubits && wire_data.bits == wire_bits + }; + let Some(wire) = candidate.find(check_wire) else { + return Err(PytketDecodeErrorInner::NoMatchingWire { + ty: ty.to_string(), + qubit_args: qubit_args + .iter() + .map(|q| q.pytket_register().to_string()) + .collect(), + bit_args: bit_args + .iter() + .map(|bit| bit.pytket_register().to_string()) + .collect(), + } + .wrap()); + }; + + // Check that none of the selected qubit or bit has been marked as outdated. + if let Some(qubit) = qubit_args + .iter() + .take(reg_count.qubits) + .find(|q| q.is_outdated()) + { + return Err(PytketDecodeErrorInner::OutdatedQubit { + qubit: qubit.pytket_register().to_string(), + } + .wrap()); + } + if let Some(bit) = bit_args + .iter() + .take(reg_count.bits) + .find(|b| b.is_outdated()) + { + return Err(PytketDecodeErrorInner::OutdatedBit { + bit: bit.pytket_register().to_string(), + } + .wrap()); + } + + // Mark the qubits and bits as used. + // TODO: We can use the slice `split_off` method once MSRV is ≥1.87 + *qubit_args = &qubit_args[reg_count.qubits..]; + *bit_args = &bit_args[reg_count.bits..]; + + Ok(FoundWire::Register(self.wires[&wire].clone())) + } + /// Returns a new set of [TrackedWires] for a list of [`TrackedQubit`]s, /// [`TrackedBit`]s, and [`LoadedParameter`]s following the required types. /// @@ -573,12 +724,11 @@ impl WireTracker { /// # Arguments /// /// * `config` - The configuration for the decoder, used to count the qubits and bits required by each type. - /// * `hugr` - The hugr to load the parameters to. /// * `types` - The types of the arguments we require in the wires. /// * `qubit_args` - The list of tracked qubits we require in the wires. /// * `bit_args` - The list of tracked bits we require in the wire. /// * `params` - The list of parameters to load to wires. See - /// [`WireTracker::load_parameter`] for more details. + /// [`WireTracker::load_half_turns_parameter`] for more details. /// /// # Errors /// @@ -590,94 +740,52 @@ impl WireTracker { &self, config: &PytketDecoderConfig, types: &[Type], - qubit_args: &[TrackedQubit], - bit_args: &[TrackedBit], - params: &[LoadedParameter], + mut qubit_args: &[TrackedQubit], + mut bit_args: &[TrackedBit], + mut params: &[LoadedParameter], ) -> Result { - // We need to return a set of wires that contain all the arguments. - // - // We collect this by checking the wires where each element is present, - // and collecting them in order. - let mut qubit_args: VecDeque<&TrackedQubit> = qubit_args.iter().collect(); - let mut bit_args: VecDeque<&TrackedBit> = bit_args.iter().collect(); - - // Check that no qubit or bit has been marked as outdated. - if qubit_args.iter().any(|q| q.is_outdated()) { - return Err(PytketDecodeErrorInner::OutdatedQubit { - qubit: qubit_args.front().unwrap().pytket_register().to_string(), - } - .wrap()); - } - if bit_args.iter().any(|b| b.is_outdated()) { - return Err(PytketDecodeErrorInner::OutdatedBit { - bit: bit_args.front().unwrap().pytket_register().to_string(), - } - .wrap()); - } - // Map each requested type to a wire. // // Ignore parameter inputs. - let param_types = [float64_type(), rotation_type()]; - let value_wires = types - .iter() - .filter(|ty| !param_types.contains(ty)) - .map(|ty| { - let Some(reg_count) = config.type_to_pytket(ty) else { - return Err(PytketDecodeErrorInner::UnexpectedInputType { - unknown_type: ty.to_string(), + let mut tracked_wires = TrackedWires { + value_wires: Vec::with_capacity(types.len() - params.len()), + parameter_wires: Vec::with_capacity(params.len()), + }; + for ty in types { + match self.find_typed_wire( + config, + ty, + &mut qubit_args, + &mut bit_args, + &mut params, + None, + ) { + Ok(FoundWire::Register(wire)) => tracked_wires.value_wires.push(wire), + Ok(FoundWire::Parameter(param)) => tracked_wires.parameter_wires.push(param), + Ok(FoundWire::Unsupported { .. }) => { + unreachable!("unsupported_wire was not defined") + } + // Add additional context to errors UnexpectedInputType errors. + Err(PytketDecodeError { + inner: PytketDecodeErrorInner::UnexpectedInputType { unknown_type, .. }, + pytket_op, + hugr_op, + }) => { + let inner = PytketDecodeErrorInner::UnexpectedInputType { + unknown_type, all_types: types.iter().map(ToString::to_string).collect(), - } - .wrap()); - }; - - // List candidate wires that contain the qubits and bits we need. - let qubit_candidates = qubit_args - .front() - .into_iter() - .flat_map(|qb| self.qubit_wires(qb)); - let bit_candidates = bit_args - .front() - .into_iter() - .flat_map(|bit| self.bit_wires(bit)); - let mut candidate = qubit_candidates.chain(bit_candidates); - - // Find a wire that contains the correct type.. - let check_wire = |w: &Wire| { - let wire_data = &self.wires[w]; - let qubits = qubit_args.iter().take(reg_count.qubits).map(|q| q.id()); - let bits = bit_args.iter().take(reg_count.bits).map(|bit| bit.id()); - wire_data.ty() == ty - && itertools::equal(wire_data.qubits.iter().copied(), qubits) - && itertools::equal(wire_data.bits.iter().copied(), bits) - }; - let Some(wire) = candidate.find(check_wire) else { - return Err(PytketDecodeErrorInner::NoMatchingWire { - ty: ty.to_string(), - qubit_args: qubit_args - .iter() - .map(|q| q.pytket_register().to_string()) - .collect(), - bit_args: bit_args - .iter() - .map(|bit| bit.pytket_register().to_string()) - .collect(), - } - .wrap()); - }; - - // Mark the qubits and bits as used. - qubit_args.drain(..reg_count.qubits); - bit_args.drain(..reg_count.bits); - - Ok(self.wires[&wire].clone()) - }) - .collect::, _>>()?; + }; + return Err(PytketDecodeError { + inner, + pytket_op, + hugr_op, + }); + } + Err(e) => return Err(e), + }; + } - Ok(TrackedWires { - value_wires, - parameter_wires: params.to_vec(), - }) + Ok(tracked_wires) } /// Loads the given parameter half-turns expression as a [`LoadedParameter`] @@ -875,6 +983,25 @@ impl WireTracker { pub(super) fn register_unused_parameter_input(&mut self, loaded: LoadedParameter) { self.unused_parameter_inputs.push_back(loaded); } + + /// Returns a tracked unsupported wire by its [`EncodedEdgeID`]. + /// + /// These are **not** associated with pytket registers or parameters, and + /// are used to track connections between unsupported subgraphs and HUGR + /// input/output nodes that got encoded as opaque barriers in the pytket + /// circuit. + /// + /// See [`EncodedEdgeID`]. + pub fn get_unsupported_wire(&self, id: EncodedEdgeID) -> Option<&Wire> { + self.unsupported_wires.get(&id) + } + + /// Register a new unsupported wire. + /// + /// See [`WireTracker::get_unsupported_wire`]. + pub fn register_unsupported_wire(&mut self, id: EncodedEdgeID, wire: Wire) { + self.unsupported_wires.insert(id, wire); + } } /// Only single-indexed registers are supported. @@ -889,6 +1016,31 @@ fn check_register(register: &PytketRegister) -> Result<(), PytketDecodeError> { } } +/// Result type of [`WireTracker::find_typed_wire`]. +/// +/// Returns either a value to append to a [`TrackedWires`] instance, or a wire +/// for an edge in an unsupported subgraph. +/// +/// The latter is only used internally when decoding unsupported subgraphs from +/// opaque pytket barriers. Users will see +/// [`PytketDecodeErrorInner::UnexpectedInputType`] if they try to decode such a +/// wire. +#[derive(Debug, Clone, PartialEq)] +pub(in crate::serialize::pytket) enum FoundWire { + /// Found a type carrying bit/qubit registers. + Register(WireData), + /// Found a parameter input. + Parameter(LoadedParameter), + /// Found an unsupported wire, registered to an existing wire. + /// + /// This variant is only used when decoding unsupported subgraphs from + /// opaque pytket barriers. + Unsupported { + /// The id of the unsupported wire. + id: EncodedEdgeID, + }, +} + #[cfg(test)] mod tests { use super::*; diff --git a/tket/src/serialize/pytket/encoder.rs b/tket/src/serialize/pytket/encoder.rs index e81c08c25..909eefd1f 100644 --- a/tket/src/serialize/pytket/encoder.rs +++ b/tket/src/serialize/pytket/encoder.rs @@ -5,6 +5,7 @@ mod unsupported_tracker; mod value_tracker; use hugr::core::HugrNode; +use hugr::hugr::views::sibling_subgraph::{IncomingPorts, OutgoingPorts}; use hugr::hugr::views::SiblingSubgraph; use hugr_core::hugr::internal::PortgraphNodeMap; use tket_json_rs::clexpr::InputClRegister; @@ -17,24 +18,23 @@ use hugr::ops::{OpTrait, OpType}; use hugr::types::EdgeKind; use std::borrow::Cow; -use std::collections::{BTreeSet, HashMap, HashSet}; +use std::collections::{BTreeSet, HashMap}; use std::sync::{Arc, RwLock}; -use hugr::{HugrView, Wire}; +use hugr::{HugrView, OutgoingPort, Wire}; use itertools::Itertools; use tket_json_rs::circuit_json::{self, SerialCircuit}; use unsupported_tracker::UnsupportedTracker; use super::opaque::OpaqueSubgraphs; use super::{ - OpConvertError, PytketEncodeError, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_REGISTERS, + PytketEncodeError, PytketEncodeOpError, METADATA_OPGROUP, METADATA_PHASE, METADATA_Q_REGISTERS, }; use crate::circuit::Circuit; +use crate::serialize::pytket::circuit::EncodedCircuitInfo; use crate::serialize::pytket::config::PytketEncoderConfig; use crate::serialize::pytket::extension::RegisterCount; -use crate::serialize::pytket::opaque::{ - OpaqueSubgraphPayload, OpaqueSubgraphPayloadType, OPGROUP_OPAQUE_HUGR, -}; +use crate::serialize::pytket::opaque::{OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR}; /// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`]. #[derive(derive_more::Debug)] @@ -216,25 +216,40 @@ impl PytketEncoderContext { /// /// # Returns /// - /// * the final [`SerialCircuit`] - /// * any parameter expressions at the circuit's output - /// * the set of unsupported subgraphs that were referenced (from/inside) pytket barriers. + /// * An [`EncodedCircuitInfo`] containing the final [`SerialCircuit`] and some additional metadata. + /// * The set of opaque subgraphs that were referenced (from/inside) pytket barriers. #[allow(clippy::type_complexity)] pub(super) fn finish( mut self, circ: &Circuit, region: H::Node, - ) -> Result<(SerialCircuit, Vec, OpaqueSubgraphs), PytketEncodeError> - { + ) -> Result<(EncodedCircuitInfo, OpaqueSubgraphs), PytketEncodeError> { // Add any remaining unsupported nodes // - // TODO: Test that unsupported subgraphs that don't affect any qubit/bit registers + // TODO: Test that opaque subgraphs that don't affect any qubit/bit registers // are correctly encoded in pytket commands. + let mut extra_subgraph: Option> = None; while !self.unsupported.is_empty() { let node = self.unsupported.iter().next().unwrap(); let opaque_subgraphs = self.unsupported.extract_component(node); - self.emit_unsupported(opaque_subgraphs, circ)?; + match self.emit_unsupported(opaque_subgraphs.clone(), circ) { + Ok(()) => (), + Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}) => { + // We'll store the nodes in the `extra_subgraph` field of the `EncodedCircuitInfo`. + // So the decoder can reconstruct the original subgraph. + extra_subgraph + .get_or_insert_default() + .extend(opaque_subgraphs); + } + Err(e) => return Err(e), + } } + let extra_subgraph = extra_subgraph.map(|nodes| { + let subgraph = + SiblingSubgraph::try_from_nodes(nodes.into_iter().collect_vec(), circ.hugr()) + .expect("Failed to create subgraph from unsupported nodes"); + self.opaque_subgraphs.register_opaque_subgraph(subgraph) + }); let final_values = self.values.finish(circ, region)?; @@ -245,7 +260,15 @@ impl PytketEncoderContext { ser.bits = final_values.bits.into_iter().map_into().collect(); ser.implicit_permutation = final_values.qubit_permutation; ser.number_of_ws = None; - Ok((ser, final_values.params, self.opaque_subgraphs)) + + let info = EncodedCircuitInfo { + serial_circuit: ser, + input_params: final_values.params, + output_params: vec![], + extra_subgraph, + }; + + Ok((info, self.opaque_subgraphs)) } /// Returns a reference to this encoder's configuration. @@ -269,8 +292,8 @@ impl PytketEncoderContext { /// /// ### Errors /// - /// - [`OpConvertError::WireHasNoValues`] if the wire is not tracked or has - /// a type that cannot be converted to pytket values. + /// - [`PytketEncodeOpError::WireHasNoValues`] if the wire is not tracked or + /// has a type that cannot be converted to pytket values. pub fn get_wire_values( &mut self, wire: Wire, @@ -291,7 +314,7 @@ impl PytketEncoderContext { return self.get_wire_values(wire, circ); } - Err(OpConvertError::WireHasNoValues { wire }.into()) + Err(PytketEncodeOpError::WireHasNoValues { wire }.into()) } /// Given a node in the HUGR, returns all the [`TrackedValue`]s associated @@ -356,7 +379,7 @@ impl PytketEncoderContext { match self.get_wire_values(wire, circ) { Ok(values) => tracked_values.extend(values.iter().copied()), - Err(PytketEncodeError::OpConversionError(OpConvertError::WireHasNoValues { + Err(PytketEncodeError::OpEncoding(PytketEncodeOpError::WireHasNoValues { wire, })) => unknown_values.push(wire), Err(e) => panic!( @@ -554,7 +577,8 @@ impl PytketEncoderContext { /// /// ## Arguments /// - /// - `unsupported_nodes`: The list of nodes to encode as an unsupported subgraph. + /// - `unsupported_nodes`: The list of nodes to encode as an opaque subgraph. + /// - `circ`: The circuit containing the unsupported nodes. fn emit_unsupported( &mut self, unsupported_nodes: BTreeSet, @@ -569,46 +593,43 @@ impl PytketEncoderContext { unsupported_nodes.iter().cloned().collect_vec(), circ.hugr(), ) - .unwrap_or_else(|_| { + .unwrap_or_else(|e| { panic!( - "Failed to create subgraph from unsupported nodes [{}]", + "Failed to create subgraph from unsupported nodes [{}]: {e}", unsupported_nodes.iter().join(", ") ) }); - let input_nodes: HashSet<_> = subgraph - .incoming_ports() - .iter() - .flat_map(|inp| inp.iter().map(|(n, _)| *n)) - .collect(); - let output_nodes: HashSet<_> = subgraph.outgoing_ports().iter().map(|(n, _)| *n).collect(); + let subgraph_incoming_ports: IncomingPorts = subgraph.incoming_ports().clone(); + let subgraph_outgoing_ports: OutgoingPorts = subgraph.outgoing_ports().clone(); // Encode a payload referencing the subgraph in the Hugr. let subgraph_id = self.opaque_subgraphs.register_opaque_subgraph(subgraph); - let subgraph = &self.opaque_subgraphs[subgraph_id]; - let payload = OpaqueSubgraphPayload::new( - subgraph, - circ.hugr(), - OpaqueSubgraphPayloadType::External { id: subgraph_id }, - ); + let payload = OpaqueSubgraphPayload::new_external(subgraph_id); // Collects the input values for the subgraph. // // The [`UnsupportedTracker`] ensures that at this point all local input wires must come from // already-encoded nodes, and not from other unsupported nodes not in `unsupported_nodes`. - // - // Non-local incoming edges (e.g. function references) must be marked so they can be recovered when - // decoding the circuit. let mut op_values = TrackedValues::default(); - let mut external_edges = Vec::new(); - for node in &input_nodes { - let NodeInputValues { - tracked_values, - unknown_values, - } = self - .get_input_values_internal(*node, circ, |w| !unsupported_nodes.contains(&w.node())); - op_values.append(tracked_values); - external_edges.extend(unknown_values); + for incoming in subgraph_incoming_ports.iter() { + let Some((first_node, first_port)) = incoming.first() else { + continue; + }; + + let (neigh, neigh_out) = circ + .hugr() + .single_linked_output(*first_node, *first_port) + .expect("Dataflow input port should have a single neighbour"); + let wire = Wire::new(neigh, neigh_out); + + let Ok(tracked_values) = self.get_wire_values(wire, circ) else { + // If the wire is not tracked, no need to consume it. + continue; + }; + + op_values.extend(tracked_values.iter().cloned()); } + let input_param_exprs: Vec = std::mem::take(&mut op_values.params) .into_iter() .map(|p| self.values.param_expression(p).to_owned()) @@ -619,23 +640,33 @@ impl PytketEncoderContext { // // Output parameters are mapped to a fresh variable, that can be tracked // back to the encoded subcircuit's function name. - for &node in &output_nodes { - let new_outputs = self.register_node_outputs( - node, + let mut out_param_count = 0; + for (out_node, out_port) in &subgraph_outgoing_ports { + let new_outputs = self.register_port_output( + *out_node, + *out_port, circ, &op_values.qubits, &op_values.bits, &input_param_exprs, EmitCommandOptions::new().output_params(|p| { - (0..p.expected_count) + let range = out_param_count..out_param_count + p.expected_count; + out_param_count += p.expected_count; + range .map(|i| format!("{subcircuit_id}_out{i}")) .collect_vec() }), - |_| true, )?; op_values.append(new_outputs); } + // Check that we have qubits or bits to attach the barrier command to. + // + // This should only fail when looking at the "leftover" unsupported nodes at the end of the decoding process. + if op_values.qubits.is_empty() && op_values.bits.is_empty() { + return Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}); + } + // Create pytket operation, and add the subcircuit as hugr let args = MakeOperationArgs { num_qubits: op_values.qubits.len(), @@ -716,13 +747,13 @@ impl PytketEncoderContext { subencoder.function_cache = self.function_cache.clone(); subencoder.run_encoder(circ, node)?; - let (serial_subcirc, output_params, opaque_subgraphs) = subencoder.finish(circ, node)?; - if !output_params.is_empty() { + let (info, opaque_subgraphs) = subencoder.finish(circ, node)?; + if !info.output_params.is_empty() { return Ok(EncodeStatus::Unsupported); } self.opaque_subgraphs = opaque_subgraphs; - self.emit_circ_box(node, serial_subcirc, circ)?; + self.emit_circ_box(node, info.serial_circuit, circ)?; Ok(EncodeStatus::Success) } @@ -765,15 +796,14 @@ impl PytketEncoderContext { let mut subencoder = PytketEncoderContext::new(circ, function, opaque_subgraphs, config)?; subencoder.function_cache = self.function_cache.clone(); subencoder.run_encoder(circ, function)?; - let (serial_subcirc, output_params, opaque_subgraphs) = - subencoder.finish(circ, function)?; + let (info, opaque_subgraphs) = subencoder.finish(circ, function)?; self.opaque_subgraphs = opaque_subgraphs; - let (result, cached_fn) = match output_params.is_empty() { + let (result, cached_fn) = match info.output_params.is_empty() { true => ( EncodeStatus::Success, CachedEncodedFunction::Encoded { - serial_circuit: serial_subcirc.clone(), + serial_circuit: info.serial_circuit.clone(), }, ), false => ( @@ -789,7 +819,7 @@ impl PytketEncoderContext { } if result == EncodeStatus::Success { - self.emit_circ_box(node, serial_subcirc, circ)?; + self.emit_circ_box(node, info.serial_circuit, circ)?; } Ok(result) } @@ -849,9 +879,20 @@ impl PytketEncoderContext { return Ok(EncodeStatus::Success); } } - OpType::LoadConstant(_) => { - self.emit_transparent_node(node, circ, |ps| ps.input_params.to_owned())?; - return Ok(EncodeStatus::Success); + OpType::LoadConstant(constant) => { + // If we are loading a supported type, emit a transparent node + // by reassigning the input values to the new outputs. + // + // Otherwise, if we're loading an unsupported type, this node + // should be part of an unsupported subgraph. + if self + .config() + .type_to_pytket(constant.constant_type()) + .is_some() + { + self.emit_transparent_node(node, circ, |ps| ps.input_params.to_owned())?; + return Ok(EncodeStatus::Success); + } } OpType::Const(op) => { let config = Arc::clone(&self.config); @@ -1050,6 +1091,124 @@ impl PytketEncoderContext { Ok(new_outputs) } + /// Helper to register values for a singular output wire. + /// + /// In general, you should prefer + /// [`PytketEncoderContext::register_node_outputs`] to register values for a + /// node's multiple output wires at once. + /// + /// Returns any new value associated with the output wire. + /// + /// ## Arguments + /// + /// - `node`: The node to register the outputs for. + /// - `circ`: The circuit containing the node. + /// - `input_qubits`: The qubit inputs to the operation. + /// - `input_bits`: The bit inputs to the operation. + /// - `input_params`: The list of input parameter expressions. + /// - `options`: Options for controlling the output qubit, bits, and + /// parameter expressions. + #[allow(clippy::too_many_arguments)] + fn register_port_output( + &mut self, + node: H::Node, + port: OutgoingPort, + circ: &Circuit, + input_qubits: &[TrackedQubit], + input_bits: &[TrackedBit], + input_params: &[String], + options: EmitCommandOptions, + ) -> Result> { + let wire = Wire::new(node, port); + + let Some(ty) = circ + .hugr() + .signature(node) + .and_then(|s| s.out_port_type(port).cloned()) + else { + return Ok(TrackedValues::default()); + }; + + let Some(count) = self.config().type_to_pytket(&ty) else { + return Err(PytketEncodeError::custom(format!( + "Found an unsupported type {ty} while encoding {port} of {node}." + ))); + }; + + let output_qubits = match options.reuse_qubits_fn { + Some(f) => f(input_qubits), + None => input_qubits.to_vec(), + }; + let output_bits = match options.reuse_bits_fn { + Some(f) => f(input_bits), + None => input_bits.to_vec(), + }; + + // Compute all the output parameters at once + let out_params = match options.output_params_fn { + Some(f) => f(OutputParamArgs { + expected_count: count.params, + input_params, + }), + None => Vec::new(), + }; + + // Check that we got the expected number of outputs. + if out_params.len() != count.params { + return Err(PytketEncodeError::custom(format!( + "Expected {} parameters in the input values for a {} at {port} of {node}, but got {}.", + count.params, + circ.hugr().get_optype(node), + out_params.len() + ))); + } + + // Update the values in the node's outputs. + // + // We preserve the order of linear values in the input + let mut new_outputs = TrackedValues::default(); + let mut out_wire_values = Vec::with_capacity(count.total()); + + // Qubits + out_wire_values.extend( + output_qubits + .into_iter() + .take(count.qubits) + .map(TrackedValue::Qubit), + ); + for _ in out_wire_values.len()..count.qubits { + // If we already assigned all input qubit ids, get a fresh one. + let qb = self.values.new_qubit(); + new_outputs.qubits.push(qb); + out_wire_values.push(TrackedValue::Qubit(qb)); + } + + // Bits + let non_bit_count = out_wire_values.len(); + out_wire_values.extend( + output_bits + .into_iter() + .take(count.bits) + .map(TrackedValue::Bit), + ); + let reused_bit_count = out_wire_values.len() - non_bit_count; + for _ in reused_bit_count..count.bits { + let b = self.values.new_bit(); + new_outputs.bits.push(b); + out_wire_values.push(TrackedValue::Bit(b)); + } + + // Parameters + for expr in out_params.into_iter().take(count.params) { + let p = self.values.new_param(expr); + new_outputs.params.push(p); + out_wire_values.push(p.into()); + } + self.values.register_wire(wire, out_wire_values, circ)?; + + Ok(new_outputs) + } + /// Return the output wires of a node that have an associated pytket [`RegisterCount`]. #[allow(clippy::type_complexity)] fn node_output_values( @@ -1162,7 +1321,7 @@ impl NodeInputValues { pub fn try_into_tracked_values(self) -> Result> { match self.unknown_values.is_empty() { true => Ok(self.tracked_values), - false => Err(OpConvertError::WireHasNoValues { + false => Err(PytketEncodeOpError::WireHasNoValues { wire: self.unknown_values[0], } .into()), diff --git a/tket/src/serialize/pytket/encoder/value_tracker.rs b/tket/src/serialize/pytket/encoder/value_tracker.rs index 3e7edfb35..67eee6393 100644 --- a/tket/src/serialize/pytket/encoder/value_tracker.rs +++ b/tket/src/serialize/pytket/encoder/value_tracker.rs @@ -22,7 +22,7 @@ use tket_json_rs::register::ElementId as RegisterUnit; use crate::circuit::Circuit; use crate::serialize::pytket::extension::RegisterCount; use crate::serialize::pytket::{ - OpConvertError, PytketEncodeError, RegisterHash, METADATA_B_REGISTERS, + PytketEncodeError, PytketEncodeOpError, RegisterHash, METADATA_B_REGISTERS, METADATA_INPUT_PARAMETERS, }; @@ -315,7 +315,7 @@ impl ValueTracker { wire: Wire, values: impl IntoIterator, circ: &Circuit>, - ) -> Result<(), OpConvertError> { + ) -> Result<(), PytketEncodeOpError> { let values = values.into_iter().map(|v| v.into()).collect_vec(); // Remove any qubit/bit used here from the unused set. @@ -337,7 +337,7 @@ impl ValueTracker { unexplored_neighbours, }; if self.wires.insert(wire, tracked).is_some() { - return Err(OpConvertError::WireAlreadyHasValues { wire }); + return Err(PytketEncodeOpError::WireAlreadyHasValues { wire }); } if unexplored_neighbours == 0 { @@ -429,7 +429,7 @@ impl ValueTracker { self, circ: &Circuit>, region: N, - ) -> Result> { + ) -> Result> { let output_node = circ.hugr().get_io(region).unwrap()[1]; // Ordered list of qubits and bits at the output of the circuit. @@ -440,7 +440,7 @@ impl ValueTracker { let wire = Wire::new(node, port); let values = self .peek_wire_values(wire) - .ok_or_else(|| OpConvertError::WireHasNoValues { wire })?; + .ok_or_else(|| PytketEncodeOpError::WireHasNoValues { wire })?; for value in values { match value { TrackedValue::Qubit(qb) => qubit_outputs.push(self.qubit_register(*qb).clone()), diff --git a/tket/src/serialize/pytket/error.rs b/tket/src/serialize/pytket/error.rs index 2739a80ee..e1c40c52c 100644 --- a/tket/src/serialize/pytket/error.rs +++ b/tket/src/serialize/pytket/error.rs @@ -2,18 +2,20 @@ use derive_more::{Display, Error, From}; use hugr::core::HugrNode; +use hugr::envelope::EnvelopeError; use hugr::ops::OpType; use hugr::Wire; use itertools::Itertools; use tket_json_rs::register::ElementId; use crate::serialize::pytket::extension::RegisterCount; +use crate::serialize::pytket::opaque::SubgraphId; /// Error type for conversion between pytket operations and tket ops. #[derive(Display, derive_more::Debug, Error)] #[non_exhaustive] #[debug(bounds(N: HugrNode))] -pub enum OpConvertError { +pub enum PytketEncodeOpError { /// Tried to decode a tket1 operation with not enough parameters. #[display( "Operation {} is missing encoded parameters. Expected at least {expected} but only \"{}\" were specified.", @@ -46,7 +48,7 @@ pub enum OpConvertError { /// Tried to query the values associated with an unexplored wire. /// /// This reflects a bug in the operation encoding logic of an operation. - #[display("Could not find values associated with wire {wire}.")] + #[display("Could not find values associated with {wire}.")] WireHasNoValues { /// The wire that has no values. wire: Wire, @@ -59,13 +61,21 @@ pub enum OpConvertError { /// The wire that already has values. wire: Wire, }, + /// Cannot encode subgraphs with nested structure or non-local edges in an standalone circuit. + #[display("Cannot encode subgraphs with nested structure or non-local edges in an standalone circuit. Unsupported nodes: {}", + nodes.iter().join(", "), + )] + UnsupportedStandaloneSubgraph { + /// The nodes that are part of the unsupported subgraph. + nodes: Vec, + }, } /// Error type for conversion between tket ops and pytket operations. #[derive(derive_more::Debug, Display, Error, From)] #[non_exhaustive] #[debug(bounds(N: HugrNode))] -pub enum PytketEncodeError { +pub enum PytketEncodeError { /// Tried to encode a non-dataflow region. #[display("Cannot encode non-dataflow region at {region} with type {optype}.")] NonDataflowRegion { @@ -76,7 +86,7 @@ pub enum PytketEncodeError { }, /// Operation conversion error. #[from] - OpConversionError(OpConvertError), + OpEncoding(PytketEncodeOpError), /// Custom user-defined error raised while encoding an operation. #[display("Error while encoding operation: {msg}")] CustomError { @@ -91,9 +101,12 @@ pub enum PytketEncodeError { /// The head region operation that is not a dataflow container. head_op: String, }, + /// No qubits or bits to attach the barrier command to for unsupported nodes. + #[display("An unsupported subgraph has no qubits or bits to attach the barrier command to.")] + UnsupportedSubgraphHasNoRegisters {}, } -impl PytketEncodeError { +impl PytketEncodeError { /// Create a new error with a custom message. pub fn custom(msg: impl ToString) -> Self { Self::CustomError { @@ -103,7 +116,7 @@ impl PytketEncodeError { } /// Error type for conversion between tket2 ops and pytket operations. -#[derive(derive_more::Debug, Display, Error, Clone)] +#[derive(derive_more::Debug, Display, Error)] #[non_exhaustive] #[display( "{inner}{context}", @@ -175,7 +188,7 @@ impl From for PytketDecodeError { /// Error variants of [`PytketDecodeError`], signalling errors during the /// conversion between tket2 ops and pytket operations. -#[derive(derive_more::Debug, Display, Error, Clone)] +#[derive(derive_more::Debug, Display, Error)] #[non_exhaustive] pub enum PytketDecodeErrorInner { /// The pytket circuit uses multi-indexed registers. @@ -342,6 +355,12 @@ pub enum PytketDecodeErrorInner { /// The bit registers expected in the wire. bit_args: Vec, }, + /// We couldn't find a parameter for the required input type. + #[display("Could not find a parameter for the required input type '{ty}'")] + NoMatchingParameter { + /// The type that couldn't be found. + ty: String, + }, /// The number of pytket registers passed to /// `PytketDecodeContext::wire_up_node` or `add_node_with_wires` does not /// match the number of registers required by the operation. @@ -384,11 +403,40 @@ pub enum PytketDecodeErrorInner { /// The bit that was marked as outdated. bit: String, }, - /// Tried to reassemble an [`EncodedCircuit`][super::circuit::EncodedCircuit] whose head region is not a dataflow container in the original hugr. - #[display("Tried to reassemble an `EncodedCircuit` whose head region is not a dataflow container in the original hugr. Head operation {head_op}")] - NonDataflowHeadRegion { - /// The head region operation that is not a dataflow container. - head_op: String, + /// Tried to reassemble a circuit from a region that was not contained in the [`EncodedCircuit`][super::circuit::EncodedCircuit]. + #[display("Tried to reassemble a circuit from region {region}, but the circuit was not found in the `EncodedCircuit`")] + NotAnEncodedRegion { + /// The region we tried to decode + region: String, + }, + /// Tried to decode a circuit into an existing region, but the region was modified since creating the [`EncodedCircuit`][super::circuit::EncodedCircuit]. + #[display("Tried to decode a circuit into region {region}, but the region was modified since creating the `EncodedCircuit`. New region optype: {new_optype}")] + IncompatibleTargetRegion { + /// The region we tried to decode + region: hugr::Node, + /// The new region optype + new_optype: OpType, + }, + /// The pytket circuit contains an opaque barrier representing a unsupported subgraph in the original HUGR, + /// but the corresponding subgraph is not present in the [`EncodedCircuit`][super::circuit::EncodedCircuit] structure. + #[display("The pytket circuit contains a barrier representing an opaque subgraph in the original HUGR, but the corresponding subgraph is not present in the `EncodedCircuit` structure. Subgraph ID {id}")] + OpaqueSubgraphNotFound { + /// The ID of the opaque subgraph. + id: SubgraphId, + }, + /// The stored subgraph payload was not a valid flat subgraph in a dataflow region of the target hugr. + #[display("The stored subgraph {id} was not a valid flat subgraph in a dataflow region of the target hugr. {context}")] + InvalidExternalSubgraph { + /// The ID of the opaque subgraph. + id: SubgraphId, + /// Additional context about the error. + context: String, + }, + /// Cannot decode Hugr from an unsupported subgraph payload in a pytket barrier operation. + #[display("Cannot decode Hugr from an opaque subgraph payload in a pytket barrier operation. {source}")] + UnsupportedSubgraphPayload { + /// The envelope decoding error. + source: EnvelopeError, }, } diff --git a/tket/src/serialize/pytket/extension/core.rs b/tket/src/serialize/pytket/extension/core.rs index 84ecf025b..296a010fc 100644 --- a/tket/src/serialize/pytket/extension/core.rs +++ b/tket/src/serialize/pytket/extension/core.rs @@ -12,9 +12,9 @@ use crate::serialize::pytket::decoder::{ use crate::serialize::pytket::extension::PytketDecoder; use crate::serialize::pytket::opaque::{OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR}; use crate::serialize::pytket::{DecodeInsertionTarget, DecodeOptions, PytketDecodeError}; -use crate::serialize::TKETDecode; use hugr::builder::Container; use hugr::extension::prelude::{bool_t, qb_t}; +use hugr::ops::handle::NodeHandle; use hugr::types::{Signature, Type}; use itertools::Itertools; use tket_json_rs::circuit_json::Operation as PytketOperation; @@ -45,23 +45,23 @@ impl PytketDecoder for CoreDecoder { data: Some(payload), .. } if opgroup == Some(OPGROUP_OPAQUE_HUGR) => { - let Ok(payload): Result = serde_json::from_str(payload) - else { + let Ok(payload) = OpaqueSubgraphPayload::load_str( + payload, + decoder.options().extension_registry(), + ) else { // Payload is invalid. We don't error here to avoid // panicking on corrupted/old user submissions. return Ok(DecodeStatus::Unsupported); }; - if payload.is_external() { - unimplemented!("Extract external unsupported hugr subgraphs."); - } - // TODO: Extract inline unsupported hugr subgraphs. - // - // For now we keep the old behaviour of producing opaque TKET1.tk1op operations. - Ok(DecodeStatus::Unsupported) + decoder.insert_subgraph_from_payload(qubits, bits, params, &payload) } PytketOperation { op_type: PytketOptype::CircBox, - op_box: Some(OpBox::CircBox { id: _id, circuit }), + op_box: + Some(OpBox::CircBox { + id: _id, + circuit: serial_circuit, + }), .. } => { // We have no way to distinguish between input and output bits @@ -83,8 +83,19 @@ impl PytketDecoder for CoreDecoder { let target = DecodeInsertionTarget::Region { parent: decoder.builder.container_node(), }; - let internal = - circuit.decode_inplace(decoder.builder.hugr_mut(), target, options)?; + + // Decode the circuit box into a DFG node in the region. + let mut nested_decoder = PytketDecoderContext::new( + serial_circuit, + decoder.builder.hugr_mut(), + target, + options, + )?; + if let Some(opaque_subgraphs) = decoder.opaque_subgraphs { + nested_decoder.register_opaque_subgraphs(opaque_subgraphs); + } + nested_decoder.run_decoder(&serial_circuit.commands)?; + let internal = nested_decoder.finish()?.node(); decoder .wire_up_node(internal, qubits, qubits, bits, bits, params) diff --git a/tket/src/serialize/pytket/opaque.rs b/tket/src/serialize/pytket/opaque.rs index ad1aefaa9..520731154 100644 --- a/tket/src/serialize/pytket/opaque.rs +++ b/tket/src/serialize/pytket/opaque.rs @@ -3,9 +3,7 @@ mod payload; -pub use payload::{ - EncodedEdgeID, OpaqueSubgraphPayload, OpaqueSubgraphPayloadType, OPGROUP_OPAQUE_HUGR, -}; +pub use payload::{EncodedEdgeID, OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR}; use std::collections::BTreeMap; use std::ops::Index; @@ -17,12 +15,12 @@ use hugr::HugrView; /// The ID of a subgraph in the Hugr. #[derive(Debug, derive_more::Display, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -#[display("{local_id}.{tracker_id}")] +#[display("{tracker_id}.{local_id}")] pub struct SubgraphId { - /// A locally unique ID in the [`OpaqueSubgraphs`] instance. - local_id: usize, /// The unique ID of the [`OpaqueSubgraphs`] instance that generated this ID. tracker_id: usize, + /// A locally unique ID in the [`OpaqueSubgraphs`] instance. + local_id: usize, } /// A set of subgraphs a HUGR that have been marked as _unsupported_ during a @@ -46,16 +44,16 @@ pub(super) struct OpaqueSubgraphs { impl serde::Serialize for SubgraphId { fn serialize(&self, s: S) -> Result { - (&self.local_id, &self.tracker_id).serialize(s) + (&self.tracker_id, &self.local_id).serialize(s) } } impl<'de> serde::Deserialize<'de> for SubgraphId { fn deserialize>(d: D) -> Result { - let (local_id, tracker_id) = serde::Deserialize::deserialize(d)?; + let (tracker_id, local_id) = serde::Deserialize::deserialize(d)?; Ok(Self { - local_id, tracker_id, + local_id, }) } } @@ -86,7 +84,7 @@ impl OpaqueSubgraphs { id } - /// Returns the unsupported subgraph with the given ID. + /// Returns the opaque subgraph with the given ID. /// /// # Panics /// @@ -132,18 +130,15 @@ impl OpaqueSubgraphs { ))); }; - let Some(mut payload) = parse_external_payload(&payload)? else { + let Some(subgraph_id) = parse_external_payload(&payload)? else { // Inline payload, nothing to do. return Ok(()); }; - let OpaqueSubgraphPayloadType::External { id: subgraph_id } = payload.typ else { - unreachable!("Checked by `parse_external_payload`"); - }; if !self.contains(subgraph_id) { return Err(PytketEncodeError::custom(format!("Barrier operation with opgroup {OPGROUP_OPAQUE_HUGR} points to an unknown subgraph: {subgraph_id}"))); } - payload.typ = OpaqueSubgraphPayloadType::inline(&self[subgraph_id], hugr); + let payload = OpaqueSubgraphPayload::new_inline(&self[subgraph_id], hugr)?; command.op.data = Some(serde_json::to_string(&payload).unwrap()); Ok(()) @@ -176,25 +171,28 @@ impl Default for OpaqueSubgraphs { /// # Errors /// /// Returns an error if the payload is invalid. -fn parse_external_payload( +fn parse_external_payload( payload: &str, -) -> Result, PytketEncodeError> { - let mk_serde_error = |e: serde_json::Error| { - PytketEncodeError::custom(format!( - "Barrier operation with opgroup {OPGROUP_OPAQUE_HUGR} has corrupt data payload: {e}" - )) - }; - +) -> Result, PytketEncodeError> { // Check if the payload is inline, without fully copying it to memory. #[derive(serde::Deserialize)] struct PartialPayload { pub typ: String, - } - let partial_payload: PartialPayload = serde_json::from_str(payload).map_err(mk_serde_error)?; - if partial_payload.typ == "Inline" { - return Ok(None); + pub id: Option, } - let payload: OpaqueSubgraphPayload = serde_json::from_str(payload).map_err(mk_serde_error)?; - Ok(Some(payload)) + let partial_payload: PartialPayload = + serde_json::from_str(payload).map_err(|e: serde_json::Error| { + PytketEncodeError::custom(format!( + "Barrier operation with opgroup {OPGROUP_OPAQUE_HUGR} has corrupt data payload: {e}" + )) + })?; + + match (partial_payload.typ.as_str(), partial_payload.id) { + ("Inline", None) => Ok(None), + ("External", Some(id)) => Ok(Some(id)), + _ => Err(PytketEncodeError::custom(format!( + "Barrier operation with opgroup {OPGROUP_OPAQUE_HUGR} has invalid data payload: {payload:?}" + ))), + } } diff --git a/tket/src/serialize/pytket/opaque/payload.rs b/tket/src/serialize/pytket/opaque/payload.rs index a9f3245f8..9d6a25c59 100644 --- a/tket/src/serialize/pytket/opaque/payload.rs +++ b/tket/src/serialize/pytket/opaque/payload.rs @@ -1,12 +1,18 @@ //! Definitions of the payloads for opaque barrier metadata in pytket circuits. use hugr::core::HugrNode; -use hugr::envelope::EnvelopeConfig; +use hugr::envelope::{EnvelopeConfig, EnvelopeError}; +use hugr::extension::resolution::{resolve_type_extensions, WeakExtensionRegistry}; +use hugr::extension::{ExtensionRegistry, ExtensionRegistryLoadError}; use hugr::hugr::views::SiblingSubgraph; use hugr::package::Package; use hugr::types::Type; use hugr::{HugrView, Wire}; +use crate::serialize::pytket::{ + PytketDecodeError, PytketDecodeErrorInner, PytketEncodeError, PytketEncodeOpError, +}; + use super::SubgraphId; /// Pytket opgroup used to identify opaque barrier operations that encode opaque HUGR subgraphs. @@ -17,7 +23,7 @@ pub const OPGROUP_OPAQUE_HUGR: &str = "OPAQUE_HUGR"; /// Identifier for a wire in the Hugr, encoded as a 64-bit hash that is /// detached from the node IDs of the in-memory Hugr. /// -/// These are used to identify edges in the [`OpaqueSubgraphPayload`] +/// These are used to identify edges in the [`OpaqueSubgraphPayload::Inline`] /// payloads encoded in opaque barriers on the encoded pytket circuits. /// /// We require them to reconstruct the edges of the hugr that are not reflected @@ -55,47 +61,15 @@ impl EncodedEdgeID { /// envelope in the operation's date, or be a reference to a subgraph tracked /// inside a [`EncodedCircuit`][super::super::circuit::EncodedCircuit] /// structure. -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct OpaqueSubgraphPayload { - /// The type of payload. - /// - /// Either an inline hugr envelope or a reference to a subgraph tracked - /// inside a [`OpaqueSubgraphs`][super::OpaqueSubgraphs] structure. - #[serde(flatten)] - pub(super) typ: OpaqueSubgraphPayloadType, - /// Input types of the subgraph. - /// - /// Each input is assigned a unique edge identifier, so we can reconstruct - /// the connections that are not encoded in the pytket circuit. - /// - /// The types can also be inferred from the encoded hugr or linked - /// subcircuit, but we store them here to be robust. - inputs: Vec<(Type, EncodedEdgeID)>, - /// Output types of the subgraph. - /// - /// Each output is assigned a unique edge identifier, so we can reconstruct - /// the connections that are not encoded in the pytket circuit. - /// - /// The types can also be inferred from the encoded hugr or linked - /// subcircuit, but we store them here for robustness. - outputs: Vec<(Type, EncodedEdgeID)>, -} - -/// Payload for a pytket barrier metadata that indicates the barrier represents -/// an opaque HUGR subgraph. /// -/// The payload may be encoded inline, embedding the HUGR subgraph as an -/// envelope in the operation's date, or be a reference to a subgraph tracked -/// inside a [`EncodedCircuit`][super::super::circuit::EncodedCircuit] -/// structure. +/// Inline payloads encode their input and output boundaries that cannot be +/// encoded as pytket qubit/bit registers using [`EncodedEdgeID`]s independent +/// from the hugr. A circuit may mix barriers with both inline and external +/// payloads, as long as there are no edges requiring a [`EncodedEdgeID`] +/// between them. #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -#[serde(tag = "typ", content = "subgraph")] -pub enum OpaqueSubgraphPayloadType { - /// An inline payload, carrying the encoded envelope for the HUGR subgraph. - Inline { - /// A string envelope containing the encoded HUGR subgraph. - hugr_envelope: String, - }, +#[serde(tag = "typ")] +pub enum OpaqueSubgraphPayload { /// A reference to a subgraph tracked by an `OpaqueSubgraphs` registry /// in an [`EncodedCircuit`][super::super::circuit::EncodedCircuit] /// structure. @@ -103,38 +77,61 @@ pub enum OpaqueSubgraphPayloadType { /// The ID of the subgraph in the `OpaqueSubgraphs` registry. id: SubgraphId, }, + /// An inline payload, carrying the encoded envelope for the HUGR subgraph. + Inline { + /// A string envelope containing the encoded HUGR subgraph. + hugr_envelope: String, + /// Input types of the subgraph. + /// + /// Each input is assigned a unique edge identifier, so we can reconstruct + /// the connections that are not encoded in the pytket circuit. + /// + /// The types can also be inferred from the encoded hugr or linked + /// subcircuit, but we store them here to be robust. + inputs: Vec<(Type, EncodedEdgeID)>, + /// Output types of the subgraph. + /// + /// Each output is assigned a unique edge identifier, so we can reconstruct + /// the connections that are not encoded in the pytket circuit. + /// + /// The types can also be inferred from the encoded hugr or linked + /// subcircuit, but we store them here for robustness. + outputs: Vec<(Type, EncodedEdgeID)>, + }, } -impl OpaqueSubgraphPayloadType { - /// Create an inline payload by encoding the subgraph as an envelope. - // - // TODO: Detect and deal with non-local edges. Include global fn/const - // definitions, and reject other non-local edges. - // - // TODO: This should include descendants of the subgraph. It doesn't. - pub(super) fn inline( - subgraph: &SiblingSubgraph, - hugr: &impl HugrView, - ) -> Self { - let opaque_hugr = subgraph.extract_subgraph(hugr, ""); - let payload = Package::from_hugr(opaque_hugr) - .store_str(EnvelopeConfig::text()) - .unwrap(); - Self::Inline { - hugr_envelope: payload, - } +impl OpaqueSubgraphPayload { + /// Create an external payload by referencing a subgraph in the tracked by + /// an [`EncodedCircuit`][super::super::EncodedCircuit]. + pub fn new_external(subgraph_id: SubgraphId) -> Self { + Self::External { id: subgraph_id } } -} -impl OpaqueSubgraphPayload { /// Create a new payload for an opaque subgraph in the Hugr. - pub fn new( + /// + /// Encodes the subgraph into a hugr envelope. + /// + /// # Errors + /// + /// Returns an error if a node in the subgraph has children or non-local const edges. + pub fn new_inline( subgraph: &SiblingSubgraph, hugr: &impl HugrView, - typ: OpaqueSubgraphPayloadType, - ) -> Self { + ) -> Result> { let signature = subgraph.signature(hugr); + if !subgraph.function_calls().is_empty() + || subgraph + .nodes() + .iter() + .any(|n| hugr.children(*n).next().is_some()) + { + return Err(PytketEncodeOpError::UnsupportedStandaloneSubgraph { + nodes: subgraph.nodes().to_vec(), + } + .into()); + } + let mut inputs = Vec::with_capacity(subgraph.incoming_ports().iter().map(Vec::len).sum()); for subgraph_inputs in subgraph.incoming_ports() { let Some((inp_node, inp_port0)) = subgraph_inputs.first() else { @@ -150,35 +147,60 @@ impl OpaqueSubgraphPayload { .iter() .map(|(n, p)| EncodedEdgeID::new(Wire::new(*n, *p))); - Self { - typ, + // TODO: This should include descendants of the subgraph. It doesn't. + let opaque_hugr = subgraph.extract_subgraph(hugr, ""); + let hugr_envelope = Package::from_hugr(opaque_hugr) + .store_str(EnvelopeConfig::text()) + .unwrap(); + + Ok(Self::Inline { + hugr_envelope, inputs: signature.input().iter().cloned().zip(inputs).collect(), outputs: signature.output().iter().cloned().zip(outputs).collect(), - } - } - - /// Returns the inputs types and internal edge IDs of the payload. - pub fn inputs(&self) -> impl Iterator + '_ { - self.inputs.iter().map(|(t, e)| (t, *e)) + }) } - /// Returns the outputs types and internal edge IDs of the payload. - pub fn outputs(&self) -> impl Iterator + '_ { - self.outputs.iter().map(|(t, e)| (t, *e)) - } + /// Load a payload encoded in a json string. + /// + /// Updates weak extension references inside the definition after loading. + pub fn load_str(json: &str, extensions: &ExtensionRegistry) -> Result { + let mut payload: Self = serde_json::from_str(json).map_err(|e| { + PytketDecodeErrorInner::UnsupportedSubgraphPayload { + source: EnvelopeError::SerdeError { source: e }, + } + .wrap() + })?; + + // Resolve the extension ops and types in the inline payload. + if let Self::Inline { + inputs, outputs, .. + } = &mut payload + { + let extensions: WeakExtensionRegistry = extensions.into(); + + // Resolve the cached input/output types. + for (ty, _) in inputs.iter_mut().chain(outputs.iter_mut()) { + resolve_type_extensions(ty, &extensions).map_err(|e| { + let registry_load_e = + ExtensionRegistryLoadError::ExtensionResolutionError(Box::new(e)); + let envelope_e = EnvelopeError::ExtensionLoad { + source: registry_load_e, + }; + PytketDecodeErrorInner::UnsupportedSubgraphPayload { source: envelope_e }.wrap() + })?; + } + } - /// Returns the type of the payload. - pub fn typ(&self) -> &OpaqueSubgraphPayloadType { - &self.typ + Ok(payload) } /// Returns `true` if the payload is an inline payload. pub fn is_inline(&self) -> bool { - matches!(self.typ, OpaqueSubgraphPayloadType::Inline { .. }) + matches!(self, Self::Inline { .. }) } /// Returns `true` if the payload is an external payload. pub fn is_external(&self) -> bool { - matches!(self.typ, OpaqueSubgraphPayloadType::External { .. }) + matches!(self, Self::External { .. }) } } diff --git a/tket/src/serialize/pytket/options.rs b/tket/src/serialize/pytket/options.rs index f342c4d8f..5f2a3e47b 100644 --- a/tket/src/serialize/pytket/options.rs +++ b/tket/src/serialize/pytket/options.rs @@ -2,11 +2,14 @@ use std::sync::Arc; +use hugr::extension::ExtensionRegistry; use hugr::types::Signature; use hugr::{Hugr, HugrView, Node}; use crate::serialize::pytket::{PytketDecoderConfig, PytketEncoderConfig}; +use super::default_decoder_config; + /// Options used when decoding a pytket /// [`SerialCircuit`][tket_json_rs::circuit_json::SerialCircuit] into a HUGR. /// @@ -14,7 +17,14 @@ use crate::serialize::pytket::{PytketDecoderConfig, PytketEncoderConfig}; /// /// In contrast to [PytketDecoderConfig] which is normally statically defined by /// a library, these options may vary between calls. -#[derive(Default, Clone)] +/// +/// The generic parameter `H` is the HugrView type of the Hugr that was encoded +/// into the pytket circuit, if any. This is required when the encoded pytket +/// circuit contains opaque barriers that reference subgraphs in the original +/// HUGR. See +/// [`OpaqueSubgraphPayload`][super::opaque::OpaqueSubgraphPayload] +/// for more details. +#[derive(Clone, Debug, Default)] #[non_exhaustive] pub struct DecodeOptions { /// The configuration for the decoder, containing custom @@ -22,11 +32,6 @@ pub struct DecodeOptions { /// /// When `None`, we will use [`default_decoder_config`][super::default_decoder_config]. pub config: Option>, - /// The name of the function to create. - /// - /// If `None`, we will use the name of the circuit, or "main" if the circuit - /// has no name. - pub fn_name: Option, /// The signature of the function to create. /// /// The number of qubits in the input types must be less than or equal to the @@ -49,6 +54,11 @@ pub struct DecodeOptions { /// If additional parameters are found in the circuit, they will be added /// after these using generic names. pub input_params: Vec, + /// The extensions to use when loading the HUGR envelope. + /// + /// When `None`, we will use a default registry that includes the prelude, + /// std, TKET1, and TketOps extensions. + pub extensions: Option, } impl DecodeOptions { @@ -58,45 +68,104 @@ impl DecodeOptions { } /// Set a decoder configuration. + #[must_use] pub fn with_config(mut self, config: impl Into>) -> Self { self.config = Some(config.into()); self } - /// Set the name of the function to create. - pub fn with_fn_name(mut self, fn_name: impl ToString) -> Self { - self.fn_name = Some(fn_name.to_string()); + /// Set `DecodeOptions::config` to use [`default_decoder_config`]. + #[must_use] + pub fn with_default_config(mut self) -> Self { + self.config = Some(Arc::new(default_decoder_config())); self } /// Set the signature of the function to create. + #[must_use] pub fn with_signature(mut self, signature: Signature) -> Self { self.signature = Some(signature); self } /// Set the input parameter names. + #[must_use] pub fn with_input_params(mut self, input_params: impl IntoIterator) -> Self { self.input_params = input_params.into_iter().collect(); self } + + /// Set the extensions to use when loading the HUGR envelope. + #[must_use] + pub fn with_extensions(mut self, extensions: ExtensionRegistry) -> Self { + self.extensions = Some(extensions); + self + } + + /// Returns the extensions to use when loading the HUGR envelope. + /// + /// If the option is `None`, we will use a default registry that includes + /// the prelude, std, TKET1, and TketOps extensions. + pub fn extension_registry(&self) -> &ExtensionRegistry { + self.extensions + .as_ref() + .unwrap_or(&crate::extension::REGISTRY) + } + + /// Returns the [`PytketDecoderConfig`] to use when decoding the circuit. + /// + /// # Panics + /// + /// Panics if the option is `None`. Use [`DecodeOptions::with_config`] or + /// [`DecodeOptions::with_default_config`] to set it. + pub(super) fn get_config(&self) -> &Arc { + self.config + .as_ref() + .expect("DecodeOptions::config is not set") + } } /// Where to insert the decoded circuit when calling /// [`TKETDecode::decode_inplace`][super::TKETDecode::decode_inplace]. -#[derive(Debug, derive_more::Display, Default, Clone, Copy, PartialEq, Eq)] +#[derive(Debug, derive_more::Display, Clone, PartialEq, Eq)] #[non_exhaustive] pub enum DecodeInsertionTarget { /// Insert the decoded circuit as a new function in the HUGR. - #[default] - Function, + #[display("{}", + match fn_name { + Some(fn_name) => format!("Function({fn_name})"), + None => "Function".to_string(), + } + )] + Function { + /// The name of the function to create. + /// + /// If `None`, we will use the encoded circuit's name, or "main" if the circuit has no name. + fn_name: Option, + }, /// Insert the decoded circuit as a dataflow region in the HUGR under the given parent. + #[display("Region({parent})")] Region { /// The parent node that will contain the circuit's decoded DFG. parent: Node, }, } +impl DecodeInsertionTarget { + /// Create a new [`DecodeInsertionTarget::Function`] with the default values. + pub fn function(fn_name: impl ToString) -> Self { + Self::Function { + fn_name: Some(fn_name.to_string()), + } + } +} + +impl Default for DecodeInsertionTarget { + fn default() -> Self { + Self::Function { fn_name: None } + } +} + /// Options used when encoding a HUGR into a pytket /// [`SerialCircuit`][tket_json_rs::circuit_json::SerialCircuit]. /// diff --git a/tket/src/serialize/pytket/tests.rs b/tket/src/serialize/pytket/tests.rs index 7d807cd1c..cb900bc60 100644 --- a/tket/src/serialize/pytket/tests.rs +++ b/tket/src/serialize/pytket/tests.rs @@ -3,15 +3,17 @@ use std::collections::{HashMap, HashSet}; use std::io::BufReader; +use cool_asserts::assert_matches; use hugr::builder::{ Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, }; -use hugr::extension::prelude::{bool_t, qb_t}; +use hugr::extension::prelude::{bool_t, option_type, qb_t, UnwrapBuilder}; +use std::sync::Arc; use hugr::hugr::hugrmut::HugrMut; use hugr::ops::handle::FuncID; -use hugr::ops::{OpParent, Value}; +use hugr::ops::{OpParent, OpType, Value}; use hugr::std_extensions::arithmetic::float_ops::FloatOps; use hugr::types::Signature; use hugr::HugrView; @@ -27,8 +29,12 @@ use crate::extension::bool::BoolOp; use crate::extension::rotation::{rotation_type, ConstRotation, RotationOp}; use crate::extension::sympy::SympyOpDef; use crate::extension::TKET1_EXTENSION_ID; +use crate::serialize::pytket::extension::{CoreDecoder, OpaqueTk1Op, PreludeEmitter}; +use crate::serialize::pytket::PytketEncodeError; use crate::serialize::pytket::{ - DecodeInsertionTarget, DecodeOptions, EncodeOptions, EncodedCircuit, + default_decoder_config, default_encoder_config, DecodeInsertionTarget, DecodeOptions, + EncodeOptions, EncodedCircuit, PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig, + PytketEncodeOpError, PytketEncoderConfig, }; use crate::TketOp; @@ -277,6 +283,48 @@ fn circ_parameterized() -> Circuit { hugr.into() } +/// A circuit with a TK1 opaque operation. +#[fixture] +fn circ_tk1_ops() -> Circuit { + let input_t = vec![qb_t(), qb_t()]; + let output_t = vec![qb_t(), qb_t()]; + let mut h = FunctionBuilder::new("tk1_ops", Signature::new(input_t, output_t)).unwrap(); + + let [q1, q2] = h.input_wires_arr(); + + // An unsupported tk1-only operation. + let mut tk1op = tket_json_rs::circuit_json::Operation::default(); + tk1op.op_type = tket_json_rs::optype::OpType::CH; + tk1op.n_qb = Some(2); + let op: OpType = OpaqueTk1Op::new_from_op(&tk1op, 2, 0) + .as_extension_op() + .into(); + let [q1, q2] = h.add_dataflow_op(op, [q1, q2]).unwrap().outputs_arr(); + + let hugr = h.finish_hugr_with_outputs([q1, q2]).unwrap(); + hugr.into() +} + +/// A circuit with a nested unsupported operation. +/// +/// Tries to allocate a qubit, and panics if it fails. +/// This creates an unsupported conditional inside the region. +#[fixture] +fn circ_nested_opaque() -> Circuit { + let input_t = vec![]; + let output_t = vec![qb_t()]; + let mut h = FunctionBuilder::new("nested_opaque", Signature::new(input_t, output_t)).unwrap(); + + let [maybe_q] = h + .add_dataflow_op(TketOp::TryQAlloc, []) + .unwrap() + .outputs_arr(); + let [q] = h.build_unwrap_sum(1, option_type(qb_t()), maybe_q).unwrap(); + + let hugr = h.finish_hugr_with_outputs([q]).unwrap(); + hugr.into() +} + /// A circuit with a recursive function call. #[fixture] fn circ_recursive() -> Circuit { @@ -529,6 +577,33 @@ fn circ_nested_dfgs() -> Circuit { h.finish_hugr_with_outputs([bool]).unwrap().into() } +// A circuit with some simple circuit and an unsupported subgraph that does not interact with it. +#[fixture] +fn circ_independent_subgraph() -> Circuit { + let input_t = vec![ + qb_t(), + qb_t(), + option_type(rotation_type()).into(), + option_type(qb_t()).into(), + ]; + let output_t = vec![qb_t(), qb_t(), rotation_type(), option_type(qb_t()).into()]; + let mut h = + FunctionBuilder::new("independent_subgraph", Signature::new(input_t, output_t)).unwrap(); + + let [q1, q2, maybe_rot, maybe_q] = h.input_wires_arr(); + + let [q1, q2] = h + .add_dataflow_op(TketOp::CX, [q1, q2]) + .unwrap() + .outputs_arr(); + let [rot] = h + .build_unwrap_sum(1, option_type(rotation_type()), maybe_rot) + .unwrap(); + + let hugr = h.finish_hugr_with_outputs([q1, q2, rot, maybe_q]).unwrap(); + hugr.into() +} + /// Check that all circuit ops have been translated to a native gate. /// /// Panics if there are tk1 ops in the circuit. @@ -594,34 +669,91 @@ fn json_file_roundtrip(#[case] circ: impl AsRef) { compare_serial_circs(&ser, &reser); } -/// Test the serialisation roundtrip from a tket circuit. +/// Test parameter to select which decoders/encoders to enable. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum CircuitRoundtripTestConfig { + // Use the default decoder/encoder configuration. + Default, + // Use only the prelude and core decoders/encoders, with no std ones. + NoStd, +} + +impl CircuitRoundtripTestConfig { + fn decoder_config(&self) -> PytketDecoderConfig { + match self { + CircuitRoundtripTestConfig::Default => default_decoder_config(), + CircuitRoundtripTestConfig::NoStd => { + let mut config = PytketDecoderConfig::new(); + config.add_decoder(CoreDecoder); + config.add_decoder(PreludeEmitter); + config.add_type_translator(PreludeEmitter); + config + } + } + } + + fn encoder_config(&self) -> PytketEncoderConfig { + match self { + CircuitRoundtripTestConfig::Default => default_encoder_config(), + CircuitRoundtripTestConfig::NoStd => { + let mut config = PytketEncoderConfig::new(); + config.add_emitter(PreludeEmitter); + config.add_type_translator(PreludeEmitter); + config + } + } + } +} + +/// Test the standalone serialisation roundtrip from a tket circuit. +/// +/// This is not a pure roundtrip as the encoder may add internal qubits/bits to +/// the circuit. /// -/// Note: this is not a pure roundtrip as the encoder may add internal qubits/bits to the circuit. +/// Standalone circuit do not currently support unsupported subgraphs with +/// nested structure or non-local edges. #[rstest] -#[case::meas_ancilla(circ_measure_ancilla(), 1)] -#[case::preset_qubits(circ_preset_qubits(), 1)] -#[case::preset_parameterized(circ_parameterized(), 1)] -#[case::nested_dfgs(circ_nested_dfgs(), 1)] -#[case::global_defs(circ_global_defs(), 1)] -#[case::recursive(circ_recursive(), 1)] -#[case::non_local(circ_non_local(), 1)] -fn circuit_roundtrip(#[case] circ: Circuit, #[case] num_circuits: usize) { +#[case::meas_ancilla(circ_measure_ancilla(), 1, CircuitRoundtripTestConfig::Default)] +#[case::preset_qubits(circ_preset_qubits(), 1, CircuitRoundtripTestConfig::Default)] +#[case::preset_parameterized(circ_parameterized(), 1, CircuitRoundtripTestConfig::Default)] +#[case::nested_dfgs(circ_nested_dfgs(), 1, CircuitRoundtripTestConfig::Default)] +#[case::tk1_ops(circ_tk1_ops(), 1, CircuitRoundtripTestConfig::Default)] +#[case::missing_decoders(circ_measure_ancilla(), 1, CircuitRoundtripTestConfig::NoStd)] +fn circuit_standalone_roundtrip( + #[case] circ: Circuit, + #[case] num_circuits: usize, + #[case] config: CircuitRoundtripTestConfig, +) { let circ_signature = circ.circuit_signature().into_owned(); + let decode_options = DecodeOptions::new() + .with_signature(circ_signature.clone()) + .with_config(config.decoder_config()); + let encode_options = EncodeOptions::new_with_subcircuits().with_config(config.encoder_config()); - let encoded: EncodedCircuit = - EncodedCircuit::from_hugr(&circ, EncodeOptions::new_with_subcircuits()) - .unwrap_or_else(|e| panic!("{e}")); + let encoded = EncodedCircuit::new_standalone(&circ, encode_options.clone()) + .unwrap_or_else(|e| panic!("{e}")); assert!(encoded.contains_circuit(circ.parent())); assert_eq!(encoded.len(), num_circuits); - let ser: SerialCircuit = encoded - .extract_standalone() + // Re-encode the EncodedCircuit + let extracted_from_circ = encoded + .reassemble( + circ.parent(), + Some("main".to_string()), + decode_options.clone(), + ) .unwrap_or_else(|e| panic!("{e}")); - let deser: Circuit = ser - .decode(DecodeOptions::new().with_signature(circ_signature.clone())) + extracted_from_circ + .validate() .unwrap_or_else(|e| panic!("{e}")); + // Extract the head pytket circuit, and re-encode it on its own. + let ser: &SerialCircuit = &encoded[circ.parent()]; + let deser: Circuit = ser.decode(decode_options).unwrap_or_else(|e| panic!("{e}")); + + deser.hugr().validate().unwrap_or_else(|e| panic!("{e}")); + let deser_sig = deser.circuit_signature(); assert_eq!( &circ_signature.input, &deser_sig.input, @@ -634,9 +766,99 @@ fn circuit_roundtrip(#[case] circ: Circuit, #[case] num_circuits: usize) { &circ_signature, &deser_sig ); - let reser = SerialCircuit::encode(&deser, EncodeOptions::new()).unwrap(); + let reser = SerialCircuit::encode(&deser, encode_options).unwrap(); + validate_serial_circ(&reser); - compare_serial_circs(&ser, &reser); + compare_serial_circs(ser, &reser); +} + +/// Test that more complex unsupported subgraphs (nested structure, non-local edges) are rejected when encoding a standalone circuit. +#[rstest] +#[case::nested_opaque(circ_nested_opaque())] +#[case::global_defs(circ_global_defs())] +#[case::recursive(circ_recursive())] +fn reject_standalone_complex_subgraphs(#[case] circ: Circuit) { + let try_encoded = EncodedCircuit::new_standalone(&circ, EncodeOptions::new()); + assert_matches!( + try_encoded, + Err(PytketEncodeError::OpEncoding( + PytketEncodeOpError::UnsupportedStandaloneSubgraph { .. } + )) + ); +} + +/// Test that modifying the hugr before reassembling an EncodedCircuit fails. +#[rstest] +fn fail_on_modified_hugr(circ_tk1_ops: Circuit) { + let encoded = EncodedCircuit::new(&circ_tk1_ops, EncodeOptions::new_with_subcircuits()) + .unwrap_or_else(|e| panic!("{e}")); + + let mut a_new_hugr = ModuleBuilder::new(); + a_new_hugr + .declare("decl", Signature::new_endo(vec![qb_t()]).into()) + .unwrap(); + let mut a_new_hugr = a_new_hugr.finish_hugr().unwrap(); + + let try_reassemble = encoded.reassemble_inline(&mut a_new_hugr, None); + + assert_matches!( + try_reassemble, + Err(PytketDecodeError { + inner: PytketDecodeErrorInner::IncompatibleTargetRegion { .. }, + .. + }) + ); +} + +/// Test the serialisation roundtrip from a tket circuit into an EncodedCircuit and back. +#[rstest] +#[case::meas_ancilla(circ_measure_ancilla(), 1, CircuitRoundtripTestConfig::Default)] +#[case::preset_qubits(circ_preset_qubits(), 1, CircuitRoundtripTestConfig::Default)] +#[case::preset_parameterized(circ_parameterized(), 1, CircuitRoundtripTestConfig::Default)] +#[case::nested_dfgs(circ_nested_dfgs(), 1, CircuitRoundtripTestConfig::Default)] +#[case::flat_opaque(circ_tk1_ops(), 1, CircuitRoundtripTestConfig::Default)] +// TODO: Fail due to eagerly emitting QAllocs that never get consumed. We should do that lazily. +// Also requires to be published in hugr 0.24.1 +//#[case::nested_opaque(circ_nested_opaque(), 3, CircuitRoundtripTestConfig::Default)] +#[case::global_defs(circ_global_defs(), 1, CircuitRoundtripTestConfig::Default)] +#[case::recursive(circ_recursive(), 1, CircuitRoundtripTestConfig::Default)] +// TODO: Encoding of independent subgraphs needs more debugging. +// Also requires to be published in hugr 0.24.1 +//#[case::independent_subgraph(circ_independent_subgraph(), 3, CircuitRoundtripTestConfig::Default)] +// TODO: fix edge case: non-local edge from an unsupported node inside a nested CircBox +// to/from the input of the head region being encoded... +//#[case::non_local(circ_non_local(), 1)] +fn encoded_circuit_roundtrip( + #[case] circ: Circuit, + #[case] num_circuits: usize, + #[case] config: CircuitRoundtripTestConfig, +) { + let circ_signature = circ.circuit_signature().into_owned(); + let encode_options = EncodeOptions::new_with_subcircuits().with_config(config.encoder_config()); + + let encoded = EncodedCircuit::new(&circ, encode_options).unwrap_or_else(|e| panic!("{e}")); + + assert!(encoded.contains_circuit(circ.parent())); + assert_eq!(encoded.len(), num_circuits); + + let mut deser = circ.clone(); + encoded + .reassemble_inline(deser.hugr_mut(), Some(Arc::new(config.decoder_config()))) + .unwrap_or_else(|e| panic!("{e}")); + + deser.hugr().validate().unwrap_or_else(|e| panic!("{e}")); + + let deser_sig = deser.circuit_signature(); + assert_eq!( + &circ_signature.input, &deser_sig.input, + "Input signature mismatch\n Expected: {}\n Actual: {}", + &circ_signature, &deser_sig + ); + assert_eq!( + &circ_signature.output, &deser_sig.output, + "Output signature mismatch\n Expected: {}\n Actual: {}", + &circ_signature, &deser_sig + ); } /// Test serialisation of circuits with a symbolic expression. @@ -672,7 +894,7 @@ fn test_inplace_decoding() { let func1 = serial .decode_inplace( builder.hugr_mut(), - DecodeInsertionTarget::Function, + DecodeInsertionTarget::Function { fn_name: None }, DecodeOptions::new(), ) .unwrap(); diff --git a/tket1-passes/Cargo.toml b/tket1-passes/Cargo.toml index f580239e5..4fbdb7e82 100644 --- a/tket1-passes/Cargo.toml +++ b/tket1-passes/Cargo.toml @@ -25,6 +25,11 @@ conan2 = "0.1.8" [dev-dependencies] rstest.workspace = true +itertools.workspace = true + +# Used for testing +hugr = { workspace = true } +tket = { path = "../tket", version = "0.15.0" } [lints] workspace = true diff --git a/tket1-passes/tests/tket1-on-hugr.rs b/tket1-passes/tests/tket1-on-hugr.rs new file mode 100644 index 000000000..94823c38e --- /dev/null +++ b/tket1-passes/tests/tket1-on-hugr.rs @@ -0,0 +1,82 @@ +//! Test running tket1 passes on hugr circuit. + +use tket1_passes::Tket1Circuit; + +use hugr::builder::{BuildError, Dataflow, DataflowHugr, FunctionBuilder}; +use hugr::extension::prelude::qb_t; +use hugr::types::Signature; +use hugr::{HugrView, Node}; +use rstest::{fixture, rstest}; +use tket::extension::{TKET1_EXTENSION_ID, TKET_EXTENSION_ID}; +use tket::serialize::pytket::{EncodeOptions, EncodedCircuit}; +use tket::{Circuit, TketOp}; + +/// A flat quantum circuit inside a function. +/// +/// This should optimize to the identity. +#[fixture] +fn circ_flat_quantum() -> Circuit { + fn build() -> Result { + let input_t = vec![qb_t(), qb_t()]; + let output_t = vec![qb_t(), qb_t()]; + let mut h = + FunctionBuilder::new("preset_qubits", Signature::new(input_t, output_t)).unwrap(); + + let mut circ = h.as_circuit(h.input_wires()); + + circ.append(TketOp::X, [0])?; + circ.append(TketOp::CX, [0, 1])?; + circ.append(TketOp::X, [0])?; + circ.append(TketOp::CX, [1, 0])?; + circ.append(TketOp::X, [0])?; + circ.append(TketOp::X, [1])?; + circ.append(TketOp::CX, [0, 1])?; + + let wires = circ.finish(); + // Implicit swap + let wires = [wires[1], wires[0]]; + + let hugr = h.finish_hugr_with_outputs(wires).unwrap(); + + Ok(hugr.into()) + } + build().unwrap() +} + +#[rstest] +#[case(circ_flat_quantum(), 0)] +fn test_clifford_simp(#[case] circ: Circuit, #[case] num_remaining_gates: usize) { + let mut encoded = EncodedCircuit::new(&circ, EncodeOptions::new_with_subcircuits()).unwrap(); + + for (_region, serial_circuit) in encoded.iter_mut() { + let mut circuit_ptr = Tket1Circuit::from_serial_circuit(serial_circuit).unwrap(); + circuit_ptr + .clifford_simp(tket_json_rs::OpType::CX, true) + .unwrap(); + *serial_circuit = circuit_ptr.to_serial_circuit().unwrap(); + } + + let mut new_circ = circ.clone(); + let updated_regions = encoded + .reassemble_inline(new_circ.hugr_mut(), None) + .unwrap(); + + let quantum_ops: usize = updated_regions + .iter() + .map(|region| count_quantum_gates(&new_circ, *region)) + .sum(); + assert_eq!(quantum_ops, num_remaining_gates); +} + +/// Helper method to count the number of quantum operations in a hugr region. +fn count_quantum_gates(circuit: &Circuit, region: Node) -> usize { + circuit + .hugr() + .children(region) + .filter(|child| { + let op = circuit.hugr().get_optype(*child); + op.as_extension_op() + .is_some_and(|e| [TKET_EXTENSION_ID, TKET1_EXTENSION_ID].contains(e.extension_id())) + }) + .count() +}