Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@
"multinomial",
"nonnegative",
"nonnegative_integer",
"optional",
"ordered_vector",
"positive",
"positive_definite",
"positive_definite_circulant_vector",
Expand Down Expand Up @@ -398,6 +400,43 @@ def __eq__(self, other: ConstraintT) -> bool:
)


class _OptionalConstraint(Constraint):
def __init__(self, base_constraint: Constraint) -> None:
assert isinstance(base_constraint, Constraint)
self.base_constraint = base_constraint
super().__init__()

@property
def is_discrete(self) -> bool:
return self.base_constraint.is_discrete

@property
def event_dim(self) -> int:
return self.base_constraint.event_dim

def __call__(self, value: Optional[ArrayLike]) -> ArrayLike:
if value is None:
return True
return self.base_constraint(value)

def __repr__(self) -> str:
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)})"

def feasible_like(self, prototype: ArrayLike) -> ArrayLike:
return self.base_constraint.feasible_like(prototype)

def tree_flatten(self):
return (self.base_constraint,), (
("base_constraint",),
{},
)

def __eq__(self, other: ConstraintT) -> bool:
if not isinstance(other, _OptionalConstraint):
return False
return self.base_constraint == other.base_constraint


class _RealVector(_IndependentConstraint, _SingletonConstraint):
def __init__(self) -> None:
super().__init__(_Real(), 1)
Expand Down Expand Up @@ -819,6 +858,7 @@ def tree_flatten(self):
multinomial = _Multinomial
nonnegative: ConstraintT = _Nonnegative()
nonnegative_integer: ConstraintT = _IntegerNonnegative()
optional = _OptionalConstraint
ordered_vector: ConstraintT = _OrderedVector()
positive: ConstraintT = _Positive()
positive_definite: ConstraintT = _PositiveDefinite()
Expand Down
44 changes: 36 additions & 8 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -648,10 +648,11 @@ class GaussianStateSpace(TransformedDistribution):

.. math::
\mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\
&=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t,
&= \mathbf{A}^t \mathbf{z}_0 + \sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k,

