Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 30 additions & 29 deletions tket/src/serialize/pytket/decoder/subgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@ use std::sync::Arc;

use hugr::builder::Container;
use hugr::hugr::hugrmut::{HugrMut, InsertedForest};
use hugr::hugr::views::SiblingSubgraph;
use hugr::ops::{OpTag, OpTrait};
use hugr::types::{Signature, Type};
use hugr::types::Type;
use hugr::{Hugr, HugrView, Node, OutgoingPort, PortIndex, Wire};
use hugr_core::hugr::internal::HugrMutInternals;
use itertools::Itertools;
Expand All @@ -16,7 +15,9 @@ use crate::serialize::pytket::decoder::{
DecodeStatus, FoundWire, LoadedParameter, PytketDecoderContext, TrackedBit, TrackedQubit,
};
use crate::serialize::pytket::extension::RegisterCount;
use crate::serialize::pytket::opaque::{EncodedEdgeID, OpaqueSubgraphPayload, SubgraphId};
use crate::serialize::pytket::opaque::{
EncodedEdgeID, OpaqueSubgraph, OpaqueSubgraphPayload, SubgraphId,
};
use crate::serialize::pytket::{PytketDecodeError, PytketDecodeErrorInner, PytketDecoderConfig};

impl<'h> PytketDecoderContext<'h> {
Expand Down Expand Up @@ -64,13 +65,8 @@ impl<'h> PytketDecoderContext<'h> {
let Some(subgraph) = self.opaque_subgraphs.and_then(|s| s.get(id)) else {
return Err(PytketDecodeErrorInner::OpaqueSubgraphNotFound { id }.wrap());
};
let signature = subgraph.signature(self.builder.hugr());

let old_parent = self
.builder
.hugr()
.get_parent(subgraph.nodes()[0])
.ok_or_else(|| PytketDecodeErrorInner::ExternalSubgraphWasModified { id }.wrap())?;
let old_parent = subgraph.region();
if !OpTag::DataflowParent.is_superset(self.builder.hugr().get_optype(old_parent).tag()) {
return Err(PytketDecodeErrorInner::ExternalSubgraphWasModified { id }.wrap());
}
Expand All @@ -82,35 +78,36 @@ impl<'h> PytketDecoderContext<'h> {
}

self.rewire_external_subgraph_inputs(
subgraph, qubits, bits, params, old_parent, new_parent, &signature,
subgraph, qubits, bits, params, old_parent, new_parent,
)?;

self.rewire_external_subgraph_outputs(
subgraph, qubits, bits, old_parent, new_parent, &signature,
)?;
self.rewire_external_subgraph_outputs(subgraph, qubits, bits, old_parent, new_parent)?;

Ok(DecodeStatus::Success)
}

/// Rewire the inputs of an external subgraph moved to the new region.
///
/// Helper for [`Self::insert_external_subgraph`].
#[expect(clippy::too_many_arguments)]
fn rewire_external_subgraph_inputs(
&mut self,
subgraph: &SiblingSubgraph<Node>,
subgraph: &OpaqueSubgraph<Node>,
mut input_qubits: &[TrackedQubit],
mut input_bits: &[TrackedBit],
mut input_params: &[LoadedParameter],
old_parent: Node,
new_parent: Node,
signature: &Signature,
) -> Result<(), PytketDecodeError> {
let old_input = self.builder.hugr().get_io(old_parent).unwrap()[0];
let new_input = self.builder.hugr().get_io(new_parent).unwrap()[0];

// Reconnect input wires from parts of/nodes in the region that have been encoded into pytket.
for (ty, targets) in signature.input().iter().zip_eq(subgraph.incoming_ports()) {
for (ty, (tgt_node, tgt_port)) in subgraph
.signature()
.input()
.iter()
.zip_eq(subgraph.incoming_ports())
{
let found_wire = self.wire_tracker.find_typed_wire(
self.config(),
ty,
Expand All @@ -125,9 +122,11 @@ impl<'h> PytketDecoderContext<'h> {
FoundWire::Parameter(param) => param.wire(),
FoundWire::Unsupported { .. } => {
// Input port with an unsupported type.
let Some((neigh, neigh_port)) = targets.first().and_then(|(tgt, port)| {
self.builder.hugr().single_linked_output(*tgt, *port)
}) else {
let Some((neigh, neigh_port)) = self
.builder
.hugr()
.single_linked_output(*tgt_node, *tgt_port)
else {
// The input was disconnected. We just skip it.
// (This is the case for unused other-ports)
continue;
Expand All @@ -143,11 +142,9 @@ impl<'h> PytketDecoderContext<'h> {
}
};

for (tgt, port) in targets {
self.builder
.hugr_mut()
.connect(wire.node(), wire.source(), *tgt, *port);
}
self.builder
.hugr_mut()
.connect(wire.node(), wire.source(), *tgt_node, *tgt_port);
}

Ok(())
Expand All @@ -161,20 +158,24 @@ impl<'h> PytketDecoderContext<'h> {
/// Helper for [`Self::insert_external_subgraph`].
fn rewire_external_subgraph_outputs(
&mut self,
subgraph: &SiblingSubgraph<Node>,
subgraph: &OpaqueSubgraph<Node>,
qubits: &[TrackedQubit],
bits: &[TrackedBit],
old_parent: Node,
new_parent: Node,
signature: &Signature,
) -> Result<(), PytketDecodeError> {
let old_output = self.builder.hugr().get_io(old_parent).unwrap()[1];
let new_output = self.builder.hugr().get_io(new_parent).unwrap()[1];

let mut output_qubits = qubits;
let mut output_bits = bits;

for (ty, (src, src_port)) in signature.output().iter().zip_eq(subgraph.outgoing_ports()) {
for (ty, (src, src_port)) in subgraph
.signature()
.output()
.iter()
.zip_eq(subgraph.outgoing_ports())
{
// Output wire from the subgraph. Depending on the type, we may need
// to track new qubits and bits, re-connect it to some output, or
// leave it untouched.
Expand All @@ -191,7 +192,7 @@ impl<'h> PytketDecoderContext<'h> {
if wire_qubits.is_none() || wire_bits.is_none() {
return Err(make_unexpected_node_out_error(
self.config(),
signature.output().iter(),
subgraph.signature().output().iter(),
qubits.len(),
bits.len(),
));
Expand Down
86 changes: 32 additions & 54 deletions tket/src/serialize/pytket/encoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ mod unsupported_tracker;
mod value_tracker;

use hugr::core::HugrNode;
use hugr::hugr::views::sibling_subgraph::{IncomingPorts, OutgoingPorts};
use hugr::hugr::views::SiblingSubgraph;
use hugr_core::hugr::internal::PortgraphNodeMap;
use tket_json_rs::clexpr::InputClRegister;
use tket_json_rs::opbox::BoxID;
Expand Down Expand Up @@ -34,7 +32,9 @@ use crate::circuit::Circuit;
use crate::serialize::pytket::circuit::EncodedCircuitInfo;
use crate::serialize::pytket::config::PytketEncoderConfig;
use crate::serialize::pytket::extension::RegisterCount;
use crate::serialize::pytket::opaque::{OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR};
use crate::serialize::pytket::opaque::{
OpaqueSubgraph, OpaqueSubgraphPayload, OPGROUP_OPAQUE_HUGR,
};

/// The state of an in-progress [`SerialCircuit`] being built from a [`Circuit`].
#[derive(derive_more::Debug)]
Expand Down Expand Up @@ -228,25 +228,25 @@ impl<H: HugrView> PytketEncoderContext<H> {
let mut extra_subgraph: Option<BTreeSet<H::Node>> = None;
while !self.unsupported.is_empty() {
let node = self.unsupported.iter().next().unwrap();
let opaque_subgraphs = self.unsupported.extract_component(node);
match self.emit_unsupported(opaque_subgraphs.clone(), circ) {
let opaque_subgraphs = self.unsupported.extract_component(node, circ.hugr())?;
match self.emit_unsupported(&opaque_subgraphs, circ) {
Ok(()) => (),
Err(PytketEncodeError::UnsupportedSubgraphHasNoRegisters {}) => {
// We'll store the nodes in the `extra_subgraph` field of the `EncodedCircuitInfo`.
// So the decoder can reconstruct the original subgraph.
extra_subgraph
.get_or_insert_default()
.extend(opaque_subgraphs);
.extend(opaque_subgraphs.nodes().iter().cloned());
}
Err(e) => return Err(e),
}
}
let extra_subgraph = extra_subgraph.map(|nodes| {
let subgraph =
SiblingSubgraph::try_from_nodes(nodes.into_iter().collect_vec(), circ.hugr())
.expect("Failed to create subgraph from unsupported nodes");
self.opaque_subgraphs.register_opaque_subgraph(subgraph)
});
let extra_subgraph = extra_subgraph
.map(|nodes| -> Result<_, PytketEncodeError<H::Node>> {
let subgraph = OpaqueSubgraph::try_from_nodes(nodes, circ.hugr())?;
Ok(self.opaque_subgraphs.register_opaque_subgraph(subgraph))
})
.transpose()?;

let tracker_result = self.values.finish(circ, region)?;

Expand Down Expand Up @@ -306,8 +306,10 @@ impl<H: HugrView> PytketEncoderContext<H> {
//
// We need to emit the unsupported node here before returning the values.
if self.unsupported.is_unsupported(wire.node()) {
let unsupported_nodes = self.unsupported.extract_component(wire.node());
self.emit_unsupported(unsupported_nodes, circ)?;
let unsupported_nodes = self
.unsupported
.extract_component(wire.node(), circ.hugr())?;
self.emit_unsupported(&unsupported_nodes, circ)?;
debug_assert!(!self.unsupported.is_unsupported(wire.node()));
return self.get_wire_values(wire, circ);
}
Expand Down Expand Up @@ -340,7 +342,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
node: H::Node,
circ: &Circuit<H>,
) -> Result<TrackedValues, PytketEncodeError<H::Node>> {
self.get_input_values_internal(node, circ, |_| true)
self.get_input_values_internal(node, circ, |_| true)?
.try_into_tracked_values()
}

Expand All @@ -354,7 +356,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
node: H::Node,
circ: &Circuit<H>,
wire_filter: impl Fn(Wire<H::Node>) -> bool,
) -> NodeInputValues<H::Node> {
) -> Result<NodeInputValues<H::Node>, PytketEncodeError<H::Node>> {
let mut tracked_values = TrackedValues::default();
let mut unknown_values = Vec::new();

Expand All @@ -380,15 +382,13 @@ impl<H: HugrView> PytketEncoderContext<H> {
Err(PytketEncodeError::OpEncoding(PytketEncodeOpError::WireHasNoValues {
wire,
})) => unknown_values.push(wire),
Err(e) => panic!(
"get_wire_values should only return WireHasNoValues errors, but got: {e}"
),
Err(e) => return Err(e),
}
}
NodeInputValues {
Ok(NodeInputValues {
tracked_values,
unknown_values,
}
})
}

/// Helper to emit a new tket1 command corresponding to a single HUGR node.
Expand Down Expand Up @@ -575,49 +575,28 @@ impl<H: HugrView> PytketEncoderContext<H> {
///
/// ## Arguments
///
/// - `unsupported_nodes`: The list of nodes to encode as an opaque subgraph.
/// - `subgraph`: The subgraph of unsupported nodes to encode as an opaque subgraph.
/// - `circ`: The circuit containing the unsupported nodes.
fn emit_unsupported(
&mut self,
unsupported_nodes: BTreeSet<H::Node>,
subgraph: &OpaqueSubgraph<H::Node>,
circ: &Circuit<H>,
) -> Result<(), PytketEncodeError<H::Node>> {
let subcircuit_id = format!("tk{}", unsupported_nodes.iter().min().unwrap());

// TODO: Use a cached topo checker here instead of traversing the full graph each time we create a `SiblingSubgraph`.
//
// TopoConvexChecker likes to borrow the hugr, so it'd be too invasive to store in the `Context`.
let subgraph = SiblingSubgraph::try_from_nodes(
unsupported_nodes.iter().cloned().collect_vec(),
circ.hugr(),
)
.unwrap_or_else(|e| {
panic!(
"Failed to create subgraph from unsupported nodes [{}]: {e}",
unsupported_nodes.iter().join(", ")
)
});
let subgraph_incoming_ports: IncomingPorts<H::Node> = subgraph.incoming_ports().clone();
let subgraph_outgoing_ports: OutgoingPorts<H::Node> = subgraph.outgoing_ports().clone();
let subgraph_signature = subgraph.signature(circ.hugr());

// Encode a payload referencing the subgraph in the Hugr.
let subgraph_id = self.opaque_subgraphs.register_opaque_subgraph(subgraph);
let subgraph_id = self
.opaque_subgraphs
.register_opaque_subgraph(subgraph.clone());
let payload = OpaqueSubgraphPayload::new_external(subgraph_id);

// Collects the input values for the subgraph.
//
// The [`UnsupportedTracker`] ensures that at this point all local input wires must come from
// already-encoded nodes, and not from other unsupported nodes not in `unsupported_nodes`.
let mut op_values = TrackedValues::default();
for incoming in subgraph_incoming_ports.iter() {
let Some((first_node, first_port)) = incoming.first() else {
continue;
};

for (node, port) in subgraph.incoming_ports().iter() {
let (neigh, neigh_out) = circ
.hugr()
.single_linked_output(*first_node, *first_port)
.single_linked_output(*node, *port)
.expect("Dataflow input port should have a single neighbour");
let wire = Wire::new(neigh, neigh_out);

Expand All @@ -640,9 +619,10 @@ impl<H: HugrView> PytketEncoderContext<H> {
// Output parameters are mapped to a fresh variable, that can be tracked
// back to the encoded subcircuit's function name.
let mut out_param_count = 0;
for ((out_node, out_port), ty) in subgraph_outgoing_ports
for ((out_node, out_port), ty) in subgraph
.outgoing_ports()
.iter()
.zip(subgraph_signature.output().iter())
.zip(subgraph.signature().output().iter())
{
if self.config().type_to_pytket(ty).is_none() {
// Do not try to register ports with unsupported types.
Expand All @@ -658,9 +638,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
EmitCommandOptions::new().output_params(|p| {
let range = out_param_count..out_param_count + p.expected_count;
out_param_count += p.expected_count;
range
.map(|i| format!("{subcircuit_id}_out{i}"))
.collect_vec()
range.map(|i| format!("{subgraph_id}_out{i}")).collect_vec()
}),
)?;
op_values.append(new_outputs);
Expand Down
10 changes: 8 additions & 2 deletions tket/src/serialize/pytket/encoder/unsupported_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@ use hugr::core::HugrNode;
use hugr::HugrView;
use petgraph::unionfind::UnionFind;

use crate::serialize::pytket::opaque::OpaqueSubgraph;
use crate::serialize::pytket::PytketEncodeError;
use crate::Circuit;

/// A structure for tracking nodes in the hugr that cannot be encoded as TKET1
Expand Down Expand Up @@ -75,7 +77,11 @@ impl<N: HugrNode> UnsupportedTracker<N> {
/// Once a component has been extracted, no new nodes can be added to it and
/// calling [`UnsupportedTracker::record_node`] will use a new component
/// instead.
pub fn extract_component(&mut self, node: N) -> BTreeSet<N> {
pub fn extract_component(
&mut self,
node: N,
hugr: &impl HugrView<Node = N>,
) -> Result<OpaqueSubgraph<N>, PytketEncodeError<N>> {
let node_data = self.nodes.remove(&node).unwrap();
let component = node_data.component;
let representative = self.components.find_mut(component);
Expand All @@ -95,7 +101,7 @@ impl<N: HugrNode> UnsupportedTracker<N> {
self.nodes.remove(n);
}

nodes
OpaqueSubgraph::try_from_nodes(nodes, hugr)
}

/// Returns an iterator over the unextracted nodes in the tracker.
Expand Down
Loading