Skip to content

Commit fce8084

Browse files
authored
feat: Add bindings for tket1-passes from python (#1225)
Adds some pass definitions to `tket.passes`: - `normalize_guppy(Tk2Circuit, ...) -> Tk2Circuit` - `clifford_simp(Tk2Circuit, *, allow_swaps = True, traverse_subcircuits = True) -> Tk2Circuit` - `squash_phased_rz(Tk2Circuit, *, traverse_subcircuits = True) -> Tk2Circuit` The last two use the new bridge to the old tket C++ codebase to run optimisation on regions of the hugr. The API here not too ergonomic, we just expose it as a MVP while working on the API refresh. To call this from a guppy program output, we'll need to convert the hugr to the rust representation and back, ```py from guppylang import guppy from guppylang.std.angles import angle from guppylang.std.builtins import result from guppylang.std.quantum import h, rx, rz, measure, qubit from hugr import Hugr from tket.circuit import Tk2Circuit from tket.passes import normalize_guppy, squash_phasedx_rz @guppy def main() -> None: q = qubit() rz(q, angle(0.1)) h(q) rx(q, angle(-0.1)) h(q) b = measure(q) result("b", b) program = main.compile() compiler_state = Tk2Circuit.from_bytes(program.to_bytes()) compiler_state = normalize_guppy(compiler_state) compiler_state = squash_phasedx_rz(compiler_state) hugr = Hugr.from_str(compiler_state.to_str()) print(hugr.render_dot()) ``` drive-by: Add `Tk2Circuit.render_mermaid` method
1 parent 5a458cf commit fce8084

File tree

11 files changed

+289
-12
lines changed

11 files changed

+289
-12
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tket-py/Cargo.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,17 @@ tket = { path = "../tket", version = "0.16.0", features = [
2727
tket-qsystem = { path = "../tket-qsystem", version = "0.22.0" }
2828
tket1-passes = { path = "../tket1-passes", version = "0.0.0" }
2929

30-
serde = { workspace = true, features = ["derive"] }
31-
serde_json = { workspace = true }
32-
tket-json-rs = { workspace = true, features = ["pyo3"] }
33-
hugr = { workspace = true }
34-
pyo3 = { workspace = true, features = ["py-clone", "abi3-py310"] }
35-
num_cpus = { workspace = true }
3630
derive_more = { workspace = true, features = ["into", "from"] }
31+
hugr = { workspace = true }
3732
itertools = { workspace = true }
33+
num_cpus = { workspace = true }
3834
portmatching = { workspace = true }
35+
pyo3 = { workspace = true, features = ["py-clone", "abi3-py310"] }
36+
rayon = { workspace = true }
37+
serde = { workspace = true, features = ["derive"] }
38+
serde_json = { workspace = true }
3939
strum = { workspace = true }
40+
tket-json-rs = { workspace = true, features = ["pyo3"] }
4041

4142
[dev-dependencies]
4243
rstest = { workspace = true }

tket-py/src/circuit.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
4747
py.get_type::<PyHUGRSerializationError>(),
4848
)?;
4949
m.add("TK1EncodeError", py.get_type::<PyTk1EncodeError>())?;
50+
m.add("TK1DecodeError", py.get_type::<PyTK1DecodeError>())?;
5051

5152
Ok(m)
5253
}
@@ -78,13 +79,13 @@ create_py_exception!(
7879
create_py_exception!(
7980
tket::serialize::pytket::PytketEncodeError,
8081
PyTk1EncodeError,
81-
"Error type for the conversion between tket and tket1 operations."
82+
"Error encoding a HUGR region into a pytket circuit."
8283
);
8384

8485
create_py_exception!(
8586
tket::serialize::pytket::PytketDecodeError,
8687
PyTK1DecodeError,
87-
"Error type for the conversion between tket1 and tket operations."
88+
"Error decoding a HUGR region from a pytket circuit."
8889
);
8990

9091
/// Run the validation checks on a circuit.

tket-py/src/circuit/tk2circuit.rs

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ use derive_more::From;
2525
use hugr::{Hugr, HugrView, Wire};
2626
use serde::Serialize;
2727
use tket::circuit::CircuitHash;
28+
use tket::modifier::qubit_types_utils::contain_qubit_term;
2829
use tket::passes::pytket::lower_to_pytket;
2930
use tket::passes::CircuitChunks;
3031
use tket::serialize::pytket::{DecodeOptions, EncodeOptions};
@@ -96,11 +97,17 @@ impl Tk2Circuit {
9697
}
9798

9899
/// Encode the circuit as a HUGR envelope.
99-
pub fn to_bytes(&self, config: Bound<'_, PyAny>) -> PyResult<Vec<u8>> {
100+
///
101+
/// If no config is given, it defaults to the default binary envelope.
102+
#[pyo3(signature = (config = None))]
103+
pub fn to_bytes(&self, config: Option<Bound<'_, PyAny>>) -> PyResult<Vec<u8>> {
100104
fn err(e: impl Display) -> PyErr {
101105
PyErr::new::<PyAttributeError, _>(format!("Could not encode circuit: {e}"))
102106
};
103-
let config = envelope_config_from_py(config)?;
107+
let config = match config {
108+
Some(cfg) => envelope_config_from_py(cfg)?,
109+
None => EnvelopeConfig::binary(),
110+
};
104111
let mut buf = Vec::new();
105112
self.circ.store(&mut buf, config).map_err(err)?;
106113
Ok(buf)
@@ -221,6 +228,7 @@ impl Tk2Circuit {
221228
pub fn circuit_cost<'py>(&self, cost_fn: &Bound<'py, PyAny>) -> PyResult<Bound<'py, PyAny>> {
222229
let py = cost_fn.py();
223230
let cost_fn = |op: &OpType| -> PyResult<PyCircuitCost> {
231+
// TODO: We should ignore non-tket operations instead.
224232
let Some(tk2_op) = op.cast::<TketOp>() else {
225233
let op_name = op.to_string();
226234
return Err(PyErr::new::<PyValueError, _>(format!(
@@ -303,6 +311,10 @@ impl Tk2Circuit {
303311
fn output_node(&self) -> PyNode {
304312
self.circ.output_node().into()
305313
}
314+
315+
fn render_mermaid(&self) -> String {
316+
self.circ.mermaid_string()
317+
}
306318
}
307319
impl Tk2Circuit {
308320
/// Tries to extract a Tk2Circuit from a python object.

tket-py/src/passes.rs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Passes for optimising circuits.
22
33
pub mod chunks;
4+
pub mod tket1;
45

56
use std::{cmp::min, convert::TryInto, fs, num::NonZeroUsize, path::PathBuf};
67

8+
use hugr::algorithms::ComposablePass;
79
use pyo3::{prelude::*, types::IntoPyDict};
810
use tket::optimiser::badger::BadgerOptions;
911
use tket::passes;
@@ -24,9 +26,13 @@ pub fn module(py: Python<'_>) -> PyResult<Bound<'_, PyModule>> {
2426
m.add_function(wrap_pyfunction!(greedy_depth_reduce, &m)?)?;
2527
m.add_function(wrap_pyfunction!(lower_to_pytket, &m)?)?;
2628
m.add_function(wrap_pyfunction!(badger_optimise, &m)?)?;
29+
m.add_function(wrap_pyfunction!(normalize_guppy, &m)?)?;
2730
m.add_class::<self::chunks::PyCircuitChunks>()?;
2831
m.add_function(wrap_pyfunction!(self::chunks::chunks, &m)?)?;
32+
m.add_function(wrap_pyfunction!(self::tket1::clifford_simp, &m)?)?;
33+
m.add_function(wrap_pyfunction!(self::tket1::squash_phasedx_rz, &m)?)?;
2934
m.add("PullForwardError", py.get_type::<PyPullForwardError>())?;
35+
m.add("TK1PassError", py.get_type::<tket1::PytketPassError>())?;
3036
Ok(m)
3137
}
3238

@@ -42,6 +48,50 @@ create_py_exception!(
4248
"Errors that can occur while removing high-level operations from HUGR intended to be encoded as a pytket circuit."
4349
);
4450

51+
create_py_exception!(
52+
tket::passes::guppy::NormalizeGuppyErrors,
53+
PyNormalizeGuppyError,
54+
"Errors from the Guppy normalization pass."
55+
);
56+
57+
/// Flatten the structure of a Guppy-generated program to enable additional optimisations.
58+
///
59+
/// This should normally be called first before other optimisations.
60+
///
61+
/// Parameters:
62+
/// - simplify_cfgs: Whether to simplify CFG control flow.
63+
/// - remove_tuple_untuple: Whether to remove tuple/untuple operations.
64+
/// - constant_folding: Whether to constant fold the program.
65+
/// - remove_dead_funcs: Whether to remove dead functions.
66+
/// - inline_dfgs: Whether to inline DFG operations.
67+
#[pyfunction]
68+
#[pyo3(signature = (circ, *, simplify_cfgs = true, remove_tuple_untuple = true, constant_folding = false, remove_dead_funcs = true, inline_dfgs = true))]
69+
fn normalize_guppy<'py>(
70+
circ: &Bound<'py, PyAny>,
71+
simplify_cfgs: bool,
72+
remove_tuple_untuple: bool,
73+
constant_folding: bool,
74+
remove_dead_funcs: bool,
75+
inline_dfgs: bool,
76+
) -> PyResult<Bound<'py, PyAny>> {
77+
let py = circ.py();
78+
try_with_circ(circ, |mut circ, typ| {
79+
let mut pass = tket::passes::NormalizeGuppy::default();
80+
81+
pass.simplify_cfgs(simplify_cfgs)
82+
.remove_tuple_untuple(remove_tuple_untuple)
83+
.constant_folding(constant_folding)
84+
.remove_dead_funcs(remove_dead_funcs)
85+
.inline_dfgs(inline_dfgs);
86+
87+
pass.run(circ.hugr_mut()).convert_pyerrs()?;
88+
89+
let circ = typ.convert(py, circ)?;
90+
PyResult::Ok(circ)
91+
})
92+
}
93+
94+
/// Pass which greedily commutes operations forwards in order to reduce depth.
4595
#[pyfunction]
4696
fn greedy_depth_reduce<'py>(circ: &Bound<'py, PyAny>) -> PyResult<(Bound<'py, PyAny>, u32)> {
4797
let py = circ.py();

tket-py/src/passes/tket1.rs

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
//! Passes that call to tket1-passes using the tket-c-api.
2+
3+
use rayon::iter::ParallelIterator;
4+
use std::sync::Arc;
5+
6+
use pyo3::prelude::*;
7+
use tket::serialize::pytket::{EncodeOptions, EncodedCircuit};
8+
use tket::Circuit;
9+
use tket_qsystem::pytket::{qsystem_decoder_config, qsystem_encoder_config};
10+
11+
use crate::circuit::try_with_circ;
12+
use crate::utils::{create_py_exception, ConvertPyErr};
13+
14+
/// An optimisation pass that applies a number of rewrite rules for simplifying
15+
/// Clifford gate sequences, similar to Duncan & Fagan
16+
/// (https://arxiv.org/abs/1901.10114). Produces a circuit comprising TK1 gates
17+
/// and the two-qubit gate specified as the target.
18+
///
19+
/// Parameters:
20+
/// - allow_swaps: whether the rewriting may introduce implicit wire swaps.
21+
/// - traverse_subcircuits: Whether to apply the optimisation to nested
22+
/// subregions in the hugr too, rather than just the top-level region.
23+
//
24+
// TODO: We should also expose `target_gate` here, but the most appropriate
25+
// parameter type [`crate::ops::PyTketOp`] doesn't include `TK2` -.-
26+
#[pyfunction]
27+
#[pyo3(signature = (circ, *, allow_swaps = true, traverse_subcircuits = true))]
28+
pub fn clifford_simp<'py>(
29+
circ: &Bound<'py, PyAny>,
30+
allow_swaps: bool,
31+
traverse_subcircuits: bool,
32+
) -> PyResult<Bound<'py, PyAny>> {
33+
let py = circ.py();
34+
35+
try_with_circ(circ, |circ, typ| {
36+
let circ = run_tket1_pass(circ, traverse_subcircuits, |tk1_circ| {
37+
tk1_circ.clifford_simp(tket_json_rs::OpType::CX, allow_swaps)
38+
})?;
39+
40+
let circ = typ.convert(py, circ)?;
41+
PyResult::Ok(circ)
42+
})
43+
}
44+
45+
/// Squash single qubit gates into PhasedX and Rz gates. Also remove identity
46+
/// gates. Commute Rz gates to the back if possible.
47+
///
48+
/// Parameters:
49+
/// - traverse_subcircuits: Whether to apply the optimisation to nested
50+
/// subregions in the hugr too, rather than just the top-level region.
51+
#[pyfunction]
52+
#[pyo3(signature = (circ, *, traverse_subcircuits = true))]
53+
pub fn squash_phasedx_rz<'py>(
54+
circ: &Bound<'py, PyAny>,
55+
traverse_subcircuits: bool,
56+
) -> PyResult<Bound<'py, PyAny>> {
57+
let py = circ.py();
58+
59+
try_with_circ(circ, |circ, typ| {
60+
let circ = run_tket1_pass(circ, traverse_subcircuits, |tk1_circ| {
61+
tk1_circ.squash_phasedx_rz()
62+
})?;
63+
64+
let circ = typ.convert(py, circ)?;
65+
PyResult::Ok(circ)
66+
})
67+
}
68+
69+
fn run_tket1_pass<F>(mut circ: Circuit, traverse_subcircuits: bool, pass: F) -> PyResult<Circuit>
70+
where
71+
F: Fn(&mut tket1_passes::Tket1Circuit) -> Result<(), tket1_passes::PassError> + Send + Sync,
72+
{
73+
let mut encoded_circ = EncodedCircuit::new(
74+
&circ,
75+
EncodeOptions::new()
76+
.with_config(qsystem_encoder_config())
77+
.with_subcircuits(traverse_subcircuits),
78+
)
79+
.convert_pyerrs()?;
80+
81+
encoded_circ
82+
.par_iter_mut()
83+
.try_for_each(|(_, circ)| -> Result<(), tket1_passes::PassError> {
84+
let mut tk1_circ = tket1_passes::Tket1Circuit::from_serial_circuit(circ)?;
85+
pass(&mut tk1_circ)?;
86+
*circ = tk1_circ.to_serial_circuit()?;
87+
Ok(())
88+
})
89+
.convert_pyerrs()?;
90+
91+
encoded_circ
92+
.reassemble_inplace(circ.hugr_mut(), Some(Arc::new(qsystem_decoder_config())))
93+
.convert_pyerrs()?;
94+
95+
Ok(circ)
96+
}
97+
98+
create_py_exception!(
99+
tket1_passes::PassError,
100+
PytketPassError,
101+
"Error from a call to tket-c-api"
102+
);

tket-py/test/test_pass.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
from pytket import Circuit, OpType
22
from dataclasses import dataclass
33
from typing import Callable, Any
4-
from tket.passes import badger_pass, greedy_depth_reduce, chunks
4+
from tket.ops import TketOp
5+
from tket.passes import (
6+
badger_pass,
7+
greedy_depth_reduce,
8+
chunks,
9+
clifford_simp,
10+
normalize_guppy,
11+
squash_phasedx_rz,
12+
)
513
from tket.circuit import Tk2Circuit
614
from tket.pattern import Rule, RuleMatcher
715
import hypothesis.strategies as st
@@ -138,3 +146,35 @@ def test_multiple_rules():
138146

139147
out = circ.to_tket1()
140148
assert out == Circuit(3).CX(0, 1).X(0)
149+
150+
151+
def test_clifford_simp():
152+
c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2))
153+
154+
c = clifford_simp(c, allow_swaps=False)
155+
156+
assert c.circuit_cost(lambda op: int(op == TketOp.CX)) == 1
157+
158+
159+
def test_squash_phasedx_rz():
160+
c = Tk2Circuit(Circuit(1).Rz(0.25, 0).Rz(0.75, 0).Rz(0.25, 0).Rz(-1.25, 0))
161+
162+
c = squash_phasedx_rz(c)
163+
164+
# TODO: We cannot use circuit_cost due to a panic on non-tket ops and there
165+
# being some parameter loads...
166+
assert c.num_operations() == 0
167+
168+
169+
def test_normalize_guppy():
170+
"""Test the normalize_guppy pass.
171+
172+
This won't actually do anything useful, we just want to check that the pass
173+
runs without errors.
174+
"""
175+
176+
c = Tk2Circuit(Circuit(4).CX(0, 2).CX(1, 2).CX(1, 2))
177+
178+
c = normalize_guppy(c)
179+
180+
assert c.circuit_cost(lambda op: int(op == TketOp.CX)) == 3

tket-py/tket/_tket/circuit.pyi

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ class Tk2Circuit:
7171
def from_hugr_json(json: str) -> Tk2Circuit:
7272
"""Decode a HUGR json string to a Tk2Circuit."""
7373

74-
def to_bytes(self, config: EnvelopeConfig) -> bytes:
74+
def to_bytes(self, config: EnvelopeConfig | None = None) -> bytes:
7575
"""Encode the circuit as a HUGR envelope, according to the given config.
7676
7777
Some envelope formats can be encoded into a string. See :meth:`to_str`.
@@ -150,6 +150,9 @@ class Tk2Circuit:
150150
def from_tket1_json(json: str) -> Tk2Circuit:
151151
"""Decode a pytket json string to a Tk2Circuit."""
152152

153+
def render_mermaid(self) -> str:
154+
"""Render the circuit as a Mermaid graph."""
155+
153156
class Node:
154157
"""Handle to node in HUGR."""
155158

0 commit comments

Comments
 (0)