diff --git a/qiskit/pulse/library/symbolic_pulses.py b/qiskit/pulse/library/symbolic_pulses.py index 713265c47a46..3e9ac89aac2a 100644 --- a/qiskit/pulse/library/symbolic_pulses.py +++ b/qiskit/pulse/library/symbolic_pulses.py @@ -20,7 +20,7 @@ import functools import warnings -from typing import Any, Dict, List, Optional, Union, Callable +from typing import Any, Dict, List, Optional, Union, Callable, Tuple import numpy as np @@ -383,6 +383,8 @@ def Sawtooth(duration, amp, freq, name): "_envelope", "_constraints", "_valid_amp_conditions", + "_canonical_params", + "_excluded_params", ) # Lambdify caches keyed on sympy expressions. Returns the corresponding callable. @@ -400,6 +402,8 @@ def __init__( envelope: Optional[sym.Expr] = None, constraints: Optional[sym.Expr] = None, valid_amp_conditions: Optional[sym.Expr] = None, + canonical_params: Optional[List[Union[ParameterExpression, complex]]] = None, + excluded_params: Optional[Tuple[str]] = None, ): """Create a parametric pulse. @@ -417,6 +421,11 @@ def __init__( will investigate the full-waveform and raise an error when the amplitude norm of any data point exceeds 1.0. If not provided, the validation always creates a full-waveform. + canonical_params: List of parameters for the equating operation of symbolic + pulses. When two pulses are compared, the two lists have to be identical to + yield `True`. + excluded_params: Tuple of strings matching keys in `parameters` which are to be + ignored when two symbolic pulses are ignored. Raises: PulseError: When not all parameters are listed in the attribute :attr:`PARAM_DEF`. @@ -436,6 +445,13 @@ def __init__( self._constraints = constraints self._valid_amp_conditions = valid_amp_conditions + if canonical_params is None: + canonical_params = [] + self._canonical_params = canonical_params + if excluded_params is None: + excluded_params = () + self._excluded_params = excluded_params + def __getattr__(self, item): # Get pulse parameters with attribute-like access. params = object.__getattribute__(self, "_params") @@ -536,6 +552,31 @@ def parameters(self) -> Dict[str, Any]: params.update(self._params) return params + def _equate_parameters(self, other): + """Helper function which compares the parameters of two pulses, taking into account + _canonical_params and _excluded_params.""" + if len(self._canonical_params) != len(other._canonical_params): + return False + + for param1, param2 in zip(self._canonical_params, other._canonical_params): + # Because the values are calculated, we need to compare to within numerical precision, + # and can't use a simple comparison of the lists. + if isinstance(param1, ParameterExpression) or isinstance(param2, ParameterExpression): + if param1 != param2: + return False + else: + if not np.isclose(param1, param2): + return False + + if self.parameters.keys() != other.parameters.keys(): + return False + + for key in self.parameters: + if key not in self._excluded_params and self.parameters[key] != other.parameters[key]: + return False + + return True + def __eq__(self, other: "SymbolicPulse") -> bool: if not isinstance(other, SymbolicPulse): @@ -547,8 +588,12 @@ def __eq__(self, other: "SymbolicPulse") -> bool: if self._envelope != other._envelope: return False + # _canonical_params is assumed to be a function of parameters. If parameters are the same, + # we don't need to check the _canonical_params. (Also solves the edge case of a pulse with + # no parameters) if self.parameters != other.parameters: - return False + if not self._equate_parameters(other): + return False return True @@ -658,6 +703,8 @@ def __new__( angle = 0 parameters = {"amp": amp, "sigma": sigma, "angle": angle} + canonical_params = [amp * np.exp(1j * angle)] + excluded_params = ("amp", "angle") # Prepare symbolic expressions _t, _duration, _amp, _sigma, _angle = sym.symbols("t, duration, amp, sigma, angle") @@ -679,6 +726,8 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + canonical_params=canonical_params, + excluded_params=excluded_params, ) instance.validate_parameters() @@ -787,6 +836,8 @@ def __new__( angle = 0 parameters = {"amp": amp, "sigma": sigma, "width": width, "angle": angle} + canonical_params = [amp * np.exp(1j * angle)] + excluded_params = ("amp", "angle") # Prepare symbolic expressions _t, _duration, _amp, _sigma, _width, _angle = sym.symbols( @@ -820,6 +871,8 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + canonical_params=canonical_params, + excluded_params=excluded_params, ) instance.validate_parameters() @@ -911,6 +964,8 @@ def __new__( angle = 0 parameters = {"amp": amp, "sigma": sigma, "beta": beta, "angle": angle} + canonical_params = [amp * np.exp(1j * angle)] + excluded_params = ("amp", "angle") # Prepare symbolic expressions _t, _duration, _amp, _sigma, _beta, _angle = sym.symbols( @@ -935,6 +990,8 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + canonical_params=canonical_params, + excluded_params=excluded_params, ) instance.validate_parameters() @@ -992,6 +1049,8 @@ def __new__( angle = 0 parameters = {"amp": amp, "angle": angle} + canonical_params = [amp * np.exp(1j * angle)] + excluded_params = ("amp", "angle") # Prepare symbolic expressions _t, _amp, _duration, _angle = sym.symbols("t, amp, duration, angle") @@ -1019,6 +1078,8 @@ def __new__( limit_amplitude=limit_amplitude, envelope=envelope_expr, valid_amp_conditions=valid_amp_conditions_expr, + canonical_params=canonical_params, + excluded_params=excluded_params, ) instance.validate_parameters() diff --git a/qiskit/pulse/parameter_manager.py b/qiskit/pulse/parameter_manager.py index d1db89ab1bed..5f999b2984a9 100644 --- a/qiskit/pulse/parameter_manager.py +++ b/qiskit/pulse/parameter_manager.py @@ -232,6 +232,13 @@ def visit_SymbolicPulse(self, node: SymbolicPulse): if isinstance(pval, ParameterExpression): new_val = self._assign_parameter_expression(pval) node._params[name] = new_val + # Assign canonical parameters + for i in range(len(node._canonical_params)): + pval = node._canonical_params[i] + if isinstance(pval, ParameterExpression): + new_val = self._assign_parameter_expression(pval) + node._canonical_params[i] = new_val + node.validate_parameters() return node diff --git a/test/python/pulse/test_pulse_lib.py b/test/python/pulse/test_pulse_lib.py index 428d8094a2f7..27723cb40034 100644 --- a/test/python/pulse/test_pulse_lib.py +++ b/test/python/pulse/test_pulse_lib.py @@ -28,6 +28,7 @@ gaussian_square, drag as pl_drag, ) +from qiskit.pulse import build, play, DriveChannel from qiskit.pulse import functional_pulse, PulseError from qiskit.test import QiskitTestCase @@ -542,6 +543,31 @@ def local_gaussian(duration, amp, t0, sig): pulse_wf_inst = local_gaussian(duration=_duration, amp=1, t0=5, sig=1) self.assertEqual(len(pulse_wf_inst.samples), _duration) + def test_comparison_parameters(self): + """Test equating of pulses with comparison_parameters.""" + # amp,angle comparison for library pulses + gaussian_negamp = Gaussian(duration=25, sigma=4, amp=-0.5, angle=0) + gaussian_piphase = Gaussian(duration=25, sigma=4, amp=0.5, angle=np.pi) + self.assertEqual(gaussian_negamp, gaussian_piphase) + + # Parameterized library pulses + amp = Parameter("amp") + gaussian1 = Gaussian(duration=25, sigma=4, amp=amp, angle=0) + gaussian2 = Gaussian(duration=25, sigma=4, amp=amp, angle=0) + self.assertEqual(gaussian1, gaussian2) + + # pulses with different parameters + gaussian1._params["sigma"] = 10 + self.assertNotEqual(gaussian1, gaussian2) + + # Assignment of parameters (to verify computation of comparison_parameters) + angle = Parameter("angle") + with build() as sc: + play(Gaussian(duration=160, amp=amp, sigma=40, angle=angle), DriveChannel(0)) + sc_piphase = sc.assign_parameters({amp: 1, angle: np.pi}, inplace=False) + sc_negamp = sc.assign_parameters({amp: -1, angle: 0}, inplace=False) + self.assertEqual(sc_piphase, sc_negamp) + if __name__ == "__main__": unittest.main()