Skip to content

Commit b8f8f67

Browse files
tillahoffmannOlaRonning
authored andcommitted
Add RecursiveLinearTransform for linear state space models. (pyro-ppl#1766)
* Format reparam module to comply with style guide. * Add `RealFastFourierTransform` to documentation. * Ignore `venv` directory for `update_headers.py` script. * Ignore autogenerated documentation sources. * Add numerical Jacobian check for bijective transforms. * Add `RecursiveLinearTransform`. * Use matrix multiplication operator and fix Jacobian. * Use non-trivial transition matrix in test. * Specify that transition matrices must (batches of) square matrices. * Fix `scan` implementation for batched transition matrices and add test.
1 parent 91af0ad commit b8f8f67

File tree

6 files changed

+167
-4
lines changed

6 files changed

+167
-4
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,6 @@ numpyro/examples/.data
3535
# docs
3636
docs/build
3737
docs/.DS_Store
38+
docs/source/examples
39+
docs/source/tutorials
40+
docs/source/getting_started.rst

docs/source/distributions.rst

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ EulerMaruyama
151151
:undoc-members:
152152
:show-inheritance:
153153
:member-order: bysource
154-
154+
155155
Exponential
156156
^^^^^^^^^^^
157157
.. autoclass:: numpyro.distributions.continuous.Exponential
@@ -948,6 +948,24 @@ PowerTransform
948948
:show-inheritance:
949949
:member-order: bysource
950950

951+
RealFastFourierTransform
952+
^^^^^^^^^^^^^^^^^^^^^^^^
953+
954+
.. autoclass:: numpyro.distributions.transforms.RealFastFourierTransform
955+
:members:
956+
:undoc-members:
957+
:show-inheritance:
958+
:member-order: bysource
959+
960+
RecursiveLinearTransform
961+
^^^^^^^^^^^^^^^^^^^^^^^^
962+
963+
.. autoclass:: numpyro.distributions.transforms.RecursiveLinearTransform
964+
:members:
965+
:undoc-members:
966+
:show-inheritance:
967+
:member-order: bysource
968+
951969
ScaledUnitLowerCholeskyTransform
952970
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
953971
.. autoclass:: numpyro.distributions.transforms.ScaledUnitLowerCholeskyTransform

numpyro/distributions/transforms.py

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1288,6 +1288,98 @@ def __eq__(self, other):
12881288
)
12891289

12901290

1291+
class RecursiveLinearTransform(Transform):
1292+
"""
1293+
Apply a linear transformation recursively such that
1294+
:math:`y_t = A y_{t - 1} + x_t` for :math:`t > 0`, where :math:`x_t` and :math:`y_t`
1295+
are vectors and :math:`A` is a square transition matrix. The series is initialized
1296+
by :math:`y_0 = 0`.
1297+
1298+
:param transition_matrix: Squared transition matrix :math:`A` for successive states
1299+
or a batch of transition matrices.
1300+
1301+
**Example:**
1302+
1303+
.. doctest::
1304+
1305+
>>> from jax import random
1306+
>>> from jax import numpy as jnp
1307+
>>> import numpyro
1308+
>>> from numpyro import distributions as dist
1309+
>>>
1310+
>>> def cauchy_random_walk():
1311+
... return numpyro.sample(
1312+
... "x",
1313+
... dist.TransformedDistribution(
1314+
... dist.Cauchy(0, 1).expand([10, 1]).to_event(1),
1315+
... dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
1316+
... ),
1317+
... )
1318+
>>>
1319+
>>> numpyro.handlers.seed(cauchy_random_walk, 0)().shape
1320+
(10, 1)
1321+
>>>
1322+
>>> def rocket_trajectory():
1323+
... scale = numpyro.sample(
1324+
... "scale",
1325+
... dist.HalfCauchy(1).expand([2]).to_event(1),
1326+
... )
1327+
... transition_matrix = jnp.array([[1, 1], [0, 1]])
1328+
... return numpyro.sample(
1329+
... "x",
1330+
... dist.TransformedDistribution(
1331+
... dist.Normal(0, scale).expand([10, 2]).to_event(1),
1332+
... dist.transforms.RecursiveLinearTransform(transition_matrix),
1333+
... ),
1334+
... )
1335+
>>>
1336+
>>> numpyro.handlers.seed(rocket_trajectory, 0)().shape
1337+
(10, 2)
1338+
"""
1339+
1340+
domain = constraints.real_matrix
1341+
codomain = constraints.real_matrix
1342+
1343+
def __init__(self, transition_matrix: jnp.ndarray) -> None:
1344+
self.transition_matrix = transition_matrix
1345+
1346+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
1347+
# Move the time axis to the first position so we can scan over it.
1348+
x = jnp.moveaxis(x, -2, 0)
1349+
1350+
def f(y, x):
1351+
y = jnp.einsum("...ij,...j->...i", self.transition_matrix, y) + x
1352+
return y, y
1353+
1354+
_, y = lax.scan(f, jnp.zeros_like(x, shape=x.shape[1:]), x)
1355+
return jnp.moveaxis(y, 0, -2)
1356+
1357+
def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
1358+
# Move the time axis to the first position so we can scan over it in reverse.
1359+
y = jnp.moveaxis(y, -2, 0)
1360+
1361+
def f(y, prev):
1362+
x = y - jnp.einsum("...ij,...j->...i", self.transition_matrix, prev)
1363+
return prev, x
1364+
1365+
_, x = lax.scan(f, y[-1], jnp.roll(y, 1, axis=0).at[0].set(0), reverse=True)
1366+
return jnp.moveaxis(x, 0, -2)
1367+
1368+
def log_abs_det_jacobian(self, x: jnp.ndarray, y: jnp.ndarray, intermediates=None):
1369+
return jnp.zeros_like(x, shape=x.shape[:-2])
1370+
1371+
def tree_flatten(self):
1372+
return (self.transition_matrix,), (
1373+
("transition_matrix",),
1374+
{},
1375+
)
1376+
1377+
def __eq__(self, other):
1378+
if not isinstance(other, RecursiveLinearTransform):
1379+
return False
1380+
return jnp.array_equal(self.transition_matrix, other.transition_matrix)
1381+
1382+
12911383
##########################################################
12921384
# CONSTRAINT_REGISTRY
12931385
##########################################################

