Skip to content

Commit

Permalink
Move and reorganize setup dict out of AnnData into Manager (#1285)
Browse files Browse the repository at this point in the history
* new base field API for setup dict

* adapt manager to new registry

* backwards compatibility implementation

* train works

* adapt get_from_registry function

* simplify needs transfer logic, working test_scvi

* change _REGISTRY_KEYS back to _CONSTANTS

* address comments, fix compat test

* codacy
  • Loading branch information
justjhong authored Dec 14, 2021
1 parent a8e4720 commit 5c3777f
Show file tree
Hide file tree
Showing 26 changed files with 642 additions and 561 deletions.
4 changes: 2 additions & 2 deletions scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ class _CONSTANTS_NT(NamedTuple):
BATCH_KEY: str = "batch"
LABELS_KEY: str = "labels"
PROTEIN_EXP_KEY: str = "protein_expression"
CAT_COVS_KEY: str = "extra_categoricals"
CONT_COVS_KEY: str = "extra_continuous"
CAT_COVS_KEY: str = "extra_categorical_covs"
CONT_COVS_KEY: str = "extra_continuous_covs"


_CONSTANTS = _CONSTANTS_NT()
8 changes: 1 addition & 7 deletions scvi/data/anndata/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
from ._utils import (
get_from_registry,
register_tensor_from_anndata,
setup_anndata,
transfer_anndata_setup,
)
from ._utils import register_tensor_from_anndata, setup_anndata, transfer_anndata_setup

__all__ = [
"setup_anndata",
"get_from_registry",
"transfer_anndata_setup",
"register_tensor_from_anndata",
]
83 changes: 74 additions & 9 deletions scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from copy import deepcopy

from anndata import AnnData

from . import _constants
Expand All @@ -10,6 +12,69 @@
from .manager import AnnDataManager


def registry_from_setup_dict(setup_dict: dict) -> dict:
"""
Converts old setup dict format to new registry dict format.
Only to be used for backwards compatibility when loading setup dictionaries for models.
Takes old hard-coded setup dictionary structure and fills in the analogous registry structure.
Parameters
----------
setup_dict
Setup dictionary created after registering an AnnData with former `setup_anndata(...)` implementation.
"""
registry = {
_constants._SCVI_VERSION_KEY: setup_dict[_constants._SCVI_VERSION_KEY],
_constants._FIELD_REGISTRIES_KEY: {},
}
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict["categorical_mappings"]
summary_stats = setup_dict[_constants._SUMMARY_STATS_KEY]
field_registries = registry[_constants._FIELD_REGISTRIES_KEY]
for (
registry_key,
adata_mapping,
) in data_registry.items(): # Note: this does not work for empty fields.
attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]

field_registries[registry_key] = {
_constants._DATA_REGISTRY_KEY: adata_mapping,
_constants._STATE_REGISTRY_KEY: dict(),
_constants._SUMMARY_STATS_KEY: dict(),
}
field_registry = field_registries[registry_key]
field_state_registry = field_registry[_constants._STATE_REGISTRY_KEY]
field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY]

if attr_name in (_constants._ADATA_ATTRS.X, _constants._ADATA_ATTRS.LAYERS):
field_state_registry[LayerField.N_CELLS_KEY] = summary_stats["n_cells"]
field_state_registry[LayerField.N_VARS_KEY] = summary_stats["n_vars"]
field_summary_stats.update(field_state_registry)
elif attr_name == _constants._ADATA_ATTRS.OBS:
categorical_mapping = categorical_mappings[attr_key]
field_state_registry[
CategoricalObsField.CATEGORICAL_MAPPING_KEY
] = categorical_mapping["mapping"]
if attr_key == "_scvi_batch":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_batch"]
elif attr_key == "_scvi_labels":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_labels"]
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if attr_key == "_scvi_extra_continuous":
columns = setup_dict["extra_continuous_keys"].copy()
field_state_registry[NumericalJointObsField.COLUMNS_KEY] = columns
field_summary_stats[f"n_{registry_key}"] = columns.shape[0]
elif attr_key == "_scvi_extra_categoricals":
extra_categoricals_mapping = deepcopy(setup_dict["extra_categoricals"])
field_state_registry.update(deepcopy(setup_dict["extra_categoricals"]))
field_summary_stats[f"n_{registry_key}"] = len(
extra_categoricals_mapping["keys"]
)
return registry


