Skip to content

Commit

Permalink
Backport PR scverse#1702: Quick fix in poisson sample() function for vae
Browse files Browse the repository at this point in the history
  • Loading branch information
ricomnl authored and meeseeksmachine committed Sep 20, 2022
1 parent 3e2525b commit e490809
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 1 deletion.
2 changes: 1 addition & 1 deletion scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ def sample(

dist = generative_outputs["px"]
if self.gene_likelihood == "poisson":
l_train = generative_outputs["px"].mu
l_train = generative_outputs["px"].rate
l_train = torch.clamp(l_train, max=1e8)
dist = torch.distributions.Poisson(
l_train
Expand Down
8 changes: 8 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,14 @@ def test_scvi(save_path):
model = SCVI(adata, gene_likelihood="nb")
model.get_likelihood_parameters()

# test different gene_likelihoods
for gene_likelihood in ["zinb", "nb", "poisson"]:
model = SCVI(adata, gene_likelihood=gene_likelihood)
model.train(1, check_val_every_n_epoch=1, train_size=0.5)
model.posterior_predictive_sample()
model.get_latent_representation()
model.get_normalized_expression()

# test train callbacks work
a = synthetic_iid()
SCVI.setup_anndata(
Expand Down

0 comments on commit e490809

Please sign in to comment.