-
Notifications
You must be signed in to change notification settings - Fork 2.9k
[Oxidize BasisTranslator] Add rust-native compose_transforms()
#13137
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
f9e97ea
3a3e734
0589a02
1a49bfd
fbb57b8
e8c957d
a7d2f17
8e96eb7
18591ce
676766d
1b16df6
a2cb862
c9082c7
3d1a6d4
956c425
9c6913b
f818ad2
1e94449
912390e
98cebda
dbdc6a9
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,192 @@ | ||
| // This code is part of Qiskit. | ||
| // | ||
| // (C) Copyright IBM 2024 | ||
| // | ||
| // This code is licensed under the Apache License, Version 2.0. You may | ||
| // obtain a copy of this license in the LICENSE.txt file in the root directory | ||
| // of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
| // | ||
| // Any modifications or derivative works of this code must retain this | ||
| // copyright notice, and modified files need to carry a notice indicating | ||
| // that they have been altered from the originals. | ||
|
|
||
| use hashbrown::{HashMap, HashSet}; | ||
| use pyo3::prelude::*; | ||
| use qiskit_circuit::circuit_instruction::OperationFromPython; | ||
| use qiskit_circuit::imports::{GATE, PARAMETER_VECTOR, QUANTUM_REGISTER}; | ||
| use qiskit_circuit::parameter_table::ParameterUuid; | ||
| use qiskit_circuit::Qubit; | ||
| use qiskit_circuit::{ | ||
| circuit_data::CircuitData, | ||
| dag_circuit::{DAGCircuit, NodeType}, | ||
| operations::{Operation, Param}, | ||
| }; | ||
| use smallvec::SmallVec; | ||
|
|
||
| use crate::equivalence::CircuitFromPython; | ||
|
|
||
| // Custom types | ||
| pub type GateIdentifier = (String, u32); | ||
| pub type BasisTransformIn = (SmallVec<[Param; 3]>, CircuitFromPython); | ||
| pub type BasisTransformOut = (SmallVec<[Param; 3]>, DAGCircuit); | ||
|
|
||
| #[pyfunction(name = "compose_transforms")] | ||
| pub(super) fn py_compose_transforms( | ||
| py: Python, | ||
| basis_transforms: Vec<(GateIdentifier, BasisTransformIn)>, | ||
| source_basis: HashSet<GateIdentifier>, | ||
| source_dag: &DAGCircuit, | ||
| ) -> PyResult<HashMap<GateIdentifier, BasisTransformOut>> { | ||
| compose_transforms(py, &basis_transforms, &source_basis, source_dag) | ||
| } | ||
|
|
||
| pub(super) fn compose_transforms<'a>( | ||
| py: Python, | ||
| basis_transforms: &'a [(GateIdentifier, BasisTransformIn)], | ||
| source_basis: &'a HashSet<GateIdentifier>, | ||
| source_dag: &'a DAGCircuit, | ||
| ) -> PyResult<HashMap<GateIdentifier, BasisTransformOut>> { | ||
| let mut gate_param_counts: HashMap<GateIdentifier, usize> = HashMap::default(); | ||
| get_gates_num_params(source_dag, &mut gate_param_counts)?; | ||
| let mut mapped_instructions: HashMap<GateIdentifier, BasisTransformOut> = HashMap::new(); | ||
|
|
||
| for (gate_name, gate_num_qubits) in source_basis.iter().cloned() { | ||
| let num_params = gate_param_counts[&(gate_name.clone(), gate_num_qubits)]; | ||
|
|
||
| let placeholder_params: SmallVec<[Param; 3]> = PARAMETER_VECTOR | ||
| .get_bound(py) | ||
| .call1((&gate_name, num_params))? | ||
| .extract()?; | ||
|
|
||
| let mut dag = DAGCircuit::new(py)?; | ||
| // Create the mock gate and add to the circuit, use Python for this. | ||
| let qubits = QUANTUM_REGISTER.get_bound(py).call1((gate_num_qubits,))?; | ||
| dag.add_qreg(py, &qubits)?; | ||
|
|
||
| let gate = GATE.get_bound(py).call1(( | ||
| &gate_name, | ||
| gate_num_qubits, | ||
| placeholder_params | ||
| .iter() | ||
| .map(|x| x.clone_ref(py)) | ||
| .collect::<SmallVec<[Param; 3]>>(), | ||
| ))?; | ||
| let gate_obj: OperationFromPython = gate.extract()?; | ||
| let qubits: Vec<Qubit> = (0..dag.num_qubits() as u32).map(Qubit).collect(); | ||
| dag.apply_operation_back( | ||
| py, | ||
| gate_obj.operation, | ||
| &qubits, | ||
| &[], | ||
| if gate_obj.params.is_empty() { | ||
| None | ||
| } else { | ||
| Some(gate_obj.params) | ||
| }, | ||
| gate_obj.extra_attrs, | ||
| #[cfg(feature = "cache_pygates")] | ||
| Some(gate.into()), | ||
| )?; | ||
| mapped_instructions.insert((gate_name, gate_num_qubits), (placeholder_params, dag)); | ||
|
|
||
| for ((gate_name, gate_num_qubits), (equiv_params, equiv)) in basis_transforms { | ||
| for (_, dag) in &mut mapped_instructions.values_mut() { | ||
| let nodes_to_replace = dag | ||
| .op_nodes(true) | ||
| .filter_map(|node| { | ||
| if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) { | ||
| if (gate_name.as_str(), *gate_num_qubits) | ||
| == (op.op.name(), op.op.num_qubits()) | ||
| { | ||
| Some(( | ||
| node, | ||
| op.params_view() | ||
| .iter() | ||
| .map(|x| x.clone_ref(py)) | ||
| .collect::<SmallVec<[Param; 3]>>(), | ||
| )) | ||
| } else { | ||
| None | ||
| } | ||
| } else { | ||
| None | ||
| } | ||
| }) | ||
| .collect::<Vec<_>>(); | ||
| for (node, params) in nodes_to_replace { | ||
| let param_mapping: HashMap<ParameterUuid, Param> = equiv_params | ||
| .iter() | ||
| .map(|x| ParameterUuid::from_parameter(x.to_object(py).bind(py))) | ||
| .zip(params) | ||
| .map(|(uuid, param)| -> PyResult<(ParameterUuid, Param)> { | ||
| Ok((uuid?, param.clone_ref(py))) | ||
| }) | ||
| .collect::<PyResult<_>>()?; | ||
| let mut replacement = equiv.clone(); | ||
| replacement | ||
| .0 | ||
| .assign_parameters_from_mapping(py, param_mapping)?; | ||
| let replace_dag: DAGCircuit = | ||
| DAGCircuit::from_circuit_data(py, replacement.0, true)?; | ||
| let op_node = dag.get_node(py, node)?; | ||
| dag.py_substitute_node_with_dag( | ||
| py, | ||
| op_node.bind(py), | ||
| &replace_dag, | ||
| None, | ||
| true, | ||
| )?; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Ok(mapped_instructions) | ||
| } | ||
|
|
||
| /// `DAGCircuit` variant. | ||
| /// | ||
| /// Gets the identifier of a gate instance (name, number of qubits) mapped to the | ||
| /// number of parameters it contains currently. | ||
| fn get_gates_num_params( | ||
| dag: &DAGCircuit, | ||
| example_gates: &mut HashMap<GateIdentifier, usize>, | ||
| ) -> PyResult<()> { | ||
| for node in dag.op_nodes(true) { | ||
| if let Some(NodeType::Operation(op)) = dag.dag().node_weight(node) { | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we now have
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Actually, that would come as part of https://github.com/Qiskit/qiskit/pull/13034/files#diff-9d10dd27b7540092c29eb23e0fe83a90b8661fb1f452178b60f0e8fb4fd54d60R85 |
||
| example_gates.insert( | ||
| (op.op.name().to_string(), op.op.num_qubits()), | ||
| op.params_view().len(), | ||
| ); | ||
| if op.op.control_flow() { | ||
| let blocks = op.op.blocks(); | ||
| for block in blocks { | ||
| get_gates_num_params_circuit(&block, example_gates)?; | ||
| } | ||
| } | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// `CircuitData` variant. | ||
| /// | ||
| /// Gets the identifier of a gate instance (name, number of qubits) mapped to the | ||
| /// number of parameters it contains currently. | ||
| fn get_gates_num_params_circuit( | ||
| circuit: &CircuitData, | ||
| example_gates: &mut HashMap<GateIdentifier, usize>, | ||
| ) -> PyResult<()> { | ||
| for inst in circuit.iter() { | ||
| example_gates.insert( | ||
| (inst.op.name().to_string(), inst.op.num_qubits()), | ||
| inst.params_view().len(), | ||
| ); | ||
| if inst.op.control_flow() { | ||
| let blocks = inst.op.blocks(); | ||
| for block in blocks { | ||
| get_gates_num_params_circuit(&block, example_gates)?; | ||
| } | ||
| } | ||
| } | ||
| Ok(()) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| // This code is part of Qiskit. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a clear benefit to having a separate folder here for I don't know that I mind it, but it's different from the folder structure we had in Python, and none of the other accelerate modules really go more than one module deep (except for
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is just to plan ahead, since the
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It suppose it doesn't really matter much, so if you think it helps to organize I'm good with it 🙂. |
||
| // | ||
| // (C) Copyright IBM 2024 | ||
| // | ||
| // This code is licensed under the Apache License, Version 2.0. You may | ||
| // obtain a copy of this license in the LICENSE.txt file in the root directory | ||
| // of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
| // | ||
| // Any modifications or derivative works of this code must retain this | ||
| // copyright notice, and modified files need to carry a notice indicating | ||
| // that they have been altered from the originals. | ||
|
|
||
| use pyo3::prelude::*; | ||
|
|
||
| mod compose_transforms; | ||
|
|
||
| #[pymodule] | ||
| pub fn basis_translator(m: &Bound<PyModule>) -> PyResult<()> { | ||
| m.add_wrapped(wrap_pyfunction!(compose_transforms::py_compose_transforms))?; | ||
| Ok(()) | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,21 @@ | ||
| // This code is part of Qiskit. | ||
| // | ||
| // (C) Copyright IBM 2024 | ||
| // | ||
| // This code is licensed under the Apache License, Version 2.0. You may | ||
| // obtain a copy of this license in the LICENSE.txt file in the root directory | ||
| // of this source tree or at http://www.apache.org/licenses/LICENSE-2.0. | ||
| // | ||
| // Any modifications or derivative works of this code must retain this | ||
| // copyright notice, and modified files need to carry a notice indicating | ||
| // that they have been altered from the originals. | ||
|
|
||
| use pyo3::{prelude::*, wrap_pymodule}; | ||
|
|
||
| pub mod basis_translator; | ||
|
|
||
| #[pymodule] | ||
| pub fn basis(m: &Bound<PyModule>) -> PyResult<()> { | ||
| m.add_wrapped(wrap_pymodule!(basis_translator::basis_translator))?; | ||
| Ok(()) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I imagine this is slower than it ought to be, but it'd be non-trivial to expose a native-Rust
substitute_node_with_dag. Maybe we can revisit doing so in a future PR.