Skip to content

Commit

Permalink
Fix dimension mismatch with continuous covariates in generative (#1548
Browse files Browse the repository at this point in the history
)

* breaking test

* fix issue when cont covs dims do not align

* update release note
  • Loading branch information
justjhong authored Jun 21, 2022
1 parent aabf14d commit f7673a6
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 8 deletions.
2 changes: 2 additions & 0 deletions docs/release_notes/v0.17.0.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
## Breaking changes

## Bug Fixes
- Fix issue with {method}`~scvi.model.SCVI.get_normalized_expression` with multiple samples and additional continuous covariates. This bug originated from {method}`~scvi.module.VAE.generative` failing to match the dimensions of the continuous covariates with the input when `n_samples>1` in {method}`~scvi.module.VAE.inference` in multiple module classes [#1548].

## Contributors

Expand All @@ -21,6 +22,7 @@
[#1356]: https://github.com/YosefLab/scvi-tools/pull/1356
[#1474]: https://github.com/YosefLab/scvi-tools/pull/1474
[#1542]: https://github.com/YosefLab/scvi-tools/pull/1542
[#1548]: https://github.com/YosefLab/scvi-tools/pull/1548

[@jjhong922]: https://github.com/jjhong922
[@adamgayoso]: https://github.com/adamgayoso
Expand Down
11 changes: 8 additions & 3 deletions scvi/module/_multivae.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,14 @@ def generative(
categorical_input = tuple()

latent = z if not use_z_mean else qz_m
decoder_input = (
latent if cont_covs is None else torch.cat([latent, cont_covs], dim=-1)
)
if cont_covs is None:
decoder_input = latent
elif latent.dim() != cont_covs.dim():
decoder_input = torch.cat(
[latent, cont_covs.unsqueeze(0).expand(latent.size(0), -1, -1)], dim=-1
)
else:
decoder_input = torch.cat([latent, cont_covs], dim=-1)

# Accessibility Decoder
p = self.z_decoder_accessibility(decoder_input, batch_index, *categorical_input)
Expand Down
11 changes: 8 additions & 3 deletions scvi/module/_peakvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,9 +301,14 @@ def generative(
categorical_input = tuple()

latent = z if not use_z_mean else qz_m
decoder_input = (
latent if cont_covs is None else torch.cat([latent, cont_covs], dim=-1)
)
if cont_covs is None:
decoder_input = latent
elif latent.dim() != cont_covs.dim():
decoder_input = torch.cat(
[latent, cont_covs.unsqueeze(0).expand(latent.size(0), -1, -1)], dim=-1
)
else:
decoder_input = torch.cat([latent, cont_covs], dim=-1)

p = self.z_decoder(decoder_input, batch_index, *categorical_input)

Expand Down
10 changes: 9 additions & 1 deletion scvi/module/_totalvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,15 @@ def generative(
size_factor=None,
transform_batch: Optional[int] = None,
) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1)
if cont_covs is None:
decoder_input = z
elif z.dim() != cont_covs.dim():
decoder_input = torch.cat(
[z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1
)
else:
decoder_input = torch.cat([z, cont_covs], dim=-1)

if cat_covs is not None:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
Expand Down
10 changes: 9 additions & 1 deletion scvi/module/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,15 @@ def generative(
"""Runs the generative model."""
# TODO: refactor forward function to not rely on y
# Likelihood distribution
decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1)
if cont_covs is None:
decoder_input = z
elif z.dim() != cont_covs.dim():
decoder_input = torch.cat(
[z, cont_covs.unsqueeze(0).expand(z.size(0), -1, -1)], dim=-1
)
else:
decoder_input = torch.cat([z, cont_covs], dim=-1)

if cat_covs is not None:
categorical_input = torch.split(cat_covs, 1, dim=1)
else:
Expand Down
19 changes: 19 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def test_scvi(save_path):
model.get_marginal_ll(n_mc_samples=3)
model.get_reconstruction_error()
model.get_normalized_expression(transform_batch="batch_1")
model.get_normalized_expression(n_samples=2)

adata2 = synthetic_iid()
# test view_anndata_setup with different anndata before transfer setup
Expand Down Expand Up @@ -1222,6 +1223,12 @@ def test_multiple_covariates_scvi(save_path):
)
m = SCVI(adata)
m.train(1)
m.get_latent_representation()
m.get_elbo()
m.get_marginal_ll(n_mc_samples=3)
m.get_reconstruction_error()
m.get_normalized_expression(n_samples=1)
m.get_normalized_expression(n_samples=2)

SCANVI.setup_anndata(
adata,
Expand All @@ -1233,6 +1240,12 @@ def test_multiple_covariates_scvi(save_path):
)
m = SCANVI(adata)
m.train(1)
m.get_latent_representation()
m.get_elbo()
m.get_marginal_ll(n_mc_samples=3)
m.get_reconstruction_error()
m.get_normalized_expression(n_samples=1)
m.get_normalized_expression(n_samples=2)

TOTALVI.setup_anndata(
adata,
Expand All @@ -1244,6 +1257,12 @@ def test_multiple_covariates_scvi(save_path):
)
m = TOTALVI(adata)
m.train(1)
m.get_latent_representation()
m.get_elbo()
m.get_marginal_ll(n_mc_samples=3)
m.get_reconstruction_error()
m.get_normalized_expression(n_samples=1)
m.get_normalized_expression(n_samples=2)


def test_multiple_encoded_covariates_scvi(save_path):
Expand Down

0 comments on commit f7673a6

Please sign in to comment.