def manager_from_setup_dict(
adata: AnnData, setup_dict: dict, **transfer_kwargs
) -> AnnDataManager:
Expand All @@ -25,13 +90,13 @@ def manager_from_setup_dict(
adata
AnnData object to be registered.
setup_dict
Setup dictionary created after registering an AnnData using an AnnDataManager object.
Setup dictionary created after registering an AnnData with former `setup_anndata(...)` implementation.
**kwargs
Keyword arguments to modify transfer behavior.
"""
source_adata_manager = AnnDataManager()
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict[_constants._CATEGORICAL_MAPPINGS_KEY]
categorical_mappings = setup_dict["categorical_mappings"]
for registry_key, adata_mapping in data_registry.items():
field = None
attr_name = adata_mapping[_constants._DR_ATTR_NAME]
Expand All @@ -41,15 +106,14 @@ def manager_from_setup_dict(
elif attr_name == _constants._ADATA_ATTRS.LAYERS:
field = LayerField(registry_key, attr_key)
elif attr_name == _constants._ADATA_ATTRS.OBS:
original_key = categorical_mappings[attr_key][_constants._CM_ORIGINAL_KEY]
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(registry_key, original_key)
elif attr_name == _constants._ADATA_ATTRS.OBSM:
cont_cov_column_key = f"{registry_key}_keys"
if cont_cov_column_key in setup_dict:
obs_keys = setup_dict[cont_cov_column_key]
if attr_key == "_scvi_extra_continuous":
obs_keys = setup_dict["extra_continuous_keys"]
field = NumericalJointObsField(registry_key, obs_keys)
elif registry_key in setup_dict:
obs_keys = setup_dict[registry_key][_constants._JO_CM_KEYS_KEY]
elif attr_key == "_scvi_extra_categoricals":
obs_keys = setup_dict["extra_categoricals"]["keys"]
field = CategoricalJointObsField(registry_key, obs_keys)
else:
raise NotImplementedError(
Expand All @@ -60,6 +124,7 @@ def manager_from_setup_dict(
f"Backwards compatibility for attribute {attr_name} is not implemented yet."
)
source_adata_manager.add_field(field)
source_registry = registry_from_setup_dict(setup_dict)
return source_adata_manager.transfer_setup(
adata, source_setup_dict=setup_dict, **transfer_kwargs
adata, source_registry=source_registry, **transfer_kwargs
)
20 changes: 4 additions & 16 deletions scvi/data/anndata/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,17 @@
################################

_SCVI_UUID_KEY = "_scvi_uuid"
_SOURCE_SCVI_UUID_KEY = "_source_scvi_uuid"

#############################
# scVI Setup Dict Constants #
#############################

_SETUP_DICT_KEY = "_scvi"
_SCVI_VERSION_KEY = "scvi_version"
_FIELD_REGISTRIES_KEY = "field_registries"
_DATA_REGISTRY_KEY = "data_registry"
_CATEGORICAL_MAPPINGS_KEY = "categorical_mappings"
_STATE_REGISTRY_KEY = "state_registry"
_SUMMARY_STATS_KEY = "summary_stats"

################################
Expand All @@ -22,21 +25,6 @@
_DR_ATTR_NAME = "attr_name"
_DR_ATTR_KEY = "attr_key"

#######################################
# scVI Categorical Mappings Constants #
#######################################

_CM_ORIGINAL_KEY = "original_key"
_CM_MAPPING_KEY = "mapping"

#######################################
# scVI Joint Obs Categorical Mappings Constants #
#######################################

_JO_CM_MAPPINGS_KEY = "mappings"
_JO_CM_KEYS_KEY = "keys"
_JO_CM_N_CATS_PER_KEY = "n_cats_per_key"


############################
# AnnData Object Constants #
Expand Down
Loading

0 comments on commit 5c3777f

Please sign in to comment.