Skip to content

Commit

Permalink
view_setup_args for saved model paths (#1355)
Browse files Browse the repository at this point in the history
* rename setup_kwargs constant and add view_setup_method_args fcn to manager

* add view_setup_args method for basemodel

* address comment
  • Loading branch information
justjhong authored Feb 15, 2022
1 parent d77d5c8 commit e20c69b
Show file tree
Hide file tree
Showing 9 changed files with 92 additions and 31 deletions.
18 changes: 9 additions & 9 deletions scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def manager_from_setup_dict(
Keyword arguments to modify transfer behavior.
"""
fields = []
setup_kwargs = dict()
setup_args = dict()
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict["categorical_mappings"]
for registry_key, adata_mapping in data_registry.items():
Expand All @@ -138,27 +138,27 @@ def manager_from_setup_dict(
attr_key = adata_mapping[_constants._DR_ATTR_KEY]
if attr_name == _constants._ADATA_ATTRS.X:
field = LayerField(REGISTRY_KEYS.X_KEY, None)
setup_kwargs["layer"] = None
setup_args["layer"] = None
elif attr_name == _constants._ADATA_ATTRS.LAYERS:
field = LayerField(REGISTRY_KEYS.X_KEY, attr_key)
setup_kwargs["layer"] = attr_key
setup_args["layer"] = attr_key
elif attr_name == _constants._ADATA_ATTRS.OBS:
if new_registry_key in {REGISTRY_KEYS.BATCH_KEY, REGISTRY_KEYS.LABELS_KEY}:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(new_registry_key, original_key)
setup_kwargs[f"{new_registry_key}_key"] = original_key
setup_args[f"{new_registry_key}_key"] = original_key
elif new_registry_key == REGISTRY_KEYS.INDICES_KEY:
adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64")
field = NumericalObsField(new_registry_key, attr_key)
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if new_registry_key == REGISTRY_KEYS.CONT_COVS_KEY:
obs_keys = setup_dict["extra_continuous_keys"]
field = NumericalJointObsField(REGISTRY_KEYS.CONT_COVS_KEY, obs_keys)
setup_kwargs["continuous_covariate_keys"] = obs_keys
setup_args["continuous_covariate_keys"] = obs_keys
elif new_registry_key == REGISTRY_KEYS.CAT_COVS_KEY:
obs_keys = setup_dict["extra_categoricals"]["keys"]
field = CategoricalJointObsField(REGISTRY_KEYS.CAT_COVS_KEY, obs_keys)
setup_kwargs["categorical_covariate_keys"] = obs_keys
setup_args["categorical_covariate_keys"] = obs_keys
elif new_registry_key == REGISTRY_KEYS.PROTEIN_EXP_KEY:
protein_names = setup_dict["protein_names"]
adata.uns["_protein_names"] = protein_names
Expand All @@ -169,8 +169,8 @@ def manager_from_setup_dict(
batch_key="_scvi_batch",
colnames_uns_key="_protein_names",
)
setup_kwargs["protein_expression_obsm_key"] = attr_key
setup_kwargs["protein_names_uns_key"] = "_protein_names"
setup_args["protein_expression_obsm_key"] = attr_key
setup_args["protein_names_uns_key"] = "_protein_names"
else:
raise NotImplementedError(
f"Unrecognized .obsm attribute {attr_key} registered as {new_registry_key}. Backwards compatibility unavailable."
Expand All @@ -183,7 +183,7 @@ def manager_from_setup_dict(

setup_method_args = {
_constants._MODEL_NAME_KEY: cls.__name__,
_constants._SETUP_KWARGS_KEY: setup_kwargs,
_constants._SETUP_ARGS_KEY: setup_args,
}
adata_manager = AnnDataManager(fields=fields, setup_method_args=setup_method_args)

Expand Down
2 changes: 1 addition & 1 deletion scvi/data/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

_SCVI_VERSION_KEY = "scvi_version"
_MODEL_NAME_KEY = "model_name"
_SETUP_KWARGS_KEY = "setup_kwargs"
_SETUP_ARGS_KEY = "setup_args"
_FIELD_REGISTRIES_KEY = "field_registries"
_DATA_REGISTRY_KEY = "data_registry"
_STATE_REGISTRY_KEY = "state_registry"
Expand Down
34 changes: 30 additions & 4 deletions scvi/data/anndata/_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __init__(
self._registry = {
_constants._SCVI_VERSION_KEY: scvi.__version__,
_constants._MODEL_NAME_KEY: None,
_constants._SETUP_KWARGS_KEY: None,
_constants._SETUP_ARGS_KEY: None,
_constants._FIELD_REGISTRIES_KEY: defaultdict(dict),
}
if setup_method_args is not None:
Expand Down Expand Up @@ -88,7 +88,7 @@ def _get_setup_method_args(self) -> dict:
return {
k: v
for k, v in self._registry.items()
if k in {_constants._MODEL_NAME_KEY, _constants._SETUP_KWARGS_KEY}
if k in {_constants._MODEL_NAME_KEY, _constants._SETUP_ARGS_KEY}
}

def _assign_uuid(self):
Expand Down Expand Up @@ -328,10 +328,36 @@ def _view_data_registry(self) -> rich.table.Table:

return t

@staticmethod
def view_setup_method_args(registry: dict) -> None:
"""
Prints setup kwargs used to produce a given registry.
Parameters
----------
registry
Registry produced by an AnnDataManager.
"""
model_name = registry[_constants._MODEL_NAME_KEY]
setup_args = registry[_constants._SETUP_ARGS_KEY]
if model_name is not None and setup_args is not None:
rich.print(f"Setup via `{model_name}.setup_anndata` with arguments:")
rich.pretty.pprint(setup_args)
rich.print()

def view_registry(self, hide_state_registries: bool = False) -> None:
"""Prints summary of the registry."""
"""
Prints summary of the registry.
Parameters
----------
hide_state_registries
If True, prints a shortened summary without details of each state registry.
"""
version = self._registry[_constants._SCVI_VERSION_KEY]
rich.print("Anndata setup with scvi-tools version {}.".format(version))
rich.print(f"Anndata setup with scvi-tools version {version}.")
rich.print()
self.view_setup_method_args(self._registry)

in_colab = "google.colab" in sys.modules
force_jupyter = None if not in_colab else True
Expand Down
6 changes: 3 additions & 3 deletions scvi/external/gimvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from scvi import REGISTRY_KEYS
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata._compat import manager_from_setup_dict
from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_KWARGS_KEY
from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
from scvi.data.anndata.fields import CategoricalObsField, LayerField
from scvi.dataloaders import DataSplitter
from scvi.model._utils import _init_library_size, parse_use_gpu_arg
Expand Down Expand Up @@ -525,14 +525,14 @@ def load(
"It appears you are loading a model from a different class."
)

if _SETUP_KWARGS_KEY not in registry:
if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
)

cls.setup_anndata(
adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY]
adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]
)

# get the parameters for the class init signiture
Expand Down
6 changes: 3 additions & 3 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from scvi import REGISTRY_KEYS
from scvi._compat import Literal
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata._constants import _SETUP_KWARGS_KEY
from scvi.data.anndata._constants import _SETUP_ARGS_KEY
from scvi.data.anndata.fields import (
CategoricalJointObsField,
CategoricalObsField,
Expand Down Expand Up @@ -207,9 +207,9 @@ def from_scvi_model(
if adata is None:
adata = scvi_model.adata

scvi_setup_kwargs = scvi_model.adata_manager.registry[_SETUP_KWARGS_KEY]
scvi_setup_args = scvi_model.adata_manager.registry[_SETUP_ARGS_KEY]
cls.setup_anndata(
adata, unlabeled_category=unlabeled_category, **scvi_setup_kwargs
adata, unlabeled_category=unlabeled_category, **scvi_setup_args
)
scanvi_model = cls(adata, **non_kwargs, **kwargs, **scanvi_kwargs)
scvi_state_dict = scvi_model.module.state_dict()
Expand Down
6 changes: 3 additions & 3 deletions scvi/model/base/_archesmixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from scvi.data.anndata import _constants
from scvi.data.anndata._compat import manager_from_setup_dict
from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_KWARGS_KEY
from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
from scvi.model._utils import parse_use_gpu_arg
from scvi.nn import FCLayers

Expand Down Expand Up @@ -102,7 +102,7 @@ def load_query_data(
"It appears you are loading a model from a different class."
)

if _SETUP_KWARGS_KEY not in registry:
if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
Expand All @@ -112,7 +112,7 @@ def load_query_data(
adata,
source_registry=registry,
extend_categories=True,
**registry[_SETUP_KWARGS_KEY]
**registry[_SETUP_ARGS_KEY]
)

model = _initialize_model(cls, adata, attr_dict)
Expand Down
39 changes: 32 additions & 7 deletions scvi/model/base/_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from scvi.data.anndata._constants import (
_MODEL_NAME_KEY,
_SCVI_UUID_KEY,
_SETUP_KWARGS_KEY,
_SETUP_ARGS_KEY,
)
from scvi.data.anndata._utils import _assign_adata_uuid
from scvi.dataloaders import AnnDataLoader
Expand Down Expand Up @@ -134,14 +134,13 @@ def _get_setup_method_args(**setup_locals) -> dict:
Must be called with ``**locals()`` at the start of the ``setup_anndata`` method
to avoid the inclusion of any extraneous variables.
"""
setup_locals.pop("adata")
cls = setup_locals.pop("cls")
model_name = cls.__name__
setup_kwargs = dict()
setup_args = dict()
for k, v in setup_locals.items():
if k not in _SETUP_INPUTS_EXCLUDED_PARAMS:
setup_kwargs[k] = v
return {_MODEL_NAME_KEY: model_name, _SETUP_KWARGS_KEY: setup_kwargs}
setup_args[k] = v
return {_MODEL_NAME_KEY: model_name, _SETUP_ARGS_KEY: setup_args}

@classmethod
def register_manager(cls, adata_manager: AnnDataManager):
Expand Down Expand Up @@ -590,7 +589,7 @@ def load(
"It appears you are loading a model from a different class."
)

if _SETUP_KWARGS_KEY not in registry:
if _SETUP_ARGS_KEY not in registry:
raise ValueError(
"Saved model does not contain original setup inputs. "
"Cannot load the original setup."
Expand All @@ -600,7 +599,7 @@ def load(
# the saved model. This enables simple backwards compatibility in the case of
# newly introduced fields or parameters.
cls.setup_anndata(
adata, source_registry=registry, **registry[_SETUP_KWARGS_KEY]
adata, source_registry=registry, **registry[_SETUP_ARGS_KEY]
)

model = _initialize_model(cls, adata, attr_dict)
Expand Down Expand Up @@ -653,6 +652,30 @@ def setup_anndata(
on a model-specific instance of :class:`~scvi.data.anndata.AnnDataManager`.
"""

@staticmethod
def view_setup_args(dir_path: str, prefix: Optional[str] = None) -> None:
"""
Print args used to setup a saved model.
Parameters
----------
dir_path
Path to saved outputs.
prefix
Prefix of saved file names.
"""
attr_dict = _load_saved_files(dir_path, False, prefix=prefix)[0]

# Legacy support for old setup dict format.
if "scvi_setup_dict_" in attr_dict:
raise NotImplementedError(
"Viewing setup args for pre v0.15.0 models is unsupported. "
"Load and resave the model to use this function."
)

registry = attr_dict.pop("registry_")
AnnDataManager.view_setup_method_args(registry)

def view_anndata_setup(
self, adata: Optional[AnnData] = None, hide_state_registries: bool = False
) -> None:
Expand All @@ -664,6 +687,8 @@ def view_anndata_setup(
adata
AnnData object setup with ``setup_anndata`` or
:meth:`~scvi.data.anndata.AnnDataManager.transfer_setup`.
hide_state_registries
If True, prints a shortened summary without details of each state registry.
"""
if adata is None:
adata = self.adata
Expand Down
9 changes: 8 additions & 1 deletion tests/dataset/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from scvi import REGISTRY_KEYS
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata._constants import _MODEL_NAME_KEY, _SETUP_ARGS_KEY
from scvi.data.anndata.fields import (
CategoricalJointObsField,
CategoricalObsField,
Expand Down Expand Up @@ -39,6 +40,10 @@ def generic_setup_adata_manager(
protein_expression_obsm_key: Optional[str] = None,
protein_names_uns_key: Optional[str] = None,
) -> AnnDataManager:
setup_args = locals()
setup_args.pop("adata")
setup_method_args = {_MODEL_NAME_KEY: "TestModel", _SETUP_ARGS_KEY: setup_args}

batch_field = CategoricalObsField(REGISTRY_KEYS.BATCH_KEY, batch_key)
anndata_fields = [
batch_field,
Expand All @@ -60,6 +65,8 @@ def generic_setup_adata_manager(
is_count_data=True,
)
)
adata_manager = AnnDataManager(fields=anndata_fields)
adata_manager = AnnDataManager(
fields=anndata_fields, setup_method_args=setup_method_args
)
adata_manager.register_fields(adata)
return adata_manager
3 changes: 3 additions & 0 deletions tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ def test_save_load_model(cls, adata, save_path, prefix=None, legacy=False):
)
else:
model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix)
model.view_setup_args(save_path, prefix=prefix)
model = cls.load(save_path, prefix=prefix)
model.get_latent_representation()

Expand Down Expand Up @@ -428,6 +429,7 @@ def test_save_load_autozi(legacy=False):
)
else:
model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix)
model.view_setup_args(save_path, prefix=prefix)
model = AUTOZI.load(save_path, prefix=prefix)
model.get_latent_representation()
tmp_adata = scvi.data.synthetic_iid(n_genes=200)
Expand Down Expand Up @@ -463,6 +465,7 @@ def test_save_load_scanvi(legacy=False):
)
else:
model.save(save_path, overwrite=True, save_anndata=True, prefix=prefix)
model.view_setup_args(save_path, prefix=prefix)
model = SCANVI.load(save_path, prefix=prefix)
model.get_latent_representation()
tmp_adata = scvi.data.synthetic_iid(n_genes=200)
Expand Down

0 comments on commit e20c69b

Please sign in to comment.