diff --git a/numpyro/distributions/constraints.py b/numpyro/distributions/constraints.py index 7b8b59340..ca278c725 100644 --- a/numpyro/distributions/constraints.py +++ b/numpyro/distributions/constraints.py @@ -168,9 +168,9 @@ class _CorrMatrix(_SingletonConstraint): def __call__(self, x): jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy # check for symmetric - symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1) + symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1)) # check for the smallest eigenvalue is positive - positive = jnp.linalg.eigh(x)[0][..., 0] > 0 + positive = jnp.linalg.eigvalsh(x)[..., 0] > 0 # check for diagonal equal to 1 unit_variance = jnp.all( jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1 diff --git a/test/test_transforms.py b/test/test_transforms.py index 1a706bbc6..4a35a7351 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -5,11 +5,13 @@ from functools import partial import math +import numpy as np import pytest from jax import jacfwd, jit, random, tree_map, vmap import jax.numpy as jnp +from numpyro.distributions import constraints from numpyro.distributions.flows import ( BlockNeuralAutoregressiveTransform, InverseAutoregressiveTransform, @@ -48,9 +50,6 @@ def _unpack(x): return (x,) -_a = jnp.asarray - - def _smoke_neural_network(): return None, None @@ -60,14 +59,12 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): TRANSFORMS = { - "affine": T( - AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), dict() - ), + "affine": T(AffineTransform, (np.array([1.0, 2.0]), np.array([3.0, 4.0])), dict()), "compose": T( ComposeTransform, ( [ - AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), + AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])), ExpTransform(), ], ), @@ -75,16 +72,16 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): ), "independent": T( IndependentTransform, - (AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),), + (AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])),), dict(reinterpreted_batch_ndims=1), ), "lower_cholesky_affine": T( - LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)), dict() + LowerCholeskyAffine, (np.array([1.0, 2.0]), np.eye(2)), dict() ), - "permute": T(PermuteTransform, (jnp.array([1, 0]),), dict()), + "permute": T(PermuteTransform, (np.array([1, 0]),), dict()), "power": T( PowerTransform, - (_a(2.0),), + (np.array(2.0),), dict(), ), "rfft": T( @@ -94,12 +91,12 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): ), "recursive_linear": T( RecursiveLinearTransform, - (jnp.eye(5),), + (np.eye(5),), dict(), ), "simplex_to_ordered": T( SimplexToOrderedTransform, - (_a(1.0),), + (np.array(1.0),), dict(), ), "unpack": T(UnpackTransform, (), dict(unpack_fn=_unpack)), @@ -123,7 +120,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): # autoregressive_nn is a non-jittable arg, which does not fit well with # the current test pipeline, which assumes jittable args, and non-jittable kwargs partial(InverseAutoregressiveTransform, _smoke_neural_network), - (_a(-1.0), _a(1.0)), + (np.array(-1.0), np.array(1.0)), dict(), ), "bna": T( @@ -277,10 +274,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): (IdentityTransform(), ()), (IndependentTransform(ExpTransform(), 2), (3, 4)), (L1BallTransform(), (9,)), - (LowerCholeskyAffine(jnp.ones(3), jnp.eye(3)), (3,)), + (LowerCholeskyAffine(np.ones(3), np.eye(3)), (3,)), (LowerCholeskyTransform(), (10,)), (OrderedTransform(), (5,)), - (PermuteTransform(jnp.roll(jnp.arange(7), 2)), (7,)), + (PermuteTransform(np.roll(np.arange(7), 2)), (7,)), (PowerTransform(2.5), ()), (RealFastFourierTransform(7), (7,)), (RealFastFourierTransform((8, 9), 2), (8, 9)), @@ -351,3 +348,41 @@ def test_batched_recursive_linear_transform(): y = transform(x) assert y.shape == x.shape assert jnp.allclose(x, transform.inv(y), atol=1e-6) + + +@pytest.mark.parametrize( + "constraint, shape", + [ + (constraints.circular, (3,)), + (constraints.complex, (3,)), + (constraints.corr_cholesky, (10, 10)), + (constraints.corr_matrix, (21,)), + (constraints.greater_than(3), ()), + (constraints.interval(8, 13), (17,)), + (constraints.l1_ball, (4,)), + (constraints.less_than(-1), ()), + (constraints.lower_cholesky, (21,)), + (constraints.open_interval(3, 4), ()), + (constraints.ordered_vector, (5,)), + (constraints.positive_definite, (6,)), + (constraints.positive_ordered_vector, (7,)), + (constraints.positive, (7,)), + (constraints.real_matrix, (17,)), + (constraints.real_vector, (18,)), + (constraints.real, (3,)), + (constraints.scaled_unit_lower_cholesky, (21,)), + (constraints.simplex, (3,)), + (constraints.softplus_lower_cholesky, (21,)), + (constraints.softplus_positive, (2,)), + (constraints.unit_interval, (4,)), + ], + ids=str, +) +def test_biject_to(constraint, shape): + batch_shape = (13, 19) + unconstrained = random.normal(random.key(93), batch_shape + shape) + constrained = biject_to(constraint)(unconstrained) + passed = constraint.check(constrained) + expected_shape = constrained.shape[: constrained.ndim - constraint.event_dim] + assert passed.shape == expected_shape + assert jnp.all(passed)