diff --git a/cirq-core/cirq/circuits/circuit.py b/cirq-core/cirq/circuits/circuit.py index fd05394b437..2eda2f8eeb4 100644 --- a/cirq-core/cirq/circuits/circuit.py +++ b/cirq-core/cirq/circuits/circuit.py @@ -143,6 +143,30 @@ class AbstractCircuit(abc.ABC): * get_independent_qubit_sets """ + @classmethod + def from_moments(cls: Type[CIRCUIT_TYPE], *moments: 'cirq.OP_TREE') -> CIRCUIT_TYPE: + """Create a circuit from moment op trees. + + Args: + *moments: Op tree for each moment. + """ + return cls._from_moments( + moment if isinstance(moment, Moment) else Moment(moment) for moment in moments + ) + + @classmethod + @abc.abstractmethod + def _from_moments(cls: Type[CIRCUIT_TYPE], moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE: + """Create a circuit from moments. + + This must be implemented by subclasses. It provides a more efficient way + to construct a circuit instance since we already have the moments and so + can skip the analysis required to implement various insert strategies. + + Args: + moments: Moments of the circuit. + """ + @property @abc.abstractmethod def moments(self) -> Sequence['cirq.Moment']: @@ -225,8 +249,7 @@ def __getitem__(self: CIRCUIT_TYPE, key: Tuple[slice, Iterable['cirq.Qid']]) -> def __getitem__(self, key): if isinstance(key, slice): - sliced_moments = self.moments[key] - return self._with_sliced_moments(sliced_moments) + return self._from_moments(self.moments[key]) if hasattr(key, '__index__'): return self.moments[key] if isinstance(key, tuple): @@ -239,17 +262,12 @@ def __getitem__(self, key): return selected_moments[qubit_idx] if isinstance(qubit_idx, ops.Qid): qubit_idx = [qubit_idx] - sliced_moments = [moment[qubit_idx] for moment in selected_moments] - return self._with_sliced_moments(sliced_moments) + return self._from_moments(moment[qubit_idx] for moment in selected_moments) raise TypeError('__getitem__ called with key not of type slice, int, or tuple.') # pylint: enable=function-redefined - @abc.abstractmethod - def _with_sliced_moments(self: CIRCUIT_TYPE, moments: Iterable['cirq.Moment']) -> CIRCUIT_TYPE: - """Helper method for constructing circuits from __getitem__.""" - def __str__(self) -> str: return self.to_text_diagram() @@ -909,7 +927,7 @@ def map_moment(moment: 'cirq.Moment') -> 'cirq.Circuit': """Apply func to expand each op into a circuit, then zip up the circuits.""" return Circuit.zip(*[Circuit(func(op)) for op in moment]) - return self._with_sliced_moments(m for moment in self for m in map_moment(moment)) + return self._from_moments(m for moment in self for m in map_moment(moment)) def qid_shape( self, qubit_order: 'cirq.QubitOrderOrList' = ops.QubitOrder.DEFAULT @@ -949,18 +967,16 @@ def _measurement_key_names_(self) -> FrozenSet[str]: return self.all_measurement_key_names() def _with_measurement_key_mapping_(self, key_map: Mapping[str, str]): - return self._with_sliced_moments( - [protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments] + return self._from_moments( + protocols.with_measurement_key_mapping(moment, key_map) for moment in self.moments ) def _with_key_path_(self, path: Tuple[str, ...]): - return self._with_sliced_moments( - [protocols.with_key_path(moment, path) for moment in self.moments] - ) + return self._from_moments(protocols.with_key_path(moment, path) for moment in self.moments) def _with_key_path_prefix_(self, prefix: Tuple[str, ...]): - return self._with_sliced_moments( - [protocols.with_key_path_prefix(moment, prefix) for moment in self.moments] + return self._from_moments( + protocols.with_key_path_prefix(moment, prefix) for moment in self.moments ) def _with_rescoped_keys_( @@ -971,7 +987,7 @@ def _with_rescoped_keys_( new_moment = protocols.with_rescoped_keys(moment, path, bindable_keys) moments.append(new_moment) bindable_keys |= protocols.measurement_key_objs(new_moment) - return self._with_sliced_moments(moments) + return self._from_moments(moments) def _qid_shape_(self) -> Tuple[int, ...]: return self.qid_shape() @@ -1552,9 +1568,7 @@ def factorize(self: CIRCUIT_TYPE) -> Iterable[CIRCUIT_TYPE]: # the qubits from one factor belong to a specific independent qubit set. # This makes it possible to create independent circuits based on these # moments. - return ( - self._with_sliced_moments([m[qubits] for m in self.moments]) for qubits in qubit_factors - ) + return (self._from_moments(m[qubits] for m in self.moments) for qubits in qubit_factors) def _control_keys_(self) -> FrozenSet['cirq.MeasurementKey']: controls = frozenset(k for op in self.all_operations() for k in protocols.control_keys(op)) @@ -1719,6 +1733,12 @@ def __init__( else: self.append(contents, strategy=strategy) + @classmethod + def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'Circuit': + new_circuit = Circuit() + new_circuit._moments[:] = moments + return new_circuit + def _load_contents_with_earliest_strategy(self, contents: 'cirq.OP_TREE'): """Optimized algorithm to load contents quickly. @@ -1813,11 +1833,6 @@ def copy(self) -> 'Circuit': copied_circuit._moments = self._moments[:] return copied_circuit - def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'Circuit': - new_circuit = Circuit() - new_circuit._moments = list(moments) - return new_circuit - # pylint: disable=function-redefined @overload def __setitem__(self, key: int, value: 'cirq.Moment'): diff --git a/cirq-core/cirq/circuits/circuit_test.py b/cirq-core/cirq/circuits/circuit_test.py index 1cb71e7d731..044544a5687 100644 --- a/cirq-core/cirq/circuits/circuit_test.py +++ b/cirq-core/cirq/circuits/circuit_test.py @@ -70,6 +70,23 @@ def validate_moment(self, moment): moment_and_op_type_validating_device = _MomentAndOpTypeValidatingDeviceType() +def test_from_moments(): + a, b, c, d = cirq.LineQubit.range(4) + assert cirq.Circuit.from_moments( + [cirq.X(a), cirq.Y(b)], + [cirq.X(c)], + [], + cirq.Z(d), + [cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')], + ) == cirq.Circuit( + cirq.Moment(cirq.X(a), cirq.Y(b)), + cirq.Moment(cirq.X(c)), + cirq.Moment(), + cirq.Moment(cirq.Z(d)), + cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')), + ) + + def test_alignment(): assert repr(cirq.Alignment.LEFT) == 'cirq.Alignment.LEFT' assert repr(cirq.Alignment.RIGHT) == 'cirq.Alignment.RIGHT' @@ -269,6 +286,16 @@ def test_append_control_key_subcircuit(): assert len(c) == 1 +def test_measurement_key_paths(): + a = cirq.LineQubit(0) + circuit1 = cirq.Circuit(cirq.measure(a, key='A')) + assert cirq.measurement_key_names(circuit1) == {'A'} + circuit2 = cirq.with_key_path(circuit1, ('B',)) + assert cirq.measurement_key_names(circuit2) == {'B:A'} + circuit3 = cirq.with_key_path_prefix(circuit2, ('C',)) + assert cirq.measurement_key_names(circuit3) == {'C:B:A'} + + def test_append_moments(): a = cirq.NamedQubit('a') b = cirq.NamedQubit('b') diff --git a/cirq-core/cirq/circuits/frozen_circuit.py b/cirq-core/cirq/circuits/frozen_circuit.py index 51f7dff3285..22c4b69964c 100644 --- a/cirq-core/cirq/circuits/frozen_circuit.py +++ b/cirq-core/cirq/circuits/frozen_circuit.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """An immutable version of the Circuit data structure.""" -from typing import TYPE_CHECKING, FrozenSet, Iterable, Iterator, Sequence, Tuple, Union +from typing import FrozenSet, Iterable, Iterator, Sequence, Tuple, TYPE_CHECKING, Union import numpy as np @@ -51,6 +51,12 @@ def __init__( base = Circuit(contents, strategy=strategy) self._moments = tuple(base.moments) + @classmethod + def _from_moments(cls, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit': + new_circuit = FrozenCircuit() + new_circuit._moments = tuple(moments) + return new_circuit + @property def moments(self) -> Sequence['cirq.Moment']: return self._moments @@ -143,11 +149,6 @@ def __pow__(self, other) -> 'cirq.FrozenCircuit': except: return NotImplemented - def _with_sliced_moments(self, moments: Iterable['cirq.Moment']) -> 'FrozenCircuit': - new_circuit = FrozenCircuit() - new_circuit._moments = tuple(moments) - return new_circuit - def _resolve_parameters_( self, resolver: 'cirq.ParamResolver', recursive: bool ) -> 'cirq.FrozenCircuit': diff --git a/cirq-core/cirq/circuits/frozen_circuit_test.py b/cirq-core/cirq/circuits/frozen_circuit_test.py index 155e3420f65..8e89332af04 100644 --- a/cirq-core/cirq/circuits/frozen_circuit_test.py +++ b/cirq-core/cirq/circuits/frozen_circuit_test.py @@ -21,6 +21,23 @@ import cirq +def test_from_moments(): + a, b, c, d = cirq.LineQubit.range(4) + assert cirq.FrozenCircuit.from_moments( + [cirq.X(a), cirq.Y(b)], + [cirq.X(c)], + [], + cirq.Z(d), + [cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')], + ) == cirq.FrozenCircuit( + cirq.Moment(cirq.X(a), cirq.Y(b)), + cirq.Moment(cirq.X(c)), + cirq.Moment(), + cirq.Moment(cirq.Z(d)), + cirq.Moment(cirq.measure(a, b, key='ab'), cirq.measure(c, d, key='cd')), + ) + + def test_freeze_and_unfreeze(): a, b = cirq.LineQubit.range(2) c = cirq.Circuit(cirq.X(a), cirq.H(b))