|
6 | 6 | import weakref |
7 | 7 |
|
8 | 8 | import numpy as np |
| 9 | +from numpy.core.numeric import normalize_axis_tuple |
9 | 10 |
|
10 | 11 | from jax import lax, vmap |
11 | 12 | from jax.flatten_util import ravel_pytree |
|
50 | 51 | "StickBreakingTransform", |
51 | 52 | "Transform", |
52 | 53 | "UnpackTransform", |
| 54 | + "ZeroSumTransform", |
53 | 55 | ] |
54 | 56 |
|
55 | 57 |
|
@@ -1380,6 +1382,92 @@ def __eq__(self, other): |
1380 | 1382 | return jnp.array_equal(self.transition_matrix, other.transition_matrix) |
1381 | 1383 |
|
1382 | 1384 |
|
| 1385 | +class ZeroSumTransform(Transform): |
| 1386 | + """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] |
| 1387 | +
|
| 1388 | + :param transform_ndims: Number of trailing dimensions to transform. |
| 1389 | +
|
| 1390 | + **References** |
| 1391 | + [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 |
| 1392 | + [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html |
| 1393 | + [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ |
| 1394 | + """ |
| 1395 | + |
| 1396 | + def __init__(self, transform_ndims: int = 1) -> None: |
| 1397 | + self.transform_ndims = transform_ndims |
| 1398 | + |
| 1399 | + @property |
| 1400 | + def domain(self) -> constraints.Constraint: |
| 1401 | + return constraints.independent(constraints.real, self.transform_ndims) |
| 1402 | + |
| 1403 | + @property |
| 1404 | + def codomain(self) -> constraints.Constraint: |
| 1405 | + return constraints.zero_sum(self.transform_ndims) |
| 1406 | + |
| 1407 | + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: |
| 1408 | + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) |
| 1409 | + for axis in zero_sum_axes: |
| 1410 | + x = self.extend_axis(x, axis=axis) |
| 1411 | + return x |
| 1412 | + |
| 1413 | + def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: |
| 1414 | + zero_sum_axes = tuple(range(-self.transform_ndims, 0)) |
| 1415 | + for axis in zero_sum_axes: |
| 1416 | + y = self.extend_axis_rev(y, axis=axis) |
| 1417 | + return y |
| 1418 | + |
| 1419 | + def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: |
| 1420 | + normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] |
| 1421 | + |
| 1422 | + n = array.shape[normalized_axis] |
| 1423 | + last = jnp.take(array, jnp.array([-1]), axis=normalized_axis) |
| 1424 | + |
| 1425 | + sum_vals = -last * jnp.sqrt(n) |
| 1426 | + norm = sum_vals / (jnp.sqrt(n) + n) |
| 1427 | + slice_before = (slice(None, None),) * normalized_axis |
| 1428 | + return array[(*slice_before, slice(None, -1))] + norm |
| 1429 | + |
| 1430 | + def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: |
| 1431 | + n = array.shape[axis] + 1 |
| 1432 | + |
| 1433 | + sum_vals = array.sum(axis, keepdims=True) |
| 1434 | + norm = sum_vals / (jnp.sqrt(n) + n) |
| 1435 | + fill_val = norm - sum_vals / jnp.sqrt(n) |
| 1436 | + |
| 1437 | + out = jnp.concatenate([array, fill_val], axis=axis) |
| 1438 | + return out - norm |
| 1439 | + |
| 1440 | + def log_abs_det_jacobian( |
| 1441 | + self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None |
| 1442 | + ) -> jnp.ndarray: |
| 1443 | + shape = jnp.broadcast_shapes( |
| 1444 | + x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] |
| 1445 | + ) |
| 1446 | + return jnp.zeros_like(x, shape=shape) |
| 1447 | + |
| 1448 | + def forward_shape(self, shape: tuple) -> tuple: |
| 1449 | + return shape[: -self.transform_ndims] + tuple( |
| 1450 | + s + 1 for s in shape[-self.transform_ndims :] |
| 1451 | + ) |
| 1452 | + |
| 1453 | + def inverse_shape(self, shape: tuple) -> tuple: |
| 1454 | + return shape[: -self.transform_ndims] + tuple( |
| 1455 | + s - 1 for s in shape[-self.transform_ndims :] |
| 1456 | + ) |
| 1457 | + |
| 1458 | + def tree_flatten(self): |
| 1459 | + aux_data = { |
| 1460 | + "transform_ndims": self.transform_ndims, |
| 1461 | + } |
| 1462 | + return (), ((), aux_data) |
| 1463 | + |
| 1464 | + def __eq__(self, other): |
| 1465 | + return ( |
| 1466 | + isinstance(other, ZeroSumTransform) |
| 1467 | + and self.transform_ndims == other.transform_ndims |
| 1468 | + ) |
| 1469 | + |
| 1470 | + |
1383 | 1471 | ########################################################## |
1384 | 1472 | # CONSTRAINT_REGISTRY |
1385 | 1473 | ########################################################## |
@@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint): |
1530 | 1618 | @biject_to.register(constraints.simplex) |
1531 | 1619 | def _transform_to_simplex(constraint): |
1532 | 1620 | return StickBreakingTransform() |
| 1621 | + |
| 1622 | + |
| 1623 | +@biject_to.register(constraints.zero_sum) |
| 1624 | +def _transform_to_zero_sum(constraint): |
| 1625 | + return ZeroSumTransform(constraint.event_dim) |
0 commit comments