From 9884cda5cddade035c842e8652de42b813f5d68d Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 30 Aug 2021 11:45:15 -0700 Subject: [PATCH 01/24] replace normal parameters in inference by distribution --- scvi/external/gimvi/_module.py | 32 ++++++++--------- scvi/model/_autozi.py | 12 +++---- scvi/model/_condscvi.py | 2 +- scvi/model/_totalvi.py | 5 ++- scvi/model/base/_rnamixin.py | 4 +-- scvi/model/base/_vaemixin.py | 8 ++--- scvi/module/_autozivae.py | 16 ++++----- scvi/module/_multivae.py | 26 +++++--------- scvi/module/_peakvae.py | 15 ++++---- scvi/module/_scanvae.py | 24 ++++++------- scvi/module/_vae.py | 64 ++++++++++++++++------------------ scvi/module/_vaec.py | 21 ++++------- scvi/nn/_base_components.py | 10 +++--- 13 files changed, 100 insertions(+), 139 deletions(-) diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index 9a9bfbc546..7eafdbb186 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -179,7 +179,7 @@ def sample_from_posterior_z( else: raise Exception("Must provide a mode when having multiple datasets") outputs = self.inference(x, mode) - qz_m = outputs["qz_m"] + qz_m = outputs["qz"].loc z = outputs["z"] if deterministic: z = qz_m @@ -250,9 +250,9 @@ def sample_scale( decode_mode = mode inference_out = self.inference(x, mode) if deterministic: - z = inference_out["qz_m"] - if inference_out["ql_m"] is not None: - library = inference_out["ql_m"] + z = inference_out["qz"].loc + if inference_out["ql"] is not None: + library = inference_out["ql"].loc else: library = inference_out["library"] else: @@ -354,14 +354,14 @@ def inference(self, x: torch.Tensor, mode: Optional[int] = None) -> dict: if self.log_variational: x_ = torch.log(1 + x_) - qz_m, qz_v, z = self.z_encoder(x_, mode) - ql_m, ql_v, library = None, None, None + qz, z = self.z_encoder(x_, mode) + ql, library = None, None if self.model_library_bools[mode]: - ql_m, ql_v, library = self.l_encoders[mode](x_) + ql, library = self.l_encoders[mode](x_) else: library = torch.log(torch.sum(x, dim=1)).view(-1, 1) - return dict(qz_m=qz_m, qz_v=qz_v, z=z, ql_m=ql_m, ql_v=ql_v, library=library) + return dict(qz=qz, z=z, ql=ql, library=library) @auto_move_data def generative( @@ -436,10 +436,8 @@ def loss( local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + qz = inference_outputs["qz"] + ql = inference_outputs["ql"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] @@ -455,15 +453,13 @@ def loss( ) # KL Divergence - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) if self.model_library_bools[mode]: kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), + ql, Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 5ef6df4238..feb889faba 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -201,11 +201,9 @@ def get_marginal_ll( px_r = gen_outputs["px_r"] px_rate = gen_outputs["px_rate"] px_dropout = gen_outputs["px_dropout"] - qz_m = inf_outputs["qz_m"] - qz_v = inf_outputs["qz_v"] + qz = inf_outputs["qz"] z = inf_outputs["z"] - ql_m = inf_outputs["ql_m"] - ql_v = inf_outputs["ql_v"] + ql = inf_outputs["ql"] library = inf_outputs["library"] # Reconstruction Loss @@ -232,13 +230,13 @@ def get_marginal_ll( .sum(dim=-1) ) p_z = ( - Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) + Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) .log_prob(z) .sum(dim=-1) ) p_x_zld = -reconst_loss.to(p_z.device) - q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) + q_z_x = qz.log_prob(z).sum(dim=-1) + q_l_x = ql.log_prob(library).sum(dim=-1) batch_log_lkl = torch.sum(p_x_zld + p_l + p_z - q_z_x - q_l_x, dim=0) to_sum[i] += batch_log_lkl.cpu() diff --git a/scvi/model/_condscvi.py b/scvi/model/_condscvi.py index ce2aa6aace..97f1aec1fb 100644 --- a/scvi/model/_condscvi.py +++ b/scvi/model/_condscvi.py @@ -140,7 +140,7 @@ def get_vamp_prior( x = tensors[_CONSTANTS.X_KEY] y = tensors[_CONSTANTS.LABELS_KEY] out = self.module.inference(x, y) - mean_, var_ = out["qz_m"], out["qz_v"] + mean_, var_ = out["qz"].loc, (out["qz"].scale ** 2) mean += [mean_.cpu()] var += [var_.cpu()] diff --git a/scvi/model/_totalvi.py b/scvi/model/_totalvi.py index 3a4f6da211..4ddc8319fd 100644 --- a/scvi/model/_totalvi.py +++ b/scvi/model/_totalvi.py @@ -299,10 +299,9 @@ def get_latent_library_size( for tensors in post: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - ql_m = outputs["ql_m"] - ql_v = outputs["ql_v"] + ql = outputs["ql"] if give_mean is True: - library = torch.exp(ql_m + 0.5 * ql_v) + library = torch.exp(ql.loc + 0.5 * (ql.scale ** 2)) else: library = outputs["library_gene"] libraries += [library.cpu()] diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index e76d8d8333..f2ca8abab4 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -552,8 +552,8 @@ def get_latent_library_size( inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - ql_m = outputs["ql_m"] - ql_v = outputs["ql_v"] + ql_m = outputs["ql"].loc + ql_v = outputs["ql"].scale library = outputs["library"] if give_mean is False: library = torch.exp(library) diff --git a/scvi/model/base/_vaemixin.py b/scvi/model/base/_vaemixin.py index 3e4c01dd53..7a2b2d7d91 100644 --- a/scvi/model/base/_vaemixin.py +++ b/scvi/model/base/_vaemixin.py @@ -4,7 +4,6 @@ import numpy as np import torch from anndata import AnnData -from torch.distributions import Normal from ._log_likelihood import compute_elbo, compute_reconstruction_error @@ -163,18 +162,17 @@ def get_latent_representation( for tensors in scdl: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - qz_m = outputs["qz_m"] - qz_v = outputs["qz_v"] + qz = outputs["qz"] z = outputs["z"] if give_mean: # does each model need to have this latent distribution param? if self.module.latent_distribution == "ln": - samples = Normal(qz_m, qz_v.sqrt()).sample([mc_samples]) + samples = qz.sample([mc_samples]) z = torch.nn.functional.softmax(samples, dim=-1) z = z.mean(dim=0) else: - z = qz_m + z = qz.loc latent += [z.cpu()] return torch.cat(latent).numpy() diff --git a/scvi/module/_autozivae.py b/scvi/module/_autozivae.py index 4d2b22c1e5..d30a6b1589 100644 --- a/scvi/module/_autozivae.py +++ b/scvi/module/_autozivae.py @@ -363,10 +363,8 @@ def loss( n_obs: int = 1.0, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + qz = inference_outputs["qz"] + ql = inference_outputs["ql"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] px_dropout = generative_outputs["px_dropout"] @@ -376,14 +374,12 @@ def loss( local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] # KL divergences wrt z_n,l_n - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), + ql, Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) diff --git a/scvi/module/_multivae.py b/scvi/module/_multivae.py index 2bbdb07b5c..af512ff497 100644 --- a/scvi/module/_multivae.py +++ b/scvi/module/_multivae.py @@ -291,13 +291,16 @@ def inference( categorical_input = tuple() # Z Encoders - qzm_acc, qzv_acc, z_acc = self.z_encoder_accessibility( + qz_acc, z_acc = self.z_encoder_accessibility( encoder_input_accessibility, batch_index, *categorical_input ) - qzm_expr, qzv_expr, z_expr = self.z_encoder_expression( + qz_expr, z_expr = self.z_encoder_expression( encoder_input_expression, batch_index, *categorical_input ) - + qzm_acc = qz_acc.loc + qzm_expr = qz_expr.loc + qzv_acc = qz_acc.scale ** 2 + qzv_expr = qz_expr.scale ** 2 # L encoders libsize_expr = self.l_encoder_expression( encoder_input_expression, batch_index, *categorical_input @@ -308,22 +311,9 @@ def inference( # ReFormat Outputs if n_samples > 1: - qzm_acc = qzm_acc.unsqueeze(0).expand( - (n_samples, qzm_acc.size(0), qzm_acc.size(1)) - ) - qzv_acc = qzv_acc.unsqueeze(0).expand( - (n_samples, qzv_acc.size(0), qzv_acc.size(1)) - ) - untran_za = Normal(qzm_acc, qzv_acc.sqrt()).sample() + untran_za = qz_acc.sample((n_samples,)) z_acc = self.z_encoder_accessibility.z_transformation(untran_za) - - qzm_expr = qzm_expr.unsqueeze(0).expand( - (n_samples, qzm_expr.size(0), qzm_expr.size(1)) - ) - qzv_expr = qzv_expr.unsqueeze(0).expand( - (n_samples, qzv_expr.size(0), qzv_expr.size(1)) - ) - untran_zr = Normal(qzm_expr, qzv_expr.sqrt()).sample() + untran_zr = qz_expr.sample((n_samples,)) z_expr = self.z_encoder_expression.z_transformation(untran_zr) libsize_expr = libsize_expr.unsqueeze(0).expand( diff --git a/scvi/module/_peakvae.py b/scvi/module/_peakvae.py index 2f50c74906..a0000d3b28 100644 --- a/scvi/module/_peakvae.py +++ b/scvi/module/_peakvae.py @@ -227,7 +227,7 @@ def _get_inference_input(self, tensors): def _get_generative_input(self, tensors, inference_outputs, transform_batch=None): z = inference_outputs["z"] - qz_m = inference_outputs["qz_m"] + qz_m = inference_outputs["qz"].loc batch_index = tensors[_CONSTANTS.BATCH_KEY] cont_covs = tensors.get(_CONSTANTS.CONT_COVS_KEY) @@ -269,7 +269,7 @@ def inference( encoder_input = x # if encode_covariates is False, cat_list to init encoder is None, so # batch_index is not used (or categorical_input, but it's empty) - qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) + qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) d = ( self.d_encoder(encoder_input, batch_index, *categorical_input) if self.model_depth @@ -277,13 +277,11 @@ def inference( ) if n_samples > 1: - qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) - qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) # when z is normal, untran_z == z - untran_z = Normal(qz_m, qz_v.sqrt()).sample() + untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) - return dict(d=d, qz_m=qz_m, qz_v=qz_v, z=z) + return dict(d=d, qz=qz, z=z) @auto_move_data def generative( @@ -315,13 +313,12 @@ def loss( self, tensors, inference_outputs, generative_outputs, kl_weight: float = 1.0 ): x = tensors[_CONSTANTS.X_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] + qz = inference_outputs["qz"] d = inference_outputs["d"] p = generative_outputs["p"] kld = kl_divergence( - Normal(qz_m, torch.sqrt(qz_v)), + qz, Normal(0, 1), ).sum(dim=1) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 3fdaa846d0..0231017055 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -188,9 +188,9 @@ def __init__( def classify(self, x, batch_index=None): if self.log_variational: x = torch.log(1 + x) - qz_m, _, z = self.z_encoder(x, batch_index) + qz, z = self.z_encoder(x, batch_index) # We classify using the inferred mean parameter of z_1 in the latent space - z = qz_m + z = qz.loc if self.use_labels_groups: w_g = self.classifier_groups(z) unw_y = self.classifier(z) @@ -228,11 +228,9 @@ def loss( px_r = generative_ouputs["px_r"] px_rate = generative_ouputs["px_rate"] px_dropout = generative_ouputs["px_dropout"] - qz1_m = inference_outputs["qz_m"] - qz1_v = inference_outputs["qz_v"] + qz1 = inference_outputs["qz"] z1 = inference_outputs["z"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + ql = inference_outputs["ql"] x = tensors[_CONSTANTS.X_KEY] local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] @@ -245,23 +243,21 @@ def loss( # Enumerate choices of label ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) - qz2_m, qz2_v, z2 = self.encoder_z2_z1(z1s, ys) + qz2, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) # KL Divergence - mean = torch.zeros_like(qz2_m) - scale = torch.ones_like(qz2_v) + mean = torch.zeros_like(qz2.loc) + scale = torch.ones_like(qz2.scale) - kl_divergence_z2 = kl( - Normal(qz2_m, torch.sqrt(qz2_v)), Normal(mean, scale) - ).sum(dim=1) + kl_divergence_z2 = kl(qz2, Normal(mean, scale)).sum(dim=1) loss_z1_unweight = -Normal(pz1_m, torch.sqrt(pz1_v)).log_prob(z1s).sum(dim=-1) - loss_z1_weight = Normal(qz1_m, torch.sqrt(qz1_v)).log_prob(z1).sum(dim=-1) + loss_z1_weight = qz1.log_prob(z1).sum(dim=-1) if not self.use_observed_lib_size: kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), + ql, Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index fe1914dd6d..2d71495253 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -234,8 +234,8 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() - qz_m, qz_v, z = self.z_encoder(encoder_input, batch_index, *categorical_input) - ql_m, ql_v, library_encoded = self.l_encoder( + qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) + ql, library_encoded = self.l_encoder( encoder_input, batch_index, *categorical_input ) @@ -243,21 +243,16 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): library = library_encoded if n_samples > 1: - qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) - qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) - # when z is normal, untran_z == z - untran_z = Normal(qz_m, qz_v.sqrt()).sample() + untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) - ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) - ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) if self.use_observed_lib_size: library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) else: - library = Normal(ql_m, ql_v.sqrt()).sample() + library = ql.sample((n_samples,)) - outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, ql_m=ql_m, ql_v=ql_v, library=library) + outputs = dict(z=z, qz=qz, ql=ql, library=library) return outputs @auto_move_data @@ -296,8 +291,20 @@ def generative( px_r = torch.exp(px_r) + if self.gene_likelihood == "zinb": + px_latents = ZeroInflatedNegativeBinomial( + mu=px_rate, theta=px_r, zi_logits=px_dropout + ) + elif self.gene_likelihood == "nb": + px_latents = NegativeBinomial(mu=px_rate, theta=px_r) + elif self.gene_likelihood == "poisson": + px_latents = Poisson(px_rate) return dict( - px_scale=px_scale, px_r=px_r, px_rate=px_rate, px_dropout=px_dropout + px_latents=px_latents, + px_scale=px_scale, + px_r=px_r, + px_rate=px_rate, + px_dropout=px_dropout, ) def loss( @@ -311,30 +318,22 @@ def loss( local_l_mean = tensors[_CONSTANTS.LOCAL_L_MEAN_KEY] local_l_var = tensors[_CONSTANTS.LOCAL_L_VAR_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] + qz = inference_outputs["qz"] + ql = inference_outputs["ql"] - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) - - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) if not self.use_observed_lib_size: kl_divergence_l = kl( - Normal(ql_m, torch.sqrt(ql_v)), + ql, Normal(local_l_mean, torch.sqrt(local_l_var)), ).sum(dim=1) else: kl_divergence_l = 0.0 - reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) + reconst_loss = generative_outputs["px_latents"].log_prob(x).sum(-1) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l @@ -442,11 +441,9 @@ def marginal_ll(self, tensors, n_mc_samples): for i in range(n_mc_samples): # Distribution parameters and sampled variables inference_outputs, generative_outputs, losses = self.forward(tensors) - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] + qz = inference_outputs["qz"] z = inference_outputs["z"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + ql = inference_outputs["ql"] library = inference_outputs["library"] # Reconstruction Loss @@ -455,14 +452,13 @@ def marginal_ll(self, tensors, n_mc_samples): # Log-probabilities p_l = Normal(local_l_mean, local_l_var.sqrt()).log_prob(library).sum(dim=-1) p_z = ( - Normal(torch.zeros_like(qz_m), torch.ones_like(qz_v)) + Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) .log_prob(z) .sum(dim=-1) ) p_x_zl = -reconst_loss - q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(library).sum(dim=-1) - + q_z_x = qz.log_prob(z).sum(dim=-1) + q_l_x = ql.log_prob(library).sum(dim=-1) to_sum[:, i] = p_z + p_l + p_x_zl - q_z_x - q_l_x batch_log_lkl = logsumexp(to_sum, dim=-1) - np.log(n_mc_samples) diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index 21ce18208b..d6f4c267ba 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -133,19 +133,16 @@ def inference(self, x, y, n_samples=1): if self.log_variational: x_ = torch.log(1 + x_) - qz_m, qz_v, z = self.z_encoder(x_, y) + qz, z = self.z_encoder(x_, y) if n_samples > 1: - qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) - qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) - # when z is normal, untran_z == z - untran_z = Normal(qz_m, qz_v.sqrt()).sample() + untran_z = qz.sample((n_samples,)) z = self.z_encoder.z_transformation(untran_z) library = library.unsqueeze(0).expand( (n_samples, library.size(0), library.size(1)) ) - outputs = dict(z=z, qz_m=qz_m, qz_v=qz_v, library=library) + outputs = dict(z=z, qz=qz, library=library) return outputs @auto_move_data @@ -166,22 +163,18 @@ def loss( ): x = tensors[_CONSTANTS.X_KEY] y = tensors[_CONSTANTS.LABELS_KEY] - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] + qz = inference_outputs["qz"] px_rate = generative_outputs["px_rate"] px_r = generative_outputs["px_r"] - mean = torch.zeros_like(qz_m) - scale = torch.ones_like(qz_v) + mean = torch.zeros_like(qz.loc) + scale = torch.ones_like(qz.scale) - kl_divergence_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(mean, scale)).sum( - dim=1 - ) + kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) reconst_loss = -NegativeBinomial(px_rate, logits=px_r).log_prob(x).sum(-1) scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) - return LossRecorder(loss, reconst_loss, kl_divergence_z, 0.0) @torch.no_grad() diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 7adc8b1b40..477c8d65f2 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -291,8 +291,9 @@ def forward(self, x: torch.Tensor, *cat_list: int): q = self.encoder(x, *cat_list) q_m = self.mean_encoder(q) q_v = self.var_activation(self.var_encoder(q)) + self.var_eps - latent = self.z_transformation(reparameterize_gaussian(q_m, q_v)) - return q_m, q_v, latent + dist = Normal(q_m, q_v.sqrt()) + latent = self.z_transformation(dist.rsample()) + return dist, latent # Decoder @@ -582,8 +583,9 @@ def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q_m = self.mean_encoder(q) q_v = torch.exp(self.var_encoder(q)) latent = reparameterize_gaussian(q_m, q_v) - - return q_m, q_v, latent + dist = Normal(q_m, q_v.sqrt()) + latent = dist.rsample() + return dist, latent class MultiDecoder(nn.Module): From 998c3458fe42e4ad7c2ead3114b475f2aa796c6d Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 16 Sep 2021 09:58:57 -0700 Subject: [PATCH 02/24] totalVI changes --- scvi/module/_totalvae.py | 37 ++++++++++++++----------------------- scvi/nn/_base_components.py | 9 +++++++-- 2 files changed, 21 insertions(+), 25 deletions(-) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index c9308760f4..aaae47f741 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -443,7 +443,7 @@ def inference( categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() - qz_m, qz_v, ql_m, ql_v, latent, untran_latent = self.encoder( + qz, ql, latent, untran_latent = self.encoder( encoder_input, batch_index, *categorical_input ) z = latent["z"] @@ -453,13 +453,10 @@ def inference( library_gene = latent["l"] if n_samples > 1: - qz_m = qz_m.unsqueeze(0).expand((n_samples, qz_m.size(0), qz_m.size(1))) - qz_v = qz_v.unsqueeze(0).expand((n_samples, qz_v.size(0), qz_v.size(1))) - untran_z = Normal(qz_m, qz_v.sqrt()).sample() + z = qz.sample((n_samples,)) z = self.encoder.z_transformation(untran_z) - ql_m = ql_m.unsqueeze(0).expand((n_samples, ql_m.size(0), ql_m.size(1))) - ql_v = ql_v.unsqueeze(0).expand((n_samples, ql_v.size(0), ql_v.size(1))) - untran_l = Normal(ql_m, ql_v.sqrt()).sample() + + untran_l = ql.sample((n_samples,)) if self.use_observed_lib_size: library_gene = library_gene.unsqueeze(0).expand( (n_samples, library_gene.size(0), library_gene.size(1)) @@ -499,12 +496,10 @@ def inference( self.back_mean_prior = Normal(py_back_alpha_prior, py_back_beta_prior) return dict( - qz_m=qz_m, - qz_v=qz_v, + qz=qz, z=z, untran_z=untran_z, - ql_m=ql_m, - ql_v=ql_v, + ql=ql, library_gene=library_gene, untran_l=untran_l, ) @@ -544,10 +539,8 @@ def loss( type the reconstruction loss and the Kullback divergences """ - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + qz = inference_outputs["qz"] + ql = inference_outputs["ql"] px_ = generative_outputs["px_"] py_ = generative_outputs["py_"] @@ -573,10 +566,10 @@ def loss( ) # KL Divergence - kl_div_z = kl(Normal(qz_m, torch.sqrt(qz_v)), Normal(0, 1)).sum(dim=1) + kl_div_z = kl(qz, Normal(0, 1)).sum(dim=1) if not self.use_observed_lib_size: kl_div_l_gene = kl( - Normal(ql_m, torch.sqrt(ql_v)), + ql, Normal(local_l_mean_gene, torch.sqrt(local_l_var_gene)), ).sum(dim=1) else: @@ -648,10 +641,8 @@ def marginal_ll(self, tensors, n_mc_samples): # Distribution parameters and sampled variables inference_outputs, generative_outputs, losses = self.forward(tensors) # outputs = self.module.inference(x, y, batch_index, labels) - qz_m = inference_outputs["qz_m"] - qz_v = inference_outputs["qz_v"] - ql_m = inference_outputs["ql_m"] - ql_v = inference_outputs["ql_v"] + qz = inference_outputs["qz"] + ql = inference_outputs["ql"] py_ = generative_outputs["py_"] log_library = inference_outputs["untran_l"] # really need not softmax transformed random variable @@ -672,8 +663,8 @@ def marginal_ll(self, tensors, n_mc_samples): p_z = Normal(0, 1).log_prob(z).sum(dim=-1) p_mu_back = self.back_mean_prior.log_prob(log_pro_back_mean).sum(dim=-1) p_xy_zl = -(reconst_loss_gene + reconst_loss_protein) - q_z_x = Normal(qz_m, qz_v.sqrt()).log_prob(z).sum(dim=-1) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(log_library).sum(dim=-1) + q_z_x = qz.log_prob(z).sum(dim=-1) + q_l_x = ql.log_prob(log_library).sum(dim=-1) q_mu_back = ( Normal(py_["back_alpha"], py_["back_beta"]) .log_prob(log_pro_back_mean) diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 477c8d65f2..ad074152b6 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -993,12 +993,17 @@ def forward(self, data: torch.Tensor, *cat_list: int): q = self.encoder(data, *cat_list) qz_m = self.z_mean_encoder(q) qz_v = torch.exp(self.z_var_encoder(q)) + 1e-4 - z, untran_z = self.reparameterize_transformation(qz_m, qz_v) + q_z = Normal(qz_m, qz_v.sqrt()) + untran_z = q_z.rsample() + z = self.z_transformation(untran_z) ql_gene = self.l_gene_encoder(data, *cat_list) ql_m = self.l_gene_mean_encoder(ql_gene) ql_v = torch.exp(self.l_gene_var_encoder(ql_gene)) + 1e-4 log_library_gene = torch.clamp(reparameterize_gaussian(ql_m, ql_v), max=15) + q_l = Normal(ql_m, ql_v.sqrt()) + log_library_gene = q_l.rsample() + log_library_gene = torch.clamp(log_library_gene, max=15) library_gene = self.l_transformation(log_library_gene) latent = {} @@ -1008,4 +1013,4 @@ def forward(self, data: torch.Tensor, *cat_list: int): untran_latent["z"] = untran_z untran_latent["l"] = log_library_gene - return qz_m, qz_v, ql_m, ql_v, latent, untran_latent + return q_z, q_l, latent, untran_latent From c9bf152f08270c5cd0660dc78fc624bb8706854a Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 16 Sep 2021 10:47:56 -0700 Subject: [PATCH 03/24] merge fixes --- scvi/module/_totalvae.py | 4 ++-- scvi/module/_vae.py | 10 ++-------- scvi/nn/_base_components.py | 3 +-- 3 files changed, 5 insertions(+), 12 deletions(-) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index 1205563f8c..33ea4a5276 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -674,7 +674,7 @@ def marginal_ll(self, tensors, n_mc_samples): reconst_loss_protein = reconst_loss["reconst_loss_protein"] # Log-probabilities - log_prob_sum = torch.zeros(qz_m.shape[0]).to(self.device) + log_prob_sum = torch.zeros(ql.loc.shape[0]).to(self.device) if not self.use_observed_lib_size: n_batch = self.library_log_means.shape[1] @@ -689,7 +689,7 @@ def marginal_ll(self, tensors, n_mc_samples): .log_prob(log_library) .sum(dim=-1) ) - q_l_x = Normal(ql_m, ql_v.sqrt()).log_prob(log_library).sum(dim=-1) + q_l_x = ql.log_prob(log_library).sum(dim=-1) log_prob_sum += p_l_gene - q_l_x diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 24c93fa214..9a40013206 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -274,9 +274,9 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): else: categorical_input = tuple() qz, z = self.z_encoder(encoder_input, batch_index, *categorical_input) - ql_m, ql_v = None, None + ql = None if not self.use_observed_lib_size: - ql_m, ql_v, library_encoded = self.l_encoder( + ql, library_encoded = self.l_encoder( encoder_input, batch_index, *categorical_input ) library = library_encoded @@ -289,9 +289,6 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): (n_samples, library.size(0), library.size(1)) ) else: - ql, library_encoded = self.l_encoder( - encoder_input, batch_index, *categorical_input - ) library = ql.sample((n_samples,)) outputs = dict(z=z, qz=qz, ql=ql, library=library) return outputs @@ -359,9 +356,6 @@ def loss( batch_index = tensors[_CONSTANTS.BATCH_KEY] qz = inference_outputs["qz"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] mean = torch.zeros_like(qz.loc) scale = torch.ones_like(qz.scale) diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index ad074152b6..483331a046 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -582,9 +582,8 @@ def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q_m = self.mean_encoder(q) q_v = torch.exp(self.var_encoder(q)) - latent = reparameterize_gaussian(q_m, q_v) dist = Normal(q_m, q_v.sqrt()) - latent = dist.rsample() + latent = reparameterize_gaussian(dist.rsample()) return dist, latent From c1ebcd4b400ef28925df91376486333295079b67 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 16 Sep 2021 17:56:41 -0700 Subject: [PATCH 04/24] fix totalvi --- scvi/module/_totalvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index 33ea4a5276..da9bceddc4 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -474,7 +474,7 @@ def inference( library_gene = latent["l"] if n_samples > 1: - z = qz.sample((n_samples,)) + untran_z = qz.sample((n_samples,)) z = self.encoder.z_transformation(untran_z) untran_l = ql.sample((n_samples,)) From 493ff5dcaa8231deeb783f3ddf61a15e2152d53b Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 16 Sep 2021 18:10:00 -0700 Subject: [PATCH 05/24] fixes --- scvi/model/_totalvi.py | 2 +- scvi/model/base/_rnamixin.py | 4 ++-- scvi/module/_vae.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/scvi/model/_totalvi.py b/scvi/model/_totalvi.py index ec380f0b24..f0cb226063 100644 --- a/scvi/model/_totalvi.py +++ b/scvi/model/_totalvi.py @@ -306,8 +306,8 @@ def get_latent_library_size( for tensors in post: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - ql = outputs["ql"] if give_mean is True: + ql = outputs["ql"] library = torch.exp(ql.loc + 0.5 * (ql.scale ** 2)) else: library = outputs["library_gene"] diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 2f69837e77..67e55ea73a 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -552,12 +552,12 @@ def get_latent_library_size( inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - ql_m = outputs["ql"].loc - ql_v = outputs["ql"].scale library = outputs["library"] if not give_mean: library = torch.exp(library) else: + ql_m = outputs["ql"].loc + ql_v = outputs["ql"].scale if ql_m is None or ql_v is None: raise RuntimeError( "The module for this model does not compute the posterior distribution " diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 9a40013206..ce9e809175 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -376,7 +376,7 @@ def loss( else: kl_divergence_l = 0.0 - reconst_loss = generative_outputs["px_latents"].log_prob(x).sum(-1) + reconst_loss = -generative_outputs["px_latents"].log_prob(x).sum(-1) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l From 0e1d07f36146e230379cbf8fe2cdaf8f0ad36813 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 16 Sep 2021 18:35:12 -0700 Subject: [PATCH 06/24] fix gimvi --- scvi/nn/_base_components.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 483331a046..5451054074 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -583,7 +583,7 @@ def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q_m = self.mean_encoder(q) q_v = torch.exp(self.var_encoder(q)) dist = Normal(q_m, q_v.sqrt()) - latent = reparameterize_gaussian(dist.rsample()) + latent = dist.rsample() return dist, latent @@ -999,7 +999,6 @@ def forward(self, data: torch.Tensor, *cat_list: int): ql_gene = self.l_gene_encoder(data, *cat_list) ql_m = self.l_gene_mean_encoder(ql_gene) ql_v = torch.exp(self.l_gene_var_encoder(ql_gene)) + 1e-4 - log_library_gene = torch.clamp(reparameterize_gaussian(ql_m, ql_v), max=15) q_l = Normal(ql_m, ql_v.sqrt()) log_library_gene = q_l.rsample() log_library_gene = torch.clamp(log_library_gene, max=15) From abffd8f5d246c8640114f144b22b61b6eb98ed4b Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 15 Feb 2022 15:12:39 -0800 Subject: [PATCH 07/24] pyro --- scvi/module/_amortizedlda.py | 13 ++++++------- tests/models/test_pyro.py | 4 ++-- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index 60859295d7..849b7e0e42 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -200,7 +200,9 @@ def forward( with pyro.plate( "cells", size=n_obs or self.n_obs, subsample_size=x.shape[0] ), poutine.scale(None, kl_weight): - cell_topic_posterior_mu, cell_topic_posterior_sigma, _ = self.encoder(x) + cell_topic_posterior, _ = self.encoder(x) + cell_topic_posterior_mu = cell_topic_posterior.loc + cell_topic_posterior_sigma = 2.0 * cell_topic_posterior.scale.log() pyro.sample( "log_cell_topic_dist", dist.Normal( @@ -327,12 +329,9 @@ def get_topic_distribution(self, x: torch.Tensor, n_samples: int) -> torch.Tenso ------- A `x.shape[0] x n_topics` tensor containing the normalized topic distribution. """ - ( - cell_topic_dist_mu, - cell_topic_dist_sigma, - _, - ) = self.guide.encoder(x) - cell_topic_dist_mu = cell_topic_dist_mu.detach().cpu() + cell_topic_dist, _ = self.guide.encoder(x) + cell_topic_dist_mu = cell_topic_dist.loc.detach().cpu() + cell_topic_dist_sigma = 2.0 * cell_topic_dist.scale.log() cell_topic_dist_sigma = F.softplus(cell_topic_dist_sigma.detach().cpu()) return torch.mean( F.softmax( diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index c69f7fed26..1c5a89a646 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -457,9 +457,9 @@ def guide(self, x, log_library): with pyro.plate("data", x.shape[0]): # use the encoder to get the parameters used to define q(z|x) x_ = torch.log(1 + x) - z_loc, z_scale, _ = self.encoder(x_) + qz, _ = self.encoder(x_) # sample the latent code z - pyro.sample("latent", dist.Normal(z_loc, z_scale).to_event(1)) + pyro.sample("latent", dist.Normal(qz.loc, qz.scale).to_event(1)) class FunctionBasedPyroModel(PyroSviTrainMixin, PyroSampleMixin, BaseModelClass): From ce1b4b437dc7a43698c16bc3e8fd33210c7ba866 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 15 Feb 2022 15:36:27 -0800 Subject: [PATCH 08/24] clean --- scvi/model/_autozi.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 4ea0298f91..7f07f9b8d9 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -217,8 +217,6 @@ def get_marginal_ll( px_dropout = gen_outputs["px_dropout"] qz = inf_outputs["qz"] z = inf_outputs["z"] - ql = inf_outputs["ql"] - library = inf_outputs["library"] # Reconstruction Loss bernoulli_params_batch = self.module.reshape_bernoulli( @@ -235,19 +233,18 @@ def get_marginal_ll( ) # Log-probabilities - log_prob_sum = torch.zeros(qz.loc.shape[0]).to(self.device) p_z = ( Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) .log_prob(z) .sum(dim=-1) ) p_x_zld = -reconst_loss - log_prob_sum += p_z + p_x_zld - q_z_x = qz.log_prob(z).sum(dim=-1) - log_prob_sum -= q_z_x + log_prob_sum = p_z + p_x_zld - q_z_x if not self.use_observed_lib_size: + ql = inf_outputs["ql"] + library = inf_outputs["library"] ( local_library_log_means, local_library_log_vars, From 1227323524c9567c5ad3819f73e15b5491048281 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 15 Feb 2022 15:43:23 -0800 Subject: [PATCH 09/24] clean --- scvi/model/_totalvi.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/model/_totalvi.py b/scvi/model/_totalvi.py index 7c2f243ba8..2f9e83067c 100644 --- a/scvi/model/_totalvi.py +++ b/scvi/model/_totalvi.py @@ -339,7 +339,7 @@ def get_latent_library_size( for tensors in post: inference_inputs = self.module._get_inference_input(tensors) outputs = self.module.inference(**inference_inputs) - if give_mean is True: + if give_mean: ql = outputs["ql"] library = torch.exp(ql.loc + 0.5 * (ql.scale**2)) else: From 7d9be4289230c27245552f0b78e500a8ee7d9f8e Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Tue, 15 Feb 2022 16:25:09 -0800 Subject: [PATCH 10/24] dummy --- scvi/module/_autozivae.py | 1 - 1 file changed, 1 deletion(-) diff --git a/scvi/module/_autozivae.py b/scvi/module/_autozivae.py index 5f5a0175c3..a1b575c7a6 100644 --- a/scvi/module/_autozivae.py +++ b/scvi/module/_autozivae.py @@ -372,7 +372,6 @@ def loss( bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - # KL divergences wrt z_n,l_n mean = torch.zeros_like(qz.loc) scale = torch.ones_like(qz.scale) From 41d9b0f31ff9a864f49677cf5fe186b36ece79ff Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 17 Feb 2022 11:10:33 -0800 Subject: [PATCH 11/24] log prob sum --- scvi/module/_totalvae.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index 06523cb555..16ad1f8a11 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -692,7 +692,7 @@ def marginal_ll(self, tensors, n_mc_samples): reconst_loss_protein = reconst_loss["reconst_loss_protein"] # Log-probabilities - log_prob_sum = torch.zeros(ql.loc.shape[0]).to(self.device) + log_prob_sum = torch.zeros(qz.loc.shape[0]).to(self.device) if not self.use_observed_lib_size: n_batch = self.library_log_means.shape[1] From 486889881668aab1feff7c050532a0d276c63b6f Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 17 Feb 2022 11:24:36 -0800 Subject: [PATCH 12/24] fix conversion --- scvi/module/_amortizedlda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index 849b7e0e42..10586d0d87 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -202,7 +202,7 @@ def forward( ), poutine.scale(None, kl_weight): cell_topic_posterior, _ = self.encoder(x) cell_topic_posterior_mu = cell_topic_posterior.loc - cell_topic_posterior_sigma = 2.0 * cell_topic_posterior.scale.log() + cell_topic_posterior_sigma = cell_topic_posterior.scale ** 2 pyro.sample( "log_cell_topic_dist", dist.Normal( From 6566c0ee5564af28bc3995dec09df98a574a9a2e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 17 Feb 2022 19:25:06 +0000 Subject: [PATCH 13/24] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- scvi/module/_amortizedlda.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index 10586d0d87..9c5d56a1bc 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -202,7 +202,7 @@ def forward( ), poutine.scale(None, kl_weight): cell_topic_posterior, _ = self.encoder(x) cell_topic_posterior_mu = cell_topic_posterior.loc - cell_topic_posterior_sigma = cell_topic_posterior.scale ** 2 + cell_topic_posterior_sigma = cell_topic_posterior.scale**2 pyro.sample( "log_cell_topic_dist", dist.Normal( From 65df0aaafecc72faa09fe596a985b7e4b651b0b3 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 17 Feb 2022 12:32:06 -0800 Subject: [PATCH 14/24] simplication log_prob_sum --- scvi/module/_vae.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 668e7320c3..9d2903df2d 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -516,18 +516,14 @@ def marginal_ll(self, tensors, n_mc_samples): reconst_loss = losses.reconstruction_loss # Log-probabilities - log_prob_sum = torch.zeros(qz.loc.shape[0]).to(self.device) - p_z = ( Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) .log_prob(z) .sum(dim=-1) ) p_x_zl = -reconst_loss - log_prob_sum += p_z + p_x_zl - q_z_x = qz.log_prob(z).sum(dim=-1) - log_prob_sum -= q_z_x + log_prob_sum = p_z + p_x_zl - q_z_x if not self.use_observed_lib_size: ( From de172e728b31ddedcd5d8d88f66973d1198d4dcd Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Fri, 18 Feb 2022 11:49:11 -0800 Subject: [PATCH 15/24] rename px --- scvi/module/_vae.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 9d2903df2d..64b676bdfb 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -354,15 +354,15 @@ def generative( px_r = torch.exp(px_r) if self.gene_likelihood == "zinb": - px_latents = ZeroInflatedNegativeBinomial( + px = ZeroInflatedNegativeBinomial( mu=px_rate, theta=px_r, zi_logits=px_dropout ) elif self.gene_likelihood == "nb": - px_latents = NegativeBinomial(mu=px_rate, theta=px_r) + px = NegativeBinomial(mu=px_rate, theta=px_r) elif self.gene_likelihood == "poisson": - px_latents = Poisson(px_rate) + px = Poisson(px_rate) return dict( - px_latents=px_latents, + px=px, px_scale=px_scale, px_r=px_r, px_rate=px_rate, @@ -400,7 +400,7 @@ def loss( else: kl_divergence_l = 0.0 - reconst_loss = -generative_outputs["px_latents"].log_prob(x).sum(-1) + reconst_loss = -generative_outputs["px"].log_prob(x).sum(-1) kl_local_for_warmup = kl_divergence_z kl_local_no_warmup = kl_divergence_l From ea225a5034ef9d40f97dc4c16f59d082a14a6602 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Fri, 18 Feb 2022 12:18:51 -0800 Subject: [PATCH 16/24] px --- scvi/module/_scanvae.py | 7 ++----- scvi/module/_vae.py | 37 ++++--------------------------------- 2 files changed, 6 insertions(+), 38 deletions(-) diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 1e195e14b0..5e4133677b 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -225,9 +225,7 @@ def loss( labelled_tensors=None, classification_ratio=None, ): - px_r = generative_ouputs["px_r"] - px_rate = generative_ouputs["px_rate"] - px_dropout = generative_ouputs["px_dropout"] + px = generative_ouputs["px"] qz1 = inference_outputs["qz"] z1 = inference_outputs["z"] x = tensors[REGISTRY_KEYS.X_KEY] @@ -243,8 +241,7 @@ def loss( ys, z1s = broadcast_labels(y, z1, n_broadcast=self.n_labels) qz2, z2 = self.encoder_z2_z1(z1s, ys) pz1_m, pz1_v = self.decoder_z1_z2(z2, ys) - - reconst_loss = self.get_reconstruction_loss(x, px_rate, px_r, px_dropout) + reconst_loss = -px.log_prob(x).sum(-1) # KL Divergence mean = torch.zeros_like(qz2.loc) diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 64b676bdfb..245d3ed1fe 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -101,7 +101,7 @@ def __init__( dropout_rate: float = 0.1, dispersion: str = "gene", log_variational: bool = True, - gene_likelihood: str = "zinb", + gene_likelihood: Literal["zinb", "nb", "poisson"] = "zinb", latent_distribution: str = "normal", encode_covariates: bool = False, deeply_inject_covariates: bool = True, @@ -442,34 +442,21 @@ def sample( tensor with shape (n_cells, n_genes, n_samples) """ inference_kwargs = dict(n_samples=n_samples) - inference_outputs, generative_outputs, = self.forward( + _, generative_outputs, = self.forward( tensors, inference_kwargs=inference_kwargs, compute_loss=False, ) - px_r = generative_outputs["px_r"] px_rate = generative_outputs["px_rate"] - px_dropout = generative_outputs["px_dropout"] + dist = generative_outputs["px"] if self.gene_likelihood == "poisson": l_train = px_rate l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train ) # Shape : (n_samples, n_cells_batch, n_genes) - elif self.gene_likelihood == "nb": - dist = NegativeBinomial(mu=px_rate, theta=px_r) - elif self.gene_likelihood == "zinb": - dist = ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) - else: - raise ValueError( - "{} reconstruction error not handled right now".format( - self.module.gene_likelihood - ) - ) if n_samples > 1: exprs = dist.sample().permute( [1, 2, 0] @@ -479,23 +466,6 @@ def sample( return exprs.cpu() - def get_reconstruction_loss(self, x, px_rate, px_r, px_dropout) -> torch.Tensor: - if self.gene_likelihood == "zinb": - reconst_loss = ( - -ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout - ) - .log_prob(x) - .sum(dim=-1) - ) - elif self.gene_likelihood == "nb": - reconst_loss = ( - -NegativeBinomial(mu=px_rate, theta=px_r).log_prob(x).sum(dim=-1) - ) - elif self.gene_likelihood == "poisson": - reconst_loss = -Poisson(px_rate).log_prob(x).sum(dim=-1) - return reconst_loss - @torch.no_grad() @auto_move_data def marginal_ll(self, tensors, n_mc_samples): @@ -592,6 +562,7 @@ class LDVAE(VAE): * ``'nb'`` - Negative binomial distribution * ``'zinb'`` - Zero-inflated negative binomial distribution + * ``'poisson'`` - Poisson distribution use_batch_norm Bool whether to use batch norm in decoder bias From fbae55f3e3951ae064e12679ce2a63e827dc5d4c Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Fri, 18 Feb 2022 15:53:51 -0800 Subject: [PATCH 17/24] Optional return_dist --- scvi/external/gimvi/_module.py | 2 ++ scvi/module/_amortizedlda.py | 2 +- scvi/module/_multivae.py | 2 ++ scvi/module/_peakvae.py | 1 + scvi/module/_scanvae.py | 1 + scvi/module/_vae.py | 4 ++++ scvi/module/_vaec.py | 1 + scvi/nn/_base_components.py | 12 ++++++++++-- tests/models/test_pyro.py | 1 + 9 files changed, 23 insertions(+), 3 deletions(-) diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index 484036b206..a18ca693fb 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -134,6 +134,7 @@ def __init__( n_layers_individual=n_layers_encoder_individual, n_layers_shared=n_layers_encoder_shared, dropout_rate=dropout_rate_encoder, + return_dist=True, ) self.l_encoders = ModuleList( @@ -143,6 +144,7 @@ def __init__( 1, n_layers=1, dropout_rate=dropout_rate_encoder, + return_dist=True, ) if self.model_library_bools[i] else None diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index 9c5d56a1bc..34a00de1e2 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -162,7 +162,7 @@ def __init__(self, n_input: int, n_topics: int, n_hidden: int): # Populated by PyroTrainingPlan. self.n_obs = None - self.encoder = Encoder(n_input, n_topics, distribution="ln") + self.encoder = Encoder(n_input, n_topics, distribution="ln", return_dist=True) ( topic_feature_posterior_mu, topic_feature_posterior_sigma, diff --git a/scvi/module/_multivae.py b/scvi/module/_multivae.py index ffbf540cce..4372529a33 100644 --- a/scvi/module/_multivae.py +++ b/scvi/module/_multivae.py @@ -180,6 +180,7 @@ def __init__( var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, + return_dist=True, ) ## expression encoder @@ -195,6 +196,7 @@ def __init__( var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, + return_dist=True, ) # expression decoder diff --git a/scvi/module/_peakvae.py b/scvi/module/_peakvae.py index 078a651141..52e76f9411 100644 --- a/scvi/module/_peakvae.py +++ b/scvi/module/_peakvae.py @@ -185,6 +185,7 @@ def __init__( var_eps=0, use_batch_norm=self.use_batch_norm_encoder, use_layer_norm=self.use_layer_norm_encoder, + return_dist=True, ) self.z_decoder = Decoder( diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 5e4133677b..4e50ea98c5 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -140,6 +140,7 @@ def __init__( dropout_rate=dropout_rate, use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, + return_dist=True, ) self.decoder_z1_z2 = Decoder( n_latent, diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 245d3ed1fe..74bb4dcf46 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -177,6 +177,7 @@ def __init__( use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, + return_dist=True, ) # l encoder goes from n_input-dimensional data to 1-d library size self.l_encoder = Encoder( @@ -190,6 +191,7 @@ def __init__( use_batch_norm=use_batch_norm_encoder, use_layer_norm=use_layer_norm_encoder, var_activation=var_activation, + return_dist=True, ) # decoder goes from n_latent-dimensional space to n_input-d data n_input_decoder = n_latent + n_continuous_cov @@ -611,6 +613,7 @@ def __init__( distribution=latent_distribution, use_batch_norm=True, use_layer_norm=False, + return_dist=True, ) self.l_encoder = Encoder( n_input, @@ -620,6 +623,7 @@ def __init__( dropout_rate=dropout_rate, use_batch_norm=True, use_layer_norm=False, + return_dist=True, ) self.decoder = LinearDecoderSCVI( n_latent, diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index c6f6f3b0b9..e3072eaf56 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -75,6 +75,7 @@ def __init__( inject_covariates=True, use_batch_norm=False, use_layer_norm=True, + return_dist=True, ) # decoder goes from n_latent-dimensional space to n_input-d data diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 9b8b9035ca..90d63b886b 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -244,6 +244,7 @@ def __init__( distribution: str = "normal", var_eps: float = 1e-4, var_activation: Optional[Callable] = None, + return_dist: bool = False, **kwargs, ): super().__init__() @@ -261,6 +262,7 @@ def __init__( ) self.mean_encoder = nn.Linear(n_hidden, n_output) self.var_encoder = nn.Linear(n_hidden, n_output) + self.return_dist = return_dist if distribution == "ln": self.z_transformation = nn.Softmax(dim=-1) @@ -295,7 +297,9 @@ def forward(self, x: torch.Tensor, *cat_list: int): q_v = self.var_activation(self.var_encoder(q)) + self.var_eps dist = Normal(q_m, q_v.sqrt()) latent = self.z_transformation(dist.rsample()) - return dist, latent + if self.return_dist: + return dist, latent + return q_m, q_v, latent # Decoder @@ -559,6 +563,7 @@ def __init__( n_layers_shared: int = 2, n_cat_list: Iterable[int] = None, dropout_rate: float = 0.1, + return_dist: bool = False, ): super().__init__() @@ -588,6 +593,7 @@ def __init__( self.mean_encoder = nn.Linear(n_hidden, n_output) self.var_encoder = nn.Linear(n_hidden, n_output) + self.return_dist = return_dist def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q = self.encoders[head_id](x, *cat_list) @@ -597,7 +603,9 @@ def forward(self, x: torch.Tensor, head_id: int, *cat_list: int): q_v = torch.exp(self.var_encoder(q)) dist = Normal(q_m, q_v.sqrt()) latent = dist.rsample() - return dist, latent + if self.return_dist: + return dist, latent + return q_m, q_v, latent class MultiDecoder(nn.Module): diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 1c5a89a646..094b7f3cf1 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -413,6 +413,7 @@ def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): n_layers=n_layers, n_hidden=n_hidden, dropout_rate=0.1, + return_dist=True, ) # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( From fedf8055ff82750737a041ac40e9063b3431507d Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 4 Apr 2022 10:16:56 -0700 Subject: [PATCH 18/24] feat generative refactor --- scvi/distributions/__init__.py | 2 + scvi/distributions/_negative_binomial.py | 27 +++++++++-- scvi/external/gimvi/_module.py | 6 +-- scvi/model/_autozi.py | 6 +-- scvi/model/base/_rnamixin.py | 16 +++---- scvi/module/_autozivae.py | 16 ++++--- scvi/module/_vae.py | 57 ++++++++++++------------ 7 files changed, 79 insertions(+), 51 deletions(-) diff --git a/scvi/distributions/__init__.py b/scvi/distributions/__init__.py index 884550bb3a..bea07334ca 100644 --- a/scvi/distributions/__init__.py +++ b/scvi/distributions/__init__.py @@ -2,6 +2,7 @@ JaxNegativeBinomialMeanDisp, NegativeBinomial, NegativeBinomialMixture, + Poisson, ZeroInflatedNegativeBinomial, ) @@ -10,4 +11,5 @@ "NegativeBinomialMixture", "ZeroInflatedNegativeBinomial", "JaxNegativeBinomialMeanDisp", + "Poisson", ] diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index 21b2d47958..8319cba7d1 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -8,7 +8,9 @@ import torch.nn.functional as F from numpyro.distributions import constraints as numpyro_constraints from numpyro.distributions.util import promote_shapes, validate_sample -from torch.distributions import Distribution, Gamma, Poisson, constraints +from torch.distributions import Distribution, Gamma +from torch.distributions import Poisson as PoissonTorch +from torch.distributions import constraints from torch.distributions.utils import ( broadcast_all, lazy_property, @@ -236,6 +238,21 @@ def _gamma(theta, mu): return gamma_d +class Poisson(PoissonTorch): + r""" + Poisson distribution. + + Parameters + ---------- + rate + rate of the Poisson distribution. + """ + + def __init__(self, rate, validate_args=None, scale: Optional[torch.Tensor] = None): + super().__init__(rate=rate, validate_args=validate_args) + self.scale = scale + + class NegativeBinomial(Distribution): r""" Negative binomial distribution. @@ -279,6 +296,7 @@ def __init__( logits: Optional[torch.Tensor] = None, mu: Optional[torch.Tensor] = None, theta: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, validate_args: bool = False, ): self._eps = 1e-8 @@ -299,6 +317,7 @@ def __init__( mu, theta = broadcast_all(mu, theta) self.mu = mu self.theta = theta + self.scale = scale super().__init__(validate_args=validate_args) @property @@ -319,7 +338,7 @@ def sample( # Clamping as distributions objects can have buggy behaviors when # their parameters are too high l_train = torch.clamp(p_means, max=1e8) - counts = Poisson( + counts = PoissonTorch( l_train ).sample() # Shape : (n_samples, n_cells_batch, n_vars) return counts @@ -388,6 +407,7 @@ def __init__( mu: Optional[torch.Tensor] = None, theta: Optional[torch.Tensor] = None, zi_logits: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, validate_args: bool = False, ): @@ -397,6 +417,7 @@ def __init__( logits=logits, mu=mu, theta=theta, + scale=scale, validate_args=validate_args, ) self.zi_logits, self.mu, self.theta = broadcast_all( @@ -522,7 +543,7 @@ def sample( # Clamping as distributions objects can have buggy behaviors when # their parameters are too high l_train = torch.clamp(p_means, max=1e8) - counts = Poisson( + counts = PoissonTorch( l_train ).sample() # Shape : (n_samples, n_cells_batch, n_features) return counts diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index a18ca693fb..f55034c9b2 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -451,9 +451,9 @@ def loss( qz = inference_outputs["qz"] ql = inference_outputs["ql"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] + px_rate = generative_outputs["px"].mu + px_r = generative_outputs["px"].theta + px_dropout = generative_outputs["px"].zi_logits # mask loss to observed genes mapping_indices = self.indices_mappings[mode] diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 32d0a82552..976c0195b3 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -212,9 +212,9 @@ def get_marginal_ll( # Distribution parameters and sampled variables inf_outputs, gen_outputs, _ = self.module.forward(tensors) - px_r = gen_outputs["px_r"] - px_rate = gen_outputs["px_rate"] - px_dropout = gen_outputs["px_dropout"] + px_r = gen_outputs["px"].theta + px_rate = gen_outputs["px"].mu + px_dropout = gen_outputs["px"].zi_logits qz = inf_outputs["qz"] z = inf_outputs["z"] diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 54d0fca9ea..01937fce1e 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -114,10 +114,10 @@ def get_normalized_expression( ) return_numpy = True if library_size == "latent": - generative_output_key = "px_rate" + generative_output_key = "rate" scaling = 1 else: - generative_output_key = "px_scale" + generative_output_key = "scale" scaling = library_size exprs = [] @@ -132,7 +132,7 @@ def get_normalized_expression( generative_kwargs=generative_kwargs, compute_loss=False, ) - output = generative_outputs[generative_output_key] + output = getattr(generative_outputs["px"], generative_output_key) output = output[..., gene_mask] output *= scaling output = output.cpu().numpy() @@ -338,8 +338,8 @@ def _get_denoised_samples( generative_kwargs=generative_kwargs, compute_loss=False, ) - px_scale = generative_outputs["px_scale"] - px_r = generative_outputs["px_r"] + px_scale = generative_outputs["px"].scale + px_r = generative_outputs["px"].theta device = px_r.device rate = rna_size_factor * px_scale @@ -483,9 +483,9 @@ def get_likelihood_parameters( inference_kwargs=inference_kwargs, compute_loss=False, ) - px_r = generative_outputs["px_r"] - px_rate = generative_outputs["px_rate"] - px_dropout = generative_outputs["px_dropout"] + px_r = generative_outputs["px"].theta + px_rate = generative_outputs["px"].mu + px_dropout = generative_outputs["px"].zi_probs n_batch = px_rate.size(0) if n_samples == 1 else px_rate.size(1) diff --git a/scvi/module/_autozivae.py b/scvi/module/_autozivae.py index a1b575c7a6..4c6fef71a2 100644 --- a/scvi/module/_autozivae.py +++ b/scvi/module/_autozivae.py @@ -302,8 +302,14 @@ def generative( size_factor=size_factor, ) # Rescale dropout - outputs["px_dropout"] = self.rescale_dropout( - outputs["px_dropout"], eps_log=eps_log + rescaled_dropout = self.rescale_dropout( + outputs["px"].zi_logits, eps_log=eps_log + ) + outputs["px"] = ZeroInflatedNegativeBinomial( + mu=outputs["px"].mu, + theta=outputs["px"].theta, + zi_logits=rescaled_dropout, + scale=outputs["px"].scale, ) # Bernoulli parameters @@ -366,9 +372,9 @@ def loss( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution qz = inference_outputs["qz"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] - px_dropout = generative_outputs["px_dropout"] + px_rate = generative_outputs["px"].mu + px_r = generative_outputs["px"].theta + px_dropout = generative_outputs["px"].zi_logits bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index d0ff5aeabd..ab92fd35f1 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -6,12 +6,12 @@ import torch import torch.nn.functional as F from torch import logsumexp -from torch.distributions import Normal, Poisson +from torch.distributions import Normal from torch.distributions import kl_divergence as kl from scvi import REGISTRY_KEYS from scvi._compat import Literal -from scvi.distributions import NegativeBinomial, ZeroInflatedNegativeBinomial +from scvi.distributions import NegativeBinomial, Poisson, ZeroInflatedNegativeBinomial from scvi.module.base import BaseModuleClass, LossRecorder, auto_move_data from scvi.nn import DecoderSCVI, Encoder, LinearDecoderSCVI, one_hot @@ -324,6 +324,7 @@ def generative( ): """Runs the generative model.""" # TODO: refactor forward function to not rely on y + # Likelihood distribution decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1) if cat_covs is not None: categorical_input = torch.split(cat_covs, 1, dim=1) @@ -357,18 +358,30 @@ def generative( if self.gene_likelihood == "zinb": px = ZeroInflatedNegativeBinomial( - mu=px_rate, theta=px_r, zi_logits=px_dropout + mu=px_rate, + theta=px_r, + zi_logits=px_dropout, + scale=px_scale, ) elif self.gene_likelihood == "nb": - px = NegativeBinomial(mu=px_rate, theta=px_r) + px = NegativeBinomial(mu=px_rate, theta=px_r, scale=px_scale) elif self.gene_likelihood == "poisson": - px = Poisson(px_rate) + px = Poisson(px_rate, scale=px_scale) + + # Priors + if self.use_observed_lib_size: + pl = None + else: + ( + local_library_log_means, + local_library_log_vars, + ) = self._compute_local_library_params(batch_index) + pl = Normal(local_library_log_means, local_library_log_vars.sqrt()) + pz = Normal(torch.zeros_like(z), torch.ones_like(z)) return dict( px=px, - px_scale=px_scale, - px_r=px_r, - px_rate=px_rate, - px_dropout=px_dropout, + pl=pl, + pz=pz, ) def loss( @@ -379,25 +392,13 @@ def loss( kl_weight: float = 1.0, ): x = tensors[REGISTRY_KEYS.X_KEY] - batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] - - qz = inference_outputs["qz"] - - mean = torch.zeros_like(qz.loc) - scale = torch.ones_like(qz.scale) - - kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) - + kl_divergence_z = kl(inference_outputs["qz"], generative_outputs["pz"]).sum( + dim=1 + ) if not self.use_observed_lib_size: - ql = inference_outputs["ql"] - ( - local_library_log_means, - local_library_log_vars, - ) = self._compute_local_library_params(batch_index) - kl_divergence_l = kl( - ql, - Normal(local_library_log_means, local_library_log_vars.sqrt()), + inference_outputs["ql"], + generative_outputs["pl"], ).sum(dim=1) else: kl_divergence_l = 0.0 @@ -450,11 +451,9 @@ def sample( compute_loss=False, ) - px_rate = generative_outputs["px_rate"] - dist = generative_outputs["px"] if self.gene_likelihood == "poisson": - l_train = px_rate + l_train = generative_outputs["px"].mu l_train = torch.clamp(l_train, max=1e8) dist = torch.distributions.Poisson( l_train From 5f48f0e6cb76aaabe2366ec32274dd532348db09 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 4 Apr 2022 10:40:26 -0700 Subject: [PATCH 19/24] updates --- scvi/external/gimvi/_module.py | 6 +++--- scvi/model/_autozi.py | 15 ++------------- scvi/model/base/_rnamixin.py | 2 +- scvi/module/_vaec.py | 9 ++++----- 4 files changed, 10 insertions(+), 22 deletions(-) diff --git a/scvi/external/gimvi/_module.py b/scvi/external/gimvi/_module.py index f55034c9b2..a18ca693fb 100644 --- a/scvi/external/gimvi/_module.py +++ b/scvi/external/gimvi/_module.py @@ -451,9 +451,9 @@ def loss( qz = inference_outputs["qz"] ql = inference_outputs["ql"] - px_rate = generative_outputs["px"].mu - px_r = generative_outputs["px"].theta - px_dropout = generative_outputs["px"].zi_logits + px_rate = generative_outputs["px_rate"] + px_r = generative_outputs["px_r"] + px_dropout = generative_outputs["px_dropout"] # mask loss to observed genes mapping_indices = self.indices_mappings[mode] diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 976c0195b3..075d08066c 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -233,11 +233,7 @@ def get_marginal_ll( ) # Log-probabilities - p_z = ( - Normal(torch.zeros_like(qz.loc), torch.ones_like(qz.scale)) - .log_prob(z) - .sum(dim=-1) - ) + p_z = gen_outputs["pz"].log_prob(z).sum(dim=-1) p_x_zld = -reconst_loss q_z_x = qz.log_prob(z).sum(dim=-1) log_prob_sum = p_z + p_x_zld - q_z_x @@ -250,14 +246,7 @@ def get_marginal_ll( local_library_log_vars, ) = self.module._compute_local_library_params(batch_index) - p_l = ( - Normal( - local_library_log_means.to(self.device), - local_library_log_vars.to(self.device).sqrt(), - ) - .log_prob(library) - .sum(dim=-1) - ) + p_l = gen_outputs["pl"].log_prob(library).sum(dim=-1) q_l_x = ql.log_prob(library).sum(dim=-1) diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 01937fce1e..501d9cbb00 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -114,7 +114,7 @@ def get_normalized_expression( ) return_numpy = True if library_size == "latent": - generative_output_key = "rate" + generative_output_key = "mu" scaling = 1 else: generative_output_key = "scale" diff --git a/scvi/module/_vaec.py b/scvi/module/_vaec.py index e3072eaf56..0c568b9279 100644 --- a/scvi/module/_vaec.py +++ b/scvi/module/_vaec.py @@ -152,8 +152,8 @@ def generative(self, z, library, y): h = self.decoder(z, y) px_scale = self.px_decoder(h) px_rate = library * px_scale - - return dict(px_scale=px_scale, px_r=self.px_r, px_rate=px_rate) + px = NegativeBinomial(px_rate, logits=self.px_r) + return dict(px=px) def loss( self, @@ -165,15 +165,14 @@ def loss( x = tensors[REGISTRY_KEYS.X_KEY] y = tensors[REGISTRY_KEYS.LABELS_KEY] qz = inference_outputs["qz"] - px_rate = generative_outputs["px_rate"] - px_r = generative_outputs["px_r"] + px = generative_outputs["px"] mean = torch.zeros_like(qz.loc) scale = torch.ones_like(qz.scale) kl_divergence_z = kl(qz, Normal(mean, scale)).sum(dim=1) - reconst_loss = -NegativeBinomial(px_rate, logits=px_r).log_prob(x).sum(-1) + reconst_loss = -px.log_prob(x).sum(-1) scaling_factor = self.ct_weight[y.long()[:, 0]] loss = torch.mean(scaling_factor * (reconst_loss + kl_weight * kl_divergence_z)) From dd2cd8bb0fe71a835ee1a2fecb68e83e07af00e8 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 4 Apr 2022 10:43:38 -0700 Subject: [PATCH 20/24] precommit --- scvi/model/_autozi.py | 2 +- scvi/model/_totalvi.py | 1 + scvi/model/base/_rnamixin.py | 1 + scvi/module/_scanvae.py | 1 + scvi/module/_totalvae.py | 1 + tests/models/test_pyro.py | 1 + 6 files changed, 6 insertions(+), 1 deletion(-) diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 075d08066c..3313c9dcbb 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -5,7 +5,7 @@ import torch from anndata import AnnData from torch import logsumexp -from torch.distributions import Beta, Normal +from torch.distributions import Beta from scvi import REGISTRY_KEYS from scvi._compat import Literal diff --git a/scvi/model/_totalvi.py b/scvi/model/_totalvi.py index b5e0fcae15..5cb5645757 100644 --- a/scvi/model/_totalvi.py +++ b/scvi/model/_totalvi.py @@ -335,6 +335,7 @@ def get_latent_library_size( post = self._make_data_loader( adata=adata, indices=indices, batch_size=batch_size ) + libraries = [] for tensors in post: inference_inputs = self.module._get_inference_input(tensors) diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index 501d9cbb00..205e5d7d21 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -557,6 +557,7 @@ def get_latent_library_size( else: ql = outputs["ql"] if ql is None: + raise RuntimeError( "The module for this model does not compute the posterior distribution " "for the library size. Set `give_mean` to False to use the observed library size instead." diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 080edd26e5..b5603a9509 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -142,6 +142,7 @@ def __init__( use_layer_norm=use_layer_norm_encoder, return_dist=True, ) + self.decoder_z1_z2 = Decoder( n_latent, n_latent, diff --git a/scvi/module/_totalvae.py b/scvi/module/_totalvae.py index 16ad1f8a11..8beb7ad08a 100644 --- a/scvi/module/_totalvae.py +++ b/scvi/module/_totalvae.py @@ -483,6 +483,7 @@ def inference( qz, ql, latent, untran_latent = self.encoder( encoder_input, batch_index, *categorical_input ) + z = latent["z"] untran_z = untran_latent["z"] untran_l = untran_latent["l"] diff --git a/tests/models/test_pyro.py b/tests/models/test_pyro.py index 701277405d..5864e14024 100644 --- a/tests/models/test_pyro.py +++ b/tests/models/test_pyro.py @@ -414,6 +414,7 @@ def __init__(self, n_input: int, n_latent: int, n_hidden: int, n_layers: int): dropout_rate=0.1, return_dist=True, ) + # decoder goes from n_latent-dimensional space to n_input-d data self.decoder = DecoderSCVI( n_latent, From af84491072357673c8bb62c9b7ce6d21e536cc07 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 2 May 2022 18:44:21 -0700 Subject: [PATCH 21/24] docstring --- scvi/distributions/_negative_binomial.py | 6 +++++- scvi/nn/_base_components.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index 8319cba7d1..c8021bf45d 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -239,13 +239,17 @@ def _gamma(theta, mu): class Poisson(PoissonTorch): - r""" + """ Poisson distribution. Parameters ---------- rate rate of the Poisson distribution. + validate_args + whether to validate input. + scale + Normalized mean expression of the Poisson distribution. """ def __init__(self, rate, validate_args=None, scale: Optional[torch.Tensor] = None): diff --git a/scvi/nn/_base_components.py b/scvi/nn/_base_components.py index 90d63b886b..b47c6732cb 100644 --- a/scvi/nn/_base_components.py +++ b/scvi/nn/_base_components.py @@ -229,6 +229,8 @@ class Encoder(nn.Module): var_activation Callable used to ensure positivity of the variance. When `None`, defaults to `torch.exp`. + return_dist + If `True`, returns directly the distribution of z instead of its parameters. **kwargs Keyword args for :class:`~scvi.module._base.FCLayers` """ From 6e205ead50fb101989687fb52ffb8bac280b74f2 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Mon, 2 May 2022 19:02:31 -0700 Subject: [PATCH 22/24] docstring --- scvi/distributions/_negative_binomial.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index c8021bf45d..4eb74cefc8 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -249,7 +249,7 @@ class Poisson(PoissonTorch): validate_args whether to validate input. scale - Normalized mean expression of the Poisson distribution. + Normalized mean expression of the distribution. """ def __init__(self, rate, validate_args=None, scale: Optional[torch.Tensor] = None): @@ -283,6 +283,8 @@ class NegativeBinomial(Distribution): Mean of the distribution. theta Inverse dispersion. + scale + Normalized mean expression of the distribution. validate_args Raise ValueError if arguments do not match constraints """ @@ -391,6 +393,8 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial): Inverse dispersion. zi_logits Logits scale of zero inflation probability. + scale + Normalized mean expression of the distribution. validate_args Raise ValueError if arguments do not match constraints """ From 63ee1494ca8307643feae37f2285f99742459fdd Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 12 May 2022 18:25:04 -0700 Subject: [PATCH 23/24] docstring --- scvi/distributions/_negative_binomial.py | 20 ++++++++++++++------ scvi/model/_autozi.py | 7 ++++--- scvi/model/base/_rnamixin.py | 7 ++++--- scvi/module/_autozivae.py | 7 ++++--- 4 files changed, 26 insertions(+), 15 deletions(-) diff --git a/scvi/distributions/_negative_binomial.py b/scvi/distributions/_negative_binomial.py index 4eb74cefc8..8a9b490b8d 100644 --- a/scvi/distributions/_negative_binomial.py +++ b/scvi/distributions/_negative_binomial.py @@ -246,13 +246,21 @@ class Poisson(PoissonTorch): ---------- rate rate of the Poisson distribution. - validate_args + validate_args : optional whether to validate input. - scale + scale : optional Normalized mean expression of the distribution. + This optional parameter is not used in any computations, but allows to store + normalization expression levels. + """ - def __init__(self, rate, validate_args=None, scale: Optional[torch.Tensor] = None): + def __init__( + self, + rate: torch.Tensor, + validate_args: Optional[bool] = None, + scale: Optional[torch.Tensor] = None, + ): super().__init__(rate=rate, validate_args=validate_args) self.scale = scale @@ -283,9 +291,9 @@ class NegativeBinomial(Distribution): Mean of the distribution. theta Inverse dispersion. - scale + scale : optional Normalized mean expression of the distribution. - validate_args + validate_args : optional Raise ValueError if arguments do not match constraints """ @@ -393,7 +401,7 @@ class ZeroInflatedNegativeBinomial(NegativeBinomial): Inverse dispersion. zi_logits Logits scale of zero inflation probability. - scale + scale : optional Normalized mean expression of the distribution. validate_args Raise ValueError if arguments do not match constraints diff --git a/scvi/model/_autozi.py b/scvi/model/_autozi.py index 3313c9dcbb..08edc5eeba 100644 --- a/scvi/model/_autozi.py +++ b/scvi/model/_autozi.py @@ -212,9 +212,10 @@ def get_marginal_ll( # Distribution parameters and sampled variables inf_outputs, gen_outputs, _ = self.module.forward(tensors) - px_r = gen_outputs["px"].theta - px_rate = gen_outputs["px"].mu - px_dropout = gen_outputs["px"].zi_logits + px = gen_outputs["px"] + px_r = px.theta + px_rate = px.mu + px_dropout = px.zi_logits qz = inf_outputs["qz"] z = inf_outputs["z"] diff --git a/scvi/model/base/_rnamixin.py b/scvi/model/base/_rnamixin.py index ab4a7de593..f070616601 100644 --- a/scvi/model/base/_rnamixin.py +++ b/scvi/model/base/_rnamixin.py @@ -483,9 +483,10 @@ def get_likelihood_parameters( inference_kwargs=inference_kwargs, compute_loss=False, ) - px_r = generative_outputs["px"].theta - px_rate = generative_outputs["px"].mu - px_dropout = generative_outputs["px"].zi_probs + px = generative_outputs["px"] + px_r = px.theta + px_rate = px.mu + px_dropout = px.zi_probs n_batch = px_rate.size(0) if n_samples == 1 else px_rate.size(1) diff --git a/scvi/module/_autozivae.py b/scvi/module/_autozivae.py index 4c6fef71a2..172a0053eb 100644 --- a/scvi/module/_autozivae.py +++ b/scvi/module/_autozivae.py @@ -372,9 +372,10 @@ def loss( ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: # Parameters for z latent distribution qz = inference_outputs["qz"] - px_rate = generative_outputs["px"].mu - px_r = generative_outputs["px"].theta - px_dropout = generative_outputs["px"].zi_logits + px = generative_outputs["px"] + px_rate = px.mu + px_r = px.theta + px_dropout = px.zi_logits bernoulli_params = generative_outputs["bernoulli_params"] x = tensors[REGISTRY_KEYS.X_KEY] batch_index = tensors[REGISTRY_KEYS.BATCH_KEY] From c581d145a1b7059bd78812c10243053947374338 Mon Sep 17 00:00:00 2001 From: Pierre Boyeau Date: Thu, 12 May 2022 18:33:30 -0700 Subject: [PATCH 24/24] release note --- docs/release_notes/v0.17.0.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/docs/release_notes/v0.17.0.md b/docs/release_notes/v0.17.0.md index 606fc8be9e..6fc7cd5fdc 100644 --- a/docs/release_notes/v0.17.0.md +++ b/docs/release_notes/v0.17.0.md @@ -3,6 +3,7 @@ ## Changes - Experimental MuData support for {class}`~scvi.model.TOTALVI` via the method {meth}`~scvi.model.TOTALVI.setup_mudata`. For several of the existing `AnnDataField` classes, there is now a MuData counterpart with an additional `mod_key` argument used to indicate the modality where the data lives (e.g. {class}`~scvi.data.fields.LayerField` to {class}`~scvi.data.fields.MuDataLayerField`). These modified classes are simply wrapped versions of the original `AnnDataField` code via the new {method}`scvi.data.fields.MuDataWrapper` method [#1474]. +- Modification of the `generative` method's outputs to return prior and likelihood properties as `torch.Distribution` objects. Concerned modules are `_amortizedlda.py`, `_autozivae.py`, `multivae.py`, `_peakvae.py`, `_scanvae.py`, `_vae.py`, and `_vaec.py`. This allows facilitating the manipulation of these distributions for model training and inference [#1356]. ## Breaking changes @@ -12,8 +13,11 @@ - [@jjhong922] - [@adamgayoso] +- [@PierreBoyeau] [#1474]: https://github.com/YosefLab/scvi-tools/pull/1474 +[#1356]: https://github.com/YosefLab/scvi-tools/pull/1356 [@jjhong922]: https://github.com/jjhong922 [@adamgayoso]: https://github.com/adamgayoso +[@pierreboyeau]: https://github.com/PierreBoyeau