Skip to content

Commit

Permalink
[AnnData setup refactor] Make setup_anndata a static method on model …
Browse files Browse the repository at this point in the history
…classes rather than one global function (#1150)

* preliminary changes

* fix poetry.lock

* fix some bad imports

* address flake8 errors

* update comments and deprecate existing global function

* fix test

* add some coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add more coverage

* Fix CondSCVI docstring

* address pr feedback

* address pr feedback cntd

* fix some typos

* preliminary changes

* fix some bad imports

* address flake8 errors

* update comments and deprecate existing global function

* fix test

* add some coverage

* add more coverage

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix CondSCVI docstring

* address pr feedback

* address pr feedback cntd

* make LDA setup_anndata static

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* remove local library stuff

* add docrep and use it for some models

* refactor other models' setup_anndata docstrings

* refactor lda's setup_anndata docstring

* add noqa comments to fix codacy failures

* add abstract setup_anndata on the BaseModelClass

* try again to fix codacy

* add a few more fixes

* remove noqa and add/remove periods to make codacy happy, also use a specific dsp instance

* fix repeated param_cat_conv

* add version to deprecation message

Co-authored-by: Valeh Valiollah Pour Amiri <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Justin Hong <[email protected]>
  • Loading branch information
4 people authored Sep 27, 2021
1 parent c81da01 commit 9855238
Show file tree
Hide file tree
Showing 46 changed files with 933 additions and 196 deletions.
2 changes: 1 addition & 1 deletion docs/contributing/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ The mainstream development branch is the master branch. We snap releases off of
We use the MeeseeksDev GitHub bot for automatic backporting. The way it works, in a nutshell, is that the bot listens to certain web events - for example commits containing “@meeseeksdev backport to [BRANCHNAME]” on a PR - and automatically opens a PR to that repo/branch. (Note: They open the PR sourced from a fork of the repo under the `MeeseeksMachine <https://github.com/meeseeksmachine>`_ organization, into the repo/branch of interest. That’s why under MeeseeksMachine you see a collection of repo's that are forks of the repo's that use MeeseeksDev).

For each release, we create a branch [MAJOR].[MINOR].x where MAJOR and MINOR are the Major and Minor version numbers for that release, respectively, and x is the literal “x”. Every time a bug fix PR is merged into master, we evaluate whether it is worthy of being backported into the current release and if so use MeeseeksDev to do it for us if it can. How? Simply leave a comment on the PR that was merged into master that says: “@meeseeksdev backport to [MAJOR].[MINOR].x” (for example “@meeseeksdev backport to 0.14.x” if we are on a release from the 0.14 series.
The PR also needs to be associated with a Milestone the description of which contains “on-merge: backport to [BRANCHNAME]”.
Note: Auto backporting can also be triggered if you associate the PR with a Milestone or Label the description of which contains “on-merge: backport to [BRANCHNAME]”.

.. highlight:: none

Expand Down
53 changes: 49 additions & 4 deletions poetry.lock

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ version = "0.13.0"
anndata = ">=0.7.5"
black = {version = ">=20.8b1", optional = true}
codecov = {version = ">=2.0.8", optional = true}
docrep = ">=0.3.2"
flake8 = {version = ">=3.7.7", optional = true}
h5py = ">=2.9.0"
importlib-metadata = {version = "^1.0", python = "<3.8"}
Expand Down
63 changes: 61 additions & 2 deletions scvi/_docs.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Shared docstrings."""
from docrep import DocstringProcessor

doc_differential_expression = """\
adata
Expand Down Expand Up @@ -37,7 +37,7 @@
batch_correction
Whether to correct for batch effects in DE inference.
batchid1
Subset of categories from `batch_key` registered in :func:`~scvi.data.setup_anndata`,
Subset of categories from `batch_key` registered in ``setup_anndata``,
e.g. [`'batch1'`, `'batch2'`, `'batch3'`], for `group1`. Only used if `batch_correction` is `True`, and
by default all categories are used.
batchid2
Expand All @@ -50,3 +50,62 @@
silent
If True, disables the progress bar. Default: False.
"""

summary = """\
Sets up the :class:`~anndata.AnnData` object for this model.
A mapping will be created between data fields used by this model to their respective locations in adata.
None of the data in adata are modified. Only adds fields to adata"""

param_adata = """\
adata
AnnData object containing raw counts. Rows represent cells, columns represent features."""

param_batch_key = """\
batch_key
key in `adata.obs` for batch information. Categories will automatically be converted into integer
categories and saved to `adata.obs['_scvi_batch']`. If `None`, assigns the same batch to all the data."""

param_labels_key = """\
labels_key
key in `adata.obs` for label information. Categories will automatically be converted into integer
categories and saved to `adata.obs['_scvi_labels']`. If `None`, assigns the same label to all the data."""

param_layer = """\
layer
if not `None`, uses this as the key in `adata.layers` for raw count data."""

param_cat_cov_keys = """\
categorical_covariate_keys
keys in `adata.obs` that correspond to categorical data."""

param_cont_cov_keys = """\
continuous_covariate_keys
keys in `adata.obs` that correspond to continuous data."""

param_copy = """\
copy
if `True`, a copy of adata is returned."""

returns = """\
If ``copy``, will return :class:`~anndata.AnnData`.
Adds the following fields to adata:
.uns['_scvi']
`scvi` setup dictionary
.obs['_scvi_labels']
labels encoded as integers
.obs['_scvi_batch']
batch encoded as integers"""

setup_anndata_dsp = DocstringProcessor(
summary=summary,
param_adata=param_adata,
param_batch_key=param_batch_key,
param_labels_key=param_labels_key,
param_layer=param_layer,
param_cat_cov_keys=param_cat_cov_keys,
param_cont_cov_keys=param_cont_cov_keys,
param_copy=param_copy,
returns=returns,
)
33 changes: 30 additions & 3 deletions scvi/data/_anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pandas.api.types import CategoricalDtype
from rich.console import Console
from scipy.sparse import isspmatrix
from sklearn.utils import deprecated

import scvi
from scvi import _CONSTANTS
Expand All @@ -30,7 +31,7 @@ def get_from_registry(adata: anndata.AnnData, key: str) -> np.ndarray:
Parameters
----------
adata
anndata object already setup with `scvi.data.setup_anndata()`
anndata object already setup with setup_anndata
key
key of object to get from ``adata.uns['_scvi]['data_registry']``
Expand Down Expand Up @@ -70,6 +71,9 @@ def get_from_registry(adata: anndata.AnnData, key: str) -> np.ndarray:
return data


@deprecated(
extra="Please use the model-specific setup_anndata methods instead. The global method will be removed in version 0.15.0."
)
def setup_anndata(
adata: anndata.AnnData,
batch_key: Optional[str] = None,
Expand All @@ -85,7 +89,6 @@ def setup_anndata(
Sets up :class:`~anndata.AnnData` object for models.
A mapping will be created between data fields used by models to their respective locations in adata.
This method will also compute the log mean and log variance per batch for the library size prior.
None of the data in adata are modified. Only adds fields to adata.
Expand Down Expand Up @@ -139,7 +142,7 @@ def setup_anndata(
uns: 'protein_names'
obsm: 'protein_expression'
Filter cells and run preprocessing before `setup_anndata`
Filter cells and run preprocessing before ``setup_anndata``
>>> sc.pp.filter_cells(adata, min_counts = 0)
Expand All @@ -166,6 +169,30 @@ def setup_anndata(
INFO Registered keys:['X', 'batch_indices', 'labels', 'protein_expression']
INFO Successfully registered anndata object containing 400 cells, 100 vars, 2 batches, 1 labels, and 100 proteins. Also registered 0 extra categorical covariates and 0 extra continuous covariates.
"""
return _setup_anndata(
adata,
batch_key,
labels_key,
layer,
protein_expression_obsm_key,
protein_names_uns_key,
categorical_covariate_keys,
continuous_covariate_keys,
copy,
)


def _setup_anndata(
adata: anndata.AnnData,
batch_key: Optional[str] = None,
labels_key: Optional[str] = None,
layer: Optional[str] = None,
protein_expression_obsm_key: Optional[str] = None,
protein_names_uns_key: Optional[str] = None,
categorical_covariate_keys: Optional[List[str]] = None,
continuous_covariate_keys: Optional[List[str]] = None,
copy: bool = False,
) -> Optional[anndata.AnnData]:
if copy:
adata = adata.copy()

Expand Down
4 changes: 2 additions & 2 deletions scvi/data/_built_in_data/_brain_large.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import scipy.sparse as sp_sparse

from scvi.data._anndata import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download

logger = logging.getLogger(__name__)
Expand All @@ -33,7 +33,7 @@ def _load_brainlarge_dataset(
loading_batch_size=loading_batch_size,
)
if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata


Expand Down
8 changes: 4 additions & 4 deletions scvi/data/_built_in_data/_cite_seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import anndata
import pandas as pd

from scvi.data import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download


Expand Down Expand Up @@ -65,7 +65,7 @@ def _load_pbmcs_10x_cite_seq(
dataset.obsm["protein_expression"] = dataset.obsm["protein_expression"].fillna(0)

if run_setup_anndata:
setup_anndata(
_setup_anndata(
dataset,
batch_key="batch",
protein_expression_obsm_key="protein_expression",
Expand Down Expand Up @@ -94,7 +94,7 @@ def _load_spleen_lymph_cite_seq(
remove_outliers
Whether to remove clusters annotated as doublet or low quality
run_setup_anndata
If true, runs setup_anndata() on dataset before returning
If true, runs _setup_anndata() on dataset before returning
Returns
-------
Expand Down Expand Up @@ -135,7 +135,7 @@ def _load_spleen_lymph_cite_seq(
dataset = dataset[include_cells].copy()

if run_setup_anndata:
setup_anndata(
_setup_anndata(
dataset,
batch_key="batch",
labels_key="cell_types",
Expand Down
4 changes: 2 additions & 2 deletions scvi/data/_built_in_data/_cortex.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from scvi.data._anndata import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download

logger = logging.getLogger(__name__)
Expand All @@ -22,7 +22,7 @@ def _load_cortex(
_download(url, save_path, save_fn)
adata = _load_cortex_txt(os.path.join(save_path, save_fn))
if run_setup_anndata:
setup_anndata(adata, labels_key="labels")
_setup_anndata(adata, labels_key="labels")
return adata


Expand Down
6 changes: 3 additions & 3 deletions scvi/data/_built_in_data/_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import anndata
import numpy as np

from scvi.data._anndata import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download

logger = logging.getLogger(__name__)
Expand All @@ -24,7 +24,7 @@ def _load_breast_cancer_dataset(
adata.obs["labels"] = np.zeros(adata.shape[0]).astype(int)

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata


Expand All @@ -40,7 +40,7 @@ def _load_mouse_ob_dataset(save_path: str = "data/", run_setup_anndata: bool = T
adata.obs["labels"] = np.zeros(adata.shape[0]).astype(int)

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")

return adata

Expand Down
4 changes: 2 additions & 2 deletions scvi/data/_built_in_data/_heartcellatlas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import anndata

from scvi.data import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download


Expand Down Expand Up @@ -50,7 +50,7 @@ def _load_heart_cell_atlas_subsampled(
dataset = dataset[keep, :].copy()

if run_setup_anndata:
setup_anndata(
_setup_anndata(
dataset,
)

Expand Down
10 changes: 5 additions & 5 deletions scvi/data/_built_in_data/_loom.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from anndata import AnnData

from scvi.data import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -50,7 +50,7 @@ def _load_retina(save_path: str = "data/", run_setup_anndata: bool = True) -> An
adata.obs["batch"] = pd.Categorical(adata.obs["BatchID"].values.copy())
del adata.obs["BatchID"]
if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")

return adata

Expand All @@ -75,7 +75,7 @@ def _load_prefrontalcortex_starmap(
adata.obs["x_coord"] = adata.obsm["Spatial_coordinates"][:, 0]
adata.obs["y_coord"] = adata.obsm["Spatial_coordinates"][:, 1]
if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata


Expand All @@ -96,7 +96,7 @@ def _load_frontalcortex_dropseq(
# self.reorder_cell_types(self.cell_types[order_labels])

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")

return adata

Expand Down Expand Up @@ -126,7 +126,7 @@ def _load_annotation_simulation(
del adata.obs["BatchID"]

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")

return adata

Expand Down
6 changes: 3 additions & 3 deletions scvi/data/_built_in_data/_pbmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pandas as pd

from scvi.data import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._dataset_10x import _load_dataset_10x
from scvi.data._built_in_data._download import _download

Expand Down Expand Up @@ -44,7 +44,7 @@ def _load_purified_pbmc_dataset(
adata = adata[row_indices].copy()

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")

return adata

Expand Down Expand Up @@ -126,5 +126,5 @@ def _load_pbmc_dataset(
adata.var["n_counts"] = np.squeeze(np.asarray(np.sum(adata.X, axis=0)))

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata
6 changes: 3 additions & 3 deletions scvi/data/_built_in_data/_seqfish.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import pandas as pd

from scvi.data import setup_anndata
from scvi.data._anndata import _setup_anndata
from scvi.data._built_in_data._download import _download

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -41,7 +41,7 @@ def _load_seqfishplus(
adata.obs["labels"] = np.zeros(adata.shape[0], dtype=np.int64)

if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata


Expand Down Expand Up @@ -81,7 +81,7 @@ def _load_seqfish(
adata.obs["batch"] = np.zeros(adata.shape[0], dtype=np.int64)
adata.obs["labels"] = np.zeros(adata.shape[0], dtype=np.int64)
if run_setup_anndata:
setup_anndata(adata, batch_key="batch", labels_key="labels")
_setup_anndata(adata, batch_key="batch", labels_key="labels")
return adata


Expand Down
Loading

0 comments on commit 9855238

Please sign in to comment.