Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions guppylang/std/qsystem/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,12 @@ def __new__(seed: int) -> "RNG":

@guppy
def random_angle(self: "RNG") -> angle:
"""Generate a random angle in the range [-pi, pi)."""
r"""Generate a random angle in the range :math:`[-\pi, \pi)`."""
return (2.0 * self.random_float() - 1.0) * pi

@guppy
def random_clifford_angle(self: "RNG") -> angle:
"""Generate a random Clifford angle (multiple of pi/2)."""
r"""Generate a random Clifford angle (multiple of :math:`\pi/2`)."""
return self.random_int_bounded(4) * pi / 2

@guppy.hugr_op(external_op("DeleteRNGContext", [], ext=QSYSTEM_RANDOM_EXTENSION))
Expand Down Expand Up @@ -85,18 +85,22 @@ def shuffle(self: "RNG", array: array[SHUFFLE_T, SHUFFLE_N]) -> None:

@guppy.struct
class DiscreteDistribution(Generic[DISCRETE_N]): # type: ignore[misc]
"""A generic probability distribution over the set {0, 1, 2, ... DISCRETE_N - 1}.
"""A generic probability distribution over a set of the form {0, 1, ..., N-1}.

The `sums` array represents the cumulative probability distribution. That is,
sums[i] is the probability of drawing a value <= i from the distribution.
Objects of this class should be generated using
:py:meth:`make_discrete_distribution`.
"""

# The `sums` array represents the cumulative probability distribution. That is,
# sums[i] is the probability of drawing a value <= i from the distribution.
sums: array[float, DISCRETE_N] # type: ignore[valid-type]

@guppy
@no_type_check
def sample(self: "DiscreteDistribution[DISCRETE_N]", rng: RNG) -> int:
"""Return a sample value from the distribution."""
"""Return a sample value from the distribution, using the provided
:py:class:`RNG`.
"""
x = rng.random_float()
# Use binary search to find the least i s.t. sums[i] >= x.
i_min = 0
Expand Down
Loading