Skip to content

Commit 6e61a14

Browse files
kylejcarontillahoffmann
authored andcommitted
Add ZeroSumNormal distribution (pyro-ppl#1751)
* added zerosumnormal and tests * added edge case handling for support shape * removed commented out functions * added zerosumnormal to docs * fixed zerosumnormal support shape default * Added v1 of docstrings for zerosumnormal * updated zsn docstring * improved init shape handling for zerosumnormal * improved docstrings * added ZeroSumTransform * made n_zerosum_axes an attribute for the zerosumtransform * removed commented out lines * added zerosumtransform class * switched zsn from ParameterFreeTransform to Transform * changed ZeroSumNormal to transformed distibutrion * changed input to tuple for _transform_to_zero_sum * added forward and inverse shape to transform, fixed zero_sum constraint handling * fixed failing zsn tests * added docstring, removed whitespace, fixed missing import * fixed allclose to be assert allclose * linted and formatted * added sample code to docstring for zsn * updated docstring * removed list from ZeroSum constraint call * removed unneeded iteration, updated docstring * updated constraint code * added ZeroSumTransform to docs * fixed transform shapes * added doctest example for zsn * added constraint test * added zero_sum constraint to docs * added type hinting to transforms file * fixed docs formatting * moved skip zsn from test_gof earlier * reversed zerosumtransform * broadcasted mean and var of zsn * added stricter zero_sum constraint tol, improved mean and var functions * fixed _transform_to_zero_sum * removed shape promote from zsn, changed broadcast to zeros_like * chose better zsn test cases * Update zero_sum constraint feasible_like Co-authored-by: Till Hoffmann <[email protected]> * fixed docstring for doctests --------- Co-authored-by: Till Hoffmann <[email protected]>
1 parent 6140916 commit 6e61a14

File tree

8 files changed

+259
-1
lines changed

8 files changed

+259
-1
lines changed

docs/source/distributions.rst

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -380,6 +380,13 @@ Weibull
380380
:show-inheritance:
381381
:member-order: bysource
382382

383+
ZeroSumNormal
384+
^^^^^^^^^^^^^
385+
.. autoclass:: numpyro.distributions.continuous.ZeroSumNormal
386+
:members:
387+
:undoc-members:
388+
:show-inheritance:
389+
:member-order: bysource
383390

384391
Discrete Distributions
385392
----------------------
@@ -820,6 +827,9 @@ unit_interval
820827
^^^^^^^^^^^^^
821828
.. autodata:: numpyro.distributions.constraints.unit_interval
822829

830+
zero_sum
831+
^^^^^^^^
832+
.. autodata:: numpyro.distributions.constraints.zero_sum
823833

824834
Transforms
825835
----------
@@ -1014,6 +1024,15 @@ StickBreakingTransform
10141024
:show-inheritance:
10151025
:member-order: bysource
10161026

1027+
ZeroSumTransform
1028+
^^^^^^^^^^^^^^^^
1029+
1030+
.. autoclass:: numpyro.distributions.transforms.ZeroSumTransform
1031+
:members:
1032+
:undoc-members:
1033+
:show-inheritance:
1034+
:member-order: bysource
1035+
10171036

10181037
Flows
10191038
-----

numpyro/distributions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
StudentT,
4848
Uniform,
4949
Weibull,
50+
ZeroSumNormal,
5051
)
5152
from numpyro.distributions.copula import GaussianCopula, GaussianCopulaBeta
5253
from numpyro.distributions.directional import (
@@ -196,4 +197,5 @@
196197
"ZeroInflatedDistribution",
197198
"ZeroInflatedPoisson",
198199
"ZeroInflatedNegativeBinomial2",
200+
"ZeroSumNormal",
199201
]

numpyro/distributions/constraints.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
"softplus_lower_cholesky",
5656
"softplus_positive",
5757
"unit_interval",
58+
"zero_sum",
5859
"Constraint",
5960
]
6061

@@ -697,6 +698,29 @@ def feasible_like(self, prototype):
697698
return jax.numpy.full_like(prototype, prototype.shape[-1] ** (-0.5))
698699

699700

701+
class _ZeroSum(Constraint):
702+
def __init__(self, event_dim=1):
703+
self.event_dim = event_dim
704+
super().__init__()
705+
706+
def __call__(self, x):
707+
jnp = np if isinstance(x, (np.ndarray, np.generic)) else jax.numpy
708+
tol = jnp.finfo(x.dtype).eps * x.shape[-1] * 10
709+
zerosum_true = True
710+
for dim in range(-self.event_dim, 0):
711+
zerosum_true = zerosum_true & jnp.allclose(x.sum(dim), 0, atol=tol)
712+
return zerosum_true
713+
714+
def __eq__(self, other):
715+
return type(self) is type(other) and self.event_dim == other.event_dim
716+
717+
def feasible_like(self, prototype):
718+
return jax.numpy.zeros_like(prototype)
719+
720+
def tree_flatten(self):
721+
return (self.event_dim,), (("event_dim",), dict())
722+
723+
700724
# TODO: Make types consistent
701725
# See https://github.com/pytorch/pytorch/issues/50616
702726

