diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 417b5f260..edcbce939 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -45,6 +45,8 @@ "multinomial", "nonnegative", "nonnegative_integer", + "optional", + "ordered_vector", "positive", "positive_definite", "positive_definite_circulant_vector", @@ -398,6 +400,43 @@ def __eq__(self, other: ConstraintT) -> bool: ) +class _OptionalConstraint(Constraint): + def __init__(self, base_constraint: Constraint) -> None: + assert isinstance(base_constraint, Constraint) + self.base_constraint = base_constraint + super().__init__() + + @property + def is_discrete(self) -> bool: + return self.base_constraint.is_discrete + + @property + def event_dim(self) -> int: + return self.base_constraint.event_dim + + def __call__(self, value: Optional[ArrayLike]) -> ArrayLike: + if value is None: + return True + return self.base_constraint(value) + + def __repr__(self) -> str: + return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)})" + + def feasible_like(self, prototype: ArrayLike) -> ArrayLike: + return self.base_constraint.feasible_like(prototype) + + def tree_flatten(self): + return (self.base_constraint,), ( + ("base_constraint",), + {}, + ) + + def __eq__(self, other: ConstraintT) -> bool: + if not isinstance(other, _OptionalConstraint): + return False + return self.base_constraint == other.base_constraint + + class _RealVector(_IndependentConstraint, _SingletonConstraint): def __init__(self) -> None: super().__init__(_Real(), 1) @@ -819,6 +858,7 @@ def tree_flatten(self): multinomial = _Multinomial nonnegative: ConstraintT = _Nonnegative() nonnegative_integer: ConstraintT = _IntegerNonnegative() +optional = _OptionalConstraint ordered_vector: ConstraintT = _OrderedVector() positive: ConstraintT = _Positive() positive_definite: ConstraintT = _PositiveDefinite() diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 254a8e60d..906647ee2 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -648,10 +648,11 @@ class GaussianStateSpace(TransformedDistribution): .. math:: \mathbf{z}_{t} &= \mathbf{A} \mathbf{z}_{t - 1} + \boldsymbol{\epsilon}_t\\ - &=\sum_{k=1} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_t, + &= \mathbf{A}^t \mathbf{z}_0 + \sum_{k=1}^{t} \mathbf{A}^{t-k} \boldsymbol{\epsilon}_k, where :math:`\mathbf{z}_t` is the state vector at step :math:`t`, :math:`\mathbf{A}` - is the transition matrix, and :math:`\boldsymbol\epsilon` is the innovation noise. + is the transition matrix, :math:`\mathbf{z}_0` is the initial value, and + :math:`\boldsymbol\epsilon` is the innovation noise. :param num_steps: Number of steps. @@ -662,6 +663,8 @@ class GaussianStateSpace(TransformedDistribution): :math:`\boldsymbol\epsilon`. :param scale_tril: Scale matrix of the innovation noise :math:`\boldsymbol\epsilon`. + :param initial_value: Initial state vector :math:`\mathbf{z}_0`. If ``None``, + defaults to zero. """ arg_constraints = { @@ -669,6 +672,7 @@ class GaussianStateSpace(TransformedDistribution): "precision_matrix": constraints.positive_definite, "scale_tril": constraints.lower_cholesky, "transition_matrix": constraints.real_matrix, + "initial_value": constraints.optional(constraints.real_vector), } support = constraints.real_matrix pytree_aux_fields = ("num_steps",) @@ -680,6 +684,7 @@ def __init__( covariance_matrix: Optional[Array] = None, precision_matrix: Optional[Array] = None, scale_tril: Optional[Array] = None, + initial_value: Optional[Array] = None, *, validate_args: Optional[bool] = None, ) -> None: @@ -691,6 +696,7 @@ def __init__( "`transition_matrix` argument should be a square matrix" ) self.transition_matrix = transition_matrix + self.initial_value = initial_value # Expand the covariance/precision/scale matrices to the right number of steps. args = { "covariance_matrix": covariance_matrix, @@ -705,23 +711,45 @@ def __init__( base_distribution = MultivariateNormal(**args) self.scale_tril = base_distribution.scale_tril[..., 0, :, :] base_distribution = base_distribution.to_event(1) - transform = RecursiveLinearTransform(transition_matrix) + + # The base distribution must have at least the same batch shape as the initial + # value. + if initial_value is not None: + batch_shape = initial_value.shape[:-1] + base_distribution = base_distribution.expand(batch_shape) + + transform = RecursiveLinearTransform( + transition_matrix, initial_value=initial_value + ) super().__init__(base_distribution, transform, validate_args=validate_args) @property def mean(self) -> ArrayLike: - # The mean of the base distribution is zero and it has the right shape. - return self.base_dist.mean + # If there's no initial value, the mean is zero (base distribution mean). + if self.initial_value is None: + return self.base_dist.mean + + # Otherwise, we need to compute A^t @ z_0 for each time step t. + # z_t = A @ z_{t-1} for the deterministic part with z_0 = initial_value + def propagate(z, _): + z_next = jnp.einsum("...ij,...j->...i", self.transition_matrix, z) + return z_next, z_next + + _, means = scan(propagate, self.initial_value, jnp.arange(self.num_steps)) + # means has shape (num_steps, ..., state_dim) + # We need to move num_steps to axis -2 to match base_dist.mean shape + return jnp.moveaxis(means, 0, -2) @property def variance(self) -> ArrayLike: - # Given z_t = \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state + # Given z_t = z_0 + \sum_{k=1}^t A^{t-k} \epsilon_t, the covariance of the state # vector at step t is E[z_t transpose(z_t)] = \sum_{k,k'}^t A^{t-k} # E[\epsilon_k transpose(\epsilon_{k'})] transpose(A^{t-k'}). We only have # contributions for k = k' because innovations at different steps are # independent such that E[z_t transpose(z_t)] = \sum_k^t A^{t-k} @ - # @ covariance_matrix @ transpose(A^{t-k}). Using `scan` is an easy way to - # evaluate this expression. + # @ covariance_matrix @ transpose(A^{t-k}). The initial value is deterministic, + # and we don't need to consider it here. Using `scan` is an easy way to evaluate + # this expression. _, scale_tril = scan( lambda carry, _: (self.transition_matrix @ carry, carry), self.scale_tril, diff --git a/test/test_distributions.py b/test/test_distributions.py index 4c0daea3d..d4ad28946 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -631,6 +631,33 @@ def get_sp_dist(jax_dist): np.array([[0.8, 0.2], [-0.1, 1.1]]), np.array([0.1, 0.3, 0.25])[:, None, None] * np.array([[0.8, 0.2], [0.2, 0.7]]), ), + T( + dist.GaussianStateSpace, + 5, + np.array([[0.8, 0.1], [-0.1, 0.9]]), + None, + None, + np.array([[0.5, 0.0], [0.0, 0.5]]), + np.array([1.0, 2.0]), + ), + T( + dist.GaussianStateSpace, + 5, + np.array([[0.8, 0.1], [-0.1, 0.9]]), + None, + None, + np.array([[0.5, 0.0], [0.0, 0.5]]), + np.array([[1.0, 2.0], [0.5, 1.5], [-1.0, 0.0]]), + ), + T( + dist.GaussianStateSpace, + 4, + np.array([[0.9, 0.0], [0.0, 0.9]]), + None, + None, + np.array([[0.3, 0.0], [0.0, 0.3]]), + np.array([[[1.0, 0.0]], [[0.0, 1.0]]]), + ), pytest.param( *T( dist.GaussianCopulaBeta, @@ -1328,7 +1355,7 @@ def gen_values_within_bounds(constraint, size, key=None): elif constraint is constraints.ordered_vector: x = jnp.cumsum(random.exponential(key, size), -1) return x - random.normal(key, size[:-1] + (1,)) - elif isinstance(constraint, constraints.independent): + elif isinstance(constraint, (constraints.independent, constraints.optional)): return gen_values_within_bounds(constraint.base_constraint, size, key) elif constraint is constraints.sphere: x = random.normal(key, size) @@ -1404,7 +1431,7 @@ def gen_values_outside_bounds(constraint, size, key=None): elif constraint is constraints.ordered_vector: x = jnp.cumsum(random.exponential(key, size), -1) return x[..., ::-1] - elif isinstance(constraint, constraints.independent): + elif isinstance(constraint, (constraints.independent, constraints.optional)): return gen_values_outside_bounds(constraint.base_constraint, size, key) elif constraint is constraints.sphere: x = random.normal(key, size) @@ -3276,9 +3303,11 @@ def f(x): # Test that parameters do not change after flattening. expected_dist = f(0) actual_dist = jax.jit(f)(0) - for name in expected_dist.arg_constraints: + for name, constraint in expected_dist.arg_constraints.items(): expected_arg = getattr(expected_dist, name) actual_arg = getattr(actual_dist, name) + if actual_arg is None and isinstance(constraint, constraints.optional): + continue assert actual_arg is not None, f"arg {name} is None" if np.issubdtype(np.asarray(expected_arg).dtype, np.number): assert_allclose(actual_arg, expected_arg, atol=1e-7)