diff --git a/pyquil/quilatom.py b/pyquil/quilatom.py index a2eea641d..4aefd05a5 100644 --- a/pyquil/quilatom.py +++ b/pyquil/quilatom.py @@ -378,30 +378,31 @@ def _substitute(self, d: Any) -> ExpressionDesignator: return self -ParamSubstitutionsMapDesignator = Mapping["Parameter", ExpressionValueDesignator] +ParameterSubstitutionsMapDesignator = Mapping[Union["Parameter", "MemoryReference"], ExpressionValueDesignator] -def substitute(expr: ExpressionDesignator, d: ParamSubstitutionsMapDesignator) -> ExpressionDesignator: +def substitute(expr: ExpressionDesignator, d: ParameterSubstitutionsMapDesignator) -> ExpressionDesignator: """ - Using a dictionary of substitutions ``d`` try and explicitly evaluate as much of ``expr`` as - possible. + Using a dictionary of substitutions ``d``, try and explicitly evaluate as much of ``expr`` as + possible. This supports substitution of both parameters and memory references. Each memory + reference must be individually assigned a value at each memory offset to be substituted. - :param expr: The expression whose parameters are substituted. - :param d: Numerical substitutions for parameters. - :return: A partially simplified Expression or a number. + :param expr: The expression whose parameters or memory references are to be substituted. + :param d: Numerical substitutions for parameters or memory references. + :return: A partially simplified Expression, or a number. """ if isinstance(expr, Expression): return expr._substitute(d) return expr -def substitute_array(a: Union[Sequence[Expression], np.ndarray], d: ParamSubstitutionsMapDesignator) -> np.ndarray: +def substitute_array(a: Union[Sequence[Expression], np.ndarray], d: ParameterSubstitutionsMapDesignator) -> np.ndarray: """ Apply ``substitute`` to all elements of an array ``a`` and return the resulting array. - :param a: The expression array to substitute. - :param d: Numerical substitutions for parameters. - :return: An array of partially substituted Expressions or numbers. + :param a: The array of expressions whose parameters or memory references are to be substituted. + :param d: Numerical substitutions for parameters or memory references, for all array elements. + :return: An array of partially substituted Expressions, or numbers. """ a = np.asarray(a, order="C") return np.array([substitute(v, d) for v in a.flat]).reshape(a.shape) # type: ignore @@ -418,7 +419,7 @@ def __init__(self, name: str): def out(self) -> str: return "%" + self.name - def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["Parameter", ExpressionValueDesignator]: + def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["Parameter", ExpressionValueDesignator]: return d.get(self, self) def __str__(self) -> str: @@ -446,7 +447,7 @@ def __init__( self.expression = expression self.fn = fn - def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["Function", ExpressionValueDesignator]: + def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["Function", ExpressionValueDesignator]: sop = substitute(self.expression, d) if isinstance(sop, Expression): return Function(self.name, sop, self.fn) @@ -497,7 +498,7 @@ def __init__(self, op1: ExpressionDesignator, op2: ExpressionDesignator): self.op1 = op1 self.op2 = op2 - def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["BinaryExp", ExpressionValueDesignator]: + def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["BinaryExp", ExpressionValueDesignator]: sop1, sop2 = substitute(self.op1, d), substitute(self.op2, d) return self.fn(sop1, sop2) @@ -709,6 +710,12 @@ def __getitem__(self, offset: int) -> "MemoryReference": return MemoryReference(name=self.name, offset=offset) + def _substitute(self, parameter_memory_map) -> ExpressionDesignator: + if self in parameter_memory_map: + return parameter_memory_map[self] + + return self + def _contained_mrefs(expression: ExpressionDesignator) -> Set[MemoryReference]: """ diff --git a/test/unit/test_memory.py b/test/unit/test_memory.py index 1dade5bf3..ec0bbd23a 100644 --- a/test/unit/test_memory.py +++ b/test/unit/test_memory.py @@ -6,6 +6,16 @@ pauli_term_to_measurement_memory_map, ) from pyquil.paulis import sX, sY +from pyquil.quilatom import ( + MemoryReference, + quil_cis, + quil_cos, + quil_exp, + quil_sin, + quil_sqrt, + substitute, + substitute_array, +) def test_merge_memory_map_lists(): @@ -34,3 +44,40 @@ def test_pauli_term_to_measurement_memory_map(): "measurement_beta": [0.0, np.pi / 2], "measurement_gamma": [0.0, -np.pi / 2], } + + +def test_substitute_memory_reference(): + x_0 = MemoryReference("x", 0, declared_size=2) + x_1 = MemoryReference("x", 1, declared_size=2) + + # complete substitutions + + assert substitute(x_0, {x_0: 5}) == 5 + + assert substitute(x_0 + x_1, {x_0: +5, x_1: -5}) == 0 + assert substitute(x_0 - x_1, {x_0: +5, x_1: -5}) == 10 + assert substitute(x_0 * x_1, {x_0: +5, x_1: -5}) == -25 + assert substitute(x_0 / x_1, {x_0: +5, x_1: -5}) == -1 + + assert substitute(x_0 * x_0 ** 2 / x_1, {x_0: 5, x_1: 10}) == 12.5 + + assert np.isclose(substitute(quil_exp(x_0), {x_0: 5, x_1: 10}), np.exp(5)) + assert np.isclose(substitute(quil_sin(x_0), {x_0: 5, x_1: 10}), np.sin(5)) + assert np.isclose(substitute(quil_cos(x_0), {x_0: 5, x_1: 10}), np.cos(5)) + assert np.isclose(substitute(quil_sqrt(x_0), {x_0: 5, x_1: 10}), np.sqrt(5)) + assert np.isclose(substitute(quil_cis(x_0), {x_0: 5, x_1: 10}), np.exp(1j * 5.0)) + + # incomplete substitutions + + y = MemoryReference("y", 0, declared_size=1) + z = MemoryReference("z", 0, declared_size=1) + + assert substitute(y + z, {y: 5}) == 5 + z + + assert substitute(quil_cis(z), {y: 5}) == quil_cis(z) + + # array substitution pass-through + + a = MemoryReference("a", 0, declared_size=1) + + assert np.allclose(substitute_array([quil_sin(a), quil_cos(a)], {a: 5}), [np.sin(5), np.cos(5)])