@@ -731,3 +755,4 @@ def feasible_like(self, prototype):
731755
sphere = _Sphere()
732756
unit_interval = _UnitInterval()
733757
open_interval = _OpenInterval
758+
zero_sum = _ZeroSum

numpyro/distributions/continuous.py

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
ExpTransform,
5959
PowerTransform,
6060
SigmoidTransform,
61+
ZeroSumTransform,
6162
)
6263
from numpyro.distributions.util import (
6364
add_diag,
@@ -2438,3 +2439,97 @@ def cdf(self, value):
24382439

24392440
def icdf(self, value):
24402441
return self._ald.icdf(value)
2442+
2443+
2444+
class ZeroSumNormal(TransformedDistribution):
2445+
r"""
2446+
Zero Sum Normal distribution adapted from PyMC [1] as described in [2,3]. This is a Normal distribution where one or
2447+
more axes are constrained to sum to zero (the last axis by default).
2448+
2449+
.. math::
2450+
\begin{align*}
2451+
ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\
2452+
\text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\
2453+
n = \text{number of zero-sum axes}
2454+
\end{align*}
2455+
2456+
:param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is
2457+
enforced.
2458+
:param tuple event_shape: The event shape of the distribution, the axes of which get constrained to sum to zero.
2459+
2460+
**Example:**
2461+
2462+
.. doctest::
2463+
2464+
>>> from numpy.testing import assert_allclose
2465+
>>> from jax import random
2466+
>>> import jax.numpy as jnp
2467+
>>> import numpyro
2468+
>>> import numpyro.distributions as dist
2469+
>>> from numpyro.infer import MCMC, NUTS
2470+
2471+
>>> N = 1000
2472+
>>> n_categories = 20
2473+
>>> rng_key = random.PRNGKey(0)
2474+
>>> key1, key2, key3 = random.split(rng_key, 3)
2475+
>>> category_ind = random.choice(key1, jnp.arange(n_categories), shape=(N,))
2476+
>>> beta = random.normal(key2, shape=(n_categories,))
2477+
>>> beta -= beta.mean(-1)
2478+
>>> y = 5 + beta[category_ind] + random.normal(key3, shape=(N,))
2479+
2480+
>>> def model(category_ind, y): # category_ind is an indexed categorical variable with 20 categories
2481+
... N = len(category_ind)
2482+
... alpha = numpyro.sample("alpha", dist.Normal(0, 2.5))
2483+
... beta = numpyro.sample("beta", dist.ZeroSumNormal(1, event_shape=(n_categories,)))
2484+
... sigma = numpyro.sample("sigma", dist.Exponential(1))
2485+
... with numpyro.plate("observations", N):
2486+
... mu = alpha + beta[category_ind]
2487+
... obs = numpyro.sample("obs", dist.Normal(mu, sigma), obs=y)
2488+
... return obs
2489+
2490+
>>> nuts_kernel = NUTS(model=model, target_accept_prob=0.9)
2491+
>>> mcmc = MCMC(
2492+
... sampler=nuts_kernel,
2493+
... num_samples=1_000, num_warmup=1_000, num_chains=4
2494+
... )
2495+
>>> mcmc.run(random.PRNGKey(0), category_ind=category_ind, y=y)
2496+
>>> posterior_samples = mcmc.get_samples()
2497+
>>> # Confirm everything along last axis sums to zero
2498+
>>> assert_allclose(posterior_samples['beta'].sum(-1), 0, atol=1e-3)
2499+
2500+
**References**
2501+
[1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637
2502+
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
2503+
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
2504+
"""
2505+
2506+
arg_constraints = {"scale": constraints.positive}
2507+
reparametrized_params = ["scale"]
2508+
2509+
def __init__(self, scale, event_shape, *, validate_args=None):
2510+
event_ndim = len(event_shape)
2511+
transformed_shape = tuple(size - 1 for size in event_shape)
2512+
self.scale = scale
2513+
super().__init__(
2514+
Normal(0, scale).expand(transformed_shape).to_event(event_ndim),
2515+
ZeroSumTransform(event_ndim),
2516+
validate_args=validate_args,
2517+
)
2518+
2519+
@constraints.dependent_property(is_discrete=False)
2520+
def support(self):
2521+
return constraints.zero_sum(len(self.event_shape))
2522+
2523+
@property
2524+
def mean(self):
2525+
return jnp.zeros(self.batch_shape + self.event_shape)
2526+
2527+
@property
2528+
def variance(self):
2529+
event_ndim = len(self.event_shape)
2530+
zero_sum_axes = tuple(range(-event_ndim, 0))
2531+
theoretical_var = jnp.square(self.scale)
2532+
for axis in zero_sum_axes:
2533+
theoretical_var *= 1 - 1 / self.event_shape[axis]
2534+
2535+
return jnp.broadcast_to(theoretical_var, self.batch_shape + self.event_shape)

numpyro/distributions/transforms.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import weakref
77

88
import numpy as np
9+
from numpy.core.numeric import normalize_axis_tuple
910

1011
from jax import lax, vmap
1112
from jax.flatten_util import ravel_pytree
@@ -50,6 +51,7 @@
5051
"StickBreakingTransform",
5152
"Transform",
5253
"UnpackTransform",
54+
"ZeroSumTransform",
5355
]
5456