where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}`
is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise.
is the transition matrix, :math:`\mathbf{z}_0` is the initial value, and
:math:`\boldsymbol\epsilon` is the innovation noise.


:param num_steps: Number of steps.
Expand All @@ -662,13 +663,16 @@ class GaussianStateSpace(TransformedDistribution):
:math:`\boldsymbol\epsilon`.
:param scale_tril: Scale matrix of the innovation noise
:math:`\boldsymbol\epsilon`.
:param initial_value: Initial state vector :math:`\mathbf{z}_0`. If ``None``,
defaults to zero.
"""

arg_constraints = {
"covariance_matrix": constraints.positive_definite,
"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
"transition_matrix": constraints.real_matrix,
"initial_value": constraints.optional(constraints.real_vector),
}
support = constraints.real_matrix
pytree_aux_fields = ("num_steps",)
Expand All @@ -680,6 +684,7 @@ def __init__(
covariance_matrix: Optional[Array] = None,
precision_matrix: Optional[Array] = None,
scale_tril: Optional[Array] = None,
initial_value: Optional[Array] = None,
*,
validate_args: Optional[bool] = None,
) -> None:
Expand All @@ -691,6 +696,7 @@ def __init__(
"`transition_matrix` argument should be a square matrix"
)
self.transition_matrix = transition_matrix
self.initial_value = initial_value
# Expand the covariance/precision/scale matrices to the right number of steps.
args = {
"covariance_matrix": covariance_matrix,
Expand All @@ -705,23 +711,45 @@ def __init__(
base_distribution = MultivariateNormal(**args)
self.scale_tril = base_distribution.scale_tril[..., 0, :, :]
base_distribution = base_distribution.to_event(1)
transform = RecursiveLinearTransform(transition_matrix)

# The base distribution must have at least the same batch shape as the initial
# value.
if initial_value is not None:
batch_shape = initial_value.shape[:-1]
base_distribution = base_distribution.expand(batch_shape)

transform = RecursiveLinearTransform(
transition_matrix, initial_value=initial_value
)
super().__init__(base_distribution, transform, validate_args=validate_args)

@property
def mean(self) -> ArrayLike:
# The mean of the base distribution is zero and it has the right shape.
return self.base_dist.mean
# If there's no initial value, the mean is zero (base distribution mean).
if self.initial_value is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about checking self._initial_value is None and setting the property self.initial_value to zero if self._initial_value is None

return self.base_dist.mean

# Otherwise, we need to compute A^t @ z_0 for each time step t.
# z_t = A @ z_{t-1} for the deterministic part with z_0 = initial_value
def propagate(z, _):
z_next = jnp.einsum("...ij,...j->...i", self.transition_matrix, z)
return z_next, z_next

_, means = scan(propagate, self.initial_value, jnp.arange(self.num_steps))
# means has shape (num_steps, ..., state_dim)
# We need to move num_steps to axis -2 to match base_dist.mean shape
return jnp.moveaxis(means, 0, -2)

@property
def variance(self) -> ArrayLike:
# Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
# Given z_t = z_0 + \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state
# vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k}
# E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have
# contributions for k = k' because innovations at different steps are
# independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @
# @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to
# evaluate this expression.
# @ covariance_matrix @ transpose(A^{t-k}). The initial value is deterministic,
# and we don't need to consider it here. Using `scan` is an easy way to evaluate
# this expression.
_, scale_tril = scan(
lambda carry, _: (self.transition_matrix @ carry, carry),
self.scale_tril,
Expand Down
35 changes: 32 additions & 3 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,33 @@ def get_sp_dist(jax_dist):
np.array([[0.8, 0.2], [-0.1, 1.1]]),
np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]),
),
T(
dist.GaussianStateSpace,
5,
np.array([[0.8, 0.1], [-0.1, 0.9]]),
None,
None,
np.array([[0.5, 0.0], [0.0, 0.5]]),
np.array([1.0, 2.0]),
),
T(
dist.GaussianStateSpace,
5,
np.array([[0.8, 0.1], [-0.1, 0.9]]),
None,
None,
np.array([[0.5, 0.0], [0.0, 0.5]]),
np.array([[1.0, 2.0], [0.5, 1.5], [-1.0, 0.0]]),
),
T(
dist.GaussianStateSpace,
4,
np.array([[0.9, 0.0], [0.0, 0.9]]),
None,
None,
np.array([[0.3, 0.0], [0.0, 0.3]]),
np.array([[[1.0, 0.0]], [[0.0, 1.0]]]),
),
pytest.param(
*T(
dist.GaussianCopulaBeta,
Expand Down Expand Up @@ -1328,7 +1355,7 @@ def gen_values_within_bounds(constraint, size, key=None):
elif constraint is constraints.ordered_vector:
x = jnp.cumsum(random.exponential(key, size), -1)
return x - random.normal(key, size[:-1] + (1,))
elif isinstance(constraint, constraints.independent):
elif isinstance(constraint, (constraints.independent, constraints.optional)):
return gen_values_within_bounds(constraint.base_constraint, size, key)
elif constraint is constraints.sphere:
x = random.normal(key, size)
Expand Down Expand Up @@ -1404,7 +1431,7 @@ def gen_values_outside_bounds(constraint, size, key=None):
elif constraint is constraints.ordered_vector:
x = jnp.cumsum(random.exponential(key, size), -1)
return x[..., ::-1]
elif isinstance(constraint, constraints.independent):
elif isinstance(constraint, (constraints.independent, constraints.optional)):
return gen_values_outside_bounds(constraint.base_constraint, size, key)
elif constraint is constraints.sphere:
x = random.normal(key, size)
Expand Down Expand Up @@ -3276,9 +3303,11 @@ def f(x):
# Test that parameters do not change after flattening.
expected_dist = f(0)
actual_dist = jax.jit(f)(0)
for name in expected_dist.arg_constraints:
for name, constraint in expected_dist.arg_constraints.items():
expected_arg = getattr(expected_dist, name)
actual_arg = getattr(actual_dist, name)
if actual_arg is None and isinstance(constraint, constraints.optional):
continue
assert actual_arg is not None, f"arg {name} is None"
if np.issubdtype(np.asarray(expected_arg).dtype, np.number):
assert_allclose(actual_arg, expected_arg, atol=1e-7)
Expand Down