Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions pyquil/quilatom.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,30 +378,31 @@ def _substitute(self, d: Any) -> ExpressionDesignator:
return self


ParamSubstitutionsMapDesignator = Mapping["Parameter", ExpressionValueDesignator]
ParameterSubstitutionsMapDesignator = Mapping[Union["Parameter", "MemoryReference"], ExpressionValueDesignator]


def substitute(expr: ExpressionDesignator, d: ParamSubstitutionsMapDesignator) -> ExpressionDesignator:
def substitute(expr: ExpressionDesignator, d: ParameterSubstitutionsMapDesignator) -> ExpressionDesignator:
"""
Using a dictionary of substitutions ``d`` try and explicitly evaluate as much of ``expr`` as
possible.
Using a dictionary of substitutions ``d``, try and explicitly evaluate as much of ``expr`` as
possible. This supports substitution of both parameters and memory references. Each memory
reference must be individually assigned a value at each memory offset to be substituted.

:param expr: The expression whose parameters are substituted.
:param d: Numerical substitutions for parameters.
:return: A partially simplified Expression or a number.
:param expr: The expression whose parameters or memory references are to be substituted.
:param d: Numerical substitutions for parameters or memory references.
:return: A partially simplified Expression, or a number.
"""
if isinstance(expr, Expression):
return expr._substitute(d)
return expr


def substitute_array(a: Union[Sequence[Expression], np.ndarray], d: ParamSubstitutionsMapDesignator) -> np.ndarray:
def substitute_array(a: Union[Sequence[Expression], np.ndarray], d: ParameterSubstitutionsMapDesignator) -> np.ndarray:
"""
Apply ``substitute`` to all elements of an array ``a`` and return the resulting array.

:param a: The expression array to substitute.
:param d: Numerical substitutions for parameters.
:return: An array of partially substituted Expressions or numbers.
:param a: The array of expressions whose parameters or memory references are to be substituted.
:param d: Numerical substitutions for parameters or memory references, for all array elements.
:return: An array of partially substituted Expressions, or numbers.
"""
a = np.asarray(a, order="C")
return np.array([substitute(v, d) for v in a.flat]).reshape(a.shape) # type: ignore
Expand All @@ -418,7 +419,7 @@ def __init__(self, name: str):
def out(self) -> str:
return "%" + self.name

def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["Parameter", ExpressionValueDesignator]:
def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["Parameter", ExpressionValueDesignator]:
return d.get(self, self)

def __str__(self) -> str:
Expand Down Expand Up @@ -446,7 +447,7 @@ def __init__(
self.expression = expression
self.fn = fn

def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["Function", ExpressionValueDesignator]:
def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["Function", ExpressionValueDesignator]:
sop = substitute(self.expression, d)
if isinstance(sop, Expression):
return Function(self.name, sop, self.fn)
Expand Down Expand Up @@ -497,7 +498,7 @@ def __init__(self, op1: ExpressionDesignator, op2: ExpressionDesignator):
self.op1 = op1
self.op2 = op2

def _substitute(self, d: ParamSubstitutionsMapDesignator) -> Union["BinaryExp", ExpressionValueDesignator]:
def _substitute(self, d: ParameterSubstitutionsMapDesignator) -> Union["BinaryExp", ExpressionValueDesignator]:
sop1, sop2 = substitute(self.op1, d), substitute(self.op2, d)
return self.fn(sop1, sop2)

Expand Down Expand Up @@ -709,6 +710,12 @@ def __getitem__(self, offset: int) -> "MemoryReference":

return MemoryReference(name=self.name, offset=offset)

def _substitute(self, parameter_memory_map) -> ExpressionDesignator:
if self in parameter_memory_map:
return parameter_memory_map[self]

return self


def _contained_mrefs(expression: ExpressionDesignator) -> Set[MemoryReference]:
"""
Expand Down
47 changes: 47 additions & 0 deletions test/unit/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,16 @@
pauli_term_to_measurement_memory_map,
)
from pyquil.paulis import sX, sY
from pyquil.quilatom import (
MemoryReference,
quil_cis,
quil_cos,
quil_exp,
quil_sin,
quil_sqrt,
substitute,
substitute_array,
)


def test_merge_memory_map_lists():
Expand Down Expand Up @@ -34,3 +44,40 @@ def test_pauli_term_to_measurement_memory_map():
"measurement_beta": [0.0, np.pi / 2],
"measurement_gamma": [0.0, -np.pi / 2],
}


def test_substitute_memory_reference():
x_0 = MemoryReference("x", 0, declared_size=2)
x_1 = MemoryReference("x", 1, declared_size=2)

# complete substitutions

assert substitute(x_0, {x_0: 5}) == 5

assert substitute(x_0 + x_1, {x_0: +5, x_1: -5}) == 0
assert substitute(x_0 - x_1, {x_0: +5, x_1: -5}) == 10
assert substitute(x_0 * x_1, {x_0: +5, x_1: -5}) == -25
assert substitute(x_0 / x_1, {x_0: +5, x_1: -5}) == -1

assert substitute(x_0 * x_0 ** 2 / x_1, {x_0: 5, x_1: 10}) == 12.5

assert np.isclose(substitute(quil_exp(x_0), {x_0: 5, x_1: 10}), np.exp(5))
assert np.isclose(substitute(quil_sin(x_0), {x_0: 5, x_1: 10}), np.sin(5))
assert np.isclose(substitute(quil_cos(x_0), {x_0: 5, x_1: 10}), np.cos(5))
assert np.isclose(substitute(quil_sqrt(x_0), {x_0: 5, x_1: 10}), np.sqrt(5))
assert np.isclose(substitute(quil_cis(x_0), {x_0: 5, x_1: 10}), np.exp(1j * 5.0))

# incomplete substitutions

y = MemoryReference("y", 0, declared_size=1)
z = MemoryReference("z", 0, declared_size=1)

assert substitute(y + z, {y: 5}) == 5 + z

assert substitute(quil_cis(z), {y: 5}) == quil_cis(z)

# array substitution pass-through

a = MemoryReference("a", 0, declared_size=1)

assert np.allclose(substitute_array([quil_sin(a), quil_cos(a)], {a: 5}), [np.sin(5), np.cos(5)])