diff --git a/scvi/data/_compat.py b/scvi/data/_compat.py index 759b7b64e1..014e8f5a5f 100644 --- a/scvi/data/_compat.py +++ b/scvi/data/_compat.py @@ -1,4 +1,5 @@ from copy import deepcopy +from typing import Optional import numpy as np from anndata import AnnData @@ -10,6 +11,7 @@ from .fields import ( CategoricalJointObsField, CategoricalObsField, + LabelsWithUnlabeledObsField, LayerField, NumericalJointObsField, NumericalObsField, @@ -27,7 +29,9 @@ } -def registry_from_setup_dict(setup_dict: dict) -> dict: +def registry_from_setup_dict( + setup_dict: dict, unlabeled_category: Optional[str] = None +) -> dict: """ Converts old setup dict format to new registry dict format. @@ -80,6 +84,10 @@ def registry_from_setup_dict(setup_dict: dict) -> dict: field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_batch"] elif new_registry_key == REGISTRY_KEYS.LABELS_KEY: field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_labels"] + if unlabeled_category is not None: + field_state_registry[ + LabelsWithUnlabeledObsField.UNLABELED_CATEGORY + ] = unlabeled_category elif attr_name == _constants._ADATA_ATTRS.OBSM: if new_registry_key == REGISTRY_KEYS.CONT_COVS_KEY: columns = setup_dict["extra_continuous_keys"].copy() @@ -106,7 +114,11 @@ def registry_from_setup_dict(setup_dict: dict) -> dict: def manager_from_setup_dict( - cls, adata: AnnData, setup_dict: dict, **transfer_kwargs + cls, + adata: AnnData, + setup_dict: dict, + unlabeled_category: Optional[str] = None, + **transfer_kwargs, ) -> AnnDataManager: """ Creates an :class:`~scvi.data.AnnDataManager` given only a scvi-tools setup dictionary. @@ -145,7 +157,15 @@ def manager_from_setup_dict( elif attr_name == _constants._ADATA_ATTRS.OBS: if new_registry_key in {REGISTRY_KEYS.BATCH_KEY, REGISTRY_KEYS.LABELS_KEY}: original_key = categorical_mappings[attr_key]["original_key"] - field = CategoricalObsField(new_registry_key, original_key) + if ( + unlabeled_category is not None + and new_registry_key == REGISTRY_KEYS.LABELS_KEY + ): + field = LabelsWithUnlabeledObsField( + new_registry_key, original_key, unlabeled_category + ) + else: + field = CategoricalObsField(new_registry_key, original_key) setup_args[f"{new_registry_key}_key"] = original_key elif new_registry_key == REGISTRY_KEYS.INDICES_KEY: adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64") @@ -187,7 +207,9 @@ def manager_from_setup_dict( } adata_manager = AnnDataManager(fields=fields, setup_method_args=setup_method_args) - source_registry = registry_from_setup_dict(setup_dict) + source_registry = registry_from_setup_dict( + setup_dict, unlabeled_category=unlabeled_category + ) adata_manager.register_fields( adata, source_registry=source_registry, **transfer_kwargs ) diff --git a/scvi/data/fields/_scanvi.py b/scvi/data/fields/_scanvi.py index 22cb0f0032..34651fb2f5 100644 --- a/scvi/data/fields/_scanvi.py +++ b/scvi/data/fields/_scanvi.py @@ -27,7 +27,6 @@ class LabelsWithUnlabeledObsField(CategoricalObsField): """ UNLABELED_CATEGORY = "unlabeled_category" - WAS_REMAPPED = "was_remapped" def __init__( self, diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index 09af6f9e50..8ba780ed98 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -3,8 +3,12 @@ from copy import deepcopy from typing import Optional, Union +import anndata +import numpy as np +import pandas as pd import torch from anndata import AnnData +from scipy.sparse import csr_matrix from scvi.data import _constants from scvi.data._compat import manager_from_setup_dict @@ -17,6 +21,8 @@ logger = logging.getLogger(__name__) +MIN_VAR_NAME_RATIO = 0.8 + class ArchesMixin: """Universal scArches implementation.""" @@ -44,7 +50,7 @@ def load_query_data( adata AnnData organized in the same way as data used to train model. It is not necessary to run setup_anndata, - as AnnData is validated against the saved `scvi` setup dictionary. + as AnnData is validated against the ``registry``. reference_model Either an already instantiated model of the same class, or a path to saved outputs for reference model. @@ -70,15 +76,10 @@ def load_query_data( Whether to freeze classifier completely. Only applies to `SCANVI`. """ use_gpu, device = parse_use_gpu_arg(use_gpu) - if isinstance(reference_model, str): - attr_dict, var_names, load_state_dict, _ = _load_saved_files( - reference_model, load_adata=False, map_location=device - ) - else: - attr_dict = reference_model._get_user_attributes() - attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} - var_names = reference_model.adata.var_names - load_state_dict = deepcopy(reference_model.module.state_dict()) + + attr_dict, var_names, load_state_dict = _get_loaded_data( + reference_model, device=device + ) if inplace_subset_query_vars: logger.debug("Subsetting query vars to reference vars.") @@ -87,9 +88,16 @@ def load_query_data( if "scvi_setup_dict_" in attr_dict: scvi_setup_dict = attr_dict.pop("scvi_setup_dict_") + # scanvi case + unlabeled_category_key = "unlabeled_category_" + unlabeled_category = attr_dict.get(unlabeled_category_key, None) cls.register_manager( manager_from_setup_dict( - cls, adata, scvi_setup_dict, extend_categories=True + cls, + adata, + scvi_setup_dict, + extend_categories=True, + unlabeled_category=unlabeled_category, ) ) else: @@ -112,7 +120,7 @@ def load_query_data( adata, source_registry=registry, extend_categories=True, - **registry[_SETUP_ARGS_KEY] + **registry[_SETUP_ARGS_KEY], ) model = _initialize_model(cls, adata, attr_dict) @@ -155,6 +163,79 @@ def load_query_data( return model + @staticmethod + def prepare_query_anndata( + adata: AnnData, + reference_model: Union[str, BaseModelClass], + return_reference_var_names: bool = False, + inplace: bool = True, + ) -> Optional[Union[AnnData, pd.Index]]: + """ + Prepare data for query integration. + + This function will return a new AnnData object with padded zeros + for missing features, as well as correctly sorted features. + + Parameters + ---------- + adata + AnnData organized in the same way as data used to train model. + It is not necessary to run setup_anndata, + as AnnData is validated against the ``registry``. + reference_model + Either an already instantiated model of the same class, or a path to + saved outputs for reference model. + return_reference_var_names + Only load and return reference var names if True. + inplace + Whether to subset and rearrange query vars inplace or return new AnnData. + + Returns + ------- + Query adata ready to use in `load_query_data` unless `return_reference_var_names` + in which case a pd.Index of reference var names is returned. + """ + _, var_names, _ = _get_loaded_data(reference_model) + var_names = pd.Index(var_names) + + if return_reference_var_names: + return var_names + + intersection = adata.var_names.intersection(var_names) + inter_len = len(intersection) + if inter_len == 0: + raise ValueError( + "No reference var names found in query data. " + "Please rerun with return_reference_var_names=True " + "to see reference var names." + ) + + ratio = inter_len / len(var_names) + logger.info("Found {}% reference vars in query data.".format(ratio * 100)) + if ratio < MIN_VAR_NAME_RATIO: + warnings.warn( + f"Query data contains less than {MIN_VAR_NAME_RATIO:.0f}% of reference var names. " + "This may result in poor performance." + ) + genes_to_add = var_names.difference(adata.var_names) + adata_padding = AnnData(csr_matrix(np.zeros((adata.n_obs, len(genes_to_add))))) + adata_padding.var_names = genes_to_add + adata_padding.obs_names = adata.obs_names + # Concatenate object + adata_out = anndata.concat( + [adata, adata_padding], + axis=1, + join="outer", + index_unique=None, + merge="unique", + ) + adata_out._inplace_subset_var(var_names) + + if inplace: + adata._init_as_actual(adata_out, dtype=adata._X.dtype) + else: + return adata_out + def _set_params_online_update( module, @@ -229,3 +310,17 @@ def requires_grad(key): par.requires_grad = True else: par.requires_grad = False + + +def _get_loaded_data(reference_model, device=None): + if isinstance(reference_model, str): + attr_dict, var_names, load_state_dict, _ = _load_saved_files( + reference_model, load_adata=False, map_location=device + ) + else: + attr_dict = reference_model._get_user_attributes() + attr_dict = {a[0]: a[1] for a in attr_dict if a[0][-1] == "_"} + var_names = reference_model.adata.var_names + load_state_dict = deepcopy(reference_model.module.state_dict()) + + return attr_dict, var_names, load_state_dict diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index 98f8b5e634..902e7ea1da 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -591,7 +591,16 @@ def load( # Legacy support for old setup dict format. if "scvi_setup_dict_" in attr_dict: scvi_setup_dict = attr_dict.pop("scvi_setup_dict_") - cls.register_manager(manager_from_setup_dict(cls, adata, scvi_setup_dict)) + unlabeled_category_key = "unlabeled_category_" + unlabeled_category = attr_dict.get(unlabeled_category_key, None) + cls.register_manager( + manager_from_setup_dict( + cls, + adata, + scvi_setup_dict, + unlabeled_category=unlabeled_category, + ) + ) else: registry = attr_dict.pop("registry_") if ( diff --git a/scvi/model/base/_utils.py b/scvi/model/base/_utils.py index 2d29a11f8f..d30c5da149 100644 --- a/scvi/model/base/_utils.py +++ b/scvi/model/base/_utils.py @@ -114,6 +114,12 @@ def _initialize_model(cls, adata, attr_dict): kwargs = {k: v for (i, j) in kwargs.items() for (k, v) in j.items()} non_kwargs.pop("use_cuda") + # backwards compat for scANVI + if "unlabeled_category" in non_kwargs.keys(): + non_kwargs.pop("unlabeled_category") + if "pretrained_model" in non_kwargs.keys(): + non_kwargs.pop("pretrained_model") + model = cls(adata, **non_kwargs, **kwargs) for attr, val in attr_dict.items(): setattr(model, attr, val) diff --git a/tests/models/test_scarches.py b/tests/models/test_scarches.py index 0880f43e88..30a941b0f9 100644 --- a/tests/models/test_scarches.py +++ b/tests/models/test_scarches.py @@ -15,6 +15,42 @@ def single_pass_for_online_update(model): scvi_loss.loss.backward() +def test_data_prep(save_path): + n_latent = 5 + adata1 = synthetic_iid() + SCVI.setup_anndata(adata1, batch_key="batch", labels_key="labels") + model = SCVI(adata1, n_latent=n_latent) + model.train(1, check_val_every_n_epoch=1) + dir_path = os.path.join(save_path, "saved_model/") + model.save(dir_path, overwrite=True) + + # adata2 has more genes and a perfect subset of adata1 + adata2 = synthetic_iid(n_genes=110) + adata2.obs["batch"] = adata2.obs.batch.cat.rename_categories(["batch_2", "batch_3"]) + SCVI.prepare_query_anndata(adata2, dir_path) + SCVI.load_query_data(adata2, dir_path) + + adata3 = SCVI.prepare_query_anndata(adata2, dir_path, inplace=False) + SCVI.load_query_data(adata3, dir_path) + + # adata4 has more genes and missing 10 genes from adata1 + adata4 = synthetic_iid(n_genes=110) + new_var_names_init = [f"Random {i}" for i in range(10)] + new_var_names = new_var_names_init + adata4.var_names[10:].to_list() + adata4.var_names = new_var_names + + SCVI.prepare_query_anndata(adata4, dir_path) + # should be padded 0s + assert np.sum(adata4[:, adata4.var_names[:10]].X) == 0 + np.testing.assert_equal( + adata4.var_names[:10].to_numpy(), adata1.var_names[:10].to_numpy() + ) + SCVI.load_query_data(adata4, dir_path) + + adata5 = SCVI.prepare_query_anndata(adata4, dir_path, inplace=False) + SCVI.load_query_data(adata5, dir_path) + + def test_scvi_online_update(save_path): n_latent = 5 adata1 = synthetic_iid()