Skip to content

Commit

Permalink
Adapt all internal models to new setup (#1301)
Browse files Browse the repository at this point in the history
* adapt LDA

* adapt linearscvi

* remove _get_var_names_from_setup_anndata

* adapt peakvi

* adapt autozi

* adapt scanvi

* fix scanvi test

* fix totalvi test

* fix dataloader tests

* fix multiple cov tests

* adapt condscvi

* adapt destvi

* adapt multivi

* fix setup compat test

* remove get_from_registry util

* fix scanvi and peakvi scarches tests

* fix backwards compat tests and default missing summary stat in models

* address comment

* Adapt all external models to new setup (#1302)

* adapt cellassign

* adapt gimvi

* adapt solo model

* adapt stereoscope
  • Loading branch information
justjhong committed Jan 15, 2022
1 parent 8821b6b commit 3afbf24
Show file tree
Hide file tree
Showing 35 changed files with 864 additions and 654 deletions.
2 changes: 2 additions & 0 deletions scvi/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ class _CONSTANTS_NT(NamedTuple):
PROTEIN_EXP_KEY: str = "proteins"
CAT_COVS_KEY: str = "extra_categorical_covs"
CONT_COVS_KEY: str = "extra_continuous_covs"
INDICES_KEY: str = "ind_x"
SIZE_FACTOR_KEY: str = "size_factor"


_CONSTANTS = _CONSTANTS_NT()
92 changes: 71 additions & 21 deletions scvi/data/anndata/_compat.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,32 @@
from copy import deepcopy

import numpy as np
from anndata import AnnData
from sklearn.utils import deprecated

from scvi import _CONSTANTS

from . import _constants
from ._manager import AnnDataManager
from .fields import (
CategoricalJointObsField,
CategoricalObsField,
LayerField,
NumericalJointObsField,
NumericalObsField,
ProteinObsmField,
)

LEGACY_REGISTRY_KEY_MAP = {
"X": _CONSTANTS.X_KEY,
"batch_indices": _CONSTANTS.BATCH_KEY,
"labels": _CONSTANTS.LABELS_KEY,
"cat_covs": _CONSTANTS.CAT_COVS_KEY,
"cont_covs": _CONSTANTS.CONT_COVS_KEY,
"protein_expression": _CONSTANTS.PROTEIN_EXP_KEY,
"ind_x": _CONSTANTS.INDICES_KEY,
}


def registry_from_setup_dict(setup_dict: dict) -> dict:
"""
Expand All @@ -37,15 +52,19 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:
registry_key,
adata_mapping,
) in data_registry.items(): # Note: this does not work for empty fields.
if registry_key not in LEGACY_REGISTRY_KEY_MAP:
continue
new_registry_key = LEGACY_REGISTRY_KEY_MAP[registry_key]

attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]

field_registries[registry_key] = {
field_registries[new_registry_key] = {
_constants._DATA_REGISTRY_KEY: adata_mapping,
_constants._STATE_REGISTRY_KEY: dict(),
_constants._SUMMARY_STATS_KEY: dict(),
}
field_registry = field_registries[registry_key]
field_registry = field_registries[new_registry_key]
field_state_registry = field_registry[_constants._STATE_REGISTRY_KEY]
field_summary_stats = field_registry[_constants._SUMMARY_STATS_KEY]

Expand All @@ -58,21 +77,32 @@ def registry_from_setup_dict(setup_dict: dict) -> dict:
field_state_registry[
CategoricalObsField.CATEGORICAL_MAPPING_KEY
] = categorical_mapping["mapping"]
if attr_key == "_scvi_batch":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_batch"]
elif attr_key == "_scvi_labels":
field_summary_stats[f"n_{registry_key}"] = summary_stats["n_labels"]
if new_registry_key == _CONSTANTS.BATCH_KEY:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_batch"]
elif new_registry_key == _CONSTANTS.LABELS_KEY:
field_summary_stats[f"n_{new_registry_key}"] = summary_stats["n_labels"]
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if attr_key == "_scvi_extra_continuous":
if new_registry_key == _CONSTANTS.CONT_COVS_KEY:
columns = setup_dict["extra_continuous_keys"].copy()
field_state_registry[NumericalJointObsField.COLUMNS_KEY] = columns
field_summary_stats[f"n_{registry_key}"] = columns.shape[0]
elif attr_key == "_scvi_extra_categoricals":
field_summary_stats[f"n_{new_registry_key}"] = columns.shape[0]
elif new_registry_key == _CONSTANTS.CAT_COVS_KEY:
extra_categoricals_mapping = deepcopy(setup_dict["extra_categoricals"])
field_state_registry.update(deepcopy(setup_dict["extra_categoricals"]))
field_summary_stats[f"n_{registry_key}"] = len(
field_summary_stats[f"n_{new_registry_key}"] = len(
extra_categoricals_mapping["keys"]
)
elif new_registry_key == _CONSTANTS.PROTEIN_EXP_KEY:
field_state_registry[ProteinObsmField.COLUMN_NAMES_KEY] = setup_dict[
"protein_names"
].copy()
if "totalvi_batch_mask" in setup_dict:
field_state_registry[
ProteinObsmField.PROTEIN_BATCH_MASK
] = setup_dict["totalvi_batch_mask"].copy()
field_summary_stats[f"n_{new_registry_key}"] = len(
setup_dict["protein_names"]
)
return registry


Expand Down Expand Up @@ -106,37 +136,57 @@ def manager_from_setup_dict(
data_registry = setup_dict[_constants._DATA_REGISTRY_KEY]
categorical_mappings = setup_dict["categorical_mappings"]
for registry_key, adata_mapping in data_registry.items():
if registry_key not in LEGACY_REGISTRY_KEY_MAP:
continue
new_registry_key = LEGACY_REGISTRY_KEY_MAP[registry_key]

field = None
attr_name = adata_mapping[_constants._DR_ATTR_NAME]
attr_key = adata_mapping[_constants._DR_ATTR_KEY]
if attr_name == _constants._ADATA_ATTRS.X:
field = LayerField(registry_key, None)
field = LayerField(_CONSTANTS.X_KEY, None)
setup_kwargs["layer"] = None
elif attr_name == _constants._ADATA_ATTRS.LAYERS:
field = LayerField(registry_key, attr_key)
field = LayerField(_CONSTANTS.X_KEY, attr_key)
setup_kwargs["layer"] = attr_key
elif attr_name == _constants._ADATA_ATTRS.OBS:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(registry_key, original_key)
setup_kwargs[f"{registry_key}_key"] = original_key
if new_registry_key in {_CONSTANTS.BATCH_KEY, _CONSTANTS.LABELS_KEY}:
original_key = categorical_mappings[attr_key]["original_key"]
field = CategoricalObsField(new_registry_key, original_key)
setup_kwargs[f"{new_registry_key}_key"] = original_key
elif new_registry_key == _CONSTANTS.INDICES_KEY:
adata.obs[attr_key] = np.arange(adata.n_obs).astype("int64")
field = NumericalObsField(new_registry_key, attr_key)
elif attr_name == _constants._ADATA_ATTRS.OBSM:
if attr_key == "_scvi_extra_continuous":
if new_registry_key == _CONSTANTS.CONT_COVS_KEY:
obs_keys = setup_dict["extra_continuous_keys"]
field = NumericalJointObsField(registry_key, obs_keys)
field = NumericalJointObsField(_CONSTANTS.CONT_COVS_KEY, obs_keys)
setup_kwargs["continuous_covariate_keys"] = obs_keys
elif attr_key == "_scvi_extra_categoricals":
elif new_registry_key == _CONSTANTS.CAT_COVS_KEY:
obs_keys = setup_dict["extra_categoricals"]["keys"]
field = CategoricalJointObsField(registry_key, obs_keys)
field = CategoricalJointObsField(_CONSTANTS.CAT_COVS_KEY, obs_keys)
setup_kwargs["categorical_covariate_keys"] = obs_keys
elif new_registry_key == _CONSTANTS.PROTEIN_EXP_KEY:
protein_names = setup_dict["protein_names"]
adata.uns["_protein_names"] = protein_names
field = ProteinObsmField(
_CONSTANTS.PROTEIN_EXP_KEY,
attr_key,
"_scvi_batch",
colnames_uns_key="_protein_names",
)
setup_kwargs["protein_expression_obsm_key"] = attr_key
setup_kwargs["protein_names_uns_key"] = "_protein_names"
else:
raise NotImplementedError(
f"Unrecognized .obsm attribute {attr_key}. Backwards compatibility unavailable."
f"Unrecognized .obsm attribute {attr_key} registered as {new_registry_key}. Backwards compatibility unavailable."
)
else:
raise NotImplementedError(
f"Backwards compatibility for attribute {attr_name} is not implemented yet."
f"Backwards compatibility for attribute {attr_name} is not implemented."
)
fields.append(field)

setup_method_args = {
_constants._MODEL_NAME_KEY: cls.__name__,
_constants._SETUP_KWARGS_KEY: setup_kwargs,
Expand Down
44 changes: 0 additions & 44 deletions scvi/data/anndata/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,50 +44,6 @@ def get_anndata_attribute(
return field


def get_from_registry(
adata: anndata.AnnData, key: str
) -> Union[np.ndarray, pd.DataFrame]:
"""
Returns the object in AnnData associated with the key in ``.uns['_scvi']['data_registry']``.
Parameters
----------
adata
anndata object already setup with setup_anndata
key
key of object to get from ``adata.uns['_scvi]['data_registry']``
Returns
-------
The requested data
Examples
--------
>>> import scvi
>>> adata = scvi.data.cortex()
>>> adata.uns['_scvi']['data_registry']
{'X': ['_X', None],
'batch': ['obs', 'batch'],
'labels': ['obs', 'labels']}
>>> batch = get_from_registry(adata, "batch")
>>> batch
array([[0],
[0],
[0],
...,
[0],
[0],
[0]])
"""
data_loc = adata.uns[_constants._SETUP_DICT_KEY][_constants._DATA_REGISTRY_KEY][key]
attr_name, attr_key = (
data_loc[_constants._DR_ATTR_NAME],
data_loc[_constants._DR_ATTR_KEY],
)

return get_anndata_attribute(adata, attr_name, attr_key)


@deprecated(
extra="Please use the model-specific setup_anndata methods instead. The global method will be removed in version 0.15.0."
)
Expand Down
2 changes: 2 additions & 0 deletions scvi/data/anndata/fields/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from ._layer_field import LayerField
from ._obs_field import CategoricalObsField, NumericalObsField
from ._obsm_field import CategoricalJointObsField, NumericalJointObsField, ObsmField
from ._scanvi import LabelsWithUnlabeledObsField
from ._totalvi import ProteinObsmField

__all__ = [
Expand All @@ -13,4 +14,5 @@
"CategoricalJointObsField",
"ObsmField",
"ProteinObsmField",
"LabelsWithUnlabeledObsField",
]
11 changes: 9 additions & 2 deletions scvi/data/anndata/fields/_obs_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ class CategoricalObsField(BaseObsField):
"""

CATEGORICAL_MAPPING_KEY = "categorical_mapping"
ORIGINAL_ATTR_KEY = "original_key"

def __init__(self, registry_key: str, obs_key: Optional[str]) -> None:
self.is_default = obs_key is None
Expand Down Expand Up @@ -114,7 +115,10 @@ def register_field(self, adata: AnnData) -> dict:
categorical_mapping = _make_obs_column_categorical(
adata, self._original_attr_key, self.attr_key, return_mapping=True
)
return {self.CATEGORICAL_MAPPING_KEY: categorical_mapping}
return {
self.CATEGORICAL_MAPPING_KEY: categorical_mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
}

def transfer_field(
self,
Expand Down Expand Up @@ -150,7 +154,10 @@ def transfer_field(
categorical_dtype=cat_dtype,
return_mapping=True,
)
return {self.CATEGORICAL_MAPPING_KEY: new_mapping}
return {
self.CATEGORICAL_MAPPING_KEY: new_mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
}

def get_summary_stats(self, state_registry: dict) -> dict:
categorical_mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
Expand Down
3 changes: 2 additions & 1 deletion scvi/data/anndata/fields/_obsm_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def is_empty(self) -> bool:

def validate_field(self, adata: AnnData) -> None:
super().validate_field(adata)
assert self.attr_key in adata.obsm, f"{self.attr_key} not found in adata.obsm."
if self.attr_key not in adata.obsm:
raise KeyError(f"{self.attr_key} not found in adata.obsm.")

obsm_data = self.get_field_data(adata)

Expand Down
90 changes: 90 additions & 0 deletions scvi/data/anndata/fields/_scanvi.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Optional, Union

import numpy as np
from anndata import AnnData
from pandas.api.types import CategoricalDtype

from scvi.data.anndata._utils import _make_obs_column_categorical

from ._obs_field import CategoricalObsField


class LabelsWithUnlabeledObsField(CategoricalObsField):
"""
An AnnDataField for labels which include explicitly unlabeled cells.
Remaps the unlabeled category to the final index if present in labels.
The unlabeled category is a specific category name specified by the user.
Parameters
----------
registry_key
Key to register field under in data registry.
obs_key
Key to access the field in the AnnData obs mapping. If None, defaults to `registry_key`.
unlabeled_category
Value assigned to unlabeled cells.
"""

UNLABELED_CATEGORY = "unlabeled_category"
WAS_REMAPPED = "was_remapped"

def __init__(
self,
registry_key: str,
obs_key: Optional[str],
unlabeled_category: Union[str, int, float],
) -> None:
super().__init__(registry_key, obs_key)
self._unlabeled_category = unlabeled_category

def _remap_unlabeled_to_final_category(
self, adata: AnnData, mapping: np.ndarray
) -> dict:
labels = self._get_original_column(adata)

if self._unlabeled_category in labels:
unlabeled_idx = np.where(mapping == self._unlabeled_category)
unlabeled_idx = unlabeled_idx[0][0]
# move unlabeled category to be the last position
mapping[unlabeled_idx], mapping[-1] = mapping[-1], mapping[unlabeled_idx]
cat_dtype = CategoricalDtype(categories=mapping, ordered=True)
# rerun setup for the batch column
mapping = _make_obs_column_categorical(
adata,
self._original_attr_key,
self.attr_key,
categorical_dtype=cat_dtype,
return_mapping=True,
)
remapped = True
else:
remapped = False

return {
self.CATEGORICAL_MAPPING_KEY: mapping,
self.ORIGINAL_ATTR_KEY: self._original_attr_key,
self.UNLABELED_CATEGORY: self._unlabeled_category,
self.WAS_REMAPPED: remapped,
}

def register_field(self, adata: AnnData) -> dict:
if self.is_default:
self._setup_default_attr(adata)

state_registry = super().register_field(adata)
mapping = state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata, mapping)

def transfer_field(
self,
state_registry: dict,
adata_target: AnnData,
extend_categories: bool = False,
**kwargs,
) -> dict:
transfer_state_registry = super().transfer_field(
state_registry, adata_target, extend_categories=extend_categories, **kwargs
)
mapping = transfer_state_registry[self.CATEGORICAL_MAPPING_KEY]
return self._remap_unlabeled_to_final_category(adata_target, mapping)
9 changes: 5 additions & 4 deletions scvi/dataloaders/_data_splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from scvi import _CONSTANTS, settings
from scvi.data.anndata import AnnDataManager
from scvi.data.anndata.fields import LabelsWithUnlabeledObsField
from scvi.dataloaders._ann_dataloader import AnnDataLoader, BatchSampler
from scvi.dataloaders._semi_dataloader import SemiSupervisedDataLoader
from scvi.model._utils import parse_use_gpu_arg
Expand Down Expand Up @@ -212,10 +213,10 @@ def __init__(
self.data_loader_kwargs = kwargs
self.n_samples_per_label = n_samples_per_label

setup_dict = adata_manager.get_setup_dict()
key = setup_dict["data_registry"][_CONSTANTS.LABELS_KEY]["attr_key"]
original_key = setup_dict["categorical_mappings"][key]["original_key"]
labels = np.asarray(adata_manager.obs[original_key]).ravel()
original_key = adata_manager.get_state_registry(_CONSTANTS.LABELS_KEY)[
LabelsWithUnlabeledObsField.ORIGINAL_ATTR_KEY
]
labels = np.asarray(adata_manager.adata.obs[original_key]).ravel()
self._unlabeled_indices = np.argwhere(labels == unlabeled_category).ravel()
self._labeled_indices = np.argwhere(labels != unlabeled_category).ravel()

Expand Down
Loading

0 comments on commit 3afbf24

Please sign in to comment.