Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numpyro/distributions/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

from numpyro.distributions.censored import (
LeftCensoredDistribution,
RightCensoredDistribution,
)
from numpyro.distributions.conjugate import (
BetaBinomial,
DirichletMultinomial,
Expand Down Expand Up @@ -194,6 +198,8 @@
"RelaxedBernoulli",
"RelaxedBernoulliLogits",
"RightTruncatedDistribution",
"LeftCensoredDistribution",
"RightCensoredDistribution",
"SineBivariateVonMises",
"SineSkewed",
"SoftLaplace",
Expand Down
216 changes: 216 additions & 0 deletions numpyro/distributions/censored.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0


from typing import Optional

import jax
from jax import lax
import jax.numpy as jnp
from jax.typing import ArrayLike

from numpyro._typing import ConstraintT, DistributionT
from numpyro.distributions import constraints
from numpyro.distributions.distribution import Distribution
from numpyro.distributions.util import (
promote_shapes,
validate_sample,
)


class LeftCensoredDistribution(Distribution):
r"""
Distribution wrapper for left-censored outcomes.

This distribution augments a base distribution with left-censoring,
so that the likelihood contribution depends on the censoring indicator.

Parameters
----------
base_dist : numpyro.distributions.Distribution
Parametric distribution for the *uncensored* values
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
This distribution must implement a `cdf` method.
censored : array-like of {0,1}
Censoring indicator per observation:
- 0 → value is observed exactly
- 1 → observation is left-censored at the reported value
(true value occurred *on or before* the reported value)

Notes
-----
- The `log_prob(value)` method expects `value` to be the observed upper bound
for each observation. The contribution to the log-likelihood is:

log f(value) if censored == 0
log F(value) if censored == 1

where f is the density and F the cumulative distribution function of `base_dist`.

- This is commonly used in survival analysis, where event times are positive,
but the approach is more general and can be applied to any distribution
with a cumulative distribution function, regardless of support.

- In R's **survival** package notation, this corresponds to
`Surv(time, event, type = "left")`.

Example:
`Surv(time = c(2, 4, 6), event = c(0, 1, 0), type="left")`
means:
* subject 1 had an event exactly at t=2
* subject 2 had an event before or at t=4 (left-censored)
* subject 3 had an event exactly at t=6

Examples
--------
>>> base = dist.LogNormal(0., 1.)
>>> surv_dist = LeftCensoredDistribution(base, censored=jnp.array([0, 1, 1]))
>>> loglik = surv_dist.log_prob(jnp.array([2., 4., 6.]))
# loglik[0] uses density at 2
# loglik[1] uses CDF at 4
# loglik[2] uses CDF at 6
"""

arg_constraints = {"censored": constraints.boolean}
pytree_data_fields = ("base_dist", "censored", "_support")

def __init__(
self,
base_dist: DistributionT,
censored: ArrayLike = False,
*,
validate_args: Optional[bool] = None,
):
# test if base_dist has an implemented cdf method
assert hasattr(base_dist, "cdf")
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored))
self.base_dist: DistributionT = jax.tree.map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
self.censored = jnp.array(
promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool
)
self._support = base_dist.support
super().__init__(batch_shape, validate_args=validate_args)

def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
return self.base_dist.sample(key, sample_shape)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
return self._support

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
minval = 1e-12

def logF(x):
# log(F(x)) with stability
return jnp.log(jnp.clip(self.base_dist.cdf(x), minval, 1.0))

return jnp.where(
self.censored,
logF(value), # left-censored observations: log F(t)
self.base_dist.log_prob(value), # observed values: log f(t)
)


class RightCensoredDistribution(Distribution):
r"""
Distribution wrapper for right-censored outcomes.

This distribution augments a base distribution with right-censoring,
so that the likelihood contribution depends on the censoring indicator.

Parameters
----------
base_dist : numpyro.distributions.Distribution
Parametric distribution for the *uncensored* values
(e.g., Exponential, Weibull, LogNormal, Normal, etc.).
This distribution must implement a `cdf` method.
censored : array-like of {0,1}
Censoring indicator per observation:
- 0 → value is observed exactly
- 1 → observation is right-censored at the reported value
(true value occurred *on or after* the reported value)

Notes
-----
- The `log_prob(value)` method expects `value` to be the observed lower bound
for each observation. The contribution to the log-likelihood is:

log f(value) if censored == 0
log (1 - F(value)) if censored == 1

where f is the density and F the cumulative distribution function of `base_dist`.

- This is commonly used in survival analysis, where event times are positive,
but the approach is more general and can be applied to any distribution
with a cumulative distribution function, regardless of support.

- In R's **survival** package notation, this corresponds to
`Surv(time, event)` with `type = "right"`.

Example:
`Surv(time = c(5, 8, 10), event = c(1, 0, 1))`
means:
* subject 1 had an event at t=5
* subject 2 was censored at t=8
* subject 3 had an event at t=10

Examples
--------
>>> base = dist.Exponential(rate=0.1)
>>> surv_dist = RightCensoredDistribution(base, censored=jnp.array([0, 1, 0]))
>>> loglik = surv_dist.log_prob(jnp.array([5., 8., 10.]))
# loglik[0] uses density at 5
# loglik[1] uses survival at 8
# loglik[2] uses density at 10
"""

arg_constraints = {"censored": constraints.boolean}
pytree_data_fields = ("base_dist", "censored", "_support")

def __init__(
self,
base_dist: DistributionT,
censored: ArrayLike = False,
*,
validate_args: Optional[bool] = None,
):
# test if base_dist has an implemented cdf method
assert hasattr(base_dist, "cdf")
batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored))
self.base_dist: DistributionT = jax.tree.map(
lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist
)
self.censored = jnp.array(
promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool
)
self._support = base_dist.support
super().__init__(batch_shape, validate_args=validate_args)

def sample(
self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = ()
) -> ArrayLike:
return self.base_dist.sample(key, sample_shape)

@constraints.dependent_property(is_discrete=False, event_dim=0)
def support(self) -> ConstraintT:
return self._support

@validate_sample
def log_prob(self, value: ArrayLike) -> ArrayLike:
minval = 1e-7
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this minval was needed to not get -inf in the subsequent log(1-cdf(x)) calculation for observations with very high tail censored observations; this is what I'm most unsure about


def logS(x):
# log(1 - F(x)) with stability
return jnp.log1p(-jnp.clip(self.base_dist.cdf(x), 0.0, 1 - minval))

return jnp.where(
self.censored,
logS(value), # censored observations: log S(t)
self.base_dist.log_prob(value), # observed values: log f(t)
)
Loading