-
Notifications
You must be signed in to change notification settings - Fork 346
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Thorough use of distributions to clean module-level code #1356
Changes from all commits
9884cda
998c345
2a0c1d1
c9bf152
c1ebcd4
493ff5d
0e1d07f
3761bb6
abffd8f
ce1b4b4
1227323
7d9be42
41d9b0f
4868898
6566c0e
65df0aa
de172e7
ea225a5
fbae55f
04c4138
fedf805
5f48f0e
dd2cd8b
a480774
af84491
6e205ea
63ee149
43b74ee
c581d14
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,7 +8,9 @@ | |
import torch.nn.functional as F | ||
from numpyro.distributions import constraints as numpyro_constraints | ||
from numpyro.distributions.util import promote_shapes, validate_sample | ||
from torch.distributions import Distribution, Gamma, Poisson, constraints | ||
from torch.distributions import Distribution, Gamma | ||
from torch.distributions import Poisson as PoissonTorch | ||
from torch.distributions import constraints | ||
from torch.distributions.utils import ( | ||
broadcast_all, | ||
lazy_property, | ||
|
@@ -236,6 +238,33 @@ def _gamma(theta, mu): | |
return gamma_d | ||
|
||
|
||
class Poisson(PoissonTorch): | ||
""" | ||
Poisson distribution. | ||
|
||
Parameters | ||
---------- | ||
rate | ||
rate of the Poisson distribution. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. improve docs on optional args There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sorry wasnt clear, just meant to expand on the definitions which looks good now. You can remove ": optional" because sphinx takes care of that for us |
||
validate_args : optional | ||
whether to validate input. | ||
scale : optional | ||
Normalized mean expression of the distribution. | ||
This optional parameter is not used in any computations, but allows to store | ||
normalization expression levels. | ||
Comment on lines
+249
to
+254
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @PierreBoyeau @jjhong922 can we fix this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you talking about the docs or the optional parameter here? I can remove the :optional in the docs now |
||
|
||
""" | ||
|
||
def __init__( | ||
self, | ||
rate: torch.Tensor, | ||
validate_args: Optional[bool] = None, | ||
scale: Optional[torch.Tensor] = None, | ||
): | ||
super().__init__(rate=rate, validate_args=validate_args) | ||
self.scale = scale | ||
|
||
|
||
class NegativeBinomial(Distribution): | ||
r""" | ||
Negative binomial distribution. | ||
|
@@ -262,7 +291,9 @@ class NegativeBinomial(Distribution): | |
Mean of the distribution. | ||
theta | ||
Inverse dispersion. | ||
validate_args | ||
scale : optional | ||
Normalized mean expression of the distribution. | ||
validate_args : optional | ||
Raise ValueError if arguments do not match constraints | ||
""" | ||
|
||
|
@@ -279,6 +310,7 @@ def __init__( | |
logits: Optional[torch.Tensor] = None, | ||
mu: Optional[torch.Tensor] = None, | ||
theta: Optional[torch.Tensor] = None, | ||
scale: Optional[torch.Tensor] = None, | ||
validate_args: bool = False, | ||
): | ||
self._eps = 1e-8 | ||
|
@@ -299,6 +331,7 @@ def __init__( | |
mu, theta = broadcast_all(mu, theta) | ||
self.mu = mu | ||
self.theta = theta | ||
self.scale = scale | ||
super().__init__(validate_args=validate_args) | ||
|
||
@property | ||
|
@@ -319,7 +352,7 @@ def sample( | |
# Clamping as distributions objects can have buggy behaviors when | ||
# their parameters are too high | ||
l_train = torch.clamp(p_means, max=1e8) | ||
counts = Poisson( | ||
counts = PoissonTorch( | ||
l_train | ||
).sample() # Shape : (n_samples, n_cells_batch, n_vars) | ||
return counts | ||
|
@@ -368,6 +401,8 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial): | |
Inverse dispersion. | ||
zi_logits | ||
Logits scale of zero inflation probability. | ||
scale : optional | ||
Normalized mean expression of the distribution. | ||
validate_args | ||
Raise ValueError if arguments do not match constraints | ||
""" | ||
|
@@ -388,6 +423,7 @@ def __init__( | |
mu: Optional[torch.Tensor] = None, | ||
theta: Optional[torch.Tensor] = None, | ||
zi_logits: Optional[torch.Tensor] = None, | ||
scale: Optional[torch.Tensor] = None, | ||
validate_args: bool = False, | ||
): | ||
|
||
|
@@ -397,6 +433,7 @@ def __init__( | |
logits=logits, | ||
mu=mu, | ||
theta=theta, | ||
scale=scale, | ||
validate_args=validate_args, | ||
) | ||
self.zi_logits, self.mu, self.theta = broadcast_all( | ||
|
@@ -522,7 +559,7 @@ def sample( | |
# Clamping as distributions objects can have buggy behaviors when | ||
# their parameters are too high | ||
l_train = torch.clamp(p_means, max=1e8) | ||
counts = Poisson( | ||
counts = PoissonTorch( | ||
l_train | ||
).sample() # Shape : (n_samples, n_cells_batch, n_features) | ||
return counts | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use the api doc notation e.g. {class}
~scvi.module.VAE
instead of_vae.py