From 9db38b91e20beabea173021991c2a37560799099 Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Mon, 6 Oct 2025 21:41:45 +0200 Subject: [PATCH 1/6] implement left and right censored --- numpyro/distributions/__init__.py | 6 + numpyro/distributions/censored.py | 217 ++++++++++++++++++++++++++++++ 2 files changed, 223 insertions(+) create mode 100644 numpyro/distributions/censored.py diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 9e0e2910f..d60ac788a 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -1,6 +1,10 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 +from numpyro.distributions.censored import ( + LeftCensoredDistribution, + RightCensoredDistribution, +) from numpyro.distributions.conjugate import ( BetaBinomial, DirichletMultinomial, @@ -194,6 +198,8 @@ "RelaxedBernoulli", "RelaxedBernoulliLogits", "RightTruncatedDistribution", + "LeftCensoredDistribution", + "RightCensoredDistribution", "SineBivariateVonMises", "SineSkewed", "SoftLaplace", diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py new file mode 100644 index 000000000..abe29836b --- /dev/null +++ b/numpyro/distributions/censored.py @@ -0,0 +1,217 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + + +from typing import Optional + +import jax +from jax import lax +import jax.numpy as jnp +from jax.typing import ArrayLike + +from numpyro._typing import ConstraintT, DistributionT +from numpyro.distributions import constraints +from numpyro.distributions.distribution import Distribution +from numpyro.distributions.util import ( + promote_shapes, + validate_sample, +) + + +class LeftCensoredDistribution(Distribution): + r""" + Distribution wrapper for left-censored outcomes. + + This distribution augments a base distribution with left-censoring, + so that the likelihood contribution depends on the censoring indicator. + + Parameters + ---------- + base_dist : numpyro.distributions.Distribution + Parametric distribution for the *uncensored* values + (e.g., Exponential, Weibull, LogNormal, Normal, etc.). + This distribution must implement a `cdf` method. + censored : array-like of {0,1} + Censoring indicator per observation: + - 0 → value is observed exactly + - 1 → observation is left-censored at the reported value + (true value occurred *on or before* the reported value) + + Notes + ----- + - The `log_prob(value)` method expects `value` to be the observed upper bound + for each observation. The contribution to the log-likelihood is: + + log f(value) if censored == 0 + log F(value) if censored == 1 + + where f is the density and F the cumulative distribution function of `base_dist`. + + - This is commonly used in survival analysis, where event times are positive, + but the approach is more general and can be applied to any distribution + with a cumulative distribution function, regardless of support. + + - In R's **survival** package notation, this corresponds to + `Surv(time, event, type = "left")`. + + Example: + `Surv(time = c(2, 4, 6), event = c(0, 1, 0), type="left")` + means: + * subject 1 had an event exactly at t=2 + * subject 2 had an event before or at t=4 (left-censored) + * subject 3 had an event exactly at t=6 + + Examples + -------- + >>> base = dist.LogNormal(0., 1.) + >>> surv_dist = LeftCensoredDistribution(base, censored=jnp.array([0, 1, 1])) + >>> loglik = surv_dist.log_prob(jnp.array([2., 4., 6.])) + # loglik[0] uses density at 2 + # loglik[1] uses CDF at 4 + # loglik[2] uses CDF at 6 + """ + + arg_constraints = {"censored": constraints.boolean} + reparametrized_params = ["censored"] + pytree_data_fields = ("base_dist", "censored", "_support") + + def __init__( + self, + base_dist: DistributionT, + censored: ArrayLike = False, + *, + validate_args: Optional[bool] = None, + ): + # test if base_dist has an implemented cdf method + assert hasattr(base_dist, "cdf") + batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored)) + self.base_dist: DistributionT = jax.tree.map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist + ) + self.censored = jnp.array( + promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool + ) + self._support = base_dist.support + super().__init__(batch_shape, validate_args=validate_args) + + def sample( + self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () + ) -> ArrayLike: + return self.base_dist.sample(key, sample_shape) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self) -> ConstraintT: + return self._support + + @validate_sample + def log_prob(self, value: ArrayLike) -> ArrayLike: + minval = 1e-10 + # minval = jnp.finfo(value).tiny + + def logF(x): + # log(F(x)) with stability + return jnp.log(jnp.clip(self.base_dist.cdf(x), minval, 1.0)) + + return jnp.where( + self.censored, + logF(value), # left-censored observations: log F(t) + self.base_dist.log_prob(value), # observed values: log f(t) + ) + + +class RightCensoredDistribution(Distribution): + r""" + Distribution wrapper for right-censored outcomes. + + This distribution augments a base distribution with right-censoring, + so that the likelihood contribution depends on the censoring indicator. + + Parameters + ---------- + base_dist : numpyro.distributions.Distribution + Parametric distribution for the *uncensored* values + (e.g., Exponential, Weibull, LogNormal, Normal, etc.). + This distribution must implement a `cdf` method. + censored : array-like of {0,1} + Censoring indicator per observation: + - 0 → value is observed exactly + - 1 → observation is right-censored at the reported value + (true value occurred *on or after* the reported value) + + Notes + ----- + - The `log_prob(value)` method expects `value` to be the observed lower bound + for each observation. The contribution to the log-likelihood is: + + log f(value) if censored == 0 + log (1 - F(value)) if censored == 1 + + where f is the density and F the cumulative distribution function of `base_dist`. + + - This is commonly used in survival analysis, where event times are positive, + but the approach is more general and can be applied to any distribution + with a cumulative distribution function, regardless of support. + + - In R's **survival** package notation, this corresponds to + `Surv(time, event)` with `type = "right"`. + + Example: + `Surv(time = c(5, 8, 10), event = c(1, 0, 1))` + means: + * subject 1 had an event at t=5 + * subject 2 was censored at t=8 + * subject 3 had an event at t=10 + + Examples + -------- + >>> base = dist.Exponential(rate=0.1) + >>> surv_dist = RightCensoredDistribution(base, censored=jnp.array([0, 1, 0])) + >>> loglik = surv_dist.log_prob(jnp.array([5., 8., 10.])) + # loglik[0] uses density at 5 + # loglik[1] uses survival at 8 + # loglik[2] uses density at 10 + """ + + arg_constraints = {"censored": constraints.boolean} + reparametrized_params = ["censored"] + pytree_data_fields = ("base_dist", "censored", "_support") + + def __init__( + self, + base_dist: DistributionT, + censored: ArrayLike = False, + *, + validate_args: Optional[bool] = None, + ): + # test if base_dist has an implemented cdf method + assert hasattr(base_dist, "cdf") + batch_shape = lax.broadcast_shapes(base_dist.batch_shape, jnp.shape(censored)) + self.base_dist: DistributionT = jax.tree.map( + lambda p: promote_shapes(p, shape=batch_shape)[0], base_dist + ) + self.censored = jnp.array( + promote_shapes(censored, shape=batch_shape)[0], dtype=jnp.bool + ) + self._support = base_dist.support + super().__init__(batch_shape, validate_args=validate_args) + + def sample( + self, key: jax.dtypes.prng_key, sample_shape: tuple[int, ...] = () + ) -> ArrayLike: + return self.base_dist.sample(key, sample_shape) + + @constraints.dependent_property(is_discrete=False, event_dim=0) + def support(self) -> ConstraintT: + return self._support + + @validate_sample + def log_prob(self, value: ArrayLike) -> ArrayLike: + def logS(x): + # log(1 - F(x)) with stability + return jnp.log1p(-self.base_dist.cdf(x)) + + return jnp.where( + self.censored, + logS(value), # censored observations: log S(t) + self.base_dist.log_prob(value), # observed values: log f(t) + ) From e56c42493c4214afb9b372140f49265bc0cb7562 Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Mon, 6 Oct 2025 21:46:32 +0200 Subject: [PATCH 2/6] add tests --- test/test_distributions.py | 54 ++++++++++++++++++++++++++++++++++++++ test/test_gof.py | 4 +++ 2 files changed, 58 insertions(+) diff --git a/test/test_distributions.py b/test/test_distributions.py index fa5c31f78..4c2464a70 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -32,6 +32,10 @@ transforms, ) from numpyro.distributions.batch_util import vmap_over +from numpyro.distributions.censored import ( + LeftCensoredDistribution, + RightCensoredDistribution, +) from numpyro.distributions.discrete import _to_probs_bernoulli, _to_probs_multinom from numpyro.distributions.flows import InverseAutoregressiveTransform from numpyro.distributions.transforms import ( @@ -150,6 +154,27 @@ def _TruncatedCauchy(loc, scale, low, high): return dist.TruncatedCauchy(loc=loc, scale=scale, low=low, high=high) +def _LeftCensoredHalfNormal(scale, censored): + base_dist = dist.HalfNormal(scale) + return LeftCensoredDistribution(base_dist, censored) + + +def _RightCensoredWeibull(scale, concentration, censored): + base_dist = dist.Weibull(scale, concentration) + return RightCensoredDistribution(base_dist, censored) + + +def _LeftCensoredNormal(loc, scale, censored): + base_dist = dist.Normal(loc, scale) + return LeftCensoredDistribution(base_dist, censored) + + +def _RightCensoredNormal(loc, scale, censored): + base_dist = dist.Normal(loc, scale) + return RightCensoredDistribution(base_dist, censored) + + + _TruncatedNormal.arg_constraints = {} _TruncatedNormal.reparametrized_params = [] _TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) @@ -492,6 +517,14 @@ def get_sp_dist(jax_dist): T(dist.Cauchy, 0.0, 1.0), T(dist.Cauchy, 0.0, np.array([1.0, 2.0])), T(dist.Cauchy, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])), + T(_RightCensoredWeibull, 1.0, 1.0, 0.0), + T(_RightCensoredWeibull, 1.0, 1.0, 1.0), + T(_LeftCensoredHalfNormal, 1.0, 0.0), + T(_LeftCensoredHalfNormal, 1.0, 1.0), + T(_LeftCensoredNormal, 0.0, 1.0, 0.0), + T(_LeftCensoredNormal, 0.0, 1.0, 1.0), + T(_RightCensoredNormal, 0.0, 1.0, 0.0), + T(_RightCensoredNormal, 0.0, 1.0, 1.0), T(dist.CirculantNormal, np.zeros((3, 4)), np.array([0.9, 0.2, 0.1, 0.2]), None), T( dist.CirculantNormal, @@ -1944,6 +1977,15 @@ def test_mean_var(jax_dist, sp_dist, params): dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") + if jax_dist in ( + _LeftCensoredHalfNormal, + _RightCensoredWeibull, + _LeftCensoredNormal, + _RightCensoredNormal, + dist.LeftCensoredDistribution, + dist.RightCensoredDistribution, + ): + pytest.skip("Censored distributions do not have mean/var implemented") if jax_dist is dist.ProjectedNormal: pytest.skip("Mean is defined in submanifold") if jax_dist in [dist.LowerTruncatedPowerLaw, dist.DoublyTruncatedPowerLaw]: @@ -2106,6 +2148,10 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): if jax_dist in ( _TruncatedNormal, _TruncatedCauchy, + _LeftCensoredHalfNormal, + _RightCensoredWeibull, + _LeftCensoredNormal, + _RightCensoredNormal, _GaussianMixture, _Gaussian2DMixture, _GeneralMixture, @@ -3247,6 +3293,14 @@ def _get_vmappable_dist_init_params(jax_dist): return [2, 3] elif jax_dist.__name__ == ("_TruncatedNormal"): return [2, 3] + elif jax_dist.__name__ == ("_LeftCensoredHalfNormal"): + return [1] + elif jax_dist.__name__ == ("_RightCensoredWeibull"): + return [2] + elif jax_dist.__name__ == ("_LeftCensoredNormal"): + return [2] + elif jax_dist.__name__ == ("_RightCensoredNormal"): + return [2] elif issubclass(jax_dist, dist.Distribution): init_parameters = list(inspect.signature(jax_dist.__init__).parameters.keys())[ 1: diff --git a/test/test_gof.py b/test/test_gof.py index 3ad25d050..d6ab3bef4 100644 --- a/test/test_gof.py +++ b/test/test_gof.py @@ -29,6 +29,10 @@ def test_gof(jax_dist, sp_dist, params): pytest.skip( "skip gof test for MatrixNormal, likely incorrect submanifold scaling" ) + if "Censored" in jax_dist.__name__: + pytest.skip( + "skip gof test for censored distribution as log_prob for censored observations is cdf instead of density" + ) num_samples = 10000 if any( From 8e5192aebb8e25761ab11f97e12f4bf764c938ea Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Mon, 6 Oct 2025 22:06:05 +0200 Subject: [PATCH 3/6] linting --- test/test_distributions.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 4c2464a70..6f1ae5ead 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -174,7 +174,6 @@ def _RightCensoredNormal(loc, scale, censored): return RightCensoredDistribution(base_dist, censored) - _TruncatedNormal.arg_constraints = {} _TruncatedNormal.reparametrized_params = [] _TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) From 5d0108001eb2b982b5345204075e687d4105101f Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Tue, 7 Oct 2025 17:01:59 +0200 Subject: [PATCH 4/6] add censored specific tests and update clipping minval --- numpyro/distributions/censored.py | 7 +- test/test_distributions.py | 123 ++++++++++++++++++++++++++++++ 2 files changed, 127 insertions(+), 3 deletions(-) diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py index abe29836b..c31a10868 100644 --- a/numpyro/distributions/censored.py +++ b/numpyro/distributions/censored.py @@ -105,8 +105,7 @@ def support(self) -> ConstraintT: @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - minval = 1e-10 - # minval = jnp.finfo(value).tiny + minval = 1e-12 def logF(x): # log(F(x)) with stability @@ -206,9 +205,11 @@ def support(self) -> ConstraintT: @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: + minval = 1e-7 + def logS(x): # log(1 - F(x)) with stability - return jnp.log1p(-self.base_dist.cdf(x)) + return jnp.log1p(-jnp.clip(self.base_dist.cdf(x), 0.0, 1 - minval)) return jnp.where( self.censored, diff --git a/test/test_distributions.py b/test/test_distributions.py index eec85d020..339b44fcd 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -3898,3 +3898,126 @@ def test_truncated_cdf_batch_shapes(batch_shape): for value in [-2.0, 0.0, 2.0]: cdf_value = truncated_dist.cdf(value) assert cdf_value.shape == batch_shape + + +param_cens_dist = pytest.mark.parametrize( + "base_dist_class, base_params", + [ + (dist.Normal, (0.0, 1.0)), + (dist.Normal, (2.0, 0.5)), + (dist.Cauchy, (0.0, 1.0)), + (dist.Laplace, (0.0, 1.0)), + (dist.Logistic, (0.0, 1.0)), + (dist.StudentT, (2.0, 0.0, 1.0)), + (dist.HalfNormal, (1.0,)), + (dist.Poisson, (1.0,)), + (dist.GammaPoisson, (1, 1)), + ], +) +param_censored = pytest.mark.parametrize("censored", [0.0, 1.0]) + + +@param_cens_dist +@param_censored +def test_left_censored_logprob(base_dist_class, base_params, censored): + """Test log_prob for left censored distributions.""" + base_dist = base_dist_class(*base_params) + censored_dist = dist.LeftCensoredDistribution(base_dist, censored) + + # Test points + test_values = base_dist.support.feasible_like(jnp.zeros((1,))) + + # Compute log_prob + logp_values = censored_dist.log_prob(test_values) + + # Basic properties + assert logp_values.shape == test_values.shape + assert jnp.all(jnp.isfinite(logp_values)) + + # for noncensored values, log_prob should match base distribution + base_logp_values = base_dist.log_prob(test_values) + # for censored values, log_prob should be log CDF of base distribution + minval = 1e-10 + cdf_values = jnp.log(jnp.clip(base_dist.cdf(jnp.array(test_values)), minval, 1.0)) + base_diff = jnp.where( + censored, logp_values - cdf_values, logp_values - base_logp_values + ) + assert jnp.abs(base_diff).max() < 1e-6 + + +@param_cens_dist +@param_censored +def test_right_censored_logprob(base_dist_class, base_params, censored): + """Test log_prob for right censored distributions.""" + base_dist = base_dist_class(*base_params) + censored_dist = dist.RightCensoredDistribution(base_dist, censored) + + # Test points + test_values = base_dist.support.feasible_like(jnp.zeros((1,))) + + # Compute log_prob + logp_values = censored_dist.log_prob(test_values) + + # Basic properties + assert logp_values.shape == test_values.shape + assert jnp.all(jnp.isfinite(logp_values)) + + # for noncensored values, log_prob should match base distribution + base_logp_values = base_dist.log_prob(test_values) + # for censored values, log_prob should be log 1 - CDF of base distribution + logS_values = jnp.log1p(-base_dist.cdf(test_values)) + base_diff = jnp.where( + censored, logp_values - logS_values, logp_values - base_logp_values + ) + assert jnp.abs(base_diff).max() < 1e-6 + + +def test_censored_logprob_edge_cases(): + """Test edge cases for censored distributions.""" + base_dist = dist.Normal(0.0, 1.0) + + # Test with extreme censored points + left_censored = dist.LeftCensoredDistribution(base_dist, 1) + right_censored = dist.RightCensoredDistribution(base_dist, 1) + + # Test that logprobs are well-behaved on extreme values + test_values = jnp.array([-10.0, 0.0, 10.0]) + + left_logprob = left_censored.log_prob(test_values) + assert jnp.all(jnp.isfinite(left_logprob)) + + right_logprob = right_censored.log_prob(test_values) + assert jnp.all(jnp.isfinite(right_logprob)) + + +@pytest.mark.parametrize("batch_shape", [(), (3,)]) +def test_censored_logprob_batch_shapes(batch_shape): + """Test that log_prob works correctly with batch shapes.""" + if batch_shape == (): + loc = 0.0 + scale = 1.0 + censored = 1.0 + else: + loc = jnp.zeros(batch_shape) + scale = jnp.ones(batch_shape) + censored = jnp.ones(batch_shape) + + base_dist = dist.Normal(loc, scale) + censored_dist = dist.RightCensoredDistribution(base_dist, censored) + + # Test with single value + value = 0.0 + logp_value = censored_dist.log_prob(value) + assert logp_value.shape == batch_shape + + # Test with multiple values - these should broadcast properly + if batch_shape == (): + values = jnp.array([-2.0, 0.0, 2.0]) + logp_values = censored_dist.log_prob(values) + expected_shape = values.shape + assert logp_values.shape == expected_shape + else: + # For batched case, test with single values to avoid broadcasting issues + for value in [-2.0, 0.0, 2.0]: + logp_value = censored_dist.log_prob(value) + assert logp_value.shape == batch_shape From 1c0dc3d8c906f9c21912cfe229e9984295ae8044 Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Wed, 8 Oct 2025 16:18:24 +0200 Subject: [PATCH 5/6] update tests --- test/test_distributions.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 339b44fcd..5ddbe77e1 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -174,6 +174,16 @@ def _RightCensoredNormal(loc, scale, censored): return RightCensoredDistribution(base_dist, censored) +def _LeftCensoredPoisson(rate, censored): + base_dist = dist.Poisson(rate) + return LeftCensoredDistribution(base_dist, censored) + + +def _RightCensoredPoisson(rate, censored): + base_dist = dist.Poisson(rate) + return RightCensoredDistribution(base_dist, censored) + + _TruncatedNormal.arg_constraints = {} _TruncatedNormal.reparametrized_params = [] _TruncatedNormal.infer_shapes = lambda *args: (lax.broadcast_shapes(*args), ()) @@ -516,14 +526,10 @@ def get_sp_dist(jax_dist): T(dist.Cauchy, 0.0, 1.0), T(dist.Cauchy, 0.0, np.array([1.0, 2.0])), T(dist.Cauchy, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])), - T(_RightCensoredWeibull, 1.0, 1.0, 0.0), - T(_RightCensoredWeibull, 1.0, 1.0, 1.0), - T(_LeftCensoredHalfNormal, 1.0, 0.0), - T(_LeftCensoredHalfNormal, 1.0, 1.0), - T(_LeftCensoredNormal, 0.0, 1.0, 0.0), - T(_LeftCensoredNormal, 0.0, 1.0, 1.0), - T(_RightCensoredNormal, 0.0, 1.0, 0.0), - T(_RightCensoredNormal, 0.0, 1.0, 1.0), + T(_RightCensoredWeibull, 1.0, 1.0, np.array([0, 1])), + T(_LeftCensoredHalfNormal, 1.0, np.array([0, 1])), + T(_LeftCensoredNormal, 0.0, 1.0, np.array([0, 1])), + T(_RightCensoredNormal, 0.0, 1.0, np.array([0, 1])), T(dist.CirculantNormal, np.zeros((3, 4)), np.array([0.9, 0.2, 0.1, 0.2]), None), T( dist.CirculantNormal, @@ -1082,6 +1088,8 @@ def get_sp_dist(jax_dist): T(dist.GeometricProbs, 0.2), T(dist.GeometricProbs, np.array([0.2, 0.7])), T(dist.GeometricLogits, np.array([-1.0, 3.0])), + T(_LeftCensoredPoisson, 1.0, np.array([0, 1])), + T(_RightCensoredPoisson, 1.0, np.array([0, 1])), T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), 10), T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), np.array([5, 8])), T(dist.MultinomialLogits, np.array([-1.0, 3.0]), np.array([[5], [8]])), From a8aa8d839a309c3d3c077a012aa163d0ccdcff55 Mon Sep 17 00:00:00 2001 From: vanamsterdam Date: Fri, 10 Oct 2025 12:52:08 +0200 Subject: [PATCH 6/6] update tests --- numpyro/distributions/censored.py | 2 -- test/test_distributions.py | 26 ++++++++++++++++++++------ 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/numpyro/distributions/censored.py b/numpyro/distributions/censored.py index c31a10868..f726a88e2 100644 --- a/numpyro/distributions/censored.py +++ b/numpyro/distributions/censored.py @@ -72,7 +72,6 @@ class LeftCensoredDistribution(Distribution): """ arg_constraints = {"censored": constraints.boolean} - reparametrized_params = ["censored"] pytree_data_fields = ("base_dist", "censored", "_support") def __init__( @@ -172,7 +171,6 @@ class RightCensoredDistribution(Distribution): """ arg_constraints = {"censored": constraints.boolean} - reparametrized_params = ["censored"] pytree_data_fields = ("base_dist", "censored", "_support") def __init__( diff --git a/test/test_distributions.py b/test/test_distributions.py index 5ddbe77e1..2984bee9b 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -526,10 +526,14 @@ def get_sp_dist(jax_dist): T(dist.Cauchy, 0.0, 1.0), T(dist.Cauchy, 0.0, np.array([1.0, 2.0])), T(dist.Cauchy, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])), - T(_RightCensoredWeibull, 1.0, 1.0, np.array([0, 1])), - T(_LeftCensoredHalfNormal, 1.0, np.array([0, 1])), - T(_LeftCensoredNormal, 0.0, 1.0, np.array([0, 1])), - T(_RightCensoredNormal, 0.0, 1.0, np.array([0, 1])), + T(_RightCensoredWeibull, 1.0, 1.0, 0), + T(_RightCensoredWeibull, 1.0, 1.0, 1), + T(_LeftCensoredHalfNormal, 1.0, 0), + T(_LeftCensoredHalfNormal, 1.0, 1), + T(_LeftCensoredNormal, 0.0, 1.0, 0), + T(_LeftCensoredNormal, 0.0, 1.0, 1), + T(_RightCensoredNormal, 0.0, 1.0, 0), + T(_RightCensoredNormal, 0.0, 1.0, 1), T(dist.CirculantNormal, np.zeros((3, 4)), np.array([0.9, 0.2, 0.1, 0.2]), None), T( dist.CirculantNormal, @@ -1088,8 +1092,10 @@ def get_sp_dist(jax_dist): T(dist.GeometricProbs, 0.2), T(dist.GeometricProbs, np.array([0.2, 0.7])), T(dist.GeometricLogits, np.array([-1.0, 3.0])), - T(_LeftCensoredPoisson, 1.0, np.array([0, 1])), - T(_RightCensoredPoisson, 1.0, np.array([0, 1])), + T(_LeftCensoredPoisson, 1.0, 0), + T(_LeftCensoredPoisson, 1.0, 1), + T(_RightCensoredPoisson, 1.0, 0), + T(_RightCensoredPoisson, 1.0, 1), T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), 10), T(dist.MultinomialProbs, np.array([0.2, 0.7, 0.1]), np.array([5, 8])), T(dist.MultinomialLogits, np.array([-1.0, 3.0]), np.array([[5], [8]])), @@ -2000,6 +2006,8 @@ def test_mean_var(jax_dist, sp_dist, params): _RightCensoredWeibull, _LeftCensoredNormal, _RightCensoredNormal, + _LeftCensoredPoisson, + _RightCensoredPoisson, dist.LeftCensoredDistribution, dist.RightCensoredDistribution, ): @@ -2170,6 +2178,8 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): _RightCensoredWeibull, _LeftCensoredNormal, _RightCensoredNormal, + _LeftCensoredPoisson, + _RightCensoredPoisson, _GaussianMixture, _Gaussian2DMixture, _GeneralMixture, @@ -3319,6 +3329,10 @@ def _get_vmappable_dist_init_params(jax_dist): return [2] elif jax_dist.__name__ == ("_RightCensoredNormal"): return [2] + elif jax_dist.__name__ == ("_LeftCensoredPoisson"): + return [1] + elif jax_dist.__name__ == ("_RightCensoredPoisson"): + return [1] elif issubclass(jax_dist, dist.Distribution): init_parameters = list(inspect.signature(jax_dist.__init__).parameters.keys())[ 1: