From f63bdb00da627389d1dc8cd8d4c61bc3446dacd9 Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 24 Mar 2022 11:28:04 -0400 Subject: [PATCH] SCANVI bug fixes (#1458) * address issue 1449 * fix scanvi covariate handling * fix tests * raise error on extra cat covs * release note --- docs/release_notes/v0.15.3.md | 18 ++++++++++ scvi/data/fields/_scanvi.py | 23 +++++++++--- scvi/model/_scanvi.py | 11 +++++- scvi/model/base/_archesmixin.py | 8 +++++ scvi/module/_scanvae.py | 28 +++++++++++++-- scvi/module/_vae.py | 4 +-- tests/models/test_models.py | 40 +++++++++++++++++++++ tests/models/test_scarches.py | 63 ++++++++++++++++++++++++++++++++- 8 files changed, 184 insertions(+), 11 deletions(-) create mode 100644 docs/release_notes/v0.15.3.md diff --git a/docs/release_notes/v0.15.3.md b/docs/release_notes/v0.15.3.md new file mode 100644 index 0000000000..7b952f261b --- /dev/null +++ b/docs/release_notes/v0.15.3.md @@ -0,0 +1,18 @@ +# New in 0.15.3 (2022-MM-DD) + +## Changes + +## Bug fixes + +- Raise `NotImplementedError` when `categorical_covariate_keys` are used with {meth}`scvi.model.SCANVI.load_query_data`. ([#1458]). +- Fix behavior when `continuous_covariate_keys` are used with {meth}`scvi.model.SCANVI.classify`. ([#1458]). +- Unlabeled category values are automatically populated when {meth}`scvi.model.SCANVI.load_query_data` run on `adata_target` missing labels column. ([#1458]). + +## Contributors + +- [@jjhong922] +- [@adamgayoso] + +[#1458]: https://github.com/YosefLab/scvi-tools/pull/1458 +[@adamgayoso]: https://github.com/adamgayoso +[@jjhong922]: https://github.com/jjhong922 diff --git a/scvi/data/fields/_scanvi.py b/scvi/data/fields/_scanvi.py index 34651fb2f5..101f6ae626 100644 --- a/scvi/data/fields/_scanvi.py +++ b/scvi/data/fields/_scanvi.py @@ -1,10 +1,11 @@ +import warnings from typing import Optional, Union import numpy as np from anndata import AnnData from pandas.api.types import CategoricalDtype -from scvi.data._utils import _make_column_categorical +from scvi.data._utils import _make_column_categorical, _set_data_in_registry from ._obs_field import CategoricalObsField @@ -68,9 +69,6 @@ def _remap_unlabeled_to_final_category( } def register_field(self, adata: AnnData) -> dict: - if self.is_default: - self._setup_default_attr(adata) - state_registry = super().register_field(adata) mapping = state_registry[self.CATEGORICAL_MAPPING_KEY] return self._remap_unlabeled_to_final_category(adata, mapping) @@ -80,8 +78,25 @@ def transfer_field( state_registry: dict, adata_target: AnnData, extend_categories: bool = False, + allow_missing_labels: bool = False, **kwargs, ) -> dict: + if ( + allow_missing_labels + and self.attr_key is not None + and self.attr_key not in adata_target.obs + ): + # Fill in original .obs attribute with unlabeled_category values. + warnings.warn( + f"Missing labels key {self.attr_key}. Filling in with unlabeled category {self._unlabeled_category}." + ) + _set_data_in_registry( + adata_target, + self._unlabeled_category, + self.attr_name, + self._original_attr_key, + ) + transfer_state_registry = super().transfer_field( state_registry, adata_target, extend_categories=extend_categories, **kwargs ) diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index b22401fcb0..c1c24028d6 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -291,7 +291,16 @@ def predict( for _, tensors in enumerate(scdl): x = tensors[REGISTRY_KEYS.X_KEY] batch = tensors[REGISTRY_KEYS.BATCH_KEY] - pred = self.module.classify(x, batch) + + cont_key = REGISTRY_KEYS.CONT_COVS_KEY + cont_covs = tensors[cont_key] if cont_key in tensors.keys() else None + + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = tensors[cat_key] if cat_key in tensors.keys() else None + + pred = self.module.classify( + x, batch_index=batch, cat_covs=cat_covs, cont_covs=cont_covs + ) if not soft: pred = pred.argmax(dim=1) y_pred.append(pred.detach().cpu()) diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index 92446f45af..f6f98f610f 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -10,6 +10,7 @@ from anndata import AnnData from scipy.sparse import csr_matrix +from scvi import REGISTRY_KEYS from scvi.data import _constants from scvi.data._compat import manager_from_setup_dict from scvi.data._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY @@ -97,6 +98,7 @@ def load_query_data( adata, scvi_setup_dict, extend_categories=True, + allow_missing_labels=True, unlabeled_category=unlabeled_category, ) ) @@ -120,12 +122,18 @@ def load_query_data( adata, source_registry=registry, extend_categories=True, + allow_missing_labels=True, **registry[_SETUP_ARGS_KEY], ) model = _initialize_model(cls, adata, attr_dict) adata_manager = model.get_anndata_manager(adata, required=True) + if REGISTRY_KEYS.CAT_COVS_KEY in adata_manager.data_registry: + raise NotImplementedError( + "scArches currently does not support models with extra categorical covariates." + ) + version_split = adata_manager.registry[_constants._SCVI_VERSION_KEY].split(".") if int(version_split[1]) < 8 and int(version_split[0]) == 0: warnings.warn( diff --git a/scvi/module/_scanvae.py b/scvi/module/_scanvae.py index 0b515fc8b9..6701a6d123 100644 --- a/scvi/module/_scanvae.py +++ b/scvi/module/_scanvae.py @@ -185,10 +185,20 @@ def __init__( ) @auto_move_data - def classify(self, x, batch_index=None): + def classify(self, x, batch_index=None, cont_covs=None, cat_covs=None): if self.log_variational: x = torch.log(1 + x) - qz_m, _, z = self.z_encoder(x, batch_index) + + if cont_covs is not None and self.encode_covariates: + encoder_input = torch.cat((x, cont_covs), dim=-1) + else: + encoder_input = x + if cat_covs is not None and self.encode_covariates: + categorical_input = torch.split(cat_covs, 1, dim=1) + else: + categorical_input = tuple() + + qz_m, _, z = self.z_encoder(encoder_input, batch_index, *categorical_input) # We classify using the inferred mean parameter of z_1 in the latent space z = qz_m if self.use_labels_groups: @@ -210,8 +220,20 @@ def classification_loss(self, labelled_dataset): x = labelled_dataset[REGISTRY_KEYS.X_KEY] y = labelled_dataset[REGISTRY_KEYS.LABELS_KEY] batch_idx = labelled_dataset[REGISTRY_KEYS.BATCH_KEY] + cont_key = REGISTRY_KEYS.CONT_COVS_KEY + cont_covs = ( + labelled_dataset[cont_key] if cont_key in labelled_dataset.keys() else None + ) + + cat_key = REGISTRY_KEYS.CAT_COVS_KEY + cat_covs = ( + labelled_dataset[cat_key] if cat_key in labelled_dataset.keys() else None + ) classification_loss = F.cross_entropy( - self.classify(x, batch_idx), y.view(-1).long() + self.classify( + x, batch_index=batch_idx, cat_covs=cat_covs, cont_covs=cont_covs + ), + y.view(-1).long(), ) return classification_loss diff --git a/scvi/module/_vae.py b/scvi/module/_vae.py index 6ce8f83da0..4946f89482 100644 --- a/scvi/module/_vae.py +++ b/scvi/module/_vae.py @@ -280,11 +280,11 @@ def inference(self, x, batch_index, cont_covs=None, cat_covs=None, n_samples=1): if self.log_variational: x_ = torch.log(1 + x_) - if cont_covs is not None and self.encode_covariates is True: + if cont_covs is not None and self.encode_covariates: encoder_input = torch.cat((x_, cont_covs), dim=-1) else: encoder_input = x_ - if cat_covs is not None and self.encode_covariates is True: + if cat_covs is not None and self.encode_covariates: categorical_input = torch.split(cat_covs, 1, dim=1) else: categorical_input = tuple() diff --git a/tests/models/test_models.py b/tests/models/test_models.py index 52e541afed..83c5f1ff7c 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -1178,6 +1178,46 @@ def test_multiple_covariates_scvi(save_path): m.train(1) +def test_multiple_encoded_covariates_scvi(save_path): + adata = synthetic_iid() + adata.obs["cont1"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cont2"] = np.random.normal(size=(adata.shape[0],)) + adata.obs["cat1"] = np.random.randint(0, 5, size=(adata.shape[0],)) + adata.obs["cat2"] = np.random.randint(0, 5, size=(adata.shape[0],)) + + SCVI.setup_anndata( + adata, + batch_key="batch", + labels_key="labels", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + m = SCVI(adata, encode_covariates=True) + m.train(1) + + SCANVI.setup_anndata( + adata, + "labels", + "Unknown", + batch_key="batch", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + m = SCANVI(adata, encode_covariates=True) + m.train(1) + + TOTALVI.setup_anndata( + adata, + batch_key="batch", + protein_expression_obsm_key="protein_expression", + protein_names_uns_key="protein_names", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + m = TOTALVI(adata, encode_covariates=True) + m.train(1) + + def test_peakvi(): data = synthetic_iid() PEAKVI.setup_anndata( diff --git a/tests/models/test_scarches.py b/tests/models/test_scarches.py index 30a941b0f9..d3e5cffd1c 100644 --- a/tests/models/test_scarches.py +++ b/tests/models/test_scarches.py @@ -202,7 +202,16 @@ def test_scanvi_online_update(save_path): new_labels = adata1.obs.labels.to_numpy() new_labels[0] = "Unknown" adata1.obs["labels"] = pd.Categorical(new_labels) - SCANVI.setup_anndata(adata1, "labels", "Unknown", batch_key="batch") + adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],)) + adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],)) + + SCANVI.setup_anndata( + adata1, + "labels", + "Unknown", + batch_key="batch", + continuous_covariate_keys=["cont1", "cont2"], + ) model = SCANVI( adata1, n_latent=n_latent, @@ -216,6 +225,20 @@ def test_scanvi_online_update(save_path): adata2 = synthetic_iid() adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) adata2.obs["labels"] = "Unknown" + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) + + model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) + model.train(max_epochs=1) + model.get_latent_representation() + model.predict() + + # query has all missing labels and no labels key + adata2 = synthetic_iid() + adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + del adata2.obs["labels"] + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model.train(max_epochs=1) @@ -225,12 +248,50 @@ def test_scanvi_online_update(save_path): # query has no missing labels adata2 = synthetic_iid() adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) model = SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) model.train(max_epochs=1) model.get_latent_representation() model.predict() + # Test error on extra categoricals + adata1 = synthetic_iid() + new_labels = adata1.obs.labels.to_numpy() + new_labels[0] = "Unknown" + adata1.obs["labels"] = pd.Categorical(new_labels) + adata1.obs["cont1"] = np.random.normal(size=(adata1.shape[0],)) + adata1.obs["cont2"] = np.random.normal(size=(adata1.shape[0],)) + adata1.obs["cat1"] = np.random.randint(0, 5, size=(adata1.shape[0],)) + adata1.obs["cat2"] = np.random.randint(0, 5, size=(adata1.shape[0],)) + SCANVI.setup_anndata( + adata1, + "labels", + "Unknown", + batch_key="batch", + continuous_covariate_keys=["cont1", "cont2"], + categorical_covariate_keys=["cat1", "cat2"], + ) + model = SCANVI( + adata1, + n_latent=n_latent, + encode_covariates=True, + ) + model.train(max_epochs=1, check_val_every_n_epoch=1) + dir_path = os.path.join(save_path, "saved_model/") + model.save(dir_path, overwrite=True) + + adata2 = synthetic_iid() + adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + adata2.obs["labels"] = "Unknown" + adata2.obs["cont1"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cont2"] = np.random.normal(size=(adata2.shape[0],)) + adata2.obs["cat1"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + adata2.obs["cat2"] = np.random.randint(0, 5, size=(adata2.shape[0],)) + with pytest.raises(NotImplementedError): + SCANVI.load_query_data(adata2, dir_path, freeze_batchnorm_encoder=True) + # ref has fully-observed labels n_latent = 5 adata1 = synthetic_iid()