-
Notifications
You must be signed in to change notification settings - Fork 271
Add ZeroSumNormal distribution #1751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 42 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
fa63f9a
added zerosumnormal and tests
kylejcaron c28fd0c
added edge case handling for support shape
kylejcaron 93bcf0f
removed commented out functions
kylejcaron d9f2b4e
added zerosumnormal to docs
kylejcaron 0abb60b
fixed zerosumnormal support shape default
kylejcaron b28f38c
Added v1 of docstrings for zerosumnormal
kylejcaron 4e1dd16
updated zsn docstring
kylejcaron 8cd792c
improved init shape handling for zerosumnormal
kylejcaron dcbdd85
improved docstrings
kylejcaron 13fff40
added ZeroSumTransform
kylejcaron 514000c
made n_zerosum_axes an attribute for the zerosumtransform
kylejcaron d6315c3
removed commented out lines
kylejcaron 907cd2e
added zerosumtransform class
kylejcaron fc3f053
switched zsn from ParameterFreeTransform to Transform
kylejcaron 8187421
changed ZeroSumNormal to transformed distibutrion
kylejcaron 0051342
changed input to tuple for _transform_to_zero_sum
kylejcaron 1820a74
added forward and inverse shape to transform, fixed zero_sum constrai…
kylejcaron ee227bf
fixed failing zsn tests
kylejcaron bb4880c
added docstring, removed whitespace, fixed missing import
kylejcaron 38b8f56
fixed allclose to be assert allclose
kylejcaron 54533ff
Merge branch 'master' into zsn-dist
kylejcaron c8af390
linted and formatted
kylejcaron 3034f4a
added sample code to docstring for zsn
kylejcaron ebdd309
updated docstring
kylejcaron 8cb7a5f
removed list from ZeroSum constraint call
kylejcaron ae1586f
removed unneeded iteration, updated docstring
kylejcaron ab58216
updated constraint code
kylejcaron ad4e7c2
added ZeroSumTransform to docs
kylejcaron 54547f2
fixed transform shapes
kylejcaron bdc6480
added doctest example for zsn
kylejcaron 0b5070b
added constraint test
kylejcaron b1129bf
added zero_sum constraint to docs
kylejcaron 5fcaf68
added type hinting to transforms file
kylejcaron 619f90b
fixed docs formatting
kylejcaron 2e79677
moved skip zsn from test_gof earlier
kylejcaron da382f5
reversed zerosumtransform
kylejcaron 5aa5aeb
broadcasted mean and var of zsn
kylejcaron f7992d1
added stricter zero_sum constraint tol, improved mean and var functions
kylejcaron 1e77815
fixed _transform_to_zero_sum
kylejcaron 98f32f9
removed shape promote from zsn, changed broadcast to zeros_like
kylejcaron c639e70
chose better zsn test cases
kylejcaron 8a7a905
Update zero_sum constraint feasible_like
kylejcaron d7f05ff
fixed docstring for doctests
kylejcaron File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -6,6 +6,7 @@ | |
| import weakref | ||
|
|
||
| import numpy as np | ||
| from numpy.core.numeric import normalize_axis_tuple | ||
|
|
||
| from jax import lax, vmap | ||
| from jax.flatten_util import ravel_pytree | ||
|
|
@@ -50,6 +51,7 @@ | |
| "StickBreakingTransform", | ||
| "Transform", | ||
| "UnpackTransform", | ||
| "ZeroSumTransform", | ||
| ] | ||
|
|
||
|
|
||
|
|
@@ -1380,6 +1382,92 @@ def __eq__(self, other): | |
| return jnp.array_equal(self.transition_matrix, other.transition_matrix) | ||
|
|
||
|
|
||
| class ZeroSumTransform(Transform): | ||
| """A transform that constrains an array to sum to zero, adapted from PyMC [1] as described in [2,3] | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @AlexAndorra @aseyboldt @ricardoV94 same as I said above, this PR is nearing ready to go - let me know if there's more I can add to properly credit all of you and pymc |
||
|
|
||
| :param transform_ndims: Number of trailing dimensions to transform. | ||
|
|
||
| **References** | ||
| [1] https://github.com/pymc-devs/pymc/blob/244fb97b01ad0f3dadf5c3837b65839e2a59a0e8/pymc/distributions/transforms.py#L266 | ||
| [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html | ||
| [3] https://learnbayesstats.com/episode/74-optimizing-nuts-developing-zerosumnormal-distribution-adrian-seyboldt/ | ||
| """ | ||
|
|
||
| def __init__(self, transform_ndims: int = 1) -> None: | ||
| self.transform_ndims = transform_ndims | ||
|
|
||
| @property | ||
| def domain(self) -> constraints.Constraint: | ||
| return constraints.independent(constraints.real, self.transform_ndims) | ||
|
|
||
| @property | ||
| def codomain(self) -> constraints.Constraint: | ||
| return constraints.zero_sum(self.transform_ndims) | ||
|
|
||
| def __call__(self, x: jnp.ndarray) -> jnp.ndarray: | ||
| zero_sum_axes = tuple(range(-self.transform_ndims, 0)) | ||
| for axis in zero_sum_axes: | ||
| x = self.extend_axis(x, axis=axis) | ||
| return x | ||
|
|
||
| def _inverse(self, y: jnp.ndarray) -> jnp.ndarray: | ||
| zero_sum_axes = tuple(range(-self.transform_ndims, 0)) | ||
| for axis in zero_sum_axes: | ||
| y = self.extend_axis_rev(y, axis=axis) | ||
| return y | ||
|
|
||
| def extend_axis_rev(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: | ||
| normalized_axis = normalize_axis_tuple(axis, array.ndim)[0] | ||
|
|
||
| n = array.shape[normalized_axis] | ||
| last = jnp.take(array, jnp.array([-1]), axis=normalized_axis) | ||
|
|
||
| sum_vals = -last * jnp.sqrt(n) | ||
| norm = sum_vals / (jnp.sqrt(n) + n) | ||
| slice_before = (slice(None, None),) * normalized_axis | ||
| return array[(*slice_before, slice(None, -1))] + norm | ||
|
|
||
| def extend_axis(self, array: jnp.ndarray, axis: int) -> jnp.ndarray: | ||
| n = array.shape[axis] + 1 | ||
|
|
||
| sum_vals = array.sum(axis, keepdims=True) | ||
| norm = sum_vals / (jnp.sqrt(n) + n) | ||
| fill_val = norm - sum_vals / jnp.sqrt(n) | ||
|
|
||
| out = jnp.concatenate([array, fill_val], axis=axis) | ||
| return out - norm | ||
|
|
||
| def log_abs_det_jacobian( | ||
| self, x: jnp.ndarray, y: jnp.ndarray, intermediates: None = None | ||
| ) -> jnp.ndarray: | ||
| shape = jnp.broadcast_shapes( | ||
| x.shape[: -self.transform_ndims], y.shape[: -self.transform_ndims] | ||
| ) | ||
| return jnp.zeros_like(x, shape=shape) | ||
|
|
||
| def forward_shape(self, shape: tuple) -> tuple: | ||
| return shape[: -self.transform_ndims] + tuple( | ||
| s + 1 for s in shape[-self.transform_ndims :] | ||
| ) | ||
|
|
||
| def inverse_shape(self, shape: tuple) -> tuple: | ||
| return shape[: -self.transform_ndims] + tuple( | ||
| s - 1 for s in shape[-self.transform_ndims :] | ||
| ) | ||
|
|
||
| def tree_flatten(self): | ||
| aux_data = { | ||
| "transform_ndims": self.transform_ndims, | ||
| } | ||
| return (), ((), aux_data) | ||
|
|
||
| def __eq__(self, other): | ||
| return ( | ||
| isinstance(other, ZeroSumTransform) | ||
| and self.transform_ndims == other.transform_ndims | ||
| ) | ||
|
|
||
|
|
||
| ########################################################## | ||
| # CONSTRAINT_REGISTRY | ||
| ########################################################## | ||
|
|
@@ -1530,3 +1618,8 @@ def _transform_to_softplus_lower_cholesky(constraint): | |
| @biject_to.register(constraints.simplex) | ||
| def _transform_to_simplex(constraint): | ||
| return StickBreakingTransform() | ||
|
|
||
|
|
||
| @biject_to.register(constraints.zero_sum) | ||
| def _transform_to_zero_sum(constraint): | ||
| return ZeroSumTransform(constraint.event_dim) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.