From 1809d353fba5a9ef82c18515a750f83c3e360bf2 Mon Sep 17 00:00:00 2001 From: Cole Haus Date: Thu, 18 Aug 2022 14:06:25 -0700 Subject: [PATCH] Add draft of SkewMultivariateNormal --- .gitignore | 1 + numpyro/distributions/__init__.py | 2 + numpyro/distributions/continuous.py | 136 +++++++++++++++++++++ setup.py | 1 + test/test_distributions.py | 182 ++++++++++++++++++++++++++++ 5 files changed, 322 insertions(+) diff --git a/.gitignore b/.gitignore index 11292796f..266cd090c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ build *.pyo /build /dist +/.hypothesis # IDE .idea diff --git a/numpyro/distributions/__init__.py b/numpyro/distributions/__init__.py index 31ac5de31..8542f0144 100644 --- a/numpyro/distributions/__init__.py +++ b/numpyro/distributions/__init__.py @@ -39,6 +39,7 @@ Pareto, RelaxedBernoulli, RelaxedBernoulliLogits, + SkewMultivariateNormal, SoftLaplace, StudentT, Uniform, @@ -158,6 +159,7 @@ "MultivariateStudentT", "LowRankMultivariateNormal", "Normal", + "SkewMultivariateNormal", "NegativeBinomialProbs", "NegativeBinomialLogits", "NegativeBinomial2", diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index fa05f6951..9bc801948 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -25,7 +25,10 @@ # ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE # POSSIBILITY OF SUCH DAMAGE. +from typing import Union, cast + import numpy as np +from numpy.typing import NDArray from jax import lax from jax.experimental.sparse import BCOO @@ -1731,6 +1734,139 @@ def variance(self): return jnp.broadcast_to(self.scale**2, self.batch_shape) +def skew_delta(skewers_: NDArray[float], cov_: NDArray[float]): + return (jnp.einsum("...ij,...j->...i", cov_, skewers_)) / jnp.sqrt( + 1 + + jnp.einsum("...j,...jk,...k->...", skewers_, cov_, skewers_)[..., jnp.newaxis] + ) + + +# Regularized Multivariate Regression Models with Skew-t Error Distributions +# https://epublications.marquette.edu/cgi/viewcontent.cgi?article=1225&context=mscs_fac +class SkewMultivariateNormal(Distribution): + arg_constraints = { + "loc": constraints.real_vector, + "scale_tril": constraints.lower_cholesky, + "skewers": constraints.real_vector, + } + support = constraints.real_vector + reparametrized_params = ["loc", "scale_tril", "skewers"] + uv_norm = Normal(0.0, 1.0) + + @staticmethod + def mk_big_mv_norm( + loc: NDArray[float], skewers: NDArray[float], scale_tril: NDArray[float] + ): + cov = jnp.einsum("...ij,...hj->...ih", scale_tril, scale_tril) + delta_ = skew_delta(skewers, cov) + cov_star = jnp.block( + [ + [ + jnp.ones(skewers.shape[:-1] + (1, 1)), + jnp.expand_dims(delta_, axis=-2), + ], + [jnp.expand_dims(delta_, axis=-1), cov], + ] + ) + + return MultivariateNormal( + loc=jnp.zeros(loc.shape[-1] + 1), scale_tril=jnp.linalg.cholesky(cov_star) + ) + + def __init__( + self, + loc: Union[NDArray[float], float], + scale_tril: NDArray[float], + skewers: NDArray[float], + validate_args: None = None, + ): + if jnp.ndim(loc) == 0: + (loc_,) = promote_shapes(loc, shape=(1,)) + else: + loc_ = cast(NDArray[float], loc) + batch_shape = lax.broadcast_shapes( + jnp.shape(loc_)[:-1], jnp.shape(scale_tril)[:-2], jnp.shape(skewers)[:-1] + ) + (self.loc,) = promote_shapes(loc_, shape=batch_shape + loc_.shape[-1:]) + (self.skewers,) = promote_shapes( + skewers, shape=batch_shape + skewers.shape[-1:] + ) + (self.scale_tril,) = promote_shapes( + scale_tril, shape=batch_shape + scale_tril.shape[-2:] + ) + cov_batch = jnp.einsum("...ij,...hj->...ih", self.scale_tril, self.scale_tril) + self._std_devs = jnp.sqrt(jnp.diagonal(cov_batch, axis1=-2, axis2=-1)) + + # Used for sampling + self._big_mv_norm = self.mk_big_mv_norm( + # The blog post just uses unstandardized skewers here but that leads to + # a discrepancy between sampling and log_prob + loc=self.loc, + skewers=skewers / self._std_devs, + scale_tril=scale_tril, + ) + # Used for log_prob + self._mv_norm = MultivariateNormal(loc_, scale_tril=scale_tril) + + skew_mean = jnp.sqrt(2 / jnp.pi) * skew_delta( + self.skewers / self._std_devs, cov_batch + ) + self._mean = self.loc + skew_mean + # The paper just uses `mean` here but that's definitely not right because + # it potentially leads to covariance matrices which are not positive semi definite + self._covariance = cov_batch - jnp.einsum( + "...i,...j->...ij", skew_mean, skew_mean + ) + + event_shape = jnp.shape(self.scale_tril)[-1:] + super().__init__( + batch_shape=batch_shape, + event_shape=event_shape, + validate_args=validate_args, + ) + + @validate_sample + def log_prob(self, value: NDArray[float]) -> NDArray[float]: + return ( + jnp.log(2) + + self._mv_norm.log_prob(value) + + jnp.log( + self.uv_norm.cdf( + jnp.einsum( + "...k,...k->...", + (value - self.loc) / self._std_devs, + self.skewers, + ) + ) + ) + ) + + @staticmethod + def infer_shapes( + loc: NDArray[float], scale_tril: NDArray[float], skewers: NDArray[float] + ): + event_shape = (scale_tril[-1],) + batch_shape = lax.broadcast_shapes(loc[:-1], scale_tril[:-2], skewers[:-1]) + return batch_shape, event_shape + + # https://gregorygundersen.com/blog/2020/12/29/multivariate-skew-normal/ + def sample( + self, key: random.PRNGKey, sample_shape: tuple[int, ...] = () + ) -> NDArray[float]: + assert is_prng_key(key) + x = self._big_mv_norm.sample(key, sample_shape=sample_shape) + sign_bit, samples = x[..., 0, jnp.newaxis], x[..., 1:] + return jnp.where(sign_bit <= 0, -1 * samples, samples) + self.loc + + @property + def mean(self): + return jnp.broadcast_to(self._mean, self.shape()) + + @property + def covariance_matrix(self): + return self._covariance + + class Pareto(TransformedDistribution): arg_constraints = {"scale": constraints.positive, "alpha": constraints.positive} reparametrized_params = ["scale", "alpha"] diff --git a/setup.py b/setup.py index f9ccf0901..3211f8f13 100644 --- a/setup.py +++ b/setup.py @@ -51,6 +51,7 @@ "test": [ "black[jupyter]>=21.8b0", "flake8", + "hypothesis[numpy]", "isort>=5.0", "pytest>=4.1", "pyro-api>=0.1.1", diff --git a/test/test_distributions.py b/test/test_distributions.py index f026be77d..61ce79250 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -6,9 +6,15 @@ import inspect import math import os +from typing import cast +from hypothesis import given, note, settings +import hypothesis.extra.numpy as hnp +import hypothesis.strategies as st +from hypothesis.strategies import DrawFn, SearchStrategy import numpy as np from numpy.testing import assert_allclose, assert_array_equal +from numpy.typing import NDArray import pytest import scipy import scipy.stats as osp @@ -534,6 +540,12 @@ def get_sp_dist(jax_dist): T(dist.Normal, 0.0, 1.0), T(dist.Normal, 1.0, np.array([1.0, 2.0])), T(dist.Normal, np.array([0.0, 1.0]), np.array([[1.0], [2.0]])), + T( + dist.SkewMultivariateNormal, + np.array([2.0, 0.0]), + np.array([[1.0, 0.0], [0.5, 1.0]]), + np.array([0.0, 0.0]), + ), T(dist.Pareto, 1.0, 2.0), T(dist.Pareto, np.array([1.0, 0.5]), np.array([0.3, 2.0])), T(dist.Pareto, np.array([[1.0], [3.0]]), np.array([1.0, 0.5])), @@ -1502,6 +1514,10 @@ def test_mean_var(jax_dist, sp_dist, params): dist.TwoSidedTruncatedDistribution, ): pytest.skip("Truncated distributions do not has mean/var implemented") + if jax_dist is dist.SkewMultivariateNormal: + pytest.skip( + "We check SkewMultivariateNormal against MultivariateNormal elsewhere" + ) if jax_dist is dist.ProjectedNormal: pytest.skip("Mean is defined in submanifold") @@ -2570,3 +2586,169 @@ def sample_binomial_withp0(key): return dist.Binomial(total_count=n, probs=0).sample(key) jax.vmap(sample_binomial_withp0)(random.split(random.PRNGKey(0), 1)) + + +def locs(size: int) -> SearchStrategy[NDArray[float]]: + return cast( + SearchStrategy[NDArray[float]], + hnp.arrays( + elements=st.floats( + min_value=-1, max_value=1, allow_nan=False, allow_infinity=False + ), + dtype=np.dtype("float"), + shape=size, + ), + ) + + +def skews(size: int) -> SearchStrategy[NDArray[float]]: + return cast( + SearchStrategy[NDArray[float]], + hnp.arrays( + elements=st.floats( + min_value=-4, max_value=4, allow_nan=False, allow_infinity=False + ), + dtype=np.dtype("float"), + shape=size, + ), + ) + + +def variances(size: int) -> SearchStrategy[NDArray[float]]: + return cast( + SearchStrategy[NDArray[float]], + hnp.arrays( + # Variances that are too small make it impossible to test t against normal + elements=st.floats( + min_value=0.1, + max_value=3, + allow_nan=False, + allow_infinity=False, + exclude_min=True, + ), + dtype=np.dtype("float"), + shape=size, + ), + ) + + +def corr_vech_to_matrix(vech: NDArray[float]): + width = (math.isqrt(8 * vech.size + 1) + 1) // 2 + zeros = np.zeros((width, width)) + zeros[np.tril_indices(width, k=-1)] = vech + np.fill_diagonal(zeros, 1) + return zeros + + +def correlation_chols(size: int) -> SearchStrategy[NDArray[float]]: + return hnp.arrays( + # Floating point issues mean we sometimes get arrays which aren't positive semi-definite + # if we allow correlations of exactly 1 and -1 + elements=st.floats( + min_value=-0.99, max_value=0.99, allow_nan=False, allow_infinity=False + ), + dtype=np.dtype("float"), + shape=size * (size - 1) // 2, + ).map( + corr_vech_to_matrix # type: ignore + ) + + +@st.composite +def loc_and_scale(draw: DrawFn): + # Would need to generalize meshgrid to relax this restriction + size = 2 + corr = draw(correlation_chols(size)) + var = draw(variances(size)) + return (draw(locs(size)), jnp.sqrt(var)[..., None] * corr) + + +@st.composite +def loc_and_scale_and_skewers(draw: DrawFn): + # Would need to generalize meshgrid to relax this restriction + size = 2 + corr = draw(correlation_chols(size)) + var = draw(variances(size)) + return ( + draw(locs(size)), + jnp.sqrt(var)[..., None] * corr, + draw(skews(size)), + ) + + +X, Y = np.meshgrid(np.linspace(-3, 3, 100), np.linspace(-3, 3, 100)) +grid = np.dstack((X, Y)) +X_wide, Y_wide = np.meshgrid(np.linspace(-6, 6, 50), np.linspace(-6, 6, 50)) +grid_wide = np.dstack((X_wide, Y_wide)) + + +@settings(deadline=None) +@given(loc_and_scale()) +def test_skew_normal_log_prob_generalizes_normal( + loc_scale_tril: tuple[NDArray[float], NDArray[float]] +): + loc, scale_tril = loc_scale_tril + mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril) + smvn = dist.SkewMultivariateNormal( + loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1]) + ) + assert_allclose(mvn.log_prob(grid), smvn.log_prob(grid), atol=1e-6) + + +@settings(deadline=None) +@given(loc_and_scale()) +def test_skew_normal_moments_generalize_normal( + loc_scale_tril: tuple[NDArray[float], NDArray[float]] +): + loc, scale_tril = loc_scale_tril + mvn = dist.MultivariateNormal(loc=loc, scale_tril=scale_tril) + smvn = dist.SkewMultivariateNormal( + loc=loc, scale_tril=scale_tril, skewers=np.zeros(scale_tril.shape[-1]) + ) + assert_allclose(mvn.mean, smvn.mean, atol=1e-30) + assert_allclose(mvn.covariance_matrix, smvn.covariance_matrix, atol=1e-30) + + +@settings(deadline=None, max_examples=10) +@given(loc_and_scale_and_skewers()) +def test_skew_normal_log_prob_vs_samples( + loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]] +): + loc, scale_tril, skewers = loc_scale_tril_skewers + note(f"Covariance: {scale_tril @ scale_tril.T}") + smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers) + samples = smvn.sample(random.PRNGKey(0), sample_shape=(50_000,)) + # gaussian_kde needs a different format + grid_ = np.vstack([X_wide.ravel(), Y_wide.ravel()]) + lp = jnp.exp(smvn.log_prob(grid_.T)) + k = osp.gaussian_kde(samples.T, bw_method="scott")(grid_) + + lp_normed = (lp - lp.min()) / (lp.max() - lp.min()) + k_normed = (k - k.min()) / (k.max() - k.min()) + assert_allclose(lp_normed, k_normed, atol=0.07) + + +def split_cov(cov: NDArray[float]) -> tuple[NDArray[float], NDArray[float]]: + std_devs = np.sqrt(np.diag(cov)) + dinv = np.diag(1 / std_devs) + corr = dinv @ cov @ dinv + tril_i = np.tril_indices(len(std_devs), k=-1) + return (std_devs, corr[tril_i]) + + +@settings(deadline=None) +@given(loc_and_scale_and_skewers()) +def test_skew_normal_moments_vs_samples( + loc_scale_tril_skewers: tuple[NDArray[float], NDArray[float], NDArray[float]] +): + loc, scale_tril, skewers = loc_scale_tril_skewers + note(f"Covariance: {scale_tril @ scale_tril.T}") + smvn = dist.SkewMultivariateNormal(loc=loc, scale_tril=scale_tril, skewers=skewers) + samples = smvn.sample(random.PRNGKey(0), sample_shape=(500_000,)) + assert_allclose(np.mean(samples, axis=0), smvn.mean, rtol=0.005, atol=0.001) + + std_devs_sample, corr_sample = split_cov(np.cov(samples.T)) + std_devs_dist, corr_dist = split_cov(smvn.covariance_matrix) + assert_allclose(std_devs_sample, std_devs_dist, rtol=0.003) + note(f"Sample corr: {corr_sample}, Distribution corr: {corr_dist}") + assert_allclose(corr_sample, corr_dist, atol=0.006)