|
3 | 3 |
|
4 | 4 | from collections import namedtuple |
5 | 5 | from functools import partial |
| 6 | +import math |
6 | 7 |
|
7 | 8 | import pytest |
8 | 9 |
|
9 | | -from jax import jit, random, tree_map, vmap |
| 10 | +from jax import jacfwd, jit, random, tree_map, vmap |
10 | 11 | import jax.numpy as jnp |
11 | 12 |
|
12 | 13 | from numpyro.distributions.flows import ( |
|
30 | 31 | PermuteTransform, |
31 | 32 | PowerTransform, |
32 | 33 | RealFastFourierTransform, |
| 34 | + RecursiveLinearTransform, |
33 | 35 | ReshapeTransform, |
34 | 36 | ScaledUnitLowerCholeskyTransform, |
35 | 37 | SigmoidTransform, |
@@ -90,6 +92,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])): |
90 | 92 | (), |
91 | 93 | dict(transform_shape=(3, 4, 5), transform_ndims=3), |
92 | 94 | ), |
| 95 | + "recursive_linear": T( |
| 96 | + RecursiveLinearTransform, |
| 97 | + (jnp.eye(5),), |
| 98 | + dict(), |
| 99 | + ), |
93 | 100 | "simplex_to_ordered": T( |
94 | 101 | SimplexToOrderedTransform, |
95 | 102 | (_a(1.0),), |
@@ -277,6 +284,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims): |
277 | 284 | (PowerTransform(2.5), ()), |
278 | 285 | (RealFastFourierTransform(7), (7,)), |
279 | 286 | (RealFastFourierTransform((8, 9), 2), (8, 9)), |
| 287 | + ( |
| 288 | + RecursiveLinearTransform(random.normal(random.key(17), (4, 4))), |
| 289 | + (7, 4), |
| 290 | + ), |
280 | 291 | (ReshapeTransform((5, 2), (10,)), (10,)), |
281 | 292 | (ReshapeTransform((15,), (3, 5)), (3, 5)), |
282 | 293 | (ScaledUnitLowerCholeskyTransform(), (6,)), |
@@ -312,4 +323,31 @@ def test_bijective_transforms(transform, shape): |
312 | 323 | atol = 1e-2 |
313 | 324 | assert jnp.allclose(x1, x2, atol=atol) |
314 | 325 |
|
315 | | - assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape |
| 326 | + log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y) |
| 327 | + assert log_abs_det_jacobian.shape == batch_shape |
| 328 | + |
| 329 | + # Also check the Jacobian numerically for transforms with the same input and output |
| 330 | + # size, unless they are explicitly excluded. E.g., the upper triangular of the |
| 331 | + # CholeskyTransform is zero, giving rise to a singular Jacobian. |
| 332 | + skip_jacobian_check = (CholeskyTransform,) |
| 333 | + size_x = int(x1.size / math.prod(batch_shape)) |
| 334 | + size_y = int(y.size / math.prod(batch_shape)) |
| 335 | + if size_x == size_y and not isinstance(transform, skip_jacobian_check): |
| 336 | + jac = ( |
| 337 | + vmap(jacfwd(transform))(x1) |
| 338 | + .reshape((-1,) + x1.shape[len(batch_shape) :]) |
| 339 | + .reshape(batch_shape + (size_y, size_x)) |
| 340 | + ) |
| 341 | + slogdet = jnp.linalg.slogdet(jac) |
| 342 | + assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol) |
| 343 | + |
| 344 | + |
| 345 | +def test_batched_recursive_linear_transform(): |
| 346 | + batch_shape = (4, 17) |
| 347 | + x = random.normal(random.key(8), batch_shape + (10, 3)) |
| 348 | + # Get a batch of matrices with eigenvalues that don't blow up the sequence. |
| 349 | + A = CorrCholeskyTransform()(random.normal(random.key(7), batch_shape + (3,))) |
| 350 | + transform = RecursiveLinearTransform(A) |
| 351 | + y = transform(x) |
| 352 | + assert y.shape == x.shape |
| 353 | + assert jnp.allclose(x, transform.inv(y), atol=1e-6) |
0 commit comments