5557

@@ -1380,6 +1382,92 @@ def __eq__(self, other):
13801382
return jnp.array_equal(self.transition_matrix, other.transition_matrix)
13811383

13821384

1385+
class ZeroSumTransform(Transform):
1386+
"""A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3]
1387+
1388+
:param transform_ndims: Number of trailing dimensions to transform.
1389+
1390+
**References**
1391+
[1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266
1392+
[2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html
1393+
[3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/
1394+
"""
1395+
1396+
def __init__(self, transform_ndims: int = 1) -> None:
1397+
self.transform_ndims = transform_ndims
1398+
1399+
@property
1400+
def domain(self) -> constraints.Constraint:
1401+
return constraints.independent(constraints.real, self.transform_ndims)
1402+
1403+
@property
1404+
def codomain(self) -> constraints.Constraint:
1405+
return constraints.zero_sum(self.transform_ndims)
1406+
1407+
def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
1408+
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
1409+
for axis in zero_sum_axes:
1410+
x = self.extend_axis(x, axis=axis)
1411+
return x
1412+
1413+
def _inverse(self, y: jnp.ndarray) -> jnp.ndarray:
1414+
zero_sum_axes = tuple(range(-self.transform_ndims, 0))
1415+
for axis in zero_sum_axes:
1416+
y = self.extend_axis_rev(y, axis=axis)
1417+
return y
1418+
1419+
def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
1420+
normalized_axis = normalize_axis_tuple(axis, array.ndim)[0]
1421+
1422+
n = array.shape[normalized_axis]
1423+
last = jnp.take(array, jnp.array([-1]), axis=normalized_axis)
1424+
1425+
sum_vals = -last * jnp.sqrt(n)
1426+
norm = sum_vals / (jnp.sqrt(n) + n)
1427+
slice_before = (slice(None, None),) * normalized_axis
1428+
return array[(*slice_before, slice(None, -1))] + norm
1429+
1430+
def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray:
1431+
n = array.shape[axis] + 1
1432+
1433+
sum_vals = array.sum(axis, keepdims=True)
1434+
norm = sum_vals / (jnp.sqrt(n) + n)
1435+
fill_val = norm - sum_vals / jnp.sqrt(n)
1436+
1437+
out = jnp.concatenate([array, fill_val], axis=axis)
1438+
return out - norm
1439+
1440+
def log_abs_det_jacobian(
1441+
self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None
1442+
) -> jnp.ndarray:
1443+
shape = jnp.broadcast_shapes(
1444+
x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims]
1445+
)
1446+
return jnp.zeros_like(x, shape=shape)
1447+
1448+
def forward_shape(self, shape: tuple) -> tuple:
1449+
return shape[: -self.transform_ndims] + tuple(
1450+
s + 1 for s in shape[-self.transform_ndims :]
1451+
)
1452+
1453+
def inverse_shape(self, shape: tuple) -> tuple:
1454+
return shape[: -self.transform_ndims] + tuple(
1455+
s - 1 for s in shape[-self.transform_ndims :]
1456+
)
1457+
1458+
def tree_flatten(self):
1459+
aux_data = {
1460+
"transform_ndims": self.transform_ndims,
1461+
}
1462+
return (), ((), aux_data)
1463+
1464+
def __eq__(self, other):
1465+
return (
1466+
isinstance(other, ZeroSumTransform)
1467+
and self.transform_ndims == other.transform_ndims
1468+
)
1469+
1470+
13831471
##########################################################
13841472
# CONSTRAINT_REGISTRY
13851473
##########################################################
@@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint):
15301618
@biject_to.register(constraints.simplex)
15311619
def _transform_to_simplex(constraint):
15321620
return StickBreakingTransform()
1621+
1622+
1623+
@biject_to.register(constraints.zero_sum)
1624+
def _transform_to_zero_sum(constraint):
1625+
return ZeroSumTransform(constraint.event_dim)

test/test_constraints.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ class T(namedtuple("TestCase", ["constraint_cls", "params", "kwargs"])):
6262
dict(),
6363
),
6464
"open_interval": T(constraints.open_interval, (_a(-1.0), _a(1.0)), dict()),
65+
"zero_sum": T(constraints.zero_sum, (), dict(event_dim=1)),
6566
}
6667

6768
# TODO: BijectorConstraint

0 commit comments

Comments
 (0)