diff --git a/qiskit/circuit/parametertable.py b/qiskit/circuit/parametertable.py index 3d127de8649f..f6cf022b0744 100644 --- a/qiskit/circuit/parametertable.py +++ b/qiskit/circuit/parametertable.py @@ -12,6 +12,7 @@ """ Look-up table for variable parameters in QuantumCircuit. """ +import operator from collections.abc import MappingView, MutableMapping, MutableSet @@ -21,6 +22,9 @@ class ParameterReferences(MutableSet): testing is overridden such that items that are otherwise value-wise equal are still considered distinct if their ``instruction``\\ s are referentially distinct. + + In the case of the special value :attr:`.ParameterTable.GLOBAL_PHASE` for ``instruction``, the + ``param_index`` should be ``None``. """ def _instance_key(self, ref): @@ -83,6 +87,24 @@ class ParameterTable(MutableMapping): __slots__ = ["_table", "_keys", "_names"] + class _GlobalPhaseSentinel: + __slots__ = () + + def __copy__(self): + return self + + def __deepcopy__(self, memo=None): + return self + + def __reduce__(self): + return (operator.attrgetter("GLOBAL_PHASE"), (ParameterTable,)) + + def __repr__(self): + return "" + + GLOBAL_PHASE = _GlobalPhaseSentinel() + """Tracking object to indicate that a reference refers to the global phase of a circuit.""" + def __init__(self, mapping=None): """Create a new instance, initialized with ``mapping`` if provided. @@ -145,6 +167,17 @@ def get_names(self): """ return self._names + def discard_references(self, expression, key): + """Remove all references to parameters contained within ``expression`` at the given table + ``key``. This also discards parameter entries from the table if they have no further + references. No action is taken if the object is not tracked.""" + for parameter in expression.parameters: + if (refs := self._table.get(parameter)) is not None: + if len(refs) == 1: + del self[parameter] + else: + refs.discard(key) + def __delitem__(self, key): del self._table[key] self._keys.discard(key) diff --git a/qiskit/circuit/quantumcircuit.py b/qiskit/circuit/quantumcircuit.py index 91cfac5c4264..723276184b89 100644 --- a/qiskit/circuit/quantumcircuit.py +++ b/qiskit/circuit/quantumcircuit.py @@ -437,7 +437,10 @@ def data(self, data_input: Iterable): else: data_input = list(data_input) self._data.clear() + self._parameters = None self._parameter_table = ParameterTable() + # Repopulate the parameter table with any global-phase entries. + self.global_phase = self.global_phase if not data_input: return if isinstance(data_input[0], CircuitInstruction): @@ -2411,6 +2414,9 @@ def copy(self, name: str | None = None) -> "QuantumCircuit": operation_copies = { id(instruction.operation): instruction.operation.copy() for instruction in self._data } + # The special global-phase sentinel doesn't need copying, but this ensures that it'll get + # recognised. The global phase itself was already copied over in 'copy_empty_like`. + operation_copies[id(ParameterTable.GLOBAL_PHASE)] = ParameterTable.GLOBAL_PHASE cpy._parameter_table = ParameterTable( { @@ -2473,6 +2479,10 @@ def copy_empty_like(self, name: str | None = None) -> "QuantumCircuit": cpy._vars_capture = self._vars_capture.copy() cpy._parameter_table = ParameterTable() + for parameter in getattr(cpy.global_phase, "parameters", ()): + cpy._parameter_table[parameter] = ParameterReferences( + [(ParameterTable.GLOBAL_PHASE, None)] + ) cpy._data = CircuitData(self._data.qubits, self._data.clbits) cpy._calibrations = copy.deepcopy(self._calibrations) @@ -2489,6 +2499,8 @@ def clear(self) -> None: """ self._data.clear() self._parameter_table.clear() + # Repopulate the parameter table with any phase symbols. + self.global_phase = self.global_phase def _create_creg(self, length: int, name: str) -> ClassicalRegister: """Creates a creg, checking if ClassicalRegister with same name exists""" @@ -2825,9 +2837,20 @@ def global_phase(self, angle: ParameterValueType): Args: angle (float, ParameterExpression): radians """ - if not (isinstance(angle, ParameterExpression) and angle.parameters): - # Set the phase to the [0, 2π) interval - angle = float(angle) % (2 * np.pi) + # If we're currently parametric, we need to throw away the references. This setter is + # called by some subclasses before the inner `_global_phase` is initialised. + global_phase_reference = (ParameterTable.GLOBAL_PHASE, None) + if isinstance(previous := getattr(self, "_global_phase", None), ParameterExpression): + self._parameter_table.discard_references(previous, global_phase_reference) + + if isinstance(angle, ParameterExpression) and angle.parameters: + for parameter in angle.parameters: + if parameter not in self._parameter_table: + self._parameters = None + self._parameter_table[parameter] = ParameterReferences(()) + self._parameter_table[parameter].add(global_phase_reference) + else: + angle = _normalize_global_phase(angle) if self._control_flow_scopes: self._control_flow_scopes[-1].global_phase = angle else: @@ -2901,10 +2924,7 @@ def parameters(self) -> ParameterView: @property def num_parameters(self) -> int: """The number of parameter objects in the circuit.""" - # Avoid a (potential) object creation if we can. - if self._parameters is not None: - return len(self._parameters) - return len(self._unsorted_parameters()) + return len(self._parameter_table) def _unsorted_parameters(self) -> set[Parameter]: """Efficiently get all parameters in the circuit, without any sorting overhead. @@ -2917,11 +2937,7 @@ def _unsorted_parameters(self) -> set[Parameter]: """ # This should be free, by accessing the actual backing data structure of the table, but that # means that we need to copy it if adding keys from the global phase. - parameters = self._parameter_table.get_keys() - if isinstance(self.global_phase, ParameterExpression): - # Deliberate copy. - parameters = parameters | self.global_phase.parameters - return parameters + return self._parameter_table.get_keys() @overload def assign_parameters( @@ -3073,7 +3089,12 @@ def assign_parameters( # pylint: disable=missing-raises-doc ) for operation, index in references: seen_operations[id(operation)] = operation - assignee = operation.params[index] + if operation is ParameterTable.GLOBAL_PHASE: + assignee = target.global_phase + validate = _normalize_global_phase + else: + assignee = operation.params[index] + validate = operation.validate_parameter if isinstance(assignee, ParameterExpression): new_parameter = assignee.assign(to_bind, bound_value) for parameter in update_parameters: @@ -3081,7 +3102,7 @@ def assign_parameters( # pylint: disable=missing-raises-doc target._parameter_table[parameter] = ParameterReferences(()) target._parameter_table[parameter].add((operation, index)) if not new_parameter.parameters: - new_parameter = operation.validate_parameter(new_parameter.numeric()) + new_parameter = validate(new_parameter.numeric()) elif isinstance(assignee, QuantumCircuit): new_parameter = assignee.assign_parameters( {to_bind: bound_value}, inplace=False, flat_input=True @@ -3091,7 +3112,12 @@ def assign_parameters( # pylint: disable=missing-raises-doc f"Saw an unknown type during symbolic binding: {assignee}." " This may indicate an internal logic error in symbol tracking." ) - operation.params[index] = new_parameter + if operation is ParameterTable.GLOBAL_PHASE: + # We've already handled parameter table updates in bulk, so we need to skip the + # public setter trying to do it again. + target._global_phase = new_parameter + else: + operation.params[index] = new_parameter # After we've been through everything at the top level, make a single visit to each # operation we've seen, rebinding its definition if necessary. @@ -3103,12 +3129,6 @@ def assign_parameters( # pylint: disable=missing-raises-doc parameter_binds.mapping, inplace=True, flat_input=True, strict=False ) - if isinstance(target.global_phase, ParameterExpression): - new_phase = target.global_phase - for parameter in new_phase.parameters & parameter_binds.mapping.keys(): - new_phase = new_phase.assign(parameter, parameter_binds.mapping[parameter]) - target.global_phase = new_phase - # Finally, assign the parameters inside any of the calibrations. We don't track these in # the `ParameterTable`, so we manually reconstruct things. def map_calibration(qubits, parameters, schedule): @@ -6064,3 +6084,11 @@ def _bit_argument_conversion_scalar(specifier, bit_sequence, bit_set, type_): else f"Invalid bit index: '{specifier}' of type '{type(specifier)}'" ) raise CircuitError(message) + + +def _normalize_global_phase(angle): + """Return the normalized form of an angle for use in the global phase. This coerces to float if + possible, and fixes to the interval :math:`[0, 2\\pi)`.""" + if isinstance(angle, ParameterExpression) and angle.parameters: + return angle + return float(angle) % (2.0 * np.pi) diff --git a/test/python/circuit/test_circuit_operations.py b/test/python/circuit/test_circuit_operations.py index af94a1a290e1..7d7a930f7cec 100644 --- a/test/python/circuit/test_circuit_operations.py +++ b/test/python/circuit/test_circuit_operations.py @@ -370,6 +370,25 @@ def test_copy_copies_registers(self): self.assertEqual(len(qc.cregs), 1) self.assertEqual(len(copied.cregs), 2) + def test_copy_handles_global_phase(self): + """Test that the global phase is included in the copy, including parameters.""" + a, b = Parameter("a"), Parameter("b") + + nonparametric = QuantumCircuit(global_phase=1.0).copy() + self.assertEqual(nonparametric.global_phase, 1.0) + self.assertEqual(set(nonparametric.parameters), set()) + + parameter_phase = QuantumCircuit(global_phase=a).copy() + self.assertEqual(parameter_phase.global_phase, a) + self.assertEqual(set(parameter_phase.parameters), {a}) + # The `assign_parameters` is an indirect test that the `ParameterTable` is fully valid. + self.assertEqual(parameter_phase.assign_parameters({a: 1.0}).global_phase, 1.0) + + expression_phase = QuantumCircuit(global_phase=a - b).copy() + self.assertEqual(expression_phase.global_phase, a - b) + self.assertEqual(set(expression_phase.parameters), {a, b}) + self.assertEqual(expression_phase.assign_parameters({a: 3, b: 2}).global_phase, 1.0) + def test_copy_empty_like_circuit(self): """Test copy_empty_like method makes a clear copy.""" qr = QuantumRegister(2) @@ -463,6 +482,24 @@ def test_copy_empty_variables(self): self.assertEqual({b, d}, set(copied.iter_captured_vars())) self.assertEqual({b}, set(qc.iter_captured_vars())) + def test_copy_empty_like_parametric_phase(self): + """Test that the parameter table of an empty circuit remains valid after copying a circuit + with a parametric global phase.""" + a, b = Parameter("a"), Parameter("b") + + single = QuantumCircuit(global_phase=a).copy_empty_like() + self.assertEqual(single.global_phase, a) + self.assertEqual(set(single.parameters), {a}) + # The `assign_parameters` is an indirect test that the `ParameterTable` is fully valid. + self.assertEqual(single.assign_parameters({a: 1.0}).global_phase, 1.0) + + stripped_instructions = QuantumCircuit(1, global_phase=a - b) + stripped_instructions.rz(a, 0) + stripped_instructions = stripped_instructions.copy_empty_like() + self.assertEqual(stripped_instructions.global_phase, a - b) + self.assertEqual(set(stripped_instructions.parameters), {a, b}) + self.assertEqual(stripped_instructions.assign_parameters({a: 3, b: 2}).global_phase, 1.0) + def test_circuit_copy_rejects_invalid_types(self): """Test copy method rejects argument with type other than 'string' and 'None' type.""" qc = QuantumCircuit(1, 1)