-
Notifications
You must be signed in to change notification settings - Fork 271
Add ZeroSumNormal distribution #1751
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 9 commits
Commits
Show all changes
43 commits
Select commit
Hold shift + click to select a range
fa63f9a
added zerosumnormal and tests
kylejcaron c28fd0c
added edge case handling for support shape
kylejcaron 93bcf0f
removed commented out functions
kylejcaron d9f2b4e
added zerosumnormal to docs
kylejcaron 0abb60b
fixed zerosumnormal support shape default
kylejcaron b28f38c
Added v1 of docstrings for zerosumnormal
kylejcaron 4e1dd16
updated zsn docstring
kylejcaron 8cd792c
improved init shape handling for zerosumnormal
kylejcaron dcbdd85
improved docstrings
kylejcaron 13fff40
added ZeroSumTransform
kylejcaron 514000c
made n_zerosum_axes an attribute for the zerosumtransform
kylejcaron d6315c3
removed commented out lines
kylejcaron 907cd2e
added zerosumtransform class
kylejcaron fc3f053
switched zsn from ParameterFreeTransform to Transform
kylejcaron 8187421
changed ZeroSumNormal to transformed distibutrion
kylejcaron 0051342
changed input to tuple for _transform_to_zero_sum
kylejcaron 1820a74
added forward and inverse shape to transform, fixed zero_sum constrai…
kylejcaron ee227bf
fixed failing zsn tests
kylejcaron bb4880c
added docstring, removed whitespace, fixed missing import
kylejcaron 38b8f56
fixed allclose to be assert allclose
kylejcaron 54533ff
Merge branch 'master' into zsn-dist
kylejcaron c8af390
linted and formatted
kylejcaron 3034f4a
added sample code to docstring for zsn
kylejcaron ebdd309
updated docstring
kylejcaron 8cb7a5f
removed list from ZeroSum constraint call
kylejcaron ae1586f
removed unneeded iteration, updated docstring
kylejcaron ab58216
updated constraint code
kylejcaron ad4e7c2
added ZeroSumTransform to docs
kylejcaron 54547f2
fixed transform shapes
kylejcaron bdc6480
added doctest example for zsn
kylejcaron 0b5070b
added constraint test
kylejcaron b1129bf
added zero_sum constraint to docs
kylejcaron 5fcaf68
added type hinting to transforms file
kylejcaron 619f90b
fixed docs formatting
kylejcaron 2e79677
moved skip zsn from test_gof earlier
kylejcaron da382f5
reversed zerosumtransform
kylejcaron 5aa5aeb
broadcasted mean and var of zsn
kylejcaron f7992d1
added stricter zero_sum constraint tol, improved mean and var functions
kylejcaron 1e77815
fixed _transform_to_zero_sum
kylejcaron 98f32f9
removed shape promote from zsn, changed broadcast to zeros_like
kylejcaron c639e70
chose better zsn test cases
kylejcaron 8a7a905
Update zero_sum constraint feasible_like
kylejcaron d7f05ff
fixed docstring for doctests
kylejcaron File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2444,3 +2444,125 @@ def cdf(self, value): | |
|
|
||
| def icdf(self, value): | ||
| return self._ald.icdf(value) | ||
|
|
||
|
|
||
| class ZeroSumNormal(Distribution): | ||
kylejcaron marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| r""" | ||
| Zero Sum Normal distribution adapted from PyMC [1] as described in [2]. This is a Normal distribution where one or | ||
| more axes are constrained to sum to zero (the last axis by default). | ||
|
|
||
| :param array_like scale: Standard deviation of the underlying normal distribution before the zerosum constraint is | ||
| enforced. | ||
| :param int n_zerosum_axes: The number of axes to enforce a zerosum constraint. | ||
| :param tuple support_shape: The event shape of the distribution. | ||
|
|
||
| .. math:: | ||
| \begin{align*} | ||
| ZSN(\sigma) = N(0, \sigma^2 (I - \tfrac{1}{n}J)) \\ | ||
| \text{where} \ ~ J_{ij} = 1 \ ~ \text{and} \\ | ||
| n = \text{number of zero-sum axes} | ||
| \end{align*} | ||
|
|
||
| **References** | ||
| [1] https://github.com/pymc-devs/pymc/blob/6252d2e58dc211c913ee2e652a4058d271d48bbd/pymc/distributions/multivariate.py#L2637 | ||
| [2] https://www.pymc.io/projects/docs/en/stable/api/distributions/generated/pymc.ZeroSumNormal.html | ||
| """ | ||
| arg_constraints = {"scale": constraints.positive} | ||
| support = constraints.real | ||
| reparametrized_params = ["scale"] | ||
| pytree_aux_fields = ("n_zerosum_axes","support_shape",) | ||
|
|
||
| def __init__(self, scale=1.0, n_zerosum_axes=None, support_shape=None, *, validate_args=None): | ||
| if not all(tuple(i == 1 for i in jnp.shape( scale ))): | ||
| raise ValueError("scale must have length one across the zero-sum axes") | ||
|
|
||
| self.n_zerosum_axes = self.check_zerosum_axes(n_zerosum_axes) | ||
| support_shape = self.check_support_shape(support_shape, self.n_zerosum_axes) | ||
| if jnp.ndim(scale) == 0: | ||
kylejcaron marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| (scale,) = promote_shapes(scale, shape=(1,)) | ||
|
|
||
| batch_shape = jnp.shape(scale)[:-1] | ||
| self.scale = scale | ||
kylejcaron marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| super(ZeroSumNormal, self).__init__( | ||
| batch_shape=batch_shape, | ||
| event_shape=support_shape, | ||
| validate_args=validate_args | ||
| ) | ||
|
|
||
| def sample(self, key, sample_shape=()): | ||
| assert is_prng_key(key) | ||
| zerosum_rv_ = random.normal( | ||
| key, shape=sample_shape + self.batch_shape + self.event_shape | ||
| ) * self.scale | ||
|
|
||
| if not zerosum_rv_.shape: | ||
| return jnp.zeros(zerosum_rv_.shape) | ||
|
|
||
| for axis in range(self.n_zerosum_axes): | ||
| zerosum_rv_ -= zerosum_rv_.mean(axis=-axis - 1, keepdims=True) | ||
| return zerosum_rv_ | ||
|
|
||
| @validate_sample | ||
| def log_prob(self, value): | ||
| shape = jnp.array(value.shape) | ||
| _deg_free_support_shape = shape.at[-self.n_zerosum_axes:].set( shape[-self.n_zerosum_axes:] - 1 ) | ||
| _full_size = jnp.prod(shape).astype(float) | ||
| _degrees_of_freedom = jnp.prod(_deg_free_support_shape).astype(float) | ||
|
|
||
| if not value.shape or self.batch_shape: | ||
| value = jnp.expand_dims(value, -1) | ||
|
|
||
| log_pdf = jnp.sum( | ||
| -0.5 * jnp.pow(value / self.scale, 2) | ||
| - (jnp.log(jnp.sqrt(2.0 * jnp.pi)) + jnp.log(self.scale)) * _degrees_of_freedom / _full_size, | ||
| axis=tuple(np.arange(-self.n_zerosum_axes, 0)), | ||
| ) | ||
| return log_pdf | ||
|
|
||
| @property | ||
| def mean(self): | ||
| return jnp.broadcast_to(0, self.batch_shape) | ||
kylejcaron marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
|
|
||
| @property | ||
| def variance(self): | ||
| theoretical_var = self.scale.astype(float)**2 | ||
| for axis in range(1,self.n_zerosum_axes+1): | ||
| theoretical_var *= (1 - 1 / self.event_shape[-axis]) | ||
|
|
||
| return theoretical_var | ||
|
||
|
|
||
| def check_zerosum_axes(self, n_zerosum_axes): | ||
| if n_zerosum_axes is None: | ||
| n_zerosum_axes = 1 | ||
|
|
||
| is_integer = isinstance(n_zerosum_axes, int) | ||
| is_jax_int_array = isinstance(n_zerosum_axes, jnp.ndarray) and jnp.issubdtype(n_zerosum_axes.dtype, jnp.integer) | ||
| if not (is_integer or is_jax_int_array): | ||
| raise TypeError("n_zerosum_axes has to be an integer") | ||
| if not n_zerosum_axes > 0: | ||
| raise ValueError("n_zerosum_axes has to be > 0") | ||
| return n_zerosum_axes | ||
|
|
||
| def check_support_shape(self, support_shape, n_zerosum_axes): | ||
| if support_shape is None: | ||
| return () | ||
| assert n_zerosum_axes <= len(support_shape), "support_shape has to be as long as n_zerosum_axes" | ||
| assert all(shape > 0 for shape in support_shape), "support_shape must be a valid shape" | ||
| assert len(support_shape) > 0, "support_shape must be a valid shape" | ||
| return support_shape | ||
|
|
||
| @staticmethod | ||
| def infer_shapes(scale=1.0, n_zerosum_axes=None, support_shape=(1,)): | ||
| '''Numpyro assumes that the event and batch shape can be entirely | ||
| determined by the shapes of the distribution inputs. This distribution | ||
| doesn't follow those conventions, so the `infer_shapes` method cant be implemented. | ||
| ''' | ||
| raise NotImplementedError() | ||
|
|
||
| def _validate_sample(self, value): | ||
| mask = super(ZeroSumNormal, self)._validate_sample(value) | ||
| batch_dim = jnp.ndim(value) - len(self.event_shape) | ||
| if batch_dim < jnp.ndim(mask): | ||
| mask = jnp.all(jnp.reshape(mask, jnp.shape(mask)[:batch_dim] + (-1,)), -1) | ||
| return mask | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.