diff --git a/qiskit/pulse/library/symbolic_pulses.py b/qiskit/pulse/library/symbolic_pulses.py index 713265c47a46..83e0a4c5c925 100644 --- a/qiskit/pulse/library/symbolic_pulses.py +++ b/qiskit/pulse/library/symbolic_pulses.py @@ -160,14 +160,23 @@ def __init__(self, attribute: str): self.lambda_funcs = dict() def __get__(self, instance, owner) -> Callable: - expr = getattr(instance, self.attribute, None) - if expr is None: + exprs = getattr(instance, self.attribute, None) + if exprs is None: raise PulseError(f"'{self.attribute}' of '{instance.pulse_type}' is not assigned.") - key = hash(expr) - if key not in self.lambda_funcs: - self.__set__(instance, expr) - - return self.lambda_funcs[key] + if isinstance(exprs, list): + is_list = True + else: + exprs = [exprs] + is_list = False + funcs = [] + for expr in exprs: + key = hash(expr) + if key not in self.lambda_funcs: + self.__set__(instance, expr) + funcs.append(self.lambda_funcs[key]) + if not is_list: + funcs = funcs[0] + return funcs def __set__(self, instance, value): key = hash(value) @@ -383,12 +392,14 @@ def Sawtooth(duration, amp, freq, name): "_envelope", "_constraints", "_valid_amp_conditions", + "_params_compare_exp", ) # Lambdify caches keyed on sympy expressions. Returns the corresponding callable. _envelope_lam = LambdifiedExpression("_envelope") _constraints_lam = LambdifiedExpression("_constraints") _valid_amp_conditions_lam = LambdifiedExpression("_valid_amp_conditions") + _params_compare_exp_lam = LambdifiedExpression("_params_compare_exp") def __init__( self, @@ -400,6 +411,7 @@ def __init__( envelope: Optional[sym.Expr] = None, constraints: Optional[sym.Expr] = None, valid_amp_conditions: Optional[sym.Expr] = None, + params_compare_exp: Optional[Union[List[sym.Expr], sym.Expr]] = None, ): """Create a parametric pulse. @@ -417,6 +429,10 @@ def __init__( will investigate the full-waveform and raise an error when the amplitude norm of any data point exceeds 1.0. If not provided, the validation always creates a full-waveform. + params_compare_exp: Symbolic expression or a list of symbolic expressions describing + functions of the pulse parameters. When two pulses are equated, parameters which + appear in any of the expressions will not be equated individually. instead, the + expressions will be equated. Raises: PulseError: When not all parameters are listed in the attribute :attr:`PARAM_DEF`. @@ -436,6 +452,13 @@ def __init__( self._constraints = constraints self._valid_amp_conditions = valid_amp_conditions + if params_compare_exp is None: + self._params_compare_exp = None + else: + if not isinstance(params_compare_exp, list): + params_compare_exp = [params_compare_exp] + self._params_compare_exp = params_compare_exp + def __getattr__(self, item): # Get pulse parameters with attribute-like access. params = object.__getattribute__(self, "_params") @@ -547,9 +570,33 @@ def __eq__(self, other: "SymbolicPulse") -> bool: if self._envelope != other._envelope: return False - if self.parameters != other.parameters: + if self._params_compare_exp != other._params_compare_exp: return False + excluded_params = [] + if self._params_compare_exp is not None: + for i, expr in enumerate(self._params_compare_exp): + expr_parameters = [str(s) for s in expr.free_symbols] + # Only handle expressions which are not parameterized + if not any( + isinstance(self.parameters[val], ParameterExpression) for val in expr_parameters + ): + excluded_params.extend(expr_parameters) + + args_self = _get_expression_args(self._params_compare_exp[i], self.parameters) + args_other = _get_expression_args(self._params_compare_exp[i], other.parameters) + if not np.isclose( + self._params_compare_exp_lam[i](*args_self), + self._params_compare_exp_lam[i](*args_other), + ): + return False + + # Compare the parameters which were not in the expressions: + for parameter in self.parameters.keys(): + if parameter not in excluded_params: + if self.parameters[parameter] != other.parameters[parameter]: + return False + return True def __hash__(self) -> int: @@ -669,6 +716,7 @@ def __new__( consts_expr = _sigma > 0 valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 + params_compare_exp = _amp * sym.exp(sym.I * _angle) instance = SymbolicPulse( pulse_type=cls.alias, @@ -679,6 +727,7 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + params_compare_exp=params_compare_exp, ) instance.validate_parameters() @@ -810,6 +859,7 @@ def __new__( consts_expr = sym.And(_sigma > 0, _width >= 0, _duration >= _width) valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 + params_compare_exp = _amp * sym.exp(sym.I * _angle) instance = SymbolicPulse( pulse_type=cls.alias, @@ -820,6 +870,7 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + params_compare_exp =params_compare_exp, ) instance.validate_parameters() @@ -925,6 +976,7 @@ def __new__( consts_expr = _sigma > 0 valid_amp_conditions_expr = sym.And(sym.Abs(_amp) <= 1.0, sym.Abs(_beta) < _sigma) + params_compare_exp = _amp * sym.exp(sym.I * _angle) instance = SymbolicPulse( pulse_type="Drag", @@ -935,6 +987,7 @@ def __new__( envelope=envelope_expr, constraints=consts_expr, valid_amp_conditions=valid_amp_conditions_expr, + params_compare_exp=params_compare_exp, ) instance.validate_parameters() @@ -1010,6 +1063,7 @@ def __new__( ) valid_amp_conditions_expr = sym.Abs(_amp) <= 1.0 + params_compare_exp = _amp * sym.exp(sym.I * _angle) instance = SymbolicPulse( pulse_type="Constant", @@ -1019,6 +1073,7 @@ def __new__( limit_amplitude=limit_amplitude, envelope=envelope_expr, valid_amp_conditions=valid_amp_conditions_expr, + params_compare_exp =params_compare_exp, ) instance.validate_parameters() diff --git a/test/python/pulse/test_pulse_lib.py b/test/python/pulse/test_pulse_lib.py index 428d8094a2f7..59eeb368fad6 100644 --- a/test/python/pulse/test_pulse_lib.py +++ b/test/python/pulse/test_pulse_lib.py @@ -542,6 +542,39 @@ def local_gaussian(duration, amp, t0, sig): pulse_wf_inst = local_gaussian(duration=_duration, amp=1, t0=5, sig=1) self.assertEqual(len(pulse_wf_inst.samples), _duration) + def test_params_comp_exp(self): + """Test equating of pulses with params_comp_exp.""" + # amp,angle comparison for library pulses + gaussian_negamp = Gaussian(duration=25, sigma=4, amp=-0.5, angle=0) + gaussian_piphase = Gaussian(duration=25, sigma=4, amp=0.5, angle=np.pi) + self.assertEqual(gaussian_negamp, gaussian_piphase) + + # Parameterized library pulses + amp = Parameter("amp") + gaussian1 = Gaussian(duration=25, sigma=4, amp=amp, angle=0) + gaussian2 = Gaussian(duration=25, sigma=4, amp=amp, angle=0) + self.assertEqual(gaussian1, gaussian2) + + # # Custom pulse with two expressions + p1, p2, p3, p4 = sym.symbols("p1, p2, p3, p4") + envelope = (p1 + p2) * (p3 + p4) + + custom_pulse1 = SymbolicPulse( + pulse_type="Custom", + duration=100, + parameters={"p1": 1, "p2": 0, "p3": 1, "p4": 0}, + envelope=envelope, + params_compare_exp=[p1 + p2, p3 + p4], + ) + custom_pulse2 = SymbolicPulse( + pulse_type="Custom", + duration=100, + parameters={"p1": 0, "p2": 1, "p3": 0, "p4": 1}, + envelope=envelope, + params_compare_exp=[p1 + p2, p3 + p4], + ) + self.assertEqual(custom_pulse1, custom_pulse2) + if __name__ == "__main__": unittest.main()