Skip to content

Conversation

@tillahoffmann
Copy link
Collaborator

@tillahoffmann tillahoffmann commented Dec 4, 2025

This PR adds an initial_value argument to the GaussianStateSpace distribution as suggested in #2098.

As part of this change, I added an optional constraint. I'm a bit torn on whether that's the right choice, and we could instead promote 0 to the right shape. However, that would make evaluating the mean of the distribution relatively inefficient when there is no initial value: We'd still scan over the sequence even though we should really just return zeros (although maybe jax.jit amortizes that?). Open to suggestions.

@tillahoffmann tillahoffmann added enhancement New feature or request question Further information is requested labels Dec 4, 2025
@tillahoffmann tillahoffmann force-pushed the init-gaussian-state-space branch 2 times, most recently from 70ba63c to 624ee8b Compare December 4, 2025 02:48
@juanitorduz juanitorduz requested a review from fehiepsi December 5, 2025 12:52
@javier-garcia-tilburg
Copy link

javier-garcia-tilburg commented Dec 6, 2025

I was playing around with this simple example and I like it 👍

import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS

def model(data):
    reversion_speed = numpyro.sample("k",dist.HalfNormal(1.0))
    sigma = numpyro.sample("sigma",dist.HalfNormal(1.0))
    x0 = numpyro.sample(
        "x0",
        dist.Normal(0.0, jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)),
        obs=data[0]
    )

    numpyro.sample(
        "x",
        dist.GaussianStateSpace(
            num_steps=jnp.shape(data)[-1]-1,
            transition_matrix=jnp.array([[ jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1 ]]),
            covariance_matrix=jnp.array([[ jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed)) ]]),
            initial_value = jnp.stack([x0])
        ),
        obs=jnp.stack([data[1:]], axis=-1)
    )

mcmc = MCMC(
    NUTS(
        model=model
    ), 
    num_warmup=500, 
    num_samples=1_000
)
mcmc.run(
    rng_key=jax.random.PRNGKey(2),
    data=(
        lambda reversion_speed, sigma, std_norm: jnp.concatenate([
            jnp.array([jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0]]),
            jax.lax.scan(
                lambda y, x: (jnp.multiply(y, jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1) + x * jnp.sqrt(jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed))),) * 2,
                init=jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0],
                xs=std_norm[1:]
            )[1]
        ])
    )(
        0.1, 0.5, jax.random.normal(key=jax.random.PRNGKey(10), shape=(20,))
    )
)
mcmc.print_summary()

# The mean of the base distribution is zero and it has the right shape.
return self.base_dist.mean
# If there's no initial value, the mean is zero (base distribution mean).
if self.initial_value is None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how about checking self._initial_value is None and setting the property self.initial_value to zero if self._initial_value is None

"precision_matrix": constraints.positive_definite,
"scale_tril": constraints.lower_cholesky,
"transition_matrix": constraints.real_matrix,
"initial_value": constraints.optional(constraints.real_vector),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think optional is unnecessary: just use constraints.real_vector here and make initial_value a property like in the above comment.

@tillahoffmann tillahoffmann force-pushed the init-gaussian-state-space branch from 624ee8b to 7fe8456 Compare January 20, 2026 15:47
@tillahoffmann tillahoffmann force-pushed the init-gaussian-state-space branch from 7fe8456 to ef0952e Compare January 20, 2026 15:50
@tillahoffmann tillahoffmann added awaiting review and removed question Further information is requested awaiting response labels Jan 20, 2026
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks @tillahoffmann!

@fehiepsi fehiepsi merged commit 931d67d into pyro-ppl:master Jan 20, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting review enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants