Skip to content

Commit

Permalink
fix renormalization issue (#6190)
Browse files Browse the repository at this point in the history
Normalization errors were occuring when the state was jax and x64 mode
was not enabled. Simply renormalizing the state did not help, as the
norm only differed from 1 by machine precision.


This PR simply defers sampling to jax whenever the state is jax. If a
prng key is not provided, we just pull a seed from the numpy rng.
Pulling a prng seed from the numpy rng will not be particularily
jit-able, but if the user wants it to be jittable, they can provide a
prng key instead.

Fixes #6100 [sc-71361]

---------

Co-authored-by: Mudit Pandey <[email protected]>
  • Loading branch information
albi3ro and mudit2812 authored Aug 30, 2024
1 parent 60d1c73 commit bb577c4
Show file tree
Hide file tree
Showing 4 changed files with 64 additions and 11 deletions.
4 changes: 4 additions & 0 deletions doc/releases/changelog-0.38.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,10 @@

<h3>Bug fixes 🐛</h3>

* For `default.qubit`, JAX is now used for sampling whenever the state is a JAX array. This fixes normalization issues
that can occur when the state uses 32-bit precision.
[(#6190)](https://github.com/PennyLaneAI/pennylane/pull/6190)

* Fix Pytree serialization of operators with empty shot vectors
[(#6155)](https://github.com/PennyLaneAI/pennylane/pull/6155)

Expand Down
13 changes: 8 additions & 5 deletions pennylane/devices/qubit/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,9 +471,9 @@ def sample_state(
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
"""
if prng_key is not None:
if prng_key is not None or qml.math.get_interface(state) == "jax":
return _sample_state_jax(
state, shots, prng_key, is_state_batched=is_state_batched, wires=wires
state, shots, prng_key, is_state_batched=is_state_batched, wires=wires, seed=rng
)

rng = np.random.default_rng(rng)
Expand Down Expand Up @@ -530,6 +530,7 @@ def _sample_state_jax(
prng_key,
is_state_batched: bool = False,
wires=None,
seed=None,
) -> np.ndarray:
"""
Returns a series of samples of a state for the JAX interface based on the PRNG.
Expand All @@ -541,6 +542,7 @@ def _sample_state_jax(
the key to the JAX pseudo random number generator.
is_state_batched (bool): whether the state is batched or not
wires (Sequence[int]): The wires to sample
seed (numpy.random.Generator): seed to use to generate a key if a ``prng_key`` is not present. ``None`` by default.
Returns:
ndarray[int]: Sample values of the shape (shots, num_wires)
Expand All @@ -549,7 +551,8 @@ def _sample_state_jax(
import jax
import jax.numpy as jnp

key = prng_key
if prng_key is None:
prng_key = jax.random.PRNGKey(np.random.default_rng(seed).integers(100000))

total_indices = len(state.shape) - is_state_batched
state_wires = qml.wires.Wires(range(total_indices))
Expand All @@ -574,6 +577,6 @@ def _sample_state_jax(
_, key = jax_random_split(prng_key)
samples = jax.random.choice(key, basis_states, shape=(shots,), p=probs)

powers_of_two = 1 << np.arange(num_wires, dtype=np.int64)[::-1]
powers_of_two = 1 << np.arange(num_wires, dtype=int)[::-1]
states_sampled_base_ten = samples[..., None] & powers_of_two
return (states_sampled_base_ten > 0).astype(np.int64)
return (states_sampled_base_ten > 0).astype(int)
46 changes: 46 additions & 0 deletions tests/devices/default_qubit/test_default_qubit.py
Original file line number Diff line number Diff line change
Expand Up @@ -2197,3 +2197,49 @@ def test_broadcasted_parameter(max_workers):
results = dev.execute(batch, config)
processed_results = pre_processing_fn(results)
assert qml.math.allclose(processed_results, np.cos(x))


@pytest.mark.jax
def test_renomalization_issue():
"""Test that no normalization error occurs with the following workflow in float32 mode.
Just tests executes without error. Not producing a more minimal example due to difficulty
finding an exact case that leads to renomalization issues.
"""
import jax
from jax import numpy as jnp

initial_mode = jax.config.jax_enable_x64
jax.config.update("jax_enable_x64", False)

def gaussian_fn(p, t):
return p[0] * jnp.exp(-((t - p[1]) ** 2) / (2 * p[2] ** 2))

global_drive = qml.pulse.rydberg_drive(
amplitude=gaussian_fn, phase=0, detuning=0, wires=[0, 1, 2]
)

a = 5

coordinates = [(0, 0), (a, 0), (a / 2, np.sqrt(a**2 - (a / 2) ** 2))]

settings = {"interaction_coeff": 862619.7915580727}

H_interaction = qml.pulse.rydberg_interaction(coordinates, **settings)

max_amplitude = 2.0
displacement = 1.0
sigma = 0.3

amplitude_params = [max_amplitude, displacement, sigma]

params = [amplitude_params]
ts = [0.0, 1.75]

def circuit(params):
qml.evolve(H_interaction + global_drive)(params, ts)
return qml.counts()

circuit_qml = qml.QNode(circuit, qml.device("default.qubit", shots=1000), interface="jax")

circuit_qml(params)
jax.config.update("jax_enable_x64", initial_mode)
12 changes: 6 additions & 6 deletions tests/devices/qubit/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,6 @@ def test_prng_key_as_seed_uses_sample_state_jax(self, mocker):

# prng_key specified, should call _sample_state_jax
_ = sample_state(state, 10, prng_key=jax.random.PRNGKey(15))
# prng_key defaults to None, should NOT call _sample_state_jax
_ = sample_state(state, 10, rng=15)

spy.assert_called_once()

Expand Down Expand Up @@ -723,7 +721,7 @@ def test_nan_shadow_expval(self, H, interface, shots):


two_qubit_state_to_be_normalized = np.array([[0, 1.0000000005j], [-1, 0]]) / np.sqrt(2)
two_qubit_state_not_normalized = np.array([[0, 1.0000005j], [-1.00000001, 0]]) / np.sqrt(2)
two_qubit_state_not_normalized = np.array([[0, 1.00005j], [-1.00000001, 0]]) / np.sqrt(2)

batched_state_to_be_normalized = np.stack(
[
Expand Down Expand Up @@ -752,8 +750,9 @@ def test_sample_state_renorm(self, interface):
state = qml.math.array(two_qubit_state_to_be_normalized, like=interface)
_ = sample_state(state, 10)

# jax.random.choice accepts unnormalized probabilities
@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
@pytest.mark.parametrize("interface", ["numpy", "torch", "tensorflow"])
def test_sample_state_renorm_error(self, interface):
"""Test that renormalization does not occur if the error is too large."""

Expand All @@ -762,15 +761,16 @@ def test_sample_state_renorm_error(self, interface):
_ = sample_state(state, 10)

@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
@pytest.mark.parametrize("interface", ["numpy", "torch", "jax", "tensorflow"])
def test_sample_batched_state_renorm(self, interface):
"""Test renormalization for a batched state."""

state = qml.math.array(batched_state_to_be_normalized, like=interface)
_ = sample_state(state, 10, is_state_batched=True)

# jax.random.choices accepts unnormalized probabilities
@pytest.mark.all_interfaces
@pytest.mark.parametrize("interface", ["numpy", "jax", "torch", "tensorflow"])
@pytest.mark.parametrize("interface", ["numpy", "torch", "tensorflow"])
def test_sample_batched_state_renorm_error(self, interface):
"""Test that renormalization does not occur if the error is too large."""

Expand Down

0 comments on commit bb577c4

Please sign in to comment.