Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions tket/src/serialize/pytket/config/decoder_config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
//! A configuration struct contains a list of custom decoders that define
//! translations of legacy tket primitives into HUGR operations.

use hugr::builder::DFGBuilder;
use hugr::types::Type;
use hugr::{Hugr, Wire};
use itertools::Itertools;
use std::collections::HashMap;

Expand Down Expand Up @@ -107,4 +109,25 @@ impl PytketDecoderConfig {
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_translators.type_to_pytket(typ)
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
self.type_translators.types_are_isomorphic(typ1, typ2)
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(in crate::serialize::pytket) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
self.type_translators
.transform_typed_value(wire, initial_type, target_type, builder)
}
}
88 changes: 85 additions & 3 deletions tket/src/serialize/pytket/config/type_translators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,16 @@
use std::collections::HashMap;
use std::sync::RwLock;

use hugr::builder::{BuildError, DFGBuilder, Dataflow};
use hugr::extension::prelude::bool_t;
use hugr::extension::ExtensionId;
use hugr::types::{Type, TypeEnum};
use hugr::{Hugr, Wire};
use itertools::Itertools;

use crate::extension::bool::BoolOp;
use crate::serialize::pytket::extension::{PytketTypeTranslator, RegisterCount};
use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner};

/// A set of [`PytketTypeTranslator`]s that can be used to translate HUGR types
/// into pytket registers (qubits, bits, and parameter expressions).
Expand Down Expand Up @@ -61,6 +65,13 @@ impl TypeTranslatorSet {
/// Only tuple sums, bools, and custom types are supported.
/// Other types will return `None`.
pub fn type_to_pytket(&self, typ: &Type) -> Option<RegisterCount> {
self.type_to_pytket_internal(typ).filter(|c| !c.is_empty())
}

/// Recursive call for [`Self::type_to_pytket`].
///
/// This allows returning empty register counts, for types that may be included inside other types.
fn type_to_pytket_internal(&self, typ: &Type) -> Option<RegisterCount> {
let cache = self.type_cache.read().ok();
if let Some(count) = cache.and_then(|c| c.get(typ).cloned()) {
return count;
Expand All @@ -79,7 +90,7 @@ impl TypeTranslatorSet {
.iter()
.map(|ty| {
match ty.clone().try_into() {
Ok(ty) => self.type_to_pytket(&ty),
Ok(ty) => self.type_to_pytket_internal(&ty),
// Sum types with row variables (variable tuple lengths) are not supported.
Err(_) => None,
}
Expand Down Expand Up @@ -121,6 +132,77 @@ impl TypeTranslatorSet {
.flatten()
.map(move |idx| &self.type_translators[*idx])
}

/// Returns `true` if the two types are isomorphic. I.e. they can be translated
/// into each other without losing information.
//
// TODO: We should allow custom TypeTranslators to expand this checks,
// and implement their own translations.
pub fn types_are_isomorphic(&self, typ1: &Type, typ2: &Type) -> bool {
if typ1 == typ2 {
return true;
}

// For now, we just hard-code this to the two kind of bits we support.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only works if the sum-bool is used linearly, right? Is this a problem?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or do we assume that pytket extraction only runs after bool linearisation?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sum bool is marked as unsupported, so we won't see them here.

let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if (typ1 == &native_bool && typ2 == &tket_bool)
|| (typ1 == &tket_bool && typ2 == &native_bool)
{
return true;
}

false
}

/// Inserts the necessary operations to translate a type into an isomorphic
/// type.
///
/// This operation fails if [`Self::types_are_isomorphic`] returns `false`.
pub(super) fn transform_typed_value(
&self,
wire: Wire,
initial_type: &Type,
target_type: &Type,
builder: &mut DFGBuilder<&mut Hugr>,
) -> Result<Wire, PytketDecodeError> {
if initial_type == target_type {
return Ok(wire);
}

let map_build_error = |e: BuildError| PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: Some(e.to_string()),
};

// Hard-coded transformations until customs calls are added to [`PytketTypeTranslator`].
let native_bool = bool_t();
let tket_bool = crate::extension::bool::bool_type();
if initial_type == &native_bool && target_type == &tket_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::make_opaque, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}
if initial_type == &tket_bool && target_type == &native_bool {
let [wire] = builder
.add_dataflow_op(BoolOp::read, [wire])
.map_err(map_build_error)?
.outputs_arr();
return Ok(wire);
}

Err(PytketDecodeErrorInner::CannotTranslateWire {
wire,
initial_type: initial_type.to_string(),
target_type: target_type.to_string(),
context: None,
}
.wrap())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -161,10 +243,10 @@ mod tests {
}

#[rstest::rstest]
#[case::empty(SumType::new_unary(0).into(), Some(RegisterCount::default()))]
#[case::empty(SumType::new_unary(0).into(), None)]
#[case::native_bool(SumType::new_unary(2).into(), Some(RegisterCount::only_bits(1)))]
#[case::simple(bool_t(), Some(RegisterCount::only_bits(1)))]
#[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t()]).into(), Some(RegisterCount::new(1, 2, 0)))]
#[case::tuple(SumType::new_tuple(vec![bool_t(), qb_t(), bool_t(), SumType::new_unary(1).into()]).into(), Some(RegisterCount::new(1, 2, 0)))]
#[case::unsupported(SumType::new([vec![bool_t(), qb_t()], vec![bool_t()]]).into(), None)]
fn test_translations(
translator_set: TypeTranslatorSet,
Expand Down
46 changes: 32 additions & 14 deletions tket/src/serialize/pytket/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ mod wires;

use hugr::extension::ExtensionRegistry;
use hugr::hugr::hugrmut::HugrMut;
use hugr::std_extensions::arithmetic::float_types::float64_type;
pub use param::{LoadedParameter, ParameterType};
pub use tracked_elem::{TrackedBit, TrackedQubit};
pub use wires::TrackedWires;
Expand All @@ -18,7 +19,7 @@ use std::sync::Arc;
use hugr::builder::{BuildHandle, Container, DFGBuilder, Dataflow, FunctionBuilder, SubContainer};
use hugr::extension::prelude::{bool_t, qb_t};
use hugr::ops::handle::{DataflowOpID, NodeHandle};
use hugr::ops::{OpParent, OpTrait, OpType, Value, DFG};
use hugr::ops::{OpParent, OpTrait, OpType, DFG};
use hugr::types::{Signature, Type, TypeRow};
use hugr::{Hugr, HugrView, Node, OutgoingPort, Wire};
use tracked_elem::{TrackedBitId, TrackedQubitId};
Expand Down Expand Up @@ -272,14 +273,16 @@ impl<'h> PytketDecoderContext<'h> {
wire_tracker.register_input_parameter(LoadedParameter::rotation(wire), param)?;
}

// Any additional qubits or bits required by the circuit get initialized to |0> / false.
// Any additional qubits or bits required by the circuit are registered
// in the tracker without a wire being created.
//
// We'll lazily initialize them with a QAlloc or a LoadConstant
// operation if necessary.
for q in qubits {
let q_wire = dfg.add_dataflow_op(TketOp::QAlloc, []).unwrap().out_wire(0);
wire_tracker.track_wire(q_wire, q.ty(), [q], [])?;
wire_tracker.track_qubit(q.pytket_register_arc(), Some(q.reg_hash()))?;
}
for b in bits {
let b_wire = dfg.add_load_value(Value::false_val());
wire_tracker.track_wire(b_wire, b.ty(), [], [b])?;
wire_tracker.track_bit(b.pytket_register_arc(), Some(b.reg_hash()))?;
}

wire_tracker.compute_output_permutation(&serialcirc.implicit_permutation);
Expand Down Expand Up @@ -341,7 +344,8 @@ impl<'h> PytketDecoderContext<'h> {
let found_wire = self
.wire_tracker
.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
ty,
&mut qubits,
&mut bits,
Expand Down Expand Up @@ -369,7 +373,14 @@ impl<'h> PytketDecoderContext<'h> {
let wire = match found_wire {
FoundWire::Register(wire) => wire.wire(),

FoundWire::Parameter(param) => param.wire(),
FoundWire::Parameter(param) => {
let param_ty = if ty == &float64_type() {
ParameterType::FloatHalfTurns
} else {
ParameterType::Rotation
};
param.with_type(param_ty, &mut self.builder).wire()
}
FoundWire::Unsupported { .. } => {
// Disconnected port with an unsupported type. We just skip
// it, since it must have been disconnected in the original
Expand Down Expand Up @@ -417,7 +428,8 @@ impl<'h> PytketDecoderContext<'h> {
for q in qubits.iter() {
let mut qubit_args: &[TrackedQubit] = std::slice::from_ref(q);
let Ok(FoundWire::Register(wire)) = self.wire_tracker.find_typed_wire(
self.config(),
&self.config,
&mut self.builder,
&qb_type,
&mut qubit_args,
&mut bit_args,
Expand Down Expand Up @@ -540,14 +552,20 @@ impl<'h> PytketDecoderContext<'h> {
/// - [`PytketDecodeErrorInner::UnexpectedInputType`] if a type in `types` cannot be mapped to a [`RegisterCount`]
/// - [`PytketDecodeErrorInner::NoMatchingWire`] if there is no wire with the requested type for the given qubit/bit arguments.
pub fn find_typed_wires(
&self,
&mut self,
types: &[Type],
qubit_args: &[TrackedQubit],
bit_args: &[TrackedBit],
params: &[LoadedParameter],
) -> Result<TrackedWires, PytketDecodeError> {
self.wire_tracker
.find_typed_wires(self.config(), types, qubit_args, bit_args, params)
self.wire_tracker.find_typed_wires(
&self.config,
&mut self.builder,
types,
qubit_args,
bit_args,
params,
)
}

/// Connects the input ports of a node using a list of input qubits, bits,
Expand Down Expand Up @@ -663,8 +681,8 @@ impl<'h> PytketDecoderContext<'h> {
}

// Gather the input wires, with the types needed by the operation.
let input_wires =
self.find_typed_wires(sig.input_types(), input_qubits, input_bits, params)?;
let input_types = sig.input_types().to_vec();
let input_wires = self.find_typed_wires(&input_types, input_qubits, input_bits, params)?;
debug_assert_eq!(op_input_count, input_wires.register_count());

for (input_idx, wire) in input_wires.wires().enumerate() {
Expand Down
Loading
Loading