Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SCVI get_normalized_expression() error when including continuous covariates #1488

Closed
nrclaudio opened this issue Apr 6, 2022 · 0 comments · Fixed by #1548
Closed

SCVI get_normalized_expression() error when including continuous covariates #1488

nrclaudio opened this issue Apr 6, 2022 · 0 comments · Fixed by #1548
Assignees
Labels
Milestone

Comments

@nrclaudio
Copy link

nrclaudio commented Apr 6, 2022

If we include a continious covariate in the model (e.g. percent.mito) and then try to obtain normalized expression values with .get_normalized_expression() with more than 1 Monte Carlo sample n_samples > 1, an error occurs.

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import scvi
import scanpy as sc

adata = sc.read(
    "data/lung_atlas.h5ad",
    backup_url="https://figshare.com/ndownloader/files/24539942",
)


adata.raw = adata  # keep full dimension safe
sc.pp.highly_variable_genes(
    adata, 
    flavor="seurat_v3", 
    n_top_genes=2000, 
    layer="counts", 
    batch_key="batch",
    subset=True
)

adata.obs_names_make_unique()

scvi.model.SCVI.setup_anndata(
    adata,
    layer="counts",
    batch_key="batch",
    continuous_covariate_keys=["percent.mito"]
)

vae = scvi.model.SCVI(adata, n_layers=2, n_latent=30, gene_likelihood="nb")

vae.train()

# This chunk runs successfully
rna = vae.get_normalized_expression(n_samples=1)

# This chunk gives error
rna = vae.get_normalized_expression(n_samples=2)
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
Input In [7], in <cell line: 2>()
      1 # This chunk gives error
----> 2 rna = vae.get_normalized_expression(n_samples=2)

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/model/base/_rnamixin.py:130, in RNASeqMixin.get_normalized_expression(self, adata, indices, transform_batch, gene_list, library_size, n_samples, n_samples_overall, batch_size, return_mean, return_numpy)
    128 generative_kwargs = self._get_transform_batch_gen_kwargs(batch)
    129 inference_kwargs = dict(n_samples=n_samples)
--> 130 _, generative_outputs = self.module.forward(
    131     tensors=tensors,
    132     inference_kwargs=inference_kwargs,
    133     generative_kwargs=generative_kwargs,
    134     compute_loss=False,
    135 )
    136 output = generative_outputs[generative_output_key]
    137 output = output[..., gene_mask]

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/module/base/_decorators.py:41, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     39 args = _move_data_to_device(args, device)
     40 kwargs = _move_data_to_device(kwargs, device)
---> 41 return fn(self, *args, **kwargs)

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/module/base/_base_module.py:162, in BaseModuleClass.forward(self, tensors, get_inference_input_kwargs, get_generative_input_kwargs, inference_kwargs, generative_kwargs, loss_kwargs, compute_loss)
    127 @auto_move_data
    128 def forward(
    129     self,
   (...)
    139     Tuple[torch.Tensor, torch.Tensor, LossRecorder],
    140 ]:
    141     """
    142     Forward pass through the network.
    143 
   (...)
    160         another return value.
    161     """
--> 162     return _generic_forward(
    163         self,
    164         tensors,
    165         inference_kwargs,
    166         generative_kwargs,
    167         loss_kwargs,
    168         get_inference_input_kwargs,
    169         get_generative_input_kwargs,
    170         compute_loss,
    171     )

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/module/base/_base_module.py:505, in _generic_forward(module, tensors, inference_kwargs, generative_kwargs, loss_kwargs, get_inference_input_kwargs, get_generative_input_kwargs, compute_loss)
    501 inference_outputs = module.inference(**inference_inputs, **inference_kwargs)
    502 generative_inputs = module._get_generative_input(
    503     tensors, inference_outputs, **get_generative_input_kwargs
    504 )
--> 505 generative_outputs = module.generative(**generative_inputs, **generative_kwargs)
    506 if compute_loss:
    507     losses = module.loss(
    508         tensors, inference_outputs, generative_outputs, **loss_kwargs
    509     )

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/module/base/_decorators.py:41, in auto_move_data.<locals>.auto_transfer_args(self, *args, **kwargs)
     39 args = _move_data_to_device(args, device)
     40 kwargs = _move_data_to_device(kwargs, device)
---> 41 return fn(self, *args, **kwargs)

File /exports/humgen/cnovellarausell/conda_envs/scvi-tools/lib/python3.10/site-packages/scvi/module/_vae.py:332, in VAE.generative(self, z, library, batch_index, cont_covs, cat_covs, size_factor, y, transform_batch)
    330 """Runs the generative model."""
    331 # TODO: refactor forward function to not rely on y
--> 332 decoder_input = z if cont_covs is None else torch.cat([z, cont_covs], dim=-1)
    333 if cat_covs is not None:
    334     categorical_input = torch.split(cat_covs, 1, dim=1)

RuntimeError: Tensors must have same number of dimensions: got 3 and 2

Versions:

0.15.2

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging a pull request may close this issue.

3 participants