Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions numpyro/distributions/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,9 +168,9 @@ class _CorrMatrix(_SingletonConstraint):
def __call__(self, x):
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
# check for symmetric
symmetric = jnp.all(jnp.all(x == jnp.swapaxes(x, -2, -1), axis=-1), axis=-1)
symmetric = jnp.all(jnp.isclose(x, jnp.swapaxes(x, -2, -1)), axis=(-2, -1))
# check for the smallest eigenvalue is positive
positive = jnp.linalg.eigh(x)[0][..., 0] > 0
positive = jnp.linalg.eigvalsh(x)[..., 0] > 0
# check for diagonal equal to 1
unit_variance = jnp.all(
jnp.abs(jnp.diagonal(x, axis1=-2, axis2=-1) - 1) < 1e-6, axis=-1
Expand Down
67 changes: 51 additions & 16 deletions test/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
from functools import partial
import math

import numpy as np
import pytest

from jax import jacfwd, jit, random, tree_map, vmap
import jax.numpy as jnp

from numpyro.distributions import constraints
from numpyro.distributions.flows import (
BlockNeuralAutoregressiveTransform,
InverseAutoregressiveTransform,
Expand Down Expand Up @@ -48,9 +50,6 @@ def _unpack(x):
return (x,)


_a = jnp.asarray


def _smoke_neural_network():
return None, None

Expand All @@ -60,31 +59,29 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):


TRANSFORMS = {
"affine": T(
AffineTransform, (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])), dict()
),
"affine": T(AffineTransform, (np.array([1.0, 2.0]), np.array([3.0, 4.0])), dict()),
"compose": T(
ComposeTransform,
(
[
AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),
AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])),
ExpTransform(),
],
),
dict(),
),
"independent": T(
IndependentTransform,
(AffineTransform(jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0])),),
(AffineTransform(np.array([1.0, 2.0]), np.array([3.0, 4.0])),),
dict(reinterpreted_batch_ndims=1),
),
"lower_cholesky_affine": T(
LowerCholeskyAffine, (jnp.array([1.0, 2.0]), jnp.eye(2)), dict()
LowerCholeskyAffine, (np.array([1.0, 2.0]), np.eye(2)), dict()
),
"permute": T(PermuteTransform, (jnp.array([1, 0]),), dict()),
"permute": T(PermuteTransform, (np.array([1, 0]),), dict()),
"power": T(
PowerTransform,
(_a(2.0),),
(np.array(2.0),),
dict(),
),
"rfft": T(
Expand All @@ -94,12 +91,12 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
),
"recursive_linear": T(
RecursiveLinearTransform,
(jnp.eye(5),),
(np.eye(5),),
dict(),
),
"simplex_to_ordered": T(
SimplexToOrderedTransform,
(_a(1.0),),
(np.array(1.0),),
dict(),
),
"unpack": T(UnpackTransform, (), dict(unpack_fn=_unpack)),
Expand All @@ -123,7 +120,7 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
# autoregressive_nn is a non-jittable arg, which does not fit well with
# the current test pipeline, which assumes jittable args, and non-jittable kwargs
partial(InverseAutoregressiveTransform, _smoke_neural_network),
(_a(-1.0), _a(1.0)),
(np.array(-1.0), np.array(1.0)),
dict(),
),
"bna": T(
Expand Down Expand Up @@ -277,10 +274,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
(IdentityTransform(), ()),
(IndependentTransform(ExpTransform(), 2), (3, 4)),
(L1BallTransform(), (9,)),
(LowerCholeskyAffine(jnp.ones(3), jnp.eye(3)), (3,)),
(LowerCholeskyAffine(np.ones(3), np.eye(3)), (3,)),
(LowerCholeskyTransform(), (10,)),
(OrderedTransform(), (5,)),
(PermuteTransform(jnp.roll(jnp.arange(7), 2)), (7,)),
(PermuteTransform(np.roll(np.arange(7), 2)), (7,)),
(PowerTransform(2.5), ()),
(RealFastFourierTransform(7), (7,)),
(RealFastFourierTransform((8, 9), 2), (8, 9)),
Expand Down Expand Up @@ -351,3 +348,41 @@ def test_batched_recursive_linear_transform():
y = transform(x)
assert y.shape == x.shape
assert jnp.allclose(x, transform.inv(y), atol=1e-6)


@pytest.mark.parametrize(
"constraint, shape",
[
(constraints.circular, (3,)),
(constraints.complex, (3,)),
(constraints.corr_cholesky, (10, 10)),
(constraints.corr_matrix, (21,)),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tillahoffmann This test is failing in my system with jax dev. Could we relax this to 15 (which passes)?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. I can send a mini PR or integrate it into #1538?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Either of them sounds good to me. Thanks, Till!

(constraints.greater_than(3), ()),
(constraints.interval(8, 13), (17,)),
(constraints.l1_ball, (4,)),
(constraints.less_than(-1), ()),
(constraints.lower_cholesky, (21,)),
(constraints.open_interval(3, 4), ()),
(constraints.ordered_vector, (5,)),
(constraints.positive_definite, (6,)),
(constraints.positive_ordered_vector, (7,)),
(constraints.positive, (7,)),
(constraints.real_matrix, (17,)),
(constraints.real_vector, (18,)),
(constraints.real, (3,)),
(constraints.scaled_unit_lower_cholesky, (21,)),
(constraints.simplex, (3,)),
(constraints.softplus_lower_cholesky, (21,)),
(constraints.softplus_positive, (2,)),
(constraints.unit_interval, (4,)),
],
ids=str,
)
def test_biject_to(constraint, shape):
batch_shape = (13, 19)
unconstrained = random.normal(random.key(93), batch_shape + shape)
constrained = biject_to(constraint)(unconstrained)
passed = constraint.check(constrained)
expected_shape = constrained.shape[: constrained.ndim - constraint.event_dim]
assert passed.shape == expected_shape
assert jnp.all(passed)