Skip to content

Commit 5a458cf

Browse files
authored
fix(encoded-circ): Track unsupported wires between input and output (#1224)
Depends on #1211 These are not represented in the pytket circuit, and cannot be encoded into a SiblingSubgraph. We just track them as an additional item in the `EncodedCircuit`'s `EncodedCircuitInfo`. Note that the info is only store for circuits we encode directly, and not for nested regions inside circuit boxes. This means that the info is lost when encoding those. I added a commented-out test with a TODO to fix that (we'll need some extra plumbing to match circ boxes to external metadata). drive-by: Make sure the encoder's WireTracker stores the input parameter names, and passes it along. I'm ignoring breaking changes since they only affect unpublished code.
1 parent ad16531 commit 5a458cf

File tree

7 files changed

+210
-46
lines changed

7 files changed

+210
-46
lines changed

tket/src/serialize/pytket.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ impl TKETDecode for SerialCircuit {
133133
options: DecodeOptions,
134134
) -> Result<Node, Self::DecodeError> {
135135
let mut decoder = PytketDecoderContext::new(self, hugr, target, options, None)?;
136-
decoder.run_decoder(&self.commands, None)?;
137-
Ok(decoder.finish()?.node())
136+
decoder.run_decoder(&self.commands, None, &[])?;
137+
Ok(decoder.finish(&[])?.node())
138138
}
139139

140140
fn encode(circuit: &Circuit, options: EncodeOptions) -> Result<Self, Self::EncodeError> {

tket/src/serialize/pytket/circuit.rs

Lines changed: 27 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use std::collections::{HashMap, VecDeque};
44
use std::ops::{Index, IndexMut};
55
use std::sync::Arc;
66

7-
use hugr::core::HugrNode;
7+
use hugr::core::{HugrNode, IncomingPort, OutgoingPort};
88
use hugr::hugr::hugrmut::HugrMut;
99
use hugr::ops::handle::NodeHandle;
1010
use hugr::ops::{OpParent, OpTag, OpTrait};
@@ -59,6 +59,11 @@ pub(super) struct EncodedCircuitInfo {
5959
/// as a pytket command, and has no qubit/bits in its boundary that could be
6060
/// used to emit an opaque barrier command in the [`serial_circuit`].
6161
pub extra_subgraph: Option<SubgraphId>,
62+
/// List of wires that directly connected the input node to the output node in the encoded region,
63+
/// and were not encoded in [`serial_circuit`].
64+
///
65+
/// We just store the input nodes's output port and output node's input port here.
66+
pub straight_through_wires: Vec<StraightThroughWire>,
6267
/// List of parameters in the pytket circuit in the order they appear in the
6368
/// hugr input.
6469
///
@@ -67,10 +72,22 @@ pub(super) struct EncodedCircuitInfo {
6772
pub input_params: Vec<String>,
6873
/// List of output parameter expressions found at the end of the encoded region.
6974
//
70-
// TODO: The decoder does not currently connect these.
75+
// TODO: The decoder does not currently connect these, everything that
76+
// _produces_ a parameter gets included in unsupported subgraphs instead.
7177
pub output_params: Vec<String>,
7278
}
7379

80+
/// A wire stored in the [`EncodedCircuitInfo`] that directly connected the
81+
/// input node to the output node in the encoded region, and was not encoded in
82+
/// the pytket circuit.
83+
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
84+
pub(super) struct StraightThroughWire {
85+
/// Source port of the wire in the input node.
86+
pub input_source: OutgoingPort,
87+
/// Target port of the wire in the output node.
88+
pub output_target: IncomingPort,
89+
}
90+
7491
impl EncodedCircuit<Node> {
7592
/// Encode a HugrView into a [`EncodedCircuit`].
7693
///
@@ -162,8 +179,12 @@ impl EncodedCircuit<Node> {
162179
options,
163180
Some(&self.opaque_subgraphs),
164181
)?;
165-
decoder.run_decoder(&encoded.serial_circuit.commands, encoded.extra_subgraph)?;
166-
let decoded_node = decoder.finish()?.node();
182+
decoder.run_decoder(
183+
&encoded.serial_circuit.commands,
184+
encoded.extra_subgraph,
185+
&encoded.straight_through_wires,
186+
)?;
187+
let decoded_node = decoder.finish(&encoded.output_params)?.node();
167188

168189
// Replace the region with the decoded function.
169190
//
@@ -312,8 +333,8 @@ impl<Node: HugrNode> EncodedCircuit<Node> {
312333

313334
let mut decoder =
314335
PytketDecoderContext::new(serial_circuit, &mut hugr, target, options, None)?;
315-
decoder.run_decoder(&serial_circuit.commands, None)?;
316-
decoder.finish()?;
336+
decoder.run_decoder(&serial_circuit.commands, None, &[])?;
337+
decoder.finish(&[])?;
317338
Ok(hugr)
318339
}
319340

tket/src/serialize/pytket/decoder.rs

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ use super::{
3333
METADATA_Q_REGISTERS,
3434
};
3535
use crate::extension::rotation::rotation_type;
36+
use crate::serialize::pytket::circuit::StraightThroughWire;
3637
use crate::serialize::pytket::config::PytketDecoderConfig;
3738
use crate::serialize::pytket::decoder::wires::WireTracker;
3839
use crate::serialize::pytket::extension::{build_opaque_tket_op, RegisterCount};
@@ -293,7 +294,12 @@ impl<'h> PytketDecoderContext<'h> {
293294
///
294295
/// The original Hugr entrypoint is _not_ modified, it must be set by the
295296
/// caller if required.
296-
pub(super) fn finish(mut self) -> Result<Node, PytketDecodeError> {
297+
///
298+
/// # Arguments
299+
///
300+
/// - `output_params`: A list of output parameter expressions to associate
301+
/// with the region's outputs.
302+
pub(super) fn finish(mut self, output_params: &[String]) -> Result<Node, PytketDecodeError> {
297303
// Order the final wires according to the serial circuit register order.
298304
let known_qubits = self
299305
.wire_tracker
@@ -304,27 +310,34 @@ impl<'h> PytketDecoderContext<'h> {
304310
let mut qubits = known_qubits.as_slice();
305311
let mut bits = known_bits.as_slice();
306312

313+
// Load the output parameter expressions.
314+
let output_params = output_params
315+
.iter()
316+
.map(|p| self.load_half_turns(p))
317+
.collect_vec();
318+
let mut params: &[LoadedParameter] = &output_params;
319+
307320
let function_type = self
308321
.builder
309322
.hugr()
310323
.get_optype(self.builder.container_node())
311324
.inner_function_type()
312325
.unwrap();
313326
let expected_output_types = function_type.output_types().iter().cloned().collect_vec();
314-
let output_node = self.builder.output().node();
327+
let [_, output_node] = self.builder.io();
315328

316329
for (ty, port) in expected_output_types
317330
.iter()
318331
.zip(self.builder.hugr().node_inputs(output_node).collect_vec())
319332
{
320333
// If the region's output is already connected, leave it alone.
321-
// (It's a wire from an unsupported operation)
334+
// (It's a wire from an unsupported operation, or was a connected
335+
// straight through wire)
322336
if self.builder.hugr().is_linked(output_node, port) {
323337
continue;
324338
}
325339

326340
// Otherwise, get the tracked wire.
327-
let mut params: &[LoadedParameter] = &[];
328341
let found_wire = self
329342
.wire_tracker
330343
.find_typed_wire(
@@ -428,21 +441,40 @@ impl<'h> PytketDecoderContext<'h> {
428441
/// - `commands`: The list of pytket commands to decode.
429442
/// - `extra_subgraph`: An additional subgraph of the original Hugr that was
430443
/// not encoded as a pytket command, and must be decoded independently.
444+
/// - `straight_through_wires`: A list of wires that directly connected the
445+
/// input node to the output node in the original region, and were not
446+
/// encoded in the pytket circuit or unsupported graphs.
447+
/// (They cannot be encoded in `extra_subgraph`).
431448
pub(super) fn run_decoder(
432449
&mut self,
433450
commands: &[circuit_json::Command],
434451
extra_subgraph: Option<SubgraphId>,
452+
straight_through_wires: &[StraightThroughWire],
435453
) -> Result<(), PytketDecodeError> {
436454
let config = self.config().clone();
437455
for com in commands {
438456
let op_type = com.op.op_type;
439457
self.process_command(com, config.as_ref())
440458
.map_err(|e| e.pytket_op(&op_type))?;
441459
}
460+
461+
// Add additional subgraphs if not encoded in commands.
442462
if let Some(subgraph_id) = extra_subgraph {
443463
self.insert_external_subgraph(subgraph_id, &[], &[], &[])
444464
.map_err(|e| e.hugr_op("External subgraph"))?;
445465
}
466+
467+
// Add wires from the input node to the output node that didn't get encoded in commands.
468+
let [input_node, output_node] = self.builder.io();
469+
for StraightThroughWire {
470+
input_source,
471+
output_target,
472+
} in straight_through_wires
473+
{
474+
self.builder
475+
.hugr_mut()
476+
.connect(input_node, *input_source, output_node, *output_target);
477+
}
446478
Ok(())
447479
}
448480

tket/src/serialize/pytket/encoder.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,9 +225,6 @@ impl<H: HugrView> PytketEncoderContext<H> {
225225
region: H::Node,
226226
) -> Result<(EncodedCircuitInfo, OpaqueSubgraphs<H::Node>), PytketEncodeError<H::Node>> {
227227
// Add any remaining unsupported nodes
228-
//
229-
// TODO: Test that opaque subgraphs that don't affect any qubit/bit registers
230-
// are correctly encoded in pytket commands.
231228
let mut extra_subgraph: Option<BTreeSet<H::Node>> = None;
232229
while !self.unsupported.is_empty() {
233230
let node = self.unsupported.iter().next().unwrap();
@@ -251,21 +248,22 @@ impl<H: HugrView> PytketEncoderContext<H> {
251248
self.opaque_subgraphs.register_opaque_subgraph(subgraph)
252249
});
253250

254-
let final_values = self.values.finish(circ, region)?;
251+
let tracker_result = self.values.finish(circ, region)?;
255252

256253
let mut ser = SerialCircuit::new(self.name, self.phase);
257254

258255
ser.commands = self.commands;
259-
ser.qubits = final_values.qubits.into_iter().map_into().collect();
260-
ser.bits = final_values.bits.into_iter().map_into().collect();
261-
ser.implicit_permutation = final_values.qubit_permutation;
256+
ser.qubits = tracker_result.qubits.into_iter().map_into().collect();
257+
ser.bits = tracker_result.bits.into_iter().map_into().collect();
258+
ser.implicit_permutation = tracker_result.qubit_permutation;
262259
ser.number_of_ws = None;
263260

264261
let info = EncodedCircuitInfo {
265262
serial_circuit: ser,
266-
input_params: final_values.params,
267-
output_params: vec![],
263+
input_params: tracker_result.input_params,
264+
output_params: tracker_result.params,
268265
extra_subgraph,
266+
straight_through_wires: tracker_result.straight_through_wires,
269267
};
270268

271269
Ok((info, self.opaque_subgraphs))
@@ -601,6 +599,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
601599
});
602600
let subgraph_incoming_ports: IncomingPorts<H::Node> = subgraph.incoming_ports().clone();
603601
let subgraph_outgoing_ports: OutgoingPorts<H::Node> = subgraph.outgoing_ports().clone();
602+
let subgraph_signature = subgraph.signature(circ.hugr());
604603

605604
// Encode a payload referencing the subgraph in the Hugr.
606605
let subgraph_id = self.opaque_subgraphs.register_opaque_subgraph(subgraph);
@@ -641,7 +640,14 @@ impl<H: HugrView> PytketEncoderContext<H> {
641640
// Output parameters are mapped to a fresh variable, that can be tracked
642641
// back to the encoded subcircuit's function name.
643642
let mut out_param_count = 0;
644-
for (out_node, out_port) in &subgraph_outgoing_ports {
643+
for ((out_node, out_port), ty) in subgraph_outgoing_ports
644+
.iter()
645+
.zip(subgraph_signature.output().iter())
646+
{
647+
if self.config().type_to_pytket(ty).is_none() {
648+
// Do not try to register ports with unsupported types.
649+
continue;
650+
}
645651
let new_outputs = self.register_port_output(
646652
*out_node,
647653
*out_port,
@@ -734,6 +740,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
734740
///
735741
// TODO: Support output parameters in subcircuits. This may require
736742
// substituting variables in the parameter expressions.
743+
#[expect(unused)]
737744
fn emit_subcircuit(
738745
&mut self,
739746
node: H::Node,
@@ -767,6 +774,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
767774
///
768775
// TODO: Support output parameters in subcircuits. This may require
769776
// substituting variables in the parameter expressions.
777+
#[expect(unused)]
770778
fn emit_function_call(
771779
&mut self,
772780
node: H::Node,
@@ -902,6 +910,15 @@ impl<H: HugrView> PytketEncoderContext<H> {
902910
return Ok(EncodeStatus::Success);
903911
}
904912
}
913+
// TODO: DFG and function call emissions are temporarily disabled,
914+
// since we cannot track additional metadata associated with the
915+
// nested circuit in a `CircuitBox` as we'd do for the root one in
916+
// [`EncodedCircuitInfo`].
917+
//
918+
// See the `unsupported_extras_in_circ_box` case in
919+
// `tests::encoded_circuit_roundtrip` for a failing case when this
920+
// is enabled.
921+
/*
905922
OpType::DFG(_) => return self.emit_subcircuit(node, circ),
906923
OpType::Call(call) => {
907924
let (fn_node, _) = circ
@@ -914,6 +931,7 @@ impl<H: HugrView> PytketEncoderContext<H> {
914931
return Ok(EncodeStatus::Success);
915932
}
916933
}
934+
*/
917935
OpType::Input(_) | OpType::Output(_) => {
918936
// I/O nodes are handled by the container's encoder.
919937
return Ok(EncodeStatus::Success);

0 commit comments

Comments
 (0)