From 9baef2cd3e0a535ba2805a1e88d1deccd5211977 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Wed, 10 Apr 2024 21:18:58 -0400 Subject: [PATCH 01/11] Add Wishart distribution. --- docs/source/distributions.rst | 8 ++ numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 130 ++++++++++++++++++++++++++++ test/test_distributions.py | 54 +++++++++++- 4 files changed, 192 insertions(+), 2 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index 06b51c929..b05ec2d6f 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -380,6 +380,14 @@ Weibull :show-inheritance: :member-order: bysource +Wishart +^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.Wishart + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + ZeroSumNormal ^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.ZeroSumNormal diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index d05376573..5273b025f 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -47,6 +47,7 @@ StudentT, Uniform, Weibull, + Wishart, ZeroSumNormal, ) from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta @@ -194,6 +195,7 @@ "Unit", "VonMises", "Weibull", + "Wishart", "ZeroInflatedDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 273fa4a43..9c4d1dc12 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2615,3 +2615,133 @@ def variance(self): theoretical_var *= 1 - 1 / self.event_shape[axis] return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape) + + +class Wishart(Distribution): + """ + Wishart distribution for covariance matrices. + + :param concentration: Positive concentration parameter analogous to the + concentration of a :class:`Gamma` distribution. The concentration must be larger + than the dimensionality of the scale matrix. + :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` + distribution. + :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` + distribution. + :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. + """ + + arg_constraints = { + "concentration": constraints.dependent(is_discrete=False), + "scale_matrix": constraints.positive_definite, + "rate_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.positive_definite + reparametrized_params = [ + "scale_matrix", + "rate_matrix", + "scale_tril", + ] + + def __init__( + self, + concentration=None, + scale_matrix=None, + rate_matrix=None, + scale_tril=None, + validate_args=None, + ): + # Determine the shapes. + batch_shape = None + event_shape = None + for x in [scale_matrix, rate_matrix, scale_tril]: + if x is not None: + batch_shape = jnp.broadcast_shapes( + jnp.shape(concentration), jnp.shape(x)[:-2] + ) + event_shape = jnp.shape(x)[-2:] + break + if event_shape is None: + raise ValueError( + "One of `scale_matrix`, `rate_matrix`, or `scale_tril` must be " + "specified." + ) + + # Coerce to scale_tril parameter. + if scale_matrix is not None: + self.scale_matrix = jnp.broadcast_to( + scale_matrix, batch_shape + event_shape + ) + self.scale_tril = jnp.linalg.cholesky(scale_matrix) + elif rate_matrix is not None: + self.rate_matrix = jnp.broadcast_to(rate_matrix, batch_shape + event_shape) + self.scale_tril = cholesky_of_inverse(rate_matrix) + elif scale_tril is not None: + self.scale_tril = scale_tril + + self.concentration = jnp.broadcast_to(concentration, batch_shape) + self.scale_tril = jnp.broadcast_to(self.scale_tril, batch_shape + event_shape) + super().__init__( + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args, + ) + + @validate_sample + def log_prob(self, value): + p = value.shape[-1] + rate_matrix, value = jnp.broadcast_arrays(self.rate_matrix, value) + trace = jnp.trace(lax.batch_matmul(rate_matrix, value), axis1=-1, axis2=-2) + return ( + (self.concentration - p - 1) / 2 * jnp.linalg.slogdet(value).logabsdet + - trace / 2 + - self.concentration * p / 2 * jnp.log(2) + - multigammaln(self.concentration / 2, p) + + self.concentration * jnp.linalg.slogdet(rate_matrix).logabsdet / 2 + ) + + @lazy_property + def mean(self): + return self.concentration[..., None, None] * jnp.linalg.inv(self.rate_matrix) + + @lazy_property + def variance(self): + diag = jnp.diagonal(self.scale_matrix, axis1=-1, axis2=-2) + return self.concentration[..., None, None] * ( + self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] + ) + + @lazy_property + def scale_matrix(self): + return jnp.matmul(self.scale_tril, self.scale_tril.mT) + + @lazy_property + def rate_matrix(self): + identity = jnp.broadcast_to( + jnp.eye(self.scale_tril.shape[-1]), self.scale_tril.shape + ) + return cho_solve((self.scale_tril, True), identity) + + def sample(self, key, sample_shape=()): + assert is_prng_key(key) + # Sample using the Bartlett decomposition + # (https://en.wikipedia.org/wiki/Wishart_distribution#Bartlett_decomposition). + rng_diag, rng_offdiag = random.split(key) + latent = jnp.zeros(sample_shape + self.batch_shape + self.event_shape) + p = self.event_shape[-1] + i = jnp.arange(p) + latent = latent.at[..., i, i].set( + jnp.sqrt( + random.chisquare( + rng_diag, self.concentration[..., None] - i, latent.shape[:-1] + ) + ) + ) + i, j = jnp.tril_indices(p, -1) + assert i.size == p * (p - 1) // 2 + latent = latent.at[..., i, j].set( + random.normal(rng_offdiag, latent.shape[:-2] + (i.size,)) + ) + factor = lax.batch_matmul(*jnp.broadcast_arrays(self.scale_tril, latent)) + return lax.batch_matmul(factor, factor.mT) diff --git a/test/test_distributions.py b/test/test_distributions.py index 43360b74b..ee961f87a 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -131,6 +131,13 @@ def _truncnorm_to_scipy(loc, scale, low, high): return osp.truncnorm(a, b, loc=loc, scale=scale) +def _wishart_to_scipy(conc, scale, rate, tril): + jax_dist = dist.Wishart(conc, scale, rate, tril) + if not np.isscalar(jax_dist.concentration): + pytest.skip("scipy Wishart only supports a single scalar concentration") + return osp.wishart(jax_dist.concentration, jax_dist.scale_matrix) + + def _TruncatedNormal(loc, scale, low, high): return dist.TruncatedNormal(loc=loc, scale=scale, low=low, high=high) @@ -444,6 +451,7 @@ def __init__( c=conc, scale=scale, ), + dist.Wishart: _wishart_to_scipy, _TruncatedNormal: _truncnorm_to_scipy, } @@ -775,6 +783,42 @@ def get_sp_dist(jax_dist): T(dist.Weibull, 0.2, 1.1), T(dist.Weibull, 2.8, np.array([2.0, 2.0])), T(dist.Weibull, 1.8, np.array([[1.0, 1.0], [2.0, 2.0]])), + T(dist.Wishart, 3, 2 * np.eye(2) + 0.1, None, None), + T( + dist.Wishart, + 3.0, + None, + np.array([[1.0, 0.5], [0.5, 1.0]]), + None, + ), + T( + dist.Wishart, + np.array([4.0, 5.0]), + None, + np.array([[[1.0, 0.5], [0.5, 1.0]]]), + None, + ), + T( + dist.Wishart, + np.array([3.0]), + None, + None, + np.array([[1.0, 0.0], [0.5, 1.0]]), + ), + T( + dist.Wishart, + np.arange(3, 9, dtype=np.float32).reshape((3, 2)), + None, + None, + np.array([[1.0, 0.0], [0.0, 1.0]]), + ), + T( + dist.Wishart, + 9.0, + None, + np.broadcast_to(np.identity(3), (2, 3, 3)), + None, + ), T(dist.ZeroSumNormal, 1.0, (5,)), T(dist.ZeroSumNormal, np.array([2.0]), (5,)), T(dist.ZeroSumNormal, 1.0, (4, 5)), @@ -1120,7 +1164,13 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): and not isinstance(jax_dist, dist.MultivariateStudentT) ): sp_dist = sp_dist(*params) - sp_samples = sp_dist.rvs(size=prepend_shape + jax_dist.batch_shape) + size = prepend_shape + jax_dist.batch_shape + # The scipy implementation of the Wishart distribution cannot handle an empty + # tuple as the sample size so we replace it by `1` which generates a single + # sample without any sample shape. + if isinstance(jax_dist, dist.Wishart): + size = size or 1 + sp_samples = sp_dist.rvs(size=size) assert jnp.shape(sp_samples) == expected_shape elif ( sp_dist @@ -1481,7 +1531,7 @@ def test_cdf_and_icdf(jax_dist, sp_dist, params): def test_gof(jax_dist, sp_dist, params): if "Improper" in jax_dist.__name__: pytest.skip("distribution has improper .log_prob()") - if "LKJ" in jax_dist.__name__: + if "LKJ" in jax_dist.__name__ or "Wishart" in jax_dist.__name__: pytest.xfail("incorrect submanifold scaling") if jax_dist is dist.EulerMaruyama: d = jax_dist(*params) From 6cac84a1ebb9f1bca98645593a8e1db685eb7834 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 12 Apr 2024 17:38:28 -0400 Subject: [PATCH 02/11] Reduce dimensionality for bijection tests of positive definite matrices. --- test/test_transforms.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/test_transforms.py b/test/test_transforms.py index f35345193..e3ada6fd9 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -360,13 +360,13 @@ def test_batched_recursive_linear_transform(): (constraints.circular, (3,)), (constraints.complex, (3,)), (constraints.corr_cholesky, (10, 10)), - (constraints.corr_matrix, (21,)), + (constraints.corr_matrix, (15,)), (constraints.greater_than(3), ()), (constraints.greater_than_eq(3), ()), (constraints.interval(8, 13), (17,)), (constraints.l1_ball, (4,)), (constraints.less_than(-1), ()), - (constraints.lower_cholesky, (21,)), + (constraints.lower_cholesky, (15,)), (constraints.open_interval(3, 4), ()), (constraints.ordered_vector, (5,)), (constraints.positive_definite, (6,)), @@ -376,9 +376,9 @@ def test_batched_recursive_linear_transform(): (constraints.real_matrix, (17,)), (constraints.real_vector, (18,)), (constraints.real, (3,)), - (constraints.scaled_unit_lower_cholesky, (21,)), + (constraints.scaled_unit_lower_cholesky, (15,)), (constraints.simplex, (3,)), - (constraints.softplus_lower_cholesky, (21,)), + (constraints.softplus_lower_cholesky, (15,)), (constraints.softplus_positive, (2,)), (constraints.unit_interval, (4,)), (constraints.nonnegative, (7,)), From 497f056c6912ac3b304ec5808eb6d361c0f2d17b Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 13 Apr 2024 00:57:28 -0400 Subject: [PATCH 03/11] Add `WishartCholesky` distribution and use it as base for `Wishart`. --- docs/source/distributions.rst | 8 ++ numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 140 ++++++++++++++++++++++++---- test/test_distributions.py | 41 +++++++- 4 files changed, 170 insertions(+), 21 deletions(-) diff --git a/docs/source/distributions.rst b/docs/source/distributions.rst index b05ec2d6f..f568d458d 100644 --- a/docs/source/distributions.rst +++ b/docs/source/distributions.rst @@ -388,6 +388,14 @@ Wishart :show-inheritance: :member-order: bysource +WishartCholesky +^^^^^^^^^^^^^^^ +.. autoclass:: numpyro.distributions.continuous.WishartCholesky + :members: + :undoc-members: + :show-inheritance: + :member-order: bysource + ZeroSumNormal ^^^^^^^^^^^^^ .. autoclass:: numpyro.distributions.continuous.ZeroSumNormal diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 5273b025f..412dd1976 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -48,6 +48,7 @@ Uniform, Weibull, Wishart, + WishartCholesky, ZeroSumNormal, ) from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta @@ -196,6 +197,7 @@ "VonMises", "Weibull", "Wishart", + "WishartCholesky", "ZeroInflatedDistribution", "ZeroInflatedPoisson", "ZeroInflatedNegativeBinomial2", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 9c4d1dc12..4a2e220a7 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -55,6 +55,7 @@ from numpyro.distributions.distribution import Distribution, TransformedDistribution from numpyro.distributions.transforms import ( AffineTransform, + CholeskyTransform, CorrMatrixCholeskyTransform, ExpTransform, PowerTransform, @@ -2617,7 +2618,7 @@ def variance(self): return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape) -class Wishart(Distribution): +class Wishart(TransformedDistribution): """ Wishart distribution for covariance matrices. @@ -2650,6 +2651,82 @@ def __init__( scale_matrix=None, rate_matrix=None, scale_tril=None, + *, + validate_args=None, + ): + base_dist = WishartCholesky( + concentration, + scale_matrix, + rate_matrix, + scale_tril, + validate_args=validate_args, + ) + super().__init__( + base_dist, CholeskyTransform().inv, validate_args=validate_args + ) + + @lazy_property + def concentration(self): + return self.base_dist.concentration + + @lazy_property + def scale_matrix(self): + return self.base_dist.scale_matrix + + @lazy_property + def rate_matrix(self): + return self.base_dist.rate_matrix + + @lazy_property + def scale_tril(self): + return self.base_dist.scale_tril + + @lazy_property + def mean(self): + return self.concentration[..., None, None] * self.scale_matrix + + @lazy_property + def variance(self): + diag = jnp.diagonal(self.scale_matrix, axis1=-1, axis2=-2) + return self.concentration[..., None, None] * ( + self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] + ) + + +class WishartCholesky(Distribution): + """ + Cholesky factor of a Wishart distribution for covariance matrices. + + :param concentration: Positive concentration parameter analogous to the + concentration of a :class:`Gamma` distribution. The concentration must be larger + than the dimensionality of the scale matrix. + :param scale_matrix: Scale matrix analogous to the inverse rate of a :class:`Gamma` + distribution. + :param rate_matrix: Rate matrix anaologous to the rate of a :class:`Gamma` + distribution. + :param scale_tril: Cholesky decomposition of the :code:`scale_matrix`. + """ + + arg_constraints = { + "concentration": constraints.dependent(is_discrete=False), + "scale_matrix": constraints.positive_definite, + "rate_matrix": constraints.positive_definite, + "scale_tril": constraints.lower_cholesky, + } + support = constraints.lower_cholesky + reparametrized_params = [ + "scale_matrix", + "rate_matrix", + "scale_tril", + ] + + def __init__( + self, + concentration=None, + scale_matrix=None, + rate_matrix=None, + scale_tril=None, + *, validate_args=None, ): # Determine the shapes. @@ -2690,26 +2767,29 @@ def __init__( @validate_sample def log_prob(self, value): + # The log density of the Wishart distribution includes a term + # t = trace(rate_matrix @ cov). Here, value = cholesky(cov) such that + # t = trace(value.T @ rate_matrix @ value) by the cyclical property of the + # trace. The rate matrix is the inverse scale matrix with Cholesky decomposition + # scale_tril. Thus, + # t = trace(value.T @ inv(scale_tril).T @ inv(scale_tril) @ value), and we can + # rewrite as t = trace(x.T @ x) for x = inv(scale_tril) @ value which we can + # obtain easily by solving a triangular system. x is again triangular such that + # trace(x @ x.T) is equal to the sum of squares of elements. + x = solve_triangular(*jnp.broadcast_arrays(self.scale_tril, value), lower=True) + trace = jnp.square(x).sum(axis=(-1, -2)) p = value.shape[-1] - rate_matrix, value = jnp.broadcast_arrays(self.rate_matrix, value) - trace = jnp.trace(lax.batch_matmul(rate_matrix, value), axis1=-1, axis2=-2) return ( - (self.concentration - p - 1) / 2 * jnp.linalg.slogdet(value).logabsdet + (self.concentration - p - 1) * jnp.linalg.slogdet(value).logabsdet - trace / 2 - - self.concentration * p / 2 * jnp.log(2) + + p * (1 - self.concentration / 2) * jnp.log(2) - multigammaln(self.concentration / 2, p) - + self.concentration * jnp.linalg.slogdet(rate_matrix).logabsdet / 2 - ) - - @lazy_property - def mean(self): - return self.concentration[..., None, None] * jnp.linalg.inv(self.rate_matrix) - - @lazy_property - def variance(self): - diag = jnp.diagonal(self.scale_matrix, axis1=-1, axis2=-2) - return self.concentration[..., None, None] * ( - self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] + - self.concentration * jnp.linalg.slogdet(self.scale_tril).logabsdet + # Part of the Jacobian of the Cholesky transformation. + + jnp.sum( + jnp.arange(p, 0, -1) * jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)), + axis=-1, + ) ) @lazy_property @@ -2743,5 +2823,27 @@ def sample(self, key, sample_shape=()): latent = latent.at[..., i, j].set( random.normal(rng_offdiag, latent.shape[:-2] + (i.size,)) ) - factor = lax.batch_matmul(*jnp.broadcast_arrays(self.scale_tril, latent)) - return lax.batch_matmul(factor, factor.mT) + return jnp.matmul(*jnp.broadcast_arrays(self.scale_tril, latent)) + + @lazy_property + def mean(self): + # The mean follows from the Bartlett decomposition sampling. All off-diagonal + # elements of the latent variable have zero expectation. The diagonal are the + # expected square roots of chi^2 variables which can be expressed in terms of + # gamma functions (see + # https://en.wikipedia.org/wiki/Chi-squared_distribution#Noncentral_moments). + k = self.concentration[..., None] - jnp.arange(self.scale_tril.shape[-1]) + sqrtchi2 = jnp.sqrt(2) * jnp.exp(gammaln((k + 1) / 2) - gammaln(k / 2)) + return self.scale_tril * sqrtchi2[..., None, :] + + @lazy_property + def variance(self): + # We have the same as for the mean except now the lower off-diagonals are one + # due to the standard normal noise, and the diagonals are equal to the dof of + # the chi^2 variables. + i = jnp.arange(self.scale_tril.shape[-1]) + k = self.concentration[..., None] - i + latent = jnp.tril( + jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k) + ) + return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean) diff --git a/test/test_distributions.py b/test/test_distributions.py index ee961f87a..a383019ff 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -133,9 +133,10 @@ def _truncnorm_to_scipy(loc, scale, low, high): def _wishart_to_scipy(conc, scale, rate, tril): jax_dist = dist.Wishart(conc, scale, rate, tril) - if not np.isscalar(jax_dist.concentration): + if not jnp.isscalar(jax_dist.concentration): pytest.skip("scipy Wishart only supports a single scalar concentration") - return osp.wishart(jax_dist.concentration, jax_dist.scale_matrix) + # Cast to float explicitly because np.isscalar returns False on scalar jax arrays. + return osp.wishart(float(jax_dist.concentration), jax_dist.scale_matrix) def _TruncatedNormal(loc, scale, low, high): @@ -819,6 +820,42 @@ def get_sp_dist(jax_dist): np.broadcast_to(np.identity(3), (2, 3, 3)), None, ), + T(dist.WishartCholesky, 3, 2 * np.eye(2) + 0.1, None, None), + T( + dist.WishartCholesky, + 3.0, + None, + np.array([[1.0, 0.5], [0.5, 1.0]]), + None, + ), + T( + dist.WishartCholesky, + np.array([4.0, 5.0]), + None, + np.array([[[1.0, 0.5], [0.5, 1.0]]]), + None, + ), + T( + dist.WishartCholesky, + np.array([3.0]), + None, + None, + np.array([[1.0, 0.0], [0.5, 1.0]]), + ), + T( + dist.WishartCholesky, + np.arange(3, 9, dtype=np.float32).reshape((3, 2)), + None, + None, + np.array([[1.0, 0.0], [0.0, 1.0]]), + ), + T( + dist.WishartCholesky, + 9.0, + None, + np.broadcast_to(np.identity(3), (2, 3, 3)), + None, + ), T(dist.ZeroSumNormal, 1.0, (5,)), T(dist.ZeroSumNormal, np.array([2.0]), (5,)), T(dist.ZeroSumNormal, 1.0, (4, 5)), From 9a1841a1a24bce3bdd8f474e8cb6ab2333671471 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 13 Apr 2024 10:51:03 -0400 Subject: [PATCH 04/11] Promote instead of broadcast Wishart parameters. --- numpyro/distributions/continuous.py | 40 ++++++++++------------------- 1 file changed, 14 insertions(+), 26 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 4a2e220a7..7fe4b3ca5 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2729,36 +2729,24 @@ def __init__( *, validate_args=None, ): - # Determine the shapes. - batch_shape = None - event_shape = None - for x in [scale_matrix, rate_matrix, scale_tril]: - if x is not None: - batch_shape = jnp.broadcast_shapes( - jnp.shape(concentration), jnp.shape(x)[:-2] - ) - event_shape = jnp.shape(x)[-2:] - break - if event_shape is None: - raise ValueError( - "One of `scale_matrix`, `rate_matrix`, or `scale_tril` must be " - "specified." - ) - - # Coerce to scale_tril parameter. + concentration = jnp.asarray(concentration)[..., None, None] if scale_matrix is not None: - self.scale_matrix = jnp.broadcast_to( - scale_matrix, batch_shape + event_shape + concentration, self.scale_matrix = promote_shapes( + concentration, scale_matrix ) - self.scale_tril = jnp.linalg.cholesky(scale_matrix) + self.scale_tril = jnp.linalg.cholesky(self.scale_matrix) elif rate_matrix is not None: - self.rate_matrix = jnp.broadcast_to(rate_matrix, batch_shape + event_shape) - self.scale_tril = cholesky_of_inverse(rate_matrix) + concentration, self.rate_matrix = promote_shapes(concentration, rate_matrix) + self.scale_tril = cholesky_of_inverse(self.rate_matrix) elif scale_tril is not None: - self.scale_tril = scale_tril - - self.concentration = jnp.broadcast_to(concentration, batch_shape) - self.scale_tril = jnp.broadcast_to(self.scale_tril, batch_shape + event_shape) + concentration, self.scale_tril = promote_shapes( + concentration, jnp.asarray(scale_tril) + ) + batch_shape = lax.broadcast_shapes( + jnp.shape(concentration)[:-2], jnp.shape(self.scale_tril)[:-2] + ) + event_shape = jnp.shape(self.scale_tril)[-2:] + self.concentration = concentration[..., 0, 0] super().__init__( batch_shape=batch_shape, event_shape=event_shape, From 231f3a2a3a3fc30836083d22b8ed707272d42199 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 13 Apr 2024 11:03:50 -0400 Subject: [PATCH 05/11] Assert exactly one of parameters is specified and update shape inference. --- numpyro/distributions/continuous.py | 23 +++++++++++++++++------ numpyro/distributions/directional.py | 5 ++++- numpyro/distributions/discrete.py | 21 +++++++-------------- numpyro/distributions/distribution.py | 5 +++-- numpyro/distributions/util.py | 11 +++++++++++ test/test_distributions.py | 11 +++++++++-- 6 files changed, 51 insertions(+), 25 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 7fe4b3ca5..6d571f407 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -64,6 +64,7 @@ ) from numpyro.distributions.util import ( add_diag, + assert_one_of, betainc, betaincinv, cholesky_of_inverse, @@ -1490,6 +1491,11 @@ def __init__( scale_tril=None, validate_args=None, ): + assert_one_of( + covariance_matrix=covariance_matrix, + precision_matrix=precision_matrix, + scale_tril=scale_tril, + ) if jnp.ndim(loc) == 0: (loc,) = promote_shapes(loc, shape=(1,)) # temporary append a new axis to loc @@ -1502,11 +1508,6 @@ def __init__( self.scale_tril = cholesky_of_inverse(self.precision_matrix) elif scale_tril is not None: loc, self.scale_tril = promote_shapes(loc, scale_tril) - else: - raise ValueError( - "One of `covariance_matrix`, `precision_matrix`, `scale_tril`" - " must be specified." - ) batch_shape = lax.broadcast_shapes( jnp.shape(loc)[:-2], jnp.shape(self.scale_tril)[:-2] ) @@ -1563,12 +1564,17 @@ def variance(self): def infer_shapes( loc=(), covariance_matrix=None, precision_matrix=None, scale_tril=None ): + assert_one_of( + covariance_matrix=covariance_matrix, + precision_matrix=precision_matrix, + scale_tril=scale_tril, + ) batch_shape, event_shape = loc[:-1], loc[-1:] for matrix in [covariance_matrix, precision_matrix, scale_tril]: if matrix is not None: batch_shape = lax.broadcast_shapes(batch_shape, matrix[:-2]) event_shape = lax.broadcast_shapes(event_shape, matrix[-1:]) - return batch_shape, event_shape + return batch_shape, event_shape def entropy(self): (n,) = self.event_shape @@ -2729,6 +2735,11 @@ def __init__( *, validate_args=None, ): + assert_one_of( + scale_matrix=scale_matrix, + rate_matrix=rate_matrix, + scale_tril=scale_tril, + ) concentration = jnp.asarray(concentration)[..., None, None] if scale_matrix is not None: concentration, self.scale_matrix = promote_shapes( diff --git a/numpyro/distributions/directional.py b/numpyro/distributions/directional.py index 841dd2194..1add0fcba 100644 --- a/numpyro/distributions/directional.py +++ b/numpyro/distributions/directional.py @@ -16,6 +16,7 @@ from numpyro.distributions import constraints from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( + assert_one_of, lazy_property, promote_shapes, safe_normalize, @@ -349,7 +350,9 @@ def __init__( weighted_correlation=None, validate_args=None, ): - assert (correlation is None) != (weighted_correlation is None) + assert_one_of( + correlation=correlation, weighted_correlation=weighted_correlation + ) if weighted_correlation is not None: correlation = weighted_correlation * jnp.sqrt( diff --git a/numpyro/distributions/discrete.py b/numpyro/distributions/discrete.py index fc0c81dd9..0ee140620 100644 --- a/numpyro/distributions/discrete.py +++ b/numpyro/distributions/discrete.py @@ -37,6 +37,7 @@ from numpyro.distributions import constraints, transforms from numpyro.distributions.distribution import Distribution from numpyro.distributions.util import ( + assert_one_of, binary_cross_entropy_with_logits, binomial, categorical, @@ -160,12 +161,11 @@ def entropy(self): def Bernoulli(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return BernoulliProbs(probs, validate_args=validate_args) elif logits is not None: return BernoulliLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class BinomialProbs(Distribution): @@ -293,12 +293,11 @@ def support(self): def Binomial(total_count=1, probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return BinomialProbs(probs, total_count, validate_args=validate_args) elif logits is not None: return BinomialLogits(logits, total_count, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class CategoricalProbs(Distribution): @@ -411,12 +410,11 @@ def entropy(self): def Categorical(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return CategoricalProbs(probs, validate_args=validate_args) elif logits is not None: return CategoricalLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class DiscreteUniform(Distribution): @@ -670,6 +668,7 @@ def Multinomial( :param int total_count_max: the maximum number of trials, i.e. `max(total_count)` """ + assert_one_of(probs=probs, logits=logits) if probs is not None: return MultinomialProbs( probs, @@ -684,8 +683,6 @@ def Multinomial( total_count_max=total_count_max, validate_args=validate_args, ) - else: - raise ValueError("One of `probs` or `logits` must be specified.") class Poisson(Distribution): @@ -837,10 +834,7 @@ def ZeroInflatedDistribution( :param numpy.ndarray gate: probability of extra zeros given via a Bernoulli distribution. :param numpy.ndarray gate_logits: logits of extra zeros given via a Bernoulli distribution. """ - if (gate is None) == (gate_logits is None): - raise ValueError( - "Either `gate` or `gate_logits` must be specified, but not both." - ) + assert_one_of(gate=gate, gate_logits=gate_logits) if gate is not None: return ZeroInflatedProbs(base_dist, gate, validate_args=validate_args) else: @@ -947,9 +941,8 @@ def entropy(self): def Geometric(probs=None, logits=None, *, validate_args=None): + assert_one_of(probs=probs, logits=logits) if probs is not None: return GeometricProbs(probs, validate_args=validate_args) elif logits is not None: return GeometricLogits(logits, validate_args=validate_args) - else: - raise ValueError("One of `probs` or `logits` must be specified.") diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index fa9e4613c..6eec9f4d3 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -509,8 +509,9 @@ def infer_shapes(cls, *args, **kwargs): # Assumes distribution is univariate. batch_shapes = [] for name, shape in kwargs.items(): - event_dim = cls.arg_constraints.get(name, constraints.real).event_dim - batch_shapes.append(shape[: len(shape) - event_dim]) + if shape is not None: + event_dim = cls.arg_constraints.get(name, constraints.real).event_dim + batch_shapes.append(shape[: len(shape) - event_dim]) batch_shape = lax.broadcast_shapes(*batch_shapes) if batch_shapes else () event_shape = () return batch_shape, event_shape diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c2e7aafc2..c7cd02da6 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -623,6 +623,17 @@ def is_prng_key(key): return False +def assert_one_of(**kwargs): + """ + Assert that exactly one of the keyword arguments is not None. + """ + specified = [key for key, value in kwargs.items() if value is not None] + if len(specified) != 1: + raise ValueError( + f"Exactly one of {list(kwargs)} must be specified; got {specified}." + ) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) diff --git a/test/test_distributions.py b/test/test_distributions.py index a383019ff..785a172cb 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1235,8 +1235,15 @@ def test_dist_shape(jax_dist_cls, sp_dist, params, prepend_shape): "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) def test_infer_shapes(jax_dist, sp_dist, params): - shapes = tuple(getattr(p, "shape", ()) for p in params) - shapes = tuple(x() if callable(x) else x for x in shapes) + shapes = [] + for param in params: + if param is None: + shapes.append(None) + continue + shape = getattr(param, "shape", ()) + if callable(shape): + shape = shape() + shapes.append(shape) jax_dist = jax_dist(*params) try: expected_batch_shape, expected_event_shape = type(jax_dist).infer_shapes( From 43d8e03465ba99e84059f5122c65d91edba4fca9 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Sat, 13 Apr 2024 11:04:02 -0400 Subject: [PATCH 06/11] Implement `infer_shapes` for `Wishart` and `WishartCholesky`. --- numpyro/distributions/continuous.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 6d571f407..eb15c030d 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2698,6 +2698,14 @@ def variance(self): self.scale_matrix**2 + diag[..., :, None] * diag[..., None, :] ) + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + return WishartCholesky.infer_shapes( + concentration, scale_matrix, rate_matrix, scale_tril + ) + class WishartCholesky(Distribution): """ @@ -2846,3 +2854,18 @@ def variance(self): jnp.ones_like(k, shape=k.shape + (k.shape[-1],)).at[..., i, i].set(k) ) return jnp.square(self.scale_tril) @ latent - jnp.square(self.mean) + + @staticmethod + def infer_shapes( + concentration=(), scale_matrix=None, rate_matrix=None, scale_tril=None + ): + assert_one_of( + scale_matrix=scale_matrix, + rate_matrix=rate_matrix, + scale_tril=scale_tril, + ) + for matrix in [scale_matrix, rate_matrix, scale_tril]: + if matrix is not None: + batch_shape = lax.broadcast_shapes(concentration, matrix[:-2]) + event_shape = matrix[-2:] + return batch_shape, event_shape From 538c1915584bd86efab0e3f9d705cec7b340e4e2 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 2 May 2024 10:53:29 -0400 Subject: [PATCH 07/11] Add entropy for Wishart distribution. --- numpyro/distributions/continuous.py | 11 +++++++++++ numpyro/distributions/util.py | 8 ++++++++ 2 files changed, 19 insertions(+) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index eb15c030d..daabb309b 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -71,6 +71,7 @@ gammaincinv, lazy_property, matrix_to_tril_vec, + multidigamma, promote_shapes, signed_stick_breaking_tril, validate_sample, @@ -2706,6 +2707,16 @@ def infer_shapes( concentration, scale_matrix, rate_matrix, scale_tril ) + def entropy(self): + p = self.event_shape[-1] + return ( + (p + 1) * jnp.linalg.slogdet(self.scale_tril).logabsdet + + p * (p + 1) / 2 * jnp.log(2) + + multigammaln(self.concentration / 2, p) + - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p) + + self.concentration * p / 2 + ) + class WishartCholesky(Distribution): """ diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index c7cd02da6..73d50b519 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -12,6 +12,7 @@ from jax import jit, lax, random, vmap import jax.numpy as jnp from jax.scipy.linalg import solve_triangular +from jax.scipy.special import digamma # Parameters for Transformed Rejection with Squeeze (TRS) algorithm - page 3. _tr_params = namedtuple( @@ -634,6 +635,13 @@ def assert_one_of(**kwargs): ) +def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray: + """ + Derivative of the log of multivariate gamma. + """ + return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke) From 9ad4ea93b032130a3f42b586bcdd2a7d7884808b Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Thu, 2 May 2024 10:53:43 -0400 Subject: [PATCH 08/11] Add sampled entropy test for distribution without `scipy` equivalent. --- test/test_distributions.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/test/test_distributions.py b/test/test_distributions.py index 785a172cb..3487ab745 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -1488,21 +1488,43 @@ def test_log_prob(jax_dist, sp_dist, params, prepend_shape, jit): @pytest.mark.parametrize( "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL ) -def test_entropy(jax_dist, sp_dist, params): +def test_entropy_scipy(jax_dist, sp_dist, params): jax_dist = jax_dist(*params) + try: + actual = jax_dist.entropy() + except NotImplementedError: + pytest.skip(reason="distribution does not implement `entropy`") if _is_batched_multivariate(jax_dist): pytest.skip("batching not allowed in multivariate distns.") if sp_dist is None: pytest.skip(reason="no corresponding scipy distribution") + + sp_dist = sp_dist(*params) + expected = sp_dist.entropy() + assert_allclose(actual, expected, atol=1e-5) + + +@pytest.mark.parametrize( + "jax_dist, sp_dist, params", CONTINUOUS + DISCRETE + DIRECTIONAL +) +def test_entropy_samples(jax_dist, sp_dist, params): + jax_dist = jax_dist(*params) + try: actual = jax_dist.entropy() except NotImplementedError: pytest.skip(reason="distribution does not implement `entropy`") - sp_dist = sp_dist(*params) - expected = sp_dist.entropy() - assert_allclose(actual, expected, atol=1e-5) + samples = jax_dist.sample(jax.random.key(8), (1000,)) + neg_log_probs = -jax_dist.log_prob(samples) + mean = neg_log_probs.mean(axis=0) + stderr = neg_log_probs.std(axis=0) / jnp.sqrt(neg_log_probs.shape[-1] - 1) + z = (actual - mean) / stderr + + # Check the z-score is small or that all values are close. This happens, for + # example, for uniform distributions with constant log prob and hence zero stderr. + assert (jnp.abs(z) < 5).all() or jnp.allclose(actual, neg_log_probs, atol=1e-5) def test_entropy_categorical(): From 9737b276761045c8966d7bede18d48b0126faddf Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 10 May 2024 11:58:22 -0600 Subject: [PATCH 09/11] Simplify `logabsdet` evaluation of `scale_tril`. --- numpyro/distributions/continuous.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index daabb309b..262295b3e 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2710,7 +2710,8 @@ def infer_shapes( def entropy(self): p = self.event_shape[-1] return ( - (p + 1) * jnp.linalg.slogdet(self.scale_tril).logabsdet + (p + 1) + * jnp.log(jnp.diagonal(self.scale_tril, axis1=-1, axis2=-2)).sum(axis=-1) + p * (p + 1) / 2 * jnp.log(2) + multigammaln(self.concentration / 2, p) - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p) From 4a471d1388c57bd254fe89a348c5bbfccbe6dde6 Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Fri, 10 May 2024 11:58:43 -0600 Subject: [PATCH 10/11] Remove default `None` argument for concentration of Wishart distribution. --- numpyro/distributions/continuous.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 262295b3e..231d37fb9 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -2654,7 +2654,7 @@ class Wishart(TransformedDistribution): def __init__( self, - concentration=None, + concentration, scale_matrix=None, rate_matrix=None, scale_tril=None, @@ -2748,7 +2748,7 @@ class WishartCholesky(Distribution): def __init__( self, - concentration=None, + concentration, scale_matrix=None, rate_matrix=None, scale_tril=None, From 34aa3e6cf792d19c726c5e26665e731eed11060d Mon Sep 17 00:00:00 2001 From: Till Hoffmann Date: Mon, 13 May 2024 09:37:29 -0400 Subject: [PATCH 11/11] Add `tri_logabsdet` utility function. --- numpyro/distributions/continuous.py | 30 ++++++++++------------------- numpyro/distributions/util.py | 7 +++++++ 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 231d37fb9..8280482c0 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -74,6 +74,7 @@ multidigamma, promote_shapes, signed_stick_breaking_tril, + tri_logabsdet, validate_sample, vec_to_tril_matrix, ) @@ -1400,12 +1401,8 @@ def sample(self, key, sample_shape=()): def log_prob(self, values): n, p = self.event_shape - row_log_det = jnp.log( - jnp.diagonal(self.scale_tril_row, axis1=-2, axis2=-1) - ).sum(-1) - col_log_det = jnp.log( - jnp.diagonal(self.scale_tril_column, axis1=-2, axis2=-1) - ).sum(-1) + row_log_det = tri_logabsdet(self.scale_tril_row) + col_log_det = tri_logabsdet(self.scale_tril_column) log_det_term = ( p * row_log_det + n * col_log_det + 0.5 * n * p * jnp.log(2 * jnp.pi) ) @@ -1532,9 +1529,7 @@ def sample(self, key, sample_shape=()): @validate_sample def log_prob(self, value): M = _batch_mahalanobis(self.scale_tril, value - self.loc) - half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum( - -1 - ) + half_log_det = tri_logabsdet(self.scale_tril) normalize_term = half_log_det + 0.5 * self.scale_tril.shape[-1] * jnp.log( 2 * jnp.pi ) @@ -1579,9 +1574,7 @@ def infer_shapes( def entropy(self): (n,) = self.event_shape - half_log_det = jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum( - -1 - ) + half_log_det = tri_logabsdet(self.scale_tril) return n * (jnp.log(2 * np.pi) + 1) / 2 + half_log_det @@ -1857,7 +1850,7 @@ def sample(self, key, sample_shape=()): def log_prob(self, value): n = self.scale_tril.shape[-1] Z = ( - jnp.log(jnp.diagonal(self.scale_tril, axis1=-2, axis2=-1)).sum(-1) + tri_logabsdet(self.scale_tril) + 0.5 * n * jnp.log(self.df) + 0.5 * n * jnp.log(jnp.pi) + gammaln(0.5 * self.df) @@ -1932,9 +1925,7 @@ def _batch_lowrank_logdet(W, D, capacitance_tril): where :math:`C` is the capacitance matrix :math:`I + W.T @ inv(D) @ W`, to compute the log determinant. """ - return 2 * jnp.sum( - jnp.log(jnp.diagonal(capacitance_tril, axis1=-2, axis2=-1)), axis=-1 - ) + jnp.log(D).sum(-1) + return 2 * tri_logabsdet(capacitance_tril) + jnp.log(D).sum(-1) def _batch_lowrank_mahalanobis(W, D, x, capacitance_tril): @@ -2710,8 +2701,7 @@ def infer_shapes( def entropy(self): p = self.event_shape[-1] return ( - (p + 1) - * jnp.log(jnp.diagonal(self.scale_tril, axis1=-1, axis2=-2)).sum(axis=-1) + (p + 1) * tri_logabsdet(self.scale_tril) + p * (p + 1) / 2 * jnp.log(2) + multigammaln(self.concentration / 2, p) - (self.concentration - p - 1) / 2 * multidigamma(self.concentration / 2, p) @@ -2799,11 +2789,11 @@ def log_prob(self, value): trace = jnp.square(x).sum(axis=(-1, -2)) p = value.shape[-1] return ( - (self.concentration - p - 1) * jnp.linalg.slogdet(value).logabsdet + (self.concentration - p - 1) * tri_logabsdet(value) - trace / 2 + p * (1 - self.concentration / 2) * jnp.log(2) - multigammaln(self.concentration / 2, p) - - self.concentration * jnp.linalg.slogdet(self.scale_tril).logabsdet + - self.concentration * tri_logabsdet(self.scale_tril) # Part of the Jacobian of the Cholesky transformation. + jnp.sum( jnp.arange(p, 0, -1) * jnp.log(jnp.diagonal(value, axis1=-2, axis2=-1)), diff --git a/numpyro/distributions/util.py b/numpyro/distributions/util.py index 73d50b519..aca32b1f7 100644 --- a/numpyro/distributions/util.py +++ b/numpyro/distributions/util.py @@ -642,6 +642,13 @@ def multidigamma(a: jnp.ndarray, d: jnp.ndarray) -> jnp.ndarray: return digamma(a[..., None] - 0.5 * jnp.arange(d)).sum(axis=-1) +def tri_logabsdet(a: jnp.ndarray) -> jnp.ndarray: + """ + Evaluate the `logabsdet` of a triangular positive-definite matrix. + """ + return jnp.log(jnp.diagonal(a, axis1=-1, axis2=-2)).sum(axis=-1) + + # The is sourced from: torch.distributions.util.py # # Copyright (c) 2016- Facebook, Inc (Adam Paszke)