scripts/update_headers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import sys
88

99
root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
10-
blacklist = ["/build/", "/dist/", "/pyro_api.egg"]
10+
blacklist = ["/build/", "/dist/", "/pyro_api.egg", "/venv/"]
1111
file_types = [("*.py", "# {}"), ("*.cpp", "// {}")]
1212

1313
parser = argparse.ArgumentParser()

test/test_distributions.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3222,3 +3222,15 @@ def test_lowrank_mvn_19885(capfd: pytest.CaptureFixture) -> None:
32223222
assert x.shape == (sample_size, batch_size, event_size)
32233223
log_prob = _assert_not_jax_issue_19885(capfd, distribution.log_prob, x)
32243224
assert log_prob.shape == (sample_size, batch_size)
3225+
3226+
3227+
def test_gaussian_random_walk_linear_recursive_equivalence():
3228+
dist1 = dist.GaussianRandomWalk(3.7, 15)
3229+
dist2 = dist.TransformedDistribution(
3230+
dist.Normal(0, 3.7).expand([15, 1]).to_event(2),
3231+
dist.transforms.RecursiveLinearTransform(jnp.eye(1)),
3232+
)
3233+
x1 = dist1.sample(random.PRNGKey(7))
3234+
x2 = dist2.sample(random.PRNGKey(7))
3235+
assert jnp.allclose(x1, x2.squeeze())
3236+
assert jnp.allclose(dist1.log_prob(x1), dist2.log_prob(x2))

test/test_transforms.py

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,11 @@
33

44
from collections import namedtuple
55
from functools import partial
6+
import math
67

78
import pytest
89

9-
from jax import jit, random, tree_map, vmap
10+
from jax import jacfwd, jit, random, tree_map, vmap
1011
import jax.numpy as jnp
1112

1213
from numpyro.distributions.flows import (
@@ -30,6 +31,7 @@
3031
PermuteTransform,
3132
PowerTransform,
3233
RealFastFourierTransform,
34+
RecursiveLinearTransform,
3335
ReshapeTransform,
3436
ScaledUnitLowerCholeskyTransform,
3537
SigmoidTransform,
@@ -90,6 +92,11 @@ class T(namedtuple("TestCase", ["transform_cls", "params", "kwargs"])):
9092
(),
9193
dict(transform_shape=(3, 4, 5), transform_ndims=3),
9294
),
95+
"recursive_linear": T(
96+
RecursiveLinearTransform,
97+
(jnp.eye(5),),
98+
dict(),
99+
),
93100
"simplex_to_ordered": T(
94101
SimplexToOrderedTransform,
95102
(_a(1.0),),
@@ -277,6 +284,10 @@ def test_real_fast_fourier_transform(input_shape, shape, ndims):
277284
(PowerTransform(2.5), ()),
278285
(RealFastFourierTransform(7), (7,)),
279286
(RealFastFourierTransform((8, 9), 2), (8, 9)),
287+
(
288+
RecursiveLinearTransform(random.normal(random.key(17), (4, 4))),
289+
(7, 4),
290+
),
280291
(ReshapeTransform((5, 2), (10,)), (10,)),
281292
(ReshapeTransform((15,), (3, 5)), (3, 5)),
282293
(ScaledUnitLowerCholeskyTransform(), (6,)),
@@ -312,4 +323,31 @@ def test_bijective_transforms(transform, shape):
312323
atol = 1e-2
313324
assert jnp.allclose(x1, x2, atol=atol)
314325

315-
assert transform.log_abs_det_jacobian(x1, y).shape == batch_shape
326+
log_abs_det_jacobian = transform.log_abs_det_jacobian(x1, y)
327+
assert log_abs_det_jacobian.shape == batch_shape
328+
329+
# Also check the Jacobian numerically for transforms with the same input and output
330+
# size, unless they are explicitly excluded. E.g., the upper triangular of the
331+
# CholeskyTransform is zero, giving rise to a singular Jacobian.
332+
skip_jacobian_check = (CholeskyTransform,)
333+
size_x = int(x1.size / math.prod(batch_shape))
334+
size_y = int(y.size / math.prod(batch_shape))
335+
if size_x == size_y and not isinstance(transform, skip_jacobian_check):
336+
jac = (
337+
vmap(jacfwd(transform))(x1)
338+
.reshape((-1,) + x1.shape[len(batch_shape) :])
339+
.reshape(batch_shape + (size_y, size_x))
340+
)
341+
slogdet = jnp.linalg.slogdet(jac)
342+
assert jnp.allclose(log_abs_det_jacobian, slogdet.logabsdet, atol=atol)
343+
344+
345+
def test_batched_recursive_linear_transform():
346+
batch_shape = (4, 17)
347+
x = random.normal(random.key(8), batch_shape + (10, 3))
348+
# Get a batch of matrices with eigenvalues that don't blow up the sequence.
349+
A = CorrCholeskyTransform()(random.normal(random.key(7), batch_shape + (3,)))
350+
transform = RecursiveLinearTransform(A)
351+
y = transform(x)
352+
assert y.shape == x.shape
353+
assert jnp.allclose(x, transform.inv(y), atol=1e-6)

0 commit comments

Comments
 (0)