diff --git a/Cargo.lock b/Cargo.lock index 97ca32bae717..78421193ca79 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -442,6 +442,7 @@ dependencies = [ "rand_pcg", "rayon", "rustworkx-core", + "smallvec", ] [[package]] diff --git a/crates/accelerate/Cargo.toml b/crates/accelerate/Cargo.toml index a0b59e6460de..a5db75b52102 100644 --- a/crates/accelerate/Cargo.toml +++ b/crates/accelerate/Cargo.toml @@ -25,6 +25,10 @@ num-complex = "0.4" num-bigint = "0.4" rustworkx-core = "0.13" +[dependencies.smallvec] +version = "1.11" +features = ["union"] + [dependencies.pyo3] workspace = true features = ["hashbrown", "indexmap", "num-complex", "num-bigint"] diff --git a/crates/accelerate/src/sabre_swap/neighbor_table.rs b/crates/accelerate/src/sabre_swap/neighbor_table.rs index 6cb44536dc0e..577466b514db 100644 --- a/crates/accelerate/src/sabre_swap/neighbor_table.rs +++ b/crates/accelerate/src/sabre_swap/neighbor_table.rs @@ -14,8 +14,10 @@ use crate::getenv_use_multiple_threads; use ndarray::prelude::*; use numpy::PyReadonlyArray2; use pyo3::prelude::*; +use pyo3::types::PyList; use rayon::prelude::*; use rustworkx_core::petgraph::prelude::*; +use smallvec::SmallVec; use crate::nlayout::PhysicalQubit; @@ -32,7 +34,11 @@ use crate::nlayout::PhysicalQubit; #[pyclass(module = "qiskit._accelerate.sabre_swap")] #[derive(Clone, Debug)] pub struct NeighborTable { - neighbors: Vec>, + // The choice of 4 `PhysicalQubit`s in the stack-allocated region is because a) this causes the + // `SmallVec` to be the same width as a `Vec` on 64-bit systems (three machine words == 24 + // bytes); b) the majority of coupling maps we're likely to encounter have a degree of 3 (heavy + // hex) or 4 (grid / heavy square). + neighbors: Vec>, } impl NeighborTable { @@ -63,21 +69,22 @@ impl NeighborTable { let neighbors = match adjacency_matrix { Some(adjacency_matrix) => { let adj_mat = adjacency_matrix.as_array(); - let build_neighbors = |row: ArrayView1| -> PyResult> { - row.iter() - .enumerate() - .filter_map(|(row_index, value)| { - if *value == 0. { - None - } else { - Some(match row_index.try_into() { - Ok(index) => Ok(PhysicalQubit::new(index)), - Err(err) => Err(err.into()), - }) - } - }) - .collect() - }; + let build_neighbors = + |row: ArrayView1| -> PyResult> { + row.iter() + .enumerate() + .filter_map(|(row_index, value)| { + if *value == 0. { + None + } else { + Some(match row_index.try_into() { + Ok(index) => Ok(PhysicalQubit::new(index)), + Err(err) => Err(err.into()), + }) + } + }) + .collect() + }; if run_in_parallel { adj_mat .axis_iter(Axis(0)) @@ -96,11 +103,26 @@ impl NeighborTable { Ok(NeighborTable { neighbors }) } - fn __getstate__(&self) -> Vec> { - self.neighbors.clone() + fn __getstate__(&self, py: Python<'_>) -> Py { + PyList::new( + py, + self.neighbors + .iter() + .map(|v| PyList::new(py, v.iter()).to_object(py)), + ) + .into() } - fn __setstate__(&mut self, state: Vec>) { + fn __setstate__(&mut self, state: &PyList) -> PyResult<()> { self.neighbors = state + .iter() + .map(|v| { + v.downcast::()? + .iter() + .map(PyAny::extract) + .collect::>() + }) + .collect::>()?; + Ok(()) } }