From 2b62e2960ab279d4d9929c558a2bf27a8ed7e2cb Mon Sep 17 00:00:00 2001 From: Justin Hong Date: Thu, 16 Dec 2021 17:17:41 -0800 Subject: [PATCH] Add anndata registration classes to developer API (#1292) * add anndata registration classes to developer API * move manager to anndata submodule to fix docs --- .gitignore | 1 + docs/api/developer.rst | 20 ++++++++++++++++++++ scvi/data/anndata/__init__.py | 2 ++ scvi/data/anndata/fields/_base_field.py | 4 ++-- scvi/data/anndata/fields/_layer_field.py | 6 +++--- scvi/data/anndata/fields/_obs_field.py | 6 +++--- scvi/data/anndata/fields/_obsm_field.py | 9 +++++---- 7 files changed, 36 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index 3fa983d3ba..4983bacdaf 100644 --- a/.gitignore +++ b/.gitignore @@ -68,6 +68,7 @@ instance/ # Sphinx documentation docs/_build/ +docs/api/ # PyBuilder target/ diff --git a/docs/api/developer.rst b/docs/api/developer.rst index 4a7ce88aec..50fa442ab9 100644 --- a/docs/api/developer.rst +++ b/docs/api/developer.rst @@ -10,6 +10,26 @@ Import scvi-tools as:: .. currentmodule:: scvi +Data Registration +----------------- + +.. currentmodule:: scvi + +AnnDataFields delineate how scvi-tools refers to fields in AnnData objects. The AnnDataManager provides an interface +for operating over a collection of AnnDataFields and an AnnData object. + + +.. autosummary:: + :toctree: reference/ + :nosignatures: + + data.anndata.AnnDataManager + data.anndata.fields.BaseAnnDataField + data.anndata.fields.LayerField + data.anndata.fields.CategoricalObsField + data.anndata.fields.NumericalJointObsField + data.anndata.fields.CategoricalJointObsField + Data Loaders ------------ diff --git a/scvi/data/anndata/__init__.py b/scvi/data/anndata/__init__.py index 89b9a74ffa..c988351483 100644 --- a/scvi/data/anndata/__init__.py +++ b/scvi/data/anndata/__init__.py @@ -1,7 +1,9 @@ from ._utils import register_tensor_from_anndata, setup_anndata, transfer_anndata_setup +from .manager import AnnDataManager __all__ = [ "setup_anndata", "transfer_anndata_setup", "register_tensor_from_anndata", + "AnnDataManager", ] diff --git a/scvi/data/anndata/fields/_base_field.py b/scvi/data/anndata/fields/_base_field.py index a47ed2d74a..5d7f8f494a 100644 --- a/scvi/data/anndata/fields/_base_field.py +++ b/scvi/data/anndata/fields/_base_field.py @@ -23,14 +23,14 @@ def __init__(self) -> None: @property @abstractmethod - def registry_key(self): + def registry_key(self) -> str: """The key that is referenced by models via a data loader.""" pass @property @abstractmethod def attr_name(self) -> str: - """The name of the AnnData attribute where the data is stored (e.g. obs).""" + """The name of the AnnData attribute where the data is stored.""" pass @property diff --git a/scvi/data/anndata/fields/_layer_field.py b/scvi/data/anndata/fields/_layer_field.py index 116e08282e..e23fb09059 100644 --- a/scvi/data/anndata/fields/_layer_field.py +++ b/scvi/data/anndata/fields/_layer_field.py @@ -40,15 +40,15 @@ def __init__( self.is_count_data = is_count_data @property - def registry_key(self): + def registry_key(self) -> str: return self._registry_key @property - def attr_name(self): + def attr_name(self) -> str: return self._attr_name @property - def attr_key(self): + def attr_key(self) -> Optional[str]: return self._attr_key @property diff --git a/scvi/data/anndata/fields/_obs_field.py b/scvi/data/anndata/fields/_obs_field.py index 5a7516c434..6995049ea0 100644 --- a/scvi/data/anndata/fields/_obs_field.py +++ b/scvi/data/anndata/fields/_obs_field.py @@ -24,15 +24,15 @@ def __init__(self, registry_key: str, obs_key: str) -> None: self._attr_key = obs_key @property - def registry_key(self): + def registry_key(self) -> str: return self._registry_key @property - def attr_name(self): + def attr_name(self) -> str: return self._attr_name @property - def attr_key(self): + def attr_key(self) -> str: return self._attr_key @property diff --git a/scvi/data/anndata/fields/_obsm_field.py b/scvi/data/anndata/fields/_obsm_field.py index 9247d0810d..fc0d3643ca 100644 --- a/scvi/data/anndata/fields/_obsm_field.py +++ b/scvi/data/anndata/fields/_obsm_field.py @@ -25,11 +25,11 @@ def __init__( self._registry_key = registry_key @property - def registry_key(self): + def registry_key(self) -> str: return self._registry_key @property - def attr_name(self): + def attr_name(self) -> str: return self._attr_name @@ -62,11 +62,12 @@ def _combine_obs_fields(self, adata: AnnData) -> None: adata.obsm[self.attr_key] = adata.obs[self.obs_keys].copy() @property - def obs_keys(self): + def obs_keys(self) -> List[str]: + """List of .obs keys that make up this joint field.""" return self._obs_keys @property - def attr_key(self): + def attr_key(self) -> str: return self._attr_key @property