-
Notifications
You must be signed in to change notification settings - Fork 273
Add initial value to Gaussian state space distribution (fixes #2098). #2104
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add initial value to Gaussian state space distribution (fixes #2098). #2104
Conversation
70ba63c to
624ee8b
Compare
|
I was playing around with this simple example and I like it 👍 import jax
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS
def model(data):
reversion_speed = numpyro.sample("k",dist.HalfNormal(1.0))
sigma = numpyro.sample("sigma",dist.HalfNormal(1.0))
x0 = numpyro.sample(
"x0",
dist.Normal(0.0, jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)),
obs=data[0]
)
numpyro.sample(
"x",
dist.GaussianStateSpace(
num_steps=jnp.shape(data)[-1]-1,
transition_matrix=jnp.array([[ jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1 ]]),
covariance_matrix=jnp.array([[ jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed)) ]]),
initial_value = jnp.stack([x0])
),
obs=jnp.stack([data[1:]], axis=-1)
)
mcmc = MCMC(
NUTS(
model=model
),
num_warmup=500,
num_samples=1_000
)
mcmc.run(
rng_key=jax.random.PRNGKey(2),
data=(
lambda reversion_speed, sigma, std_norm: jnp.concatenate([
jnp.array([jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0]]),
jax.lax.scan(
lambda y, x: (jnp.multiply(y, jnp.exp(- jnp.multiply(reversion_speed, 1.0)) - 1) + x * jnp.sqrt(jnp.multiply(jnp.divide(jnp.power(sigma, 2), 2), jnp.divide(1 - jnp.exp(- jnp.multiply(reversion_speed, 1.0)), reversion_speed))),) * 2,
init=jnp.sqrt(jnp.divide(jnp.power(sigma, 2), 2 * reversion_speed)) * std_norm[0],
xs=std_norm[1:]
)[1]
])
)(
0.1, 0.5, jax.random.normal(key=jax.random.PRNGKey(10), shape=(20,))
)
)
mcmc.print_summary() |
numpyro/distributions/continuous.py
Outdated
| # The mean of the base distribution is zero and it has the right shape. | ||
| return self.base_dist.mean | ||
| # If there's no initial value, the mean is zero (base distribution mean). | ||
| if self.initial_value is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
how about checking self._initial_value is None and setting the property self.initial_value to zero if self._initial_value is None
numpyro/distributions/continuous.py
Outdated
| "precision_matrix": constraints.positive_definite, | ||
| "scale_tril": constraints.lower_cholesky, | ||
| "transition_matrix": constraints.real_matrix, | ||
| "initial_value": constraints.optional(constraints.real_vector), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think optional is unnecessary: just use constraints.real_vector here and make initial_value a property like in the above comment.
624ee8b to
7fe8456
Compare
7fe8456 to
ef0952e
Compare
fehiepsi
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks @tillahoffmann!
This PR adds an
initial_valueargument to theGaussianStateSpacedistribution as suggested in #2098.As part of this change, I added an
optionalconstraint. I'm a bit torn on whether that's the right choice, and we could instead promote 0 to the right shape. However, that would make evaluating themeanof the distribution relatively inefficient when there is no initial value: We'd stillscanover the sequence even though we should really just return zeros (although maybejax.jitamortizes that?). Open to suggestions.