diff --git a/scvi/data/anndata/_compat.py b/scvi/data/anndata/_compat.py index 8c5966c20e..1656252243 100644 --- a/scvi/data/anndata/_compat.py +++ b/scvi/data/anndata/_compat.py @@ -125,7 +125,7 @@ def manager_from_setup_dict( Keyword arguments to modify transfer behavior. """ fields = [] - setup_kwargs = dict() + setup_args = dict() data_registry = setup_dict[_constants._DATA_REGISTRY_KEY] categorical_mappings = setup_dict["categorical_mappings"] for registry_key, adata_mapping in data_registry.items(): @@ -138,15 +138,15 @@ def manager_from_setup_dict( attr_key = adata_mapping[_constants._DR_ATTR_KEY] if attr_name == _constants._ADATA_ATTRS.X: field = LayerField(REGISTRY_KEYS.X_KEY, None) - setup_kwargs["layer"] = None + setup_args["layer"] = None elif attr_name == _constants._ADATA_ATTRS.LAYERS: field = LayerField(REGISTRY_KEYS.X_KEY, attr_key) - setup_kwargs["layer"] = attr_key + setup_args["layer"] = attr_key elif attr_name == _constants._ADATA_ATTRS.OBS: if new_registry_key in {REGISTRY_KEYS.BATCH_KEY, REGISTRY_KEYS.LABELS_KEY}: original_key = categorical_mappings[attr_key]["original_key"] field = CategoricalObsField(new_registry_key, original_key) - setup_kwargs[f"{new_registry_key}_key"] = original_key + setup_args[f"{new_registry_key}_key"] = original_key elif new_registry_key == REGISTRY_KEYS.INDICES_KEY: adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64") field = NumericalObsField(new_registry_key, attr_key) @@ -154,11 +154,11 @@ def manager_from_setup_dict( if new_registry_key == REGISTRY_KEYS.CONT_COVS_KEY: obs_keys = setup_dict["extra_continuous_keys"] field = NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, obs_keys) - setup_kwargs["continuous_covariate_keys"] = obs_keys + setup_args["continuous_covariate_keys"] = obs_keys elif new_registry_key == REGISTRY_KEYS.CAT_COVS_KEY: obs_keys = setup_dict["extra_categoricals"]["keys"] field = CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, obs_keys) - setup_kwargs["categorical_covariate_keys"] = obs_keys + setup_args["categorical_covariate_keys"] = obs_keys elif new_registry_key == REGISTRY_KEYS.PROTEIN_EXP_KEY: protein_names = setup_dict["protein_names"] adata.uns["_protein_names"] = protein_names @@ -169,8 +169,8 @@ def manager_from_setup_dict( batch_key="_scvi_batch", colnames_uns_key="_protein_names", ) - setup_kwargs["protein_expression_obsm_key"] = attr_key - setup_kwargs["protein_names_uns_key"] = "_protein_names" + setup_args["protein_expression_obsm_key"] = attr_key + setup_args["protein_names_uns_key"] = "_protein_names" else: raise NotImplementedError( f"Unrecognized .obsm attribute {attr_key} registered as {new_registry_key}. Backwards compatibility unavailable." @@ -183,7 +183,7 @@ def manager_from_setup_dict( setup_method_args = { _constants._MODEL_NAME_KEY: cls.__name__, - _constants._SETUP_KWARGS_KEY: setup_kwargs, + _constants._SETUP_ARGS_KEY: setup_args, } adata_manager = AnnDataManager(fields=fields, setup_method_args=setup_method_args) diff --git a/scvi/data/anndata/_constants.py b/scvi/data/anndata/_constants.py index fc677e4b87..14f191b3ff 100644 --- a/scvi/data/anndata/_constants.py +++ b/scvi/data/anndata/_constants.py @@ -13,7 +13,7 @@ _SCVI_VERSION_KEY = "scvi_version" _MODEL_NAME_KEY = "model_name" -_SETUP_KWARGS_KEY = "setup_kwargs" +_SETUP_ARGS_KEY = "setup_args" _FIELD_REGISTRIES_KEY = "field_registries" _DATA_REGISTRY_KEY = "data_registry" _STATE_REGISTRY_KEY = "state_registry" diff --git a/scvi/data/anndata/_manager.py b/scvi/data/anndata/_manager.py index 23fde10834..2653a4bf58 100644 --- a/scvi/data/anndata/_manager.py +++ b/scvi/data/anndata/_manager.py @@ -58,7 +58,7 @@ def __init__( self._registry = { _constants._SCVI_VERSION_KEY: scvi.__version__, _constants._MODEL_NAME_KEY: None, - _constants._SETUP_KWARGS_KEY: None, + _constants._SETUP_ARGS_KEY: None, _constants._FIELD_REGISTRIES_KEY: defaultdict(dict), } if setup_method_args is not None: @@ -88,7 +88,7 @@ def _get_setup_method_args(self) -> dict: return { k: v for k, v in self._registry.items() - if k in {_constants._MODEL_NAME_KEY, _constants._SETUP_KWARGS_KEY} + if k in {_constants._MODEL_NAME_KEY, _constants._SETUP_ARGS_KEY} } def _assign_uuid(self): @@ -328,10 +328,36 @@ def _view_data_registry(self) -> rich.table.Table: return t + @staticmethod + def view_setup_method_args(registry: dict) -> None: + """ + Prints setup kwargs used to produce a given registry. + + Parameters + ---------- + registry + Registry produced by an AnnDataManager. + """ + model_name = registry[_constants._MODEL_NAME_KEY] + setup_args = registry[_constants._SETUP_ARGS_KEY] + if model_name is not None and setup_args is not None: + rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:") + rich.pretty.pprint(setup_args) + rich.print() + def view_registry(self, hide_state_registries: bool = False) -> None: - """Prints summary of the registry.""" + """ + Prints summary of the registry. + + Parameters + ---------- + hide_state_registries + If True, prints a shortened summary without details of each state registry. + """ version = self._registry[_constants._SCVI_VERSION_KEY] - rich.print("Anndata setup with scvi-tools version {}.".format(version)) + rich.print(f"Anndata setup with scvi-tools version {version}.") + rich.print() + self.view_setup_method_args(self._registry) in_colab = "google.colab" in sys.modules force_jupyter = None if not in_colab else True diff --git a/scvi/external/gimvi/_model.py b/scvi/external/gimvi/_model.py index 328cd56db3..29efd36651 100644 --- a/scvi/external/gimvi/_model.py +++ b/scvi/external/gimvi/_model.py @@ -12,7 +12,7 @@ from scvi import REGISTRY_KEYS from scvi.data.anndata import AnnDataManager from scvi.data.anndata._compat import manager_from_setup_dict -from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_KWARGS_KEY +from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY from scvi.data.anndata.fields import CategoricalObsField, LayerField from scvi.dataloaders import DataSplitter from scvi.model._utils import _init_library_size, parse_use_gpu_arg @@ -525,14 +525,14 @@ def load( "It appears you are loading a model from a different class." ) - if _SETUP_KWARGS_KEY not in registry: + if _SETUP_ARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup." ) cls.setup_anndata( - adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY] + adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] ) # get the parameters for the class init signiture diff --git a/scvi/model/_scanvi.py b/scvi/model/_scanvi.py index d052375efb..534ce32d21 100644 --- a/scvi/model/_scanvi.py +++ b/scvi/model/_scanvi.py @@ -10,7 +10,7 @@ from scvi import REGISTRY_KEYS from scvi._compat import Literal from scvi.data.anndata import AnnDataManager -from scvi.data.anndata._constants import _SETUP_KWARGS_KEY +from scvi.data.anndata._constants import _SETUP_ARGS_KEY from scvi.data.anndata.fields import ( CategoricalJointObsField, CategoricalObsField, @@ -207,9 +207,9 @@ def from_scvi_model( if adata is None: adata = scvi_model.adata - scvi_setup_kwargs = scvi_model.adata_manager.registry[_SETUP_KWARGS_KEY] + scvi_setup_args = scvi_model.adata_manager.registry[_SETUP_ARGS_KEY] cls.setup_anndata( - adata, unlabeled_category=unlabeled_category, **scvi_setup_kwargs + adata, unlabeled_category=unlabeled_category, **scvi_setup_args ) scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs) scvi_state_dict = scvi_model.module.state_dict() diff --git a/scvi/model/base/_archesmixin.py b/scvi/model/base/_archesmixin.py index c47c545b9f..6611fa8d66 100644 --- a/scvi/model/base/_archesmixin.py +++ b/scvi/model/base/_archesmixin.py @@ -8,7 +8,7 @@ from scvi.data.anndata import _constants from scvi.data.anndata._compat import manager_from_setup_dict -from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_KWARGS_KEY +from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY from scvi.model._utils import parse_use_gpu_arg from scvi.nn import FCLayers @@ -102,7 +102,7 @@ def load_query_data( "It appears you are loading a model from a different class." ) - if _SETUP_KWARGS_KEY not in registry: + if _SETUP_ARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup." @@ -112,7 +112,7 @@ def load_query_data( adata, source_registry=registry, extend_categories=True, - **registry[_SETUP_KWARGS_KEY] + **registry[_SETUP_ARGS_KEY] ) model = _initialize_model(cls, adata, attr_dict) diff --git a/scvi/model/base/_base_model.py b/scvi/model/base/_base_model.py index d107e4b8e1..1a72481436 100644 --- a/scvi/model/base/_base_model.py +++ b/scvi/model/base/_base_model.py @@ -18,7 +18,7 @@ from scvi.data.anndata._constants import ( _MODEL_NAME_KEY, _SCVI_UUID_KEY, - _SETUP_KWARGS_KEY, + _SETUP_ARGS_KEY, ) from scvi.data.anndata._utils import _assign_adata_uuid from scvi.dataloaders import AnnDataLoader @@ -134,14 +134,13 @@ def _get_setup_method_args(**setup_locals) -> dict: Must be called with ``**locals()`` at the start of the ``setup_anndata`` method to avoid the inclusion of any extraneous variables. """ - setup_locals.pop("adata") cls = setup_locals.pop("cls") model_name = cls.__name__ - setup_kwargs = dict() + setup_args = dict() for k, v in setup_locals.items(): if k not in _SETUP_INPUTS_EXCLUDED_PARAMS: - setup_kwargs[k] = v - return {_MODEL_NAME_KEY: model_name, _SETUP_KWARGS_KEY: setup_kwargs} + setup_args[k] = v + return {_MODEL_NAME_KEY: model_name, _SETUP_ARGS_KEY: setup_args} @classmethod def register_manager(cls, adata_manager: AnnDataManager): @@ -590,7 +589,7 @@ def load( "It appears you are loading a model from a different class." ) - if _SETUP_KWARGS_KEY not in registry: + if _SETUP_ARGS_KEY not in registry: raise ValueError( "Saved model does not contain original setup inputs. " "Cannot load the original setup." @@ -600,7 +599,7 @@ def load( # the saved model. This enables simple backwards compatibility in the case of # newly introduced fields or parameters. cls.setup_anndata( - adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY] + adata, source_registry=registry, **registry[_SETUP_ARGS_KEY] ) model = _initialize_model(cls, adata, attr_dict) @@ -653,6 +652,30 @@ def setup_anndata( on a model-specific instance of :class:`~scvi.data.anndata.AnnDataManager`. """ + @staticmethod + def view_setup_args(dir_path: str, prefix: Optional[str] = None) -> None: + """ + Print args used to setup a saved model. + + Parameters + ---------- + dir_path + Path to saved outputs. + prefix + Prefix of saved file names. + """ + attr_dict = _load_saved_files(dir_path, False, prefix=prefix)[0] + + # Legacy support for old setup dict format. + if "scvi_setup_dict_" in attr_dict: + raise NotImplementedError( + "Viewing setup args for pre v0.15.0 models is unsupported. " + "Load and resave the model to use this function." + ) + + registry = attr_dict.pop("registry_") + AnnDataManager.view_setup_method_args(registry) + def view_anndata_setup( self, adata: Optional[AnnData] = None, hide_state_registries: bool = False ) -> None: @@ -664,6 +687,8 @@ def view_anndata_setup( adata AnnData object setup with ``setup_anndata`` or :meth:`~scvi.data.anndata.AnnDataManager.transfer_setup`. + hide_state_registries + If True, prints a shortened summary without details of each state registry. """ if adata is None: adata = self.adata diff --git a/tests/dataset/utils.py b/tests/dataset/utils.py index 14189d6820..2be33fafe1 100644 --- a/tests/dataset/utils.py +++ b/tests/dataset/utils.py @@ -5,6 +5,7 @@ from scvi import REGISTRY_KEYS from scvi.data.anndata import AnnDataManager +from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY from scvi.data.anndata.fields import ( CategoricalJointObsField, CategoricalObsField, @@ -39,6 +40,10 @@ def generic_setup_adata_manager( protein_expression_obsm_key: Optional[str] = None, protein_names_uns_key: Optional[str] = None, ) -> AnnDataManager: + setup_args = locals() + setup_args.pop("adata") + setup_method_args = {_MODEL_NAME_KEY: "TestModel", _SETUP_ARGS_KEY: setup_args} + batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key) anndata_fields = [ batch_field, @@ -60,6 +65,8 @@ def generic_setup_adata_manager( is_count_data=True, ) ) - adata_manager = AnnDataManager(fields=anndata_fields) + adata_manager = AnnDataManager( + fields=anndata_fields, setup_method_args=setup_method_args + ) adata_manager.register_fields(adata) return adata_manager diff --git a/tests/models/test_models.py b/tests/models/test_models.py index a07cc89ed4..17dfb1ff82 100644 --- a/tests/models/test_models.py +++ b/tests/models/test_models.py @@ -372,6 +372,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False): ) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) + model.view_setup_args(save_path, prefix=prefix) model = cls.load(save_path, prefix=prefix) model.get_latent_representation() @@ -428,6 +429,7 @@ def test_save_load_autozi(legacy=False): ) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) + model.view_setup_args(save_path, prefix=prefix) model = AUTOZI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200) @@ -463,6 +465,7 @@ def test_save_load_scanvi(legacy=False): ) else: model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix) + model.view_setup_args(save_path, prefix=prefix) model = SCANVI.load(save_path, prefix=prefix) model.get_latent_representation() tmp_adata = scvi.data.synthetic_iid(n_genes=200)