Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Thorough use of distributions to clean module-level code #1356

Merged
merged 29 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9884cda
replace normal parameters in inference by distribution
PierreBoyeau Aug 30, 2021
998c345
totalVI changes
PierreBoyeau Sep 16, 2021
2a0c1d1
Merge branch 'master' of https://github.com/YosefLab/scvi-tools into …
PierreBoyeau Sep 16, 2021
c9bf152
merge fixes
PierreBoyeau Sep 16, 2021
c1ebcd4
fix totalvi
PierreBoyeau Sep 17, 2021
493ff5d
fixes
PierreBoyeau Sep 17, 2021
0e1d07f
fix gimvi
PierreBoyeau Sep 17, 2021
3761bb6
merge from main
PierreBoyeau Feb 15, 2022
abffd8f
pyro
PierreBoyeau Feb 15, 2022
ce1b4b4
clean
PierreBoyeau Feb 15, 2022
1227323
clean
PierreBoyeau Feb 15, 2022
7d9be42
dummy
PierreBoyeau Feb 16, 2022
41d9b0f
log prob sum
PierreBoyeau Feb 17, 2022
4868898
fix conversion
PierreBoyeau Feb 17, 2022
6566c0e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 17, 2022
65df0aa
simplication log_prob_sum
PierreBoyeau Feb 17, 2022
de172e7
rename px
PierreBoyeau Feb 18, 2022
ea225a5
px
PierreBoyeau Feb 18, 2022
fbae55f
Optional return_dist
PierreBoyeau Feb 18, 2022
04c4138
Merge branch 'master' of https://github.com/YosefLab/scvi-tools into …
PierreBoyeau Apr 4, 2022
fedf805
feat generative refactor
PierreBoyeau Apr 4, 2022
5f48f0e
updates
PierreBoyeau Apr 4, 2022
dd2cd8b
precommit
PierreBoyeau Apr 4, 2022
a480774
Merge branch 'master' of https://github.com/scverse/scvi-tools into d…
PierreBoyeau Apr 25, 2022
af84491
docstring
PierreBoyeau May 3, 2022
6e205ea
docstring
PierreBoyeau May 3, 2022
63ee149
docstring
PierreBoyeau May 13, 2022
43b74ee
Merge branch 'master' of https://github.com/scverse/scvi-tools into d…
PierreBoyeau May 13, 2022
c581d14
release note
PierreBoyeau May 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions docs/release_notes/v0.17.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## Changes

- Experimental MuData support for {class}`~scvi.model.TOTALVI` via the method {meth}`~scvi.model.TOTALVI.setup_mudata`. For several of the existing `AnnDataField` classes, there is now a MuData counterpart with an additional `mod_key` argument used to indicate the modality where the data lives (e.g. {class}`~scvi.data.fields.LayerField` to {class}`~scvi.data.fields.MuDataLayerField`). These modified classes are simply wrapped versions of the original `AnnDataField` code via the new {method}`scvi.data.fields.MuDataWrapper` method [#1474].
- Modification of the `generative` method's outputs to return prior and likelihood properties as `torch.Distribution` objects. Concerned modules are `_amortizedlda.py`, `_autozivae.py`, `multivae.py`, `_peakvae.py`, `_scanvae.py`, `_vae.py`, and `_vaec.py`. This allows facilitating the manipulation of these distributions for model training and inference [#1356].
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use the api doc notation e.g. {class}~scvi.module.VAE instead of _vae.py


## Breaking changes

Expand All @@ -12,8 +13,11 @@

- [@jjhong922]
- [@adamgayoso]
- [@PierreBoyeau]

[#1474]: https://github.com/YosefLab/scvi-tools/pull/1474
[#1356]: https://github.com/YosefLab/scvi-tools/pull/1356

[@jjhong922]: https://github.com/jjhong922
[@adamgayoso]: https://github.com/adamgayoso
[@pierreboyeau]: https://github.com/PierreBoyeau
2 changes: 2 additions & 0 deletions scvi/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
JaxNegativeBinomialMeanDisp,
NegativeBinomial,
NegativeBinomialMixture,
Poisson,
ZeroInflatedNegativeBinomial,
)

Expand All @@ -10,4 +11,5 @@
"NegativeBinomialMixture",
"ZeroInflatedNegativeBinomial",
"JaxNegativeBinomialMeanDisp",
"Poisson",
]
45 changes: 41 additions & 4 deletions scvi/distributions/_negative_binomial.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
import torch.nn.functional as F
from numpyro.distributions import constraints as numpyro_constraints
from numpyro.distributions.util import promote_shapes, validate_sample
from torch.distributions import Distribution, Gamma, Poisson, constraints
from torch.distributions import Distribution, Gamma
from torch.distributions import Poisson as PoissonTorch
from torch.distributions import constraints
from torch.distributions.utils import (
broadcast_all,
lazy_property,
Expand Down Expand Up @@ -236,6 +238,33 @@ def _gamma(theta, mu):
return gamma_d


class Poisson(PoissonTorch):
"""
Poisson distribution.

Parameters
----------
rate
rate of the Poisson distribution.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

improve docs on optional args

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry wasnt clear, just meant to expand on the definitions which looks good now. You can remove ": optional" because sphinx takes care of that for us

validate_args : optional
whether to validate input.
scale : optional
Normalized mean expression of the distribution.
This optional parameter is not used in any computations, but allows to store
normalization expression levels.
Comment on lines +249 to +254
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PierreBoyeau @jjhong922 can we fix this?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are you talking about the docs or the optional parameter here? I can remove the :optional in the docs now


"""

def __init__(
self,
rate: torch.Tensor,
validate_args: Optional[bool] = None,
scale: Optional[torch.Tensor] = None,
):
super().__init__(rate=rate, validate_args=validate_args)
self.scale = scale


class NegativeBinomial(Distribution):
r"""
Negative binomial distribution.
Expand All @@ -262,7 +291,9 @@ class NegativeBinomial(Distribution):
Mean of the distribution.
theta
Inverse dispersion.
validate_args
scale : optional
Normalized mean expression of the distribution.
validate_args : optional
Raise ValueError if arguments do not match constraints
"""

Expand All @@ -279,6 +310,7 @@ def __init__(
logits: Optional[torch.Tensor] = None,
mu: Optional[torch.Tensor] = None,
theta: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
validate_args: bool = False,
):
self._eps = 1e-8
Expand All @@ -299,6 +331,7 @@ def __init__(
mu, theta = broadcast_all(mu, theta)
self.mu = mu
self.theta = theta
self.scale = scale
super().__init__(validate_args=validate_args)

@property
Expand All @@ -319,7 +352,7 @@ def sample(
# Clamping as distributions objects can have buggy behaviors when
# their parameters are too high
l_train = torch.clamp(p_means, max=1e8)
counts = Poisson(
counts = PoissonTorch(
l_train
).sample() # Shape : (n_samples, n_cells_batch, n_vars)
return counts
Expand Down Expand Up @@ -368,6 +401,8 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial):
Inverse dispersion.
zi_logits
Logits scale of zero inflation probability.
scale : optional
Normalized mean expression of the distribution.
validate_args
Raise ValueError if arguments do not match constraints
"""
Expand All @@ -388,6 +423,7 @@ def __init__(
mu: Optional[torch.Tensor] = None,
theta: Optional[torch.Tensor] = None,
zi_logits: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
validate_args: bool = False,
):

Expand All @@ -397,6 +433,7 @@ def __init__(
logits=logits,
mu=mu,
theta=theta,
scale=scale,
validate_args=validate_args,
)
self.zi_logits, self.mu, self.theta = broadcast_all(
Expand Down Expand Up @@ -522,7 +559,7 @@ def sample(
# Clamping as distributions objects can have buggy behaviors when
# their parameters are too high
l_train = torch.clamp(p_means, max=1e8)
counts = Poisson(
counts = PoissonTorch(
l_train
).sample() # Shape : (n_samples, n_cells_batch, n_features)
return counts
Expand Down
34 changes: 16 additions & 18 deletions scvi/external/gimvi/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
n_layers_individual=n_layers_encoder_individual,
n_layers_shared=n_layers_encoder_shared,
dropout_rate=dropout_rate_encoder,
return_dist=True,
)

self.l_encoders = ModuleList(
Expand All @@ -143,6 +144,7 @@ def __init__(
1,
n_layers=1,
dropout_rate=dropout_rate_encoder,
return_dist=True,
)
if self.model_library_bools[i]
else None
Expand Down Expand Up @@ -197,7 +199,7 @@ def sample_from_posterior_z(
else:
raise Exception("Must provide a mode when having multiple datasets")
outputs = self.inference(x, mode)
qz_m = outputs["qz_m"]
qz_m = outputs["qz"].loc
z = outputs["z"]
if deterministic:
z = qz_m
Expand Down Expand Up @@ -268,9 +270,9 @@ def sample_scale(
decode_mode = mode
inference_out = self.inference(x, mode)
if deterministic:
z = inference_out["qz_m"]
if inference_out["ql_m"] is not None:
library = inference_out["ql_m"]
z = inference_out["qz"].loc
if inference_out["ql"] is not None:
library = inference_out["ql"].loc
else:
library = inference_out["library"]
else:
Expand Down Expand Up @@ -372,14 +374,14 @@ def inference(self, x: torch.Tensor, mode: Optional[int] = None) -> dict:
if self.log_variational:
x_ = torch.log(1 + x_)

qz_m, qz_v, z = self.z_encoder(x_, mode)
ql_m, ql_v, library = None, None, None
qz, z = self.z_encoder(x_, mode)
ql, library = None, None
if self.model_library_bools[mode]:
ql_m, ql_v, library = self.l_encoders[mode](x_)
ql, library = self.l_encoders[mode](x_)
else:
library = torch.log(torch.sum(x, dim=1)).view(-1, 1)

return dict(qz_m=qz_m, qz_v=qz_v, z=z, ql_m=ql_m, ql_v=ql_v, library=library)
return dict(qz=qz, z=z, ql=ql, library=library)

@auto_move_data
def generative(
Expand Down Expand Up @@ -447,10 +449,8 @@ def loss(
x = tensors[REGISTRY_KEYS.X_KEY]
batch_index = tensors[REGISTRY_KEYS.BATCH_KEY]

qz_m = inference_outputs["qz_m"]
qz_v = inference_outputs["qz_v"]
ql_m = inference_outputs["ql_m"]
ql_v = inference_outputs["ql_v"]
qz = inference_outputs["qz"]
ql = inference_outputs["ql"]
px_rate = generative_outputs["px_rate"]
px_r = generative_outputs["px_r"]
px_dropout = generative_outputs["px_dropout"]
Expand All @@ -466,11 +466,9 @@ def loss(
)

# KL Divergence
mean = torch.zeros_like(qz_m)
scale = torch.ones_like(qz_v)
kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum(
dim=1
)
mean = torch.zeros_like(qz.loc)
scale = torch.ones_like(qz.scale)
kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1)

if self.model_library_bools[mode]:
library_log_means = getattr(self, f"library_log_means_{mode}")
Expand All @@ -483,7 +481,7 @@ def loss(
one_hot(batch_index, self.n_batch), library_log_vars
)
kl_divergence_l = kl(
Normal(ql_m, torch.sqrt(ql_v)),
ql,
Normal(local_library_log_means, local_library_log_vars.sqrt()),
).sum(dim=1)
else:
Expand Down
43 changes: 14 additions & 29 deletions scvi/model/_autozi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from anndata import AnnData
from torch import logsumexp
from torch.distributions import Beta, Normal
from torch.distributions import Beta

from scvi import REGISTRY_KEYS
from scvi._compat import Literal
Expand Down Expand Up @@ -212,13 +212,12 @@ def get_marginal_ll(
# Distribution parameters and sampled variables
inf_outputs, gen_outputs, _ = self.module.forward(tensors)

px_r = gen_outputs["px_r"]
px_rate = gen_outputs["px_rate"]
px_dropout = gen_outputs["px_dropout"]
qz_m = inf_outputs["qz_m"]
qz_v = inf_outputs["qz_v"]
px = gen_outputs["px"]
px_r = px.theta
px_rate = px.mu
px_dropout = px.zi_logits
qz = inf_outputs["qz"]
z = inf_outputs["z"]
library = inf_outputs["library"]

# Reconstruction Loss
bernoulli_params_batch = self.module.reshape_bernoulli(
Expand All @@ -235,36 +234,22 @@ def get_marginal_ll(
)

# Log-probabilities
log_prob_sum = torch.zeros(qz_m.shape[0]).to(self.device)
PierreBoyeau marked this conversation as resolved.
Show resolved Hide resolved
p_z = (
Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v))
.log_prob(z)
.sum(dim=-1)
)
p_z = gen_outputs["pz"].log_prob(z).sum(dim=-1)
p_x_zld = -reconst_loss
log_prob_sum += p_z + p_x_zld

q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1)
log_prob_sum -= q_z_x
q_z_x = qz.log_prob(z).sum(dim=-1)
log_prob_sum = p_z + p_x_zld - q_z_x

if not self.use_observed_lib_size:
ql = inf_outputs["ql"]
library = inf_outputs["library"]
(
local_library_log_means,
local_library_log_vars,
) = self.module._compute_local_library_params(batch_index)

p_l = (
Normal(
local_library_log_means.to(self.device),
local_library_log_vars.to(self.device).sqrt(),
)
.log_prob(library)
.sum(dim=-1)
)

ql_m = inf_outputs["ql_m"]
ql_v = inf_outputs["ql_v"]
q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1)
p_l = gen_outputs["pl"].log_prob(library).sum(dim=-1)

q_l_x = ql.log_prob(library).sum(dim=-1)

log_prob_sum += p_l - q_l_x

Expand Down
2 changes: 1 addition & 1 deletion scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def get_vamp_prior(
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz_m"], out["qz_v"]
mean_, var_ = out["qz"].loc, (out["qz"].scale ** 2)
mean += [mean_.cpu()]
var += [var_.cpu()]

Expand Down
6 changes: 3 additions & 3 deletions scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,14 @@ def get_latent_library_size(
post = self._make_data_loader(
adata=adata, indices=indices, batch_size=batch_size
)

libraries = []
for tensors in post:
inference_inputs = self.module._get_inference_input(tensors)
outputs = self.module.inference(**inference_inputs)
if give_mean:
ql_m = outputs["ql_m"]
ql_v = outputs["ql_v"]
library = torch.exp(ql_m + 0.5 * ql_v)
ql = outputs["ql"]
library = torch.exp(ql.loc + 0.5 * (ql.scale**2))
else:
library = outputs["library_gene"]
libraries += [library.cpu()]
Expand Down
Loading