diff --git a/.gitignore b/.gitignore index 8c11624e..053dabda 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,6 @@ +build/ +scratch/ + **.h5ad *test_output* diff --git a/.readthedocs.yaml b/.readthedocs.yaml new file mode 100644 index 00000000..9aae0277 --- /dev/null +++ b/.readthedocs.yaml @@ -0,0 +1,31 @@ +# .readthedocs.yaml +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the version of Python and other tools you might need +build: + os: ubuntu-20.04 + tools: + python: "3.8" + apt_packages: + - r-base + +# Build documentation in the docs/ directory with Sphinx +sphinx: + configuration: docs/source/conf.py + fail_on_warning: false + +# If using Sphinx, optionally build your docs in additional formats such as PDF +# formats: +# - pdf + +# Optionally declare the Python requirements required to build your docs +python: + install: + - method: pip + path: . + extra_requirements: + - docs diff --git a/README.md b/README.md index cc4c94e1..e766747d 100644 --- a/README.md +++ b/README.md @@ -1,156 +1,93 @@ # Benchmarking atlas-level data integration in single-cell genomics -This repository contains the code for our benchmarking study for data integration tools. -In [our study](https://www.biorxiv.org/content/10.1101/2020.05.22.111161v1), we benchmark 16 -methods ([see here](##Tools)) with 4 combinations of preprocessing steps leading to 68 methods combinations on 85 -batches of gene expression and chromatin accessibility data. +This repository contains the code for the `scib` package used in our benchmarking study for data integration tools. +In [our study](https://doi.org/10.1038/s41592-021-01336-8), we benchmark 16 methods (see Tools) with 4 combinations of +preprocessing steps leading to 68 methods combinations on 85 batches of gene expression and chromatin accessibility data. -![Workflow](./figure.png) +![Workflow](https://raw.githubusercontent.com/theislab/scib/main/figure.png) ## Resources -+ On our [website](https://theislab.github.io/scib-reproducibility) we visualise the results of the study. - ++ The git repository of the [`scib` package](https://github.com/theislab/scib) and its [documentation](https://scib.readthedocs.io/). + The reusable pipeline we used in the study can be found in the separate [scib pipeline](https://github.com/theislab/scib-pipeline.git) repository. It is reproducible and automates the computation of preprocesssing combinations, integration methods and benchmarking metrics. - ++ On our [website](https://theislab.github.io/scib-reproducibility) we visualise the results of the study. + For reproducibility and visualisation we have a dedicated repository: [scib-reproducibility](https://github.com/theislab/scib-reproducibility). ### Please cite: -**Benchmarking atlas-level data integration in single-cell genomics.** -MD Luecken, M Büttner, K Chaichoompu, A Danese, M Interlandi, MF Mueller, DC Strobl, L Zappia, M Dugas, M Colomé-Tatché, -FJ Theis bioRxiv 2020.05.22.111161; doi: https://doi.org/10.1101/2020.05.22.111161_ - -## Package: `scib` - -We created the python package called `scib` that uses `scanpy` to streamline the integration of single-cell datasets and -evaluate the results. For evaluating the integration quality it provides a number of metrics. - -### Requirements - -+ Linux or UNIX system -+ Python >= 3.7 -+ 3.6 <= R <= 4.0 - -We recommend working with environments such as Conda or virtualenv, so that python and R dependencies are in one place. -Please also check out [scib pipeline](https://github.com/theislab/scib-pipeline.git) for ready-to-use environments. -Alternatively, manually install the package on your system using pip, described in the next section. - -### Installation - -The `scib` python package is in the folder scib. You can simply install it from the root of this repository using - -``` -pip install . -``` - -Alternatively, you can also install the package directly from GitHub via - -``` -pip install git+https://github.com/theislab/scib.git -``` - -Additionally, in order to run the R package `kBET`, you need to install it through R. - -```R -devtools::install_github('theislab/kBET') -``` - -> **Note:** By default dependencies for integration methods are not installed due to dependency clashes. -> In order to use integration methods, see the next section - -### Installing additional packages - -This package contains code for running integration methods as well as for evaluating their output. However, due to -dependency clashes, `scib` is only installed with the packages needed for the metrics. In order to use the integration -wrapper functions, we recommend to work with different environments for different methods, each with their own -installation of `scib`. You can install optional Python dependencies via pip as follows: - -``` -pip install .[bbknn] # using BBKNN -pip install .[scanorama] # using Scanorama -pip install .[bbknn,scanorama] # Multiple methods in one go -``` +Luecken, M.D., Büttner, M., Chaichoompu, K. et al. Benchmarking atlas-level data integration in single-cell genomics. +Nat Methods 19, 41–50 (2022). [https://doi.org/10.1038/s41592-021-01336-8](https://doi.org/10.1038/s41592-021-01336-8) -The `setup.cfg` for a full list of Python dependencies. For a comprehensive list of supported integration methods, -including R packages, check out the `Tools`. +## Package: scib -## Usage +We created the python package called `scib` that uses `scanpy` to streamline the integration of single-cell datasets +and evaluate the results. +The package contains several modules for preprocessing an ``anndata`` object, running integration methods and +evaluating the resulting using a number of metrics. +For preprocessing, ``scib.preprocessing`` (or ``scib.pp``) contains functions for normalising, scaling or batch-aware +selection of highly variable genes. +Functions for the integration methods are in ``scib.integration`` or for short ``scib.ig`` and metrics are under +``scib.metrics`` (or ``scib.me``). -The package contains several modules for the different steps of the integration and benchmarking pipeline. Functions for -the integration methods are in `scib.integration` or for short `scib.ig`. The methods can be called using +The `scib` python package is available on [PyPI](https://pypi.org/) and can be installed through -```py -scib.integration.(adata, batch=) ``` - -where `` is the name of the integration method and `` is the name of the batch column in `adata.obs`. -For example, in order to run Scanorama, on a dataset with batch key 'batch' call - -```py -scib.integration.scanorama(adata, batch='batch') +pip install scib ``` -> **Warning:** the following notation is deprecated. -> ``` -> scib.integration.run(adata, batch=) -> ``` -> Please use the snake case naming without the `run` prefix. -Some integration methods (`scgen`, `scanvi`) also use cell type labels as input. For these, you need to additionally provide -the corresponding label column. +Import `scib` in python: -```py -scgen(adata, batch=, cell_type=) -scanvi(adata, batch=, labels=) +```python +import scib ``` -`scib.preprocessing` (or `scib.pp`) contains functions for normalising, scaling or selecting highly variable genes per batch -The metrics are under `scib.metrics` (or `scib.me`). - ## Metrics -For a detailed description of the metrics implemented in this package, please see -the [manuscript](https://www.biorxiv.org/content/10.1101/2020.05.22.111161v2). - -### Batch removal metrics include: - -- Principal component regression `pcr_comparison()` -- Batch ASW `silhouette()` -- K-nearest neighbour batch effect `kBET()` -- Graph connectivity `graph_connectivity()` -- Graph iLISI `lisi_graph()` - -### Biological conservation metrics include: - -- Normalised mutual information `nmi()` -- Adjusted Rand Index `ari()` -- Cell type ASW `silhouette_batch()` -- Isolated label score F1 `isolated_labels()` -- Isolated label score ASW `isolated_labels()` -- Cell cycle conservation `cell_cycle()` -- Highly variable gene conservation `hvg_overlap()` -- Trajectory conservation `trajectory_conservation()` -- Graph cLISI `lisi_graph()` - -### Metrics Wrapper Functions -We provide wrapper functions to run multiple metrics in one function call. -The `scib.metrics.metrics()` function returns a `pandas.Dataframe` of all metrics specified as parameters. - -```py -scib.metrics.metrics(adata, adata_int, ari=True, nmi=True) -``` - -Furthermore, `scib.metrics.metrics()` is wrapped by convenience functions that only select certain metrics: - -+ `scib.me.metrics_fast()` only computes metrics that require little preprocessing -+ `scib.me.metrics_slim()` includes all functions of `scib.me.metrics_fast()` and adds clustering-based metrics -+ `scib.me.metrics_all()` includes all metrics - -## Tools +We implemented different metrics for evaluating batch correction and biological conservation in the `scib.metrics` +module. + + + + + + + + + + + + + + + + + +

Biological Conservation

Batch Correction

    +
  • Cell type ASW

  • +
  • Cell cycle conservation

  • +
  • Graph cLISI

  • +
  • Adjusted rand index (ARI) for cell label

  • +
  • Normalised mutual information (NMI) for cell label

  • +
  • Highly variable gene conservation

  • +
  • Isolated label ASW

  • +
  • Isolated label F1

  • +
  • Trajectory conservation

  • +
    +
  • Batch ASW

  • +
  • Principal component regression

  • +
  • Graph iLISI

  • +
  • Graph connectivity

  • +
  • kBET (K-nearest neighbour batch effect)

  • +
+ +For a detailed description of the metrics implemented in this package, please see our +[publication](https://doi.org/10.1038/s41592-021-01336-8) and the package [documentation](https://scib.readthedocs.io/). + +## Integration Tools Tools that are compared include: @@ -169,4 +106,4 @@ Tools that are compared include: - [scVI](https://github.com/YosefLab/scVI) 0.6.7 - [Seurat v3](https://github.com/satijalab/seurat) 3.2.0 CCA (default) and RPCA - [TrVae](https://github.com/theislab/trvae) 0.0.1 -- [TrVaep](https://github.com/theislab/trvaep) 0.1.0 +- [TrVaep](https://github.com/theislab/trvaep) 0.1.0 \ No newline at end of file diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 00000000..d0c3cbf1 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/source/.gitignore b/docs/source/.gitignore new file mode 100644 index 00000000..b8798c54 --- /dev/null +++ b/docs/source/.gitignore @@ -0,0 +1 @@ +api/ \ No newline at end of file diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100755 index 00000000..2ab4320a --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,64 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +import scib + +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- + +project = 'scib' +copyright = '2021, Malte D. Luecken, Maren Buettner, Daniel C. Strobl, Michaela F. Mueller' +author = 'Malte D. Luecken, Maren Buettner, Daniel C. Strobl, Michaela F. Mueller' +github_url = 'https://github.com/theislab/scib' + +# The full version, including alpha/beta/rc tags +release = scib.__version__ + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.duration', + 'sphinx.ext.doctest', + 'sphinx.ext.autodoc', + 'sphinx.ext.autosummary', + 'sphinx.ext.intersphinx', + 'sphinx_automodapi.automodapi', + 'sphinx_automodapi.smart_resolver', + 'myst_parser' +] +numpydoc_show_class_members = False + +# Add any paths that contain templates here, relative to this directory. +# templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +# html_static_path = ['_static'] diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100755 index 00000000..17f5dc96 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,27 @@ +.. scib documentation master file, created by + sphinx-quickstart on Wed Dec 1 14:50:06 2021. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Single-cell integration benchmark scib +====================================== + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + installation + scib_preprocessing + scib_integration + scib_metrics + +.. include:: ../../README.md + :parser: myst_parser.sphinx_ + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/source/installation.rst b/docs/source/installation.rst new file mode 100644 index 00000000..3fa4cffd --- /dev/null +++ b/docs/source/installation.rst @@ -0,0 +1,70 @@ +Installation +============ + +We recommend working with environments such as Conda or virtualenv, so that python and R dependencies are in one place. +Please also check out `scib-pipeline `_ for ready-to-use environments and +an end-to-end workflow. + +Requirements +------------ + ++ Linux or UNIX system ++ Python >= 3.7 ++ 3.6 <= R <= 4.0 + + +Installation with pip +--------------------- + +The ``scib`` python package is available on `PyPI `_ and can be installed through + +.. code-block:: python + + pip install scib + + +Alternatively, you can also install the package directly from GitHub directly via + +.. code-block:: python + + pip install git+https://github.com/theislab/scib.git + + +Additionally, in order to run the R package ``kBET``, you need to install it through R. + +.. code-block:: R + + devtools::install_github('theislab/kBET') + + +.. note:: + + By default dependencies for integration methods are not installed due to dependency clashes. + In order to use integration methods, see + + +Installing additional packages +------------------------------ + +This package contains code for running integration methods as well as for evaluating their output. However, due to +dependency clashes, ``scib`` is only installed with the packages needed for the metrics. In order to use the integration +wrapper functions, we recommend to work with different environments for different methods, each with their own +installation of ``scib``. You can install optional Python dependencies via pip as follows: + +.. code-block:: + + pip install scib[bbknn] # using BBKNN + pip install scib[scanorama] # using Scanorama + pip install scib[bbknn,scanorama] # Multiple methods in one go + +.. note:: + + Zsh often doesn't like square brackets. If you are a zsh user, use quotation marks around any statements containing + square brackets. For example: + + .. code-block:: python + + pip install 'scib[bbknn]' + + +The ``setup.cfg`` of the source code for a full list of Python dependencies. diff --git a/docs/source/scib_integration.rst b/docs/source/scib_integration.rst new file mode 100644 index 00000000..1631495e --- /dev/null +++ b/docs/source/scib_integration.rst @@ -0,0 +1,54 @@ +Integration +=========== + +Integration method functions require the preprocessed ``anndata`` object (here ``adata``) and the name of the batch column +in ``adata.obs`` (here ``'batch'``). +The methods can be called using the following, where ```` is the name of the integration method. + +.. code-block:: python + + scib.ig.(adata, batch='batch') + + +For example, in order to run Scanorama, on a dataset, call: + +.. code-block:: python + + scib.ig.scanorama(adata, batch='batch') + +.. warning:: + + The following notation is deprecated. + + .. code-block:: python + + scib.integration.run(adata, batch='batch') + + Please use the snake_case naming without the ``run`` prefix. + +Some integration methods (e.g. :func:`~scib.integration.scgen`, :func:`~scib.integration.scanvi`) also use cell type +labels as input. +For these, you need to additionally provide the corresponding label column of ``adata.obs`` (here ``cell_type``). + +.. code-block:: python + + scib.ig.scgen(adata, batch='batch', cell_type ='cell_type') + scib.ig.scanvi(adata, batch='batch', labels ='cell_type') + + +.. automodapi:: scib.integration + + :no-heading: + + :skip: runBBKNN + :skip: runCombat + :skip: runMNN + :skip: runDESC + :skip: runSaucie + :skip: runScanorama + :skip: runScanvi + :skip: runScGen + :skip: runScvi + :skip: runTrVae + :skip: runTrVaep + :skip: issparse diff --git a/docs/source/scib_metrics.rst b/docs/source/scib_metrics.rst new file mode 100644 index 00000000..8c4fbafd --- /dev/null +++ b/docs/source/scib_metrics.rst @@ -0,0 +1,90 @@ +Metrics +======= + +.. currentmodule:: scib.metrics + +This package contains all the metrics used for benchmarking scRNA-seq data integration performance. +The metrics can be classified into biological conservation and batch removal metrics. +For a detailed description of the metrics implemented in this package, please see our `publication`_. + +.. _publication: https://doi.org/10.1038/s41592-021-01336-8 + + +Biological Conservation Metrics +------------------------------- + +Biological conservation metrics quantify either the integrity of cluster-based metrics based on clustering results of +the integration output, or the difference in the feature spaces of integrated and unintegrated data. +Each metric is scaled to a value ranging from 0 to 1 by default, where larger scores represent better conservation of +the biological aspect that the metric addresses. + +.. autosummary:: + :toctree: api/ + + hvg_overlap + silhouette + isolated_labels + nmi + ari + cell_cycle + trajectory_conservation + clisi_graph + + +Batch Correction Metrics +------------------------ + +Batch correction metrics values are scaled by default between 0 and 1, in which larger scores represent better batch +removal. + +.. autosummary:: + :toctree: api/ + + graph_connectivity + silhouette_batch + pcr_comparison + kBET + ilisi_graph + + +Metrics Wrapper Functions +------------------------- + +For convenience, ``scib`` provides wrapper functions that, given integrated and unintegrated adata objects, apply +multiple metrics and return all the results in a ``pandas.Dataframe``. +The main function is :func:`~scib.metrics.metrics`, that provides all the parameters for the different metrics. + +.. code-block:: python + + scib.metrics.metrics(adata, adata_int, ari=True, nmi=True) + + +The remaining functions call the :func:`~scib.metrics.metrics` for + +Furthermore, :func:`~scib.metrics.metrics()` is wrapped by convenience functions with preconfigured subsets of metrics +based on expected computation time: + ++ :func:`~scib.metrics.metrics_fast()` only computes metrics that require little preprocessing ++ :func:`~scib.metrics.metrics_slim()` includes all functions of :func:`~scib.metrics.metrics_fast()` and adds clustering-based metrics ++ :func:`~scib.metrics.metrics_all()` includes all metrics + + +.. autosummary:: + :toctree: api/ + + metrics + metrics_fast + metrics_slim + metrics_all + + +.. raw:: html + +

Auxiliary Functions

+ +.. autosummary:: + :toctree: api/ + + lisi_graph + pcr + pc_regression diff --git a/docs/source/scib_preprocessing.rst b/docs/source/scib_preprocessing.rst new file mode 100644 index 00000000..b990c1f8 --- /dev/null +++ b/docs/source/scib_preprocessing.rst @@ -0,0 +1,22 @@ +Preprocessing +============= + +This module contains helper functions for preparing anndata objects for integration. +The most relevant preprocessing steps are: + ++ normalization ++ scaling, batch-aware ++ highly variable gene selection, batch-aware ++ cell cycle scoring + +.. automodapi:: scib.preprocessing + + :no-heading: + :no-main-docstr: + + :skip: plot_count_filter + :skip: plot_scatter + :skip: readConos + :skip: readSeurat + :skip: saveSeurat + diff --git a/scib/__init__.py b/scib/__init__.py index 17c4aa2d..83bac6b6 100644 --- a/scib/__init__.py +++ b/scib/__init__.py @@ -22,8 +22,9 @@ 'runSaucie': integration.saucie, 'runCombat': integration.combat, 'runDESC': integration.desc, - 'readSeurat': preprocessing.read_seurat, 'readConos': preprocessing.read_conos, + 'readSeurat': preprocessing.read_seurat, + 'saveSeurat': preprocessing.save_seurat, } for alias, func in alias_func_map.items(): diff --git a/scib/integration.py b/scib/integration.py index c2e95ae8..d287a169 100644 --- a/scib/integration.py +++ b/scib/integration.py @@ -1,8 +1,3 @@ -""" -This module provides a toolkit for running a large range of single cell data integration -methods as well as tools and metrics to benchmark them. -""" - import logging import os import tempfile @@ -21,6 +16,16 @@ def scanorama(adata, batch, hvg=None, **kwargs): + """ Scanorama wrapper function + + Based on `scanorama `_ version 1.7.0 + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the + corrected data + """ try: import scanorama except ModuleNotFoundError as e: @@ -39,6 +44,16 @@ def scanorama(adata, batch, hvg=None, **kwargs): def trvae(adata, batch, hvg=None): + """trVAE wrapper function + + Based on `trVAE `_ version 1.1.2 + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the + corrected data + """ try: import trvae except ModuleNotFoundError as e: @@ -85,6 +100,18 @@ def trvae(adata, batch, hvg=None): def trvaep(adata, batch, hvg=None): + """trVAE wrapper function (``pytorch`` implementatioon) + + Based on `trvaep`_ version 0.1.0 + + .. _trvaep: https://github.com/theislab/trvaep + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the + corrected data + """ try: import trvaep except ModuleNotFoundError as e: @@ -126,9 +153,17 @@ def trvaep(adata, batch, hvg=None): def scgen(adata, batch, cell_type, epochs=100, hvg=None, model_path=None, **kwargs): - """ - Parametrization taken from the tutorial notebook at: - https://nbviewer.jupyter.org/github/M0hammadL/scGen_notebooks/blob/master/notebooks/scgen_batch_removal.ipynb + """scGen wrapper function + + Based on `scgen`_ version 1.1.5 with parametrization taken from the tutorial `notebook`_. + + .. _scgen: https://github.com/theislab/scgen + .. _notebook: https://nbviewer.jupyter.org/github/M0hammadL/scGen_notebooks/blob/master/notebooks/scgen_batch_removal.ipynb + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix """ try: import scgen @@ -165,8 +200,19 @@ def scgen(adata, batch, cell_type, epochs=100, hvg=None, model_path=None, **kwar def scvi(adata, batch, hvg=None): - # Use non-normalized (count) data for scvi! - # Expects data only on HVGs + """scVI wrapper function + + Based on scVI version 0.6.7 (available through `conda `_) + + .. note:: + scVI expects only non-normalized (count) data on highly variable genes! + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the + corrected data + """ try: from scvi.dataset import AnnDatasetFromAnnData from scvi.inference import UnsupervisedTrainer @@ -226,7 +272,19 @@ def scvi(adata, batch, hvg=None): def scanvi(adata, batch, labels): - # Use non-normalized (count) data for scanvi! + """scANVI wrapper function + + Based on scVI version 0.6.7 (available through `conda `_) + + .. note:: + Use non-normalized (count) data for scANVI! + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param labels: label key in ``adata.obs`` + :return: ``anndata`` object containing the corrected feature matrix as well as an embedding representation of the + corrected data + """ try: from scvi.dataset import AnnDatasetFromAnnData from scvi.inference import SemiSupervisedTrainer, UnsupervisedTrainer @@ -314,6 +372,19 @@ def scanvi(adata, batch, labels): def mnn(adata, batch, hvg=None, **kwargs): + """MNN wrapper function (``mnnpy`` implementation) + + Based on `mnnpy package `_ version 0.1.9.5 + + .. note: + + ``mnnpy`` might break with newer versions of ``numpy`` and ``pandas`` + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :return: ``anndata`` object containing the corrected feature matrix + """ try: import mnnpy except ModuleNotFoundError as e: @@ -335,6 +406,16 @@ def mnn(adata, batch, hvg=None, **kwargs): def bbknn(adata, batch, hvg=None, **kwargs): + """BBKNN wrapper function + + Based on `bbknn package `_ version 1.3.9 + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :param hvg: list of highly variables to subset to. If ``None``, the full dataset will be used + :params \\**kwargs: additional parameters for BBKNN + :return: ``anndata`` object containing the corrected graph + """ try: import bbknn except ModuleNotFoundError as e: @@ -356,8 +437,14 @@ def bbknn(adata, batch, hvg=None, **kwargs): def saucie(adata, batch): - """ - parametrisation from https://github.com/KrishnaswamyLab/SAUCIE/blob/master/scripts/SAUCIE.py + """SAUCIE wrapper function + + Using SAUCIE `source code `_. + Parametrisation from https://github.com/KrishnaswamyLab/SAUCIE/blob/master/scripts/SAUCIE.py + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :return: ``anndata`` object containing the corrected embedding """ try: import SAUCIE @@ -383,16 +470,29 @@ def saucie(adata, batch): def combat(adata, batch): + """ComBat wrapper function (``scanpy`` implementation) + + Using scanpy implementation of `Combat `_ + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :return: ``anndata`` object containing the corrected feature matrix + """ adata_int = adata.copy() sc.pp.combat(adata_int, key=batch) return adata_int def desc(adata, batch, res=0.8, ncores=None, tmp_dir=None, use_gpu=False): - """ - Convenience function to run DESC. Parametrization was taken from: - https://github.com/eleozzr/desc/issues/28 - as suggested by the developer (rather than from the tutorial notebook). + """DESC wrapper function + + Based on `desc package `_ version 2.0.3. + Parametrization was taken from: https://github.com/eleozzr/desc/issues/28 as suggested by the developer (rather + than from the tutorial notebook). + + :param adata: preprocessed ``anndata`` object + :param batch: batch key in ``adata.obs`` + :return: ``anndata`` object containing the corrected embedding """ try: import desc diff --git a/scib/metrics/__init__.py b/scib/metrics/__init__.py index 355506b2..1d8ec799 100644 --- a/scib/metrics/__init__.py +++ b/scib/metrics/__init__.py @@ -1,16 +1,3 @@ -__all__ = [ - "ari", - "cell_cycle", - "graph_connectivity", - "highly_variable_genes", - "isolated_labels", - "kbet", - "lisi", - "metrics", - "nmi", - "pcr", - "silhouette", - "trajectory" -] - from .metrics import * +from .lisi import lisi_graph +from .pcr import pc_regression, pcr diff --git a/scib/metrics/ari.py b/scib/metrics/ari.py index ab517ed4..61fbd546 100644 --- a/scib/metrics/ari.py +++ b/scib/metrics/ari.py @@ -7,15 +7,16 @@ def ari(adata, group1, group2, implementation=None): - """ Adjusted Rand Index + """Adjusted Rand Index + The function is symmetric, so group1 and group2 can be switched - For single cell integration evaluation the scenario is: - predicted cluster assignments vs. ground-truth (e.g. cell type) assignments + For single cell integration evaluation the comparison is between predicted cluster + assignments and the ground truth (e.g. cell type) :param adata: anndata object :param group1: string of column in adata.obs containing labels :param group2: string of column in adata.obs containing labels - :params implementation: of set to 'sklearn', uses sklearns implementation, + :param implementation: if set to 'sklearn', uses sklearn's implementation, otherwise native implementation is taken """ diff --git a/scib/metrics/cell_cycle.py b/scib/metrics/cell_cycle.py index 63da42da..02cfb9e8 100644 --- a/scib/metrics/cell_cycle.py +++ b/scib/metrics/cell_cycle.py @@ -18,23 +18,26 @@ def cell_cycle( recompute_cc=True, precompute_pcr_key=None ): - """ - Cell cycle score based on principle component regression + """Cell cycle conservation score Compare the variance contribution of S-phase and G2/M-phase cell cycle scores before and after integration. Cell cycle scores are computed per batch on the unintegrated data set, - eliminating the batch effect confounded by the `batch_key` variable. + eliminating the batch effect confounded by the ``batch_key`` variable. + + .. math:: + + CC \\, conservation = 1 - \\frac { |Var_{after} - Var_{before}| } {Var_{before}} - This score can be calculated on full corrected feature spaces and latent embeddings as - variance contributions of a fixed score can be obtained via PC regression here. + Variance contribution is obtained through principal component regression using :func:`~scib.metrics.pc_regression`. + The score can be computed on full corrected feature spaces and latent embeddings. :param adata_pre: adata before integration :param adata_post: adata after integration :param embed: Name of embedding in adata_post.obsm. - If `embed=None`, use the full expression matrix (`adata.X`), otherwise use the - embedding provided in `adata_post.obsm[embed]` + If ``embed=None``, use the full expression matrix (``adata.X``), otherwise use the + embedding provided in ``adata_post.obsm[embed]`` :param agg_func: any function that takes a list of numbers and aggregates them into - a single value. If `agg_func=None`, all results will be returned + a single value. If ``agg_func=None``, all results will be returned :param organism: 'mouse' or 'human' for choosing cell cycle genes :param recompute_cc: If True, force recompute cell cycle score, otherwise use precomputed scores if available as 'S_score' and 'G2M_score' in adata.obs @@ -125,8 +128,8 @@ def get_pcr_before_after( :param adata_pre: adata before integration :param adata_post: adata after integration :param embed: Name of embedding in adata_post.obsm. - If `embed=None`, use the full expression matrix (`adata.X`), otherwise use the - embedding provided in `adata_post.obsm[embed]` + If ``embed=None``, use the full expression matrix (``adata.X``), otherwise use the + embedding provided in ``adata_post.obsm[embed]`` :param organism: 'mouse' or 'human' for choosing cell cycle genes :param recompute_cc: If True, force recompute cell cycle score, otherwise use precomputed scores if available as 'S_score' and 'G2M_score' in adata.obs diff --git a/scib/metrics/clustering.py b/scib/metrics/clustering.py index e736ce8b..c6acb3ed 100644 --- a/scib/metrics/clustering.py +++ b/scib/metrics/clustering.py @@ -6,27 +6,41 @@ from .nmi import nmi -def opt_louvain(adata, label_key, cluster_key, function=None, resolutions=None, - use_rep=None, - inplace=True, plot=False, force=True, verbose=True, **kwargs): - """ - params: - label_key: name of column in adata.obs containing biological labels to be - optimised against - cluster_key: name of column to be added to adata.obs during clustering. - Will be overwritten if exists and `force=True` - function: function that computes the cost to be optimised over. Must take as - arguments (adata, group1, group2, **kwargs) and returns a number for maximising - resolutions: list if resolutions to be optimised over. If `resolutions=None`, - default resolutions of 20 values ranging between 0.1 and 2 will be used - use_rep: key of embedding to use only if adata.uns['neighbors'] is not defined, - otherwise will be ignored - returns: - res_max: resolution of maximum score - score_max: maximum score - score_all: `pd.DataFrame` containing all scores at resolutions. Can be used to plot the score profile. - clustering: only if `inplace=False`, return cluster assignment as `pd.Series` - plot: if `plot=True` plot the score profile over resolution +def opt_louvain( + adata, + label_key, + cluster_key, + function=None, + resolutions=None, + use_rep=None, + inplace=True, + plot=False, + force=True, + verbose=True, + **kwargs +): + """Optimised Louvain clustering + + Louvain clustering with resolution optimised against a metric + + :param adata: anndata object + :param label_key: name of column in adata.obs containing biological labels to be + optimised against + :param cluster_key: name of column to be added to adata.obs during clustering. + Will be overwritten if exists and ``force=True`` + :param function: function that computes the cost to be optimised over. Must take as + arguments ``(adata, group1, group2, **kwargs)`` and returns a number for maximising + :param resolutions: list of resolutions to be optimised over. If ``resolutions=None``, + default resolutions of 20 values ranging between 0.1 and 2 will be used + :param use_rep: key of embedding to use only if ``adata.uns['neighbors']`` is not + defined, otherwise will be ignored + :returns: + Tuple of ``(res_max, score_max, score_all)`` or + ``(res_max, score_max, score_all, clustering)`` if ``inplace=False``. + ``res_max``: resolution of maximum score; + ``score_max``: maximum score; + ``score_all``: ``pd.DataFrame`` containing all scores at resolutions. Can be used to plot the score profile. + ``clustering``: only if ``inplace=False``, return cluster assignment as ``pd.Series`` """ if verbose: diff --git a/scib/metrics/graph_connectivity.py b/scib/metrics/graph_connectivity.py index eefea999..8aee881e 100644 --- a/scib/metrics/graph_connectivity.py +++ b/scib/metrics/graph_connectivity.py @@ -4,10 +4,17 @@ def graph_connectivity(adata, label_key): - """" - Quantify how connected the subgraph corresponding to each batch cluster is. - Calculate per label: #cells_in_largest_connected_component/#all_cells - Final score: Average over labels + """Graph Connectivity + + Quantify the connectivity of the subgraph per cell type label. + The final score is the average for all cell type labels :math:`C`, according to the equation: + + .. math:: + + GC = \\frac {1} {|C|} \\sum_{c \\in C} \\frac {|{LCC(subgraph_c)}|} {|c|} + + where :math:`|LCC(subgraph_c)|` stands for all cells in the largest connected component and :math:`|c|` stands for all cells of + cell type :math:`c`. :param adata: adata with computed neighborhood graph :param label_key: name in adata.obs containing the cell identity labels diff --git a/scib/metrics/highly_variable_genes.py b/scib/metrics/highly_variable_genes.py index a257928c..86668413 100644 --- a/scib/metrics/highly_variable_genes.py +++ b/scib/metrics/highly_variable_genes.py @@ -34,7 +34,8 @@ def precompute_hvg_batch(adata, batch, features, n_hvg=500, save_hvg=False): def hvg_overlap(adata_pre, adata_post, batch, n_hvg=500, verbose=False): - """ + """Highly variable gene overlap + Metric that computes the average percentage of overlapping highly variable genes per batch pre post integration. @@ -50,7 +51,7 @@ def hvg_overlap(adata_pre, adata_post, batch, n_hvg=500, verbose=False): adata_post_list = split_batches(adata_post, batch) overlap = [] - hvg_pre_list = precompute_hvg_batch(adata_pre, batch, hvg_post) + hvg_pre_list = precompute_hvg_batch(adata_pre, batch, hvg_post, n_hvg=n_hvg) for ad_post in adata_post_list: # range(len(adata_pre_list)): # remove genes unexpressed (otherwise hvg might break) diff --git a/scib/metrics/isolated_labels.py b/scib/metrics/isolated_labels.py index daf2fe15..db324e42 100644 --- a/scib/metrics/isolated_labels.py +++ b/scib/metrics/isolated_labels.py @@ -15,8 +15,10 @@ def isolated_labels( return_all=False, verbose=True ): - """ + """Isolated label score + Score how well labels of isolated labels are distiguished in the dataset by either + 1. clustering-based approach F1 score, or 2. average-width silhouette score (ASW) on isolated label vs all other labels diff --git a/scib/metrics/kbet.py b/scib/metrics/kbet.py index 7ef77413..d0570832 100644 --- a/scib/metrics/kbet.py +++ b/scib/metrics/kbet.py @@ -15,61 +15,6 @@ rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages -def kBET_single( - matrix, - batch, - k0=10, - knn=None, - verbose=False -): - """ - params: - matrix: expression matrix (at the moment: a PCA matrix, so do.pca is set to FALSE - batch: series or list of batch assignemnts - returns: - kBET observed rejection rate - """ - try: - ro.r("library(kBET)") - except Exception as ex: - RLibraryNotFound(ex) - - anndata2ri.activate() - - if verbose: - print("importing expression matrix") - ro.globalenv['data_mtrx'] = matrix - ro.globalenv['batch'] = batch - - if verbose: - print("kBET estimation") - - ro.globalenv['knn_graph'] = knn - ro.globalenv['k0'] = k0 - ro.r( - "batch.estimate <- kBET(" - " data_mtrx," - " batch," - " knn=knn_graph," - " k0=k0," - " plot=FALSE," - " do.pca=FALSE," - " heuristic=FALSE," - " adapt=FALSE," - f" verbose={str(verbose).upper()}" - ")" - ) - - try: - score = ro.r("batch.estimate$summary$kBET.observed")[0] - except rpy2.rinterface_lib.embedded.RRuntimeError: - score = np.nan - - anndata2ri.deactivate() - - return score - - def kBET( adata, batch_key, @@ -80,7 +25,11 @@ def kBET( return_df=False, verbose=False ): - """ + """kBET score + + Compute the average of k-nearest neighbour batch effect test (`kBET`_) score per label. + + .. _kBET: https://doi.org/10.1038/s41592-018-0254-1 :param adata: anndata object to compute kBET on :param batch_key: name of batch column in adata.obs @@ -88,9 +37,10 @@ def kBET( :param scaled: whether to scale between 0 and 1 with 0 meaning low batch mixing and 1 meaning optimal batch mixing if scaled=False, 0 means optimal batch mixing and 1 means low batch mixing - return: - kBET score (average of kBET per label) based on observed rejection rate - return_df=True: pd.DataFrame with kBET observed rejection rates per cluster for batch + :return: + kBET score (average of kBET per label) based on observed rejection rate. + If ``return_df=True``, also return a ``pd.DataFrame`` with kBET observed + rejection rate per cluster """ check_adata(adata) @@ -209,3 +159,60 @@ def kBET( final_score = np.nanmean(kBET_scores['kBET']) return 1 - final_score if scaled else final_score + + +def kBET_single( + matrix, + batch, + k0=10, + knn=None, + verbose=False +): + """Single kBET run + + Compute k-nearest neighbour batch effect test (kBET) score as described in + https://doi.org/10.1038/s41592-018-0254-1 + + :param matrix: expression matrix (at the moment: a PCA matrix, so ``do.pca`` is set to ``FALSE``) + :param batch: series or list of batch assignments + :returns: kBET observed rejection rate + """ + try: + ro.r("library(kBET)") + except Exception as ex: + RLibraryNotFound(ex) + + anndata2ri.activate() + + if verbose: + print("importing expression matrix") + ro.globalenv['data_mtrx'] = matrix + ro.globalenv['batch'] = batch + + if verbose: + print("kBET estimation") + + ro.globalenv['knn_graph'] = knn + ro.globalenv['k0'] = k0 + ro.r( + "batch.estimate <- kBET(" + " data_mtrx," + " batch," + " knn=knn_graph," + " k0=k0," + " plot=FALSE," + " do.pca=FALSE," + " heuristic=FALSE," + " adapt=FALSE," + f" verbose={str(verbose).upper()}" + ")" + ) + + try: + score = ro.r("batch.estimate$summary$kBET.observed")[0] + except rpy2.rinterface_lib.embedded.RRuntimeError: + score = np.nan + + anndata2ri.deactivate() + + return score diff --git a/scib/metrics/lisi.py b/scib/metrics/lisi.py index 66dffa4d..9866dd5d 100644 --- a/scib/metrics/lisi.py +++ b/scib/metrics/lisi.py @@ -13,6 +13,7 @@ import rpy2.robjects as ro import scanpy as sc import scipy.sparse +from deprecated import deprecated from scipy.io import mmwrite from ..exceptions import RLibraryNotFound @@ -21,153 +22,6 @@ rpy2.rinterface_lib.callbacks.logger.setLevel(logging.ERROR) # Ignore R warning messages -# Main LISI - -def lisi( - adata, - batch_key, - label_key, - k0=90, - type_=None, - scale=True, - verbose=False -): - """ - Compute lisi score (after integration) - params: - matrix: matrix from adata to calculate on - covariate_key: variable to compute iLISI on - cluster_key: variable to compute cLISI on - return: - pd.DataFrame with median cLISI and median iLISI scores (following the harmony paper) - """ - - check_adata(adata) - check_batch(batch_key, adata.obs) - check_batch(label_key, adata.obs) - - # if type_ != 'knn': - # if verbose: - # print("recompute kNN graph with {k0} nearest neighbors.") - # recompute neighbours - if (type_ == 'embed'): - adata_tmp = sc.pp.neighbors(adata, n_neighbors=k0, use_rep='X_emb', copy=True) - elif (type_ == 'full'): - if 'X_pca' not in adata.obsm.keys(): - sc.pp.pca(adata, svd_solver='arpack') - adata_tmp = sc.pp.neighbors(adata, n_neighbors=k0, copy=True) - else: - adata_tmp = adata.copy() - # if knn - do not compute a new neighbourhood graph (it exists already) - - # lisi_score = lisi_knn(adata=adata, batch_key=batch_key, label_key=label_key, verbose=verbose) - lisi_score = lisi_knn_py(adata=adata_tmp, batch_key=batch_key, label_key=label_key, verbose=verbose) - - # iLISI: nbatches good, 1 bad - ilisi_score = np.nanmedian(lisi_score[batch_key]) - # cLISI: 1 good, nbatches bad - clisi_score = np.nanmedian(lisi_score[label_key]) - - if scale: - # get number of batches - nbatches = len(np.unique(adata.obs[batch_key])) - ilisi_score, clisi_score = scale_lisi(ilisi_score, clisi_score, nbatches) - - return ilisi_score, clisi_score - - -def lisi_knn_py( - adata, - batch_key, - label_key, - perplexity=None, - verbose=False -): - """ - Compute LISI score on kNN graph provided in the adata object. By default, perplexity - is chosen as 1/3 * number of nearest neighbours in the knn-graph. - """ - - if 'neighbors' not in adata.uns: - raise AttributeError(f"key 'neighbors' not found. Please make sure that a " + - "kNN graph has been computed") - elif verbose: - print("using precomputed kNN graph") - - # get knn index matrix - if verbose: - print("Convert nearest neighbor matrix and distances for LISI.") - dist_mat = scipy.sparse.find(adata.obsp['distances']) - # get number of nearest neighbours parameter - if 'params' not in adata.uns['neighbors']: - # estimate the number of nearest neighbors as the median - # of the distance matrix - _, e = np.unique(dist_mat[0], return_counts=True) - n_nn = np.nanmedian(e) - n_nn = n_nn.astype('int') - else: - n_nn = adata.uns['neighbors']['params']['n_neighbors'] - 1 - # initialise index and fill it with NaN values - nn_index = np.empty(shape=(adata.obsp['distances'].shape[0], - n_nn)) - nn_index[:] = np.NaN - nn_dists = np.empty(shape=(adata.obsp['distances'].shape[0], - n_nn)) - nn_dists[:] = np.NaN - index_out = [] - for cell_id in np.arange(np.min(dist_mat[0]), np.max(dist_mat[0]) + 1): - get_idx = dist_mat[0] == cell_id - num_idx = get_idx.sum() - # in case that get_idx contains more than n_nn neighbours, cut away the outlying ones - fin_idx = np.min([num_idx, n_nn]) - nn_index[cell_id, :fin_idx] = dist_mat[1][get_idx][np.argsort(dist_mat[2][get_idx])][:fin_idx] - nn_dists[cell_id, :fin_idx] = np.sort(dist_mat[2][get_idx])[:fin_idx] - if num_idx < n_nn: - index_out.append(cell_id) - - out_cells = len(index_out) - - if out_cells > 0: - if verbose: - print(f"{out_cells} had less than {n_nn} neighbors.") - - if perplexity is None: - # use LISI default - perplexity = np.floor(nn_index.shape[1] / 3) - - # run LISI in python - if verbose: - print("importing knn-graph") - - batch = adata.obs[batch_key].cat.codes.values - n_batches = len(np.unique(adata.obs[batch_key])) - label = adata.obs[label_key].cat.codes.values - n_labels = len(np.unique(adata.obs[label_key])) - - if verbose: - print("LISI score estimation") - - simpson_estimate_batch = compute_simpson_index(D=nn_dists, - knn_idx=nn_index, - batch_labels=batch, - n_batches=n_batches, - perplexity=perplexity, - ) - simpson_estimate_label = compute_simpson_index(D=nn_dists, - knn_idx=nn_index, - batch_labels=label, - n_batches=n_labels, - perplexity=perplexity - ) - simpson_est_batch = 1 / simpson_estimate_batch - simpson_est_label = 1 / simpson_estimate_label - # extract results - d = {batch_key: simpson_est_batch, label_key: simpson_est_label} - lisi_estimate = pd.DataFrame(data=d, index=np.arange(0, len(simpson_est_label))) - - return lisi_estimate - - # Graph LISI (analoguous to lisi function) def lisi_graph( adata, @@ -175,15 +29,18 @@ def lisi_graph( label_key, **kwargs ): - """ - Compute cLISI and iLISI scores on precomputed kNN graph + """cLISI and iLISI scores + + This is a reimplementation of the LISI (Local Inverse Simpson’s Index) metrics + https://doi.org/10.1038/s41592-019-0619-0 + + see :func:`~scib.metrics.clisi_graph` and :func:`~scib.metrics.ilisi_graph` :param adata: adata object to calculate on - :param batch_key: batch column name in adata.obs - :param label_key: label column name in adata.obs - :param **kwargs: arguments to be passed to iLISI and cLISI functions - :return: - Median cLISI and iLISI scores + :param batch_key: batch column name in ``adata.obs`` + :param label_key: label column name in ``adata.obs`` + :params \\**kwargs: arguments to be passed to :func:`~scib.metrics.clisi_graph` and :func:`~scib.metrics.ilisi_graph` + :return: Overall cLISI and iLISI scores """ ilisi = ilisi_graph(adata, batch_key=batch_key, **kwargs) clisi = clisi_graph(adata, batch_key=batch_key, label_key=label_key, **kwargs) @@ -201,11 +58,16 @@ def ilisi_graph( nodes=None, verbose=False ): - """ - Compute iLISI score adapted from Harmony paper (Korsunsky et al, Nat Meth, 2019) + """Integration LISI (iLISI) score + + Local Inverse Simpson’s Index metrics adapted from https://doi.org/10.1038/s41592-019-0619-0 to run on all full + feature, embedding and kNN integration outputs via shortest path-based distance computation on single-cell kNN + graphs. + By default, this function returns a value scaled between 0 and 1 instead of the original LISI range of 0 to the + number of batches. :param adata: adata object to calculate on - :param batch_key: batch column name in adata.obs + :param batch_key: batch column name in ``adata.obs`` :param k0: number of nearest neighbors to compute lisi score Please note that the initial neighborhood size that is used to compute shortest paths is 15. @@ -213,11 +75,11 @@ def ilisi_graph( :param subsample: Percentage of observations (integer between 0 and 100) to which lisi scoring should be subsampled :param scale: scale output values between 0 and 1 (True/False) - :param multiprocessing: parallel computation of LISI scores, if None, no parallisation + :param multiprocessing: parallel computation of LISI scores, if None, no parallelisation via multiprocessing is performed :param nodes: number of nodes (i.e. CPUs to use for multiprocessing); ignored, if multiprocessing is set to None - :return: Median of iLISI score + :return: Median of iLISI scores per batch labels """ check_adata(adata) @@ -247,7 +109,7 @@ def ilisi_graph( def clisi_graph( adata, - batch_key, + batch_key, # TODO: remove label_key, k0=90, type_=None, @@ -257,12 +119,17 @@ def clisi_graph( nodes=None, verbose=False ): - """ - Compute cLISI score adapted from Harmony paper (Korsunsky et al, Nat Meth, 2019) + """Cell-type LISI (cLISI) score + + Local Inverse Simpson’s Index metrics adapted from https://doi.org/10.1038/s41592-019-0619-0 to run on all full + feature, embedding and kNN integration outputs via shortest path-based distance computation on single-cell kNN + graphs. + By default, this function returns a value scaled between 0 and 1 instead of the original LISI range of 0 to the + number of labels. - :params adata: adata object to calculate on - :param batch_key: batch column name in adata.obs - :param label_key: label column name in adata.obs + :param adata: adata object to calculate on + :param batch_key: batch column name in ``adata.obs`` + :param label_key: label column name in ``adata.obs`` :param k0: number of nearest neighbors to compute lisi score Please note that the initial neighborhood size that is used to compute shortest paths is 15. @@ -270,15 +137,15 @@ def clisi_graph( :param subsample: Percentage of observations (integer between 0 and 100) to which lisi scoring should be subsampled :param scale: scale output values between 0 and 1 (True/False) - :param multiprocessing: parallel computation of LISI scores, if None, no parallisation + :param multiprocessing: parallel computation of LISI scores, if None, no parallelisation via multiprocessing is performed :param nodes: number of nodes (i.e. CPUs to use for multiprocessing); ignored, if multiprocessing is set to None - :return: Median of cLISI score + :return: Median of cLISI scores per cell type labels """ check_adata(adata) - check_batch(batch_key, adata.obs) + check_batch(batch_key, adata.obs) # TODO: remove check_batch(label_key, adata.obs) adata_tmp = recompute_knn(adata, type_) @@ -305,8 +172,7 @@ def clisi_graph( def recompute_knn(adata, type_): - """ - Recompute neighbours + """Recompute neighbours """ if type_ == 'embed': return sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_emb', copy=True) @@ -470,16 +336,15 @@ def compute_simpson_index( tol=1e-5 ): """ - Simpson index of batch labels subsetted for each group. - params: - D: distance matrix n_cells x n_nearest_neighbors - knn_idx: index of n_nearest_neighbors of each cell - batch_labels: a vector of length n_cells with batch info - n_batches: number of unique batch labels - perplexity: effective neighborhood size - tol: a tolerance for testing effective neighborhood size - returns: - simpson: the simpson index for the neighborhood of each cell + Simpson index of batch labels subset by group. + + :param D: distance matrix ``n_cells x n_nearest_neighbors`` + :param knn_idx: index of ``n_nearest_neighbors`` of each cell + :param batch_labels: a vector of length n_cells with batch info + :param n_batches: number of unique batch labels + :param perplexity: effective neighborhood size + :param tol: a tolerance for testing effective neighborhood size + :returns: the simpson index for the neighborhood of each cell """ n = D.shape[0] P = np.zeros(D.shape[1]) @@ -527,7 +392,7 @@ def compute_simpson_index( # convertToOneHot omits all nan entries. # Therefore, we run into errors in np.matmul. if len(batch) == len(P): - B = convertToOneHot(batch, n_batches) + B = convert_to_one_hot(batch, n_batches) sumP = np.matmul(P, B) # sum P per batch simpson[i] = np.dot(sumP, sumP) # sum squares else: # assign worst possible score @@ -546,17 +411,16 @@ def compute_simpson_index_graph( tol=1e-5 ): """ - Simpson index of batch labels subsetted for each group. - params: - input_path: file_path to pre-computed index and distance files - batch_labels: a vector of length n_cells with batch info - n_batches: number of unique batch labels - n_neighbors: number of nearest neighbors - perplexity: effective neighborhood size - chunk_no: for parallelisation, chunk id to evaluate - tol: a tolerance for testing effective neighborhood size - returns: - simpson: the simpson index for the neighborhood of each cell + Simpson index of batch labels subset by group. + + :param input_path: file_path to pre-computed index and distance files + :param batch_labels: a vector of length n_cells with batch info + :param n_batches: number of unique batch labels + :param n_neighbors: number of nearest neighbors + :param perplexity: effective neighborhood size + :param chunk_no: for parallelization, chunk id to evaluate + :param tol: a tolerance for testing effective neighborhood size + :returns: the simpson index for the neighborhood of each cell """ # initialize @@ -635,7 +499,7 @@ def compute_simpson_index_graph( continue # then compute Simpson's Index batch = batch_labels[knn_idx] - B = convertToOneHot(batch, n_batches) + B = convert_to_one_hot(batch, n_batches) sumP = np.matmul(P, B) # sum P per batch simpson[i[0]] = np.dot(sumP, sumP) # sum squares @@ -657,17 +521,21 @@ def Hbeta(D_row, beta): return H, P -def convertToOneHot(vector, num_classes=None): +def convert_to_one_hot(vector, num_classes=None): """ - Converts an input 1-D vector of integers into an output - 2-D array of one-hot vectors, where an i'th input value - of j will set a '1' in the i'th row, j'th column of the + Converts an input 1-D vector of integers into an output 2-D array of one-hot vectors, + where an i'th input value of j will set a '1' in the i'th row, j'th column of the output array. Example: + + .. code-block:: python + v = np.array((1, 0, 4)) one_hot_v = convertToOneHot(v) - print one_hot_v + print(one_hot_v) + + .. code-block:: [[0 1 0 0 0] [1 0 0 0 0] @@ -688,16 +556,166 @@ def convertToOneHot(vector, num_classes=None): return result.astype(int) -# DEPRECATED -# This code scales clisi incorrectly! +# Deprecated functions + +@deprecated +def lisi( + adata, + batch_key, + label_key, + k0=90, + type_=None, + scale=True, + verbose=False +): + """Compute iLISI and cLISI scores + + This is a reimplementation of the LISI (Local Inverse Simpson’s Index) metrics + https://doi.org/10.1038/s41592-019-0619-0 + + :param matrix: matrix from adata to calculate on + :param covariate_key: variable to compute iLISI on + :param cluster_key: variable to compute cLISI on + :return: Tuple of median iLISI and median cLISI scores + """ + + check_adata(adata) + check_batch(batch_key, adata.obs) + check_batch(label_key, adata.obs) + + # if type_ != 'knn': + # if verbose: + # print("recompute kNN graph with {k0} nearest neighbors.") + # recompute neighbours + if (type_ == 'embed'): + adata_tmp = sc.pp.neighbors(adata, n_neighbors=k0, use_rep='X_emb', copy=True) + elif (type_ == 'full'): + if 'X_pca' not in adata.obsm.keys(): + sc.pp.pca(adata, svd_solver='arpack') + adata_tmp = sc.pp.neighbors(adata, n_neighbors=k0, copy=True) + else: + adata_tmp = adata.copy() + # if knn - do not compute a new neighbourhood graph (it exists already) + + # lisi_score = lisi_knn(adata=adata, batch_key=batch_key, label_key=label_key, verbose=verbose) + lisi_score = lisi_knn_py(adata=adata_tmp, batch_key=batch_key, label_key=label_key, verbose=verbose) + + # iLISI: nbatches good, 1 bad + ilisi_score = np.nanmedian(lisi_score[batch_key]) + # cLISI: 1 good, nbatches bad + clisi_score = np.nanmedian(lisi_score[label_key]) + + if scale: + # get number of batches + nbatches = len(np.unique(adata.obs[batch_key])) + ilisi_score, clisi_score = scale_lisi(ilisi_score, clisi_score, nbatches) + + return ilisi_score, clisi_score + + +@deprecated +def lisi_knn_py( + adata, + batch_key, + label_key, + perplexity=None, + verbose=False +): + """ + Compute LISI score on kNN graph provided in the adata object. By default, perplexity + is chosen as 1/3 * number of nearest neighbours in the knn-graph. + """ + + if 'neighbors' not in adata.uns: + raise AttributeError(f"key 'neighbors' not found. Please make sure that a " + + "kNN graph has been computed") + elif verbose: + print("using precomputed kNN graph") + + # get knn index matrix + if verbose: + print("Convert nearest neighbor matrix and distances for LISI.") + dist_mat = scipy.sparse.find(adata.obsp['distances']) + # get number of nearest neighbours parameter + if 'params' not in adata.uns['neighbors']: + # estimate the number of nearest neighbors as the median + # of the distance matrix + _, e = np.unique(dist_mat[0], return_counts=True) + n_nn = np.nanmedian(e) + n_nn = n_nn.astype('int') + else: + n_nn = adata.uns['neighbors']['params']['n_neighbors'] - 1 + # initialise index and fill it with NaN values + nn_index = np.empty(shape=(adata.obsp['distances'].shape[0], + n_nn)) + nn_index[:] = np.NaN + nn_dists = np.empty(shape=(adata.obsp['distances'].shape[0], + n_nn)) + nn_dists[:] = np.NaN + index_out = [] + for cell_id in np.arange(np.min(dist_mat[0]), np.max(dist_mat[0]) + 1): + get_idx = dist_mat[0] == cell_id + num_idx = get_idx.sum() + # in case that get_idx contains more than n_nn neighbours, cut away the outlying ones + fin_idx = np.min([num_idx, n_nn]) + nn_index[cell_id, :fin_idx] = dist_mat[1][get_idx][np.argsort(dist_mat[2][get_idx])][:fin_idx] + nn_dists[cell_id, :fin_idx] = np.sort(dist_mat[2][get_idx])[:fin_idx] + if num_idx < n_nn: + index_out.append(cell_id) + + out_cells = len(index_out) + + if out_cells > 0: + if verbose: + print(f"{out_cells} had less than {n_nn} neighbors.") + + if perplexity is None: + # use LISI default + perplexity = np.floor(nn_index.shape[1] / 3) + + # run LISI in python + if verbose: + print("importing knn-graph") + + batch = adata.obs[batch_key].cat.codes.values + n_batches = len(np.unique(adata.obs[batch_key])) + label = adata.obs[label_key].cat.codes.values + n_labels = len(np.unique(adata.obs[label_key])) + + if verbose: + print("LISI score estimation") + + simpson_estimate_batch = compute_simpson_index(D=nn_dists, + knn_idx=nn_index, + batch_labels=batch, + n_batches=n_batches, + perplexity=perplexity, + ) + simpson_estimate_label = compute_simpson_index(D=nn_dists, + knn_idx=nn_index, + batch_labels=label, + n_batches=n_labels, + perplexity=perplexity + ) + simpson_est_batch = 1 / simpson_estimate_batch + simpson_est_label = 1 / simpson_estimate_label + # extract results + d = {batch_key: simpson_est_batch, label_key: simpson_est_label} + lisi_estimate = pd.DataFrame(data=d, index=np.arange(0, len(simpson_est_label))) + + return lisi_estimate + + +@deprecated def scale_lisi(ilisi_score, clisi_score, nbatches): # scale iLISI score to 0 bad 1 good ilisi_score = (ilisi_score - 1) / (nbatches - 1) # scale clisi score to 0 bad 1 good - clisi_score = (nbatches - clisi_score) / (nbatches - 1) # Scaled incorrectly by n_batches + clisi_score = (nbatches - clisi_score) / (nbatches - 1) # Scaled incorrectly by n_batches return ilisi_score, clisi_score +@deprecated def lisi_knn( adata, batch_key, @@ -706,7 +724,6 @@ def lisi_knn( verbose=False ): """ - Deprecated Compute LISI score on kNN graph provided in the adata object. By default, perplexity is chosen as 1/3 * number of nearest neighbours in the knn-graph. """ @@ -798,6 +815,7 @@ def lisi_knn( return lisi_estimate +@deprecated def lisi_matrix( adata, batch_key, @@ -806,7 +824,6 @@ def lisi_matrix( verbose=False ): """ - Deprecated Computes the LISI scores for a given data matrix in adata.X. The scoring function of the LISI R package is called with default parameters. This function takes a data matrix and recomputes nearest neighbours. diff --git a/scib/metrics/metrics.py b/scib/metrics/metrics.py index df2c133a..6594b1c0 100755 --- a/scib/metrics/metrics.py +++ b/scib/metrics/metrics.py @@ -1,5 +1,6 @@ import numpy as np import pandas as pd +from deprecated import deprecated from ..utils import check_adata, check_batch from .ari import ari @@ -23,18 +24,29 @@ def metrics_fast( label_key, **kwargs ): - """ - Only fast metrics: + """Only metrics with minimal preprocessing and runtime + + + :Biological conservation: + + HVG overlap :func:`~scib.metrics.hvg_overlap` + + Cell type ASW :func:`~scib.metrics.silhouette` + + Isolated label ASW :func:`~scib.metrics.isolated_labels` - Biological conservation - HVG overlap - Cell type ASW - Isolated label ASW + :Batch correction: + + Graph connectivity :func:`~scib.metrics.graph_connectivity` + + Batch ASW :func:`~scib.metrics.silhouette_batch` + + Principal component regression :func:`~scib.metrics.pcr_comparison` - Batch conservation - Graph connectivity - Batch ASW - PC regression + :param adata: unintegrated, preprocessed anndata object + :param adata_int: integrated anndata object + :param batch_key: name of batch column in adata.obs and adata_int.obs + :param label_key: name of biological label (cell type) column in adata.obs and adata_int.obs + :param kwargs: + Parameters to pass on to :func:`~scib.metrics.metrics` function: + + + ``embed`` + + ``si_metric`` + + ``n_isolated`` """ return metrics( adata, @@ -57,22 +69,38 @@ def metrics_slim( label_key, **kwargs ): - """ - All metrics apart from kBET and LISI scores: - - Biological conservation - HVG overlap - Cell type ASW - Isolated label ASW - Isolated label F1 - NMI cluster/label - ARI cluster/label - Cell cycle conservation - - Batch conservation - Graph connectivity - Batch ASW - PC regression + """All metrics apart from kBET and LISI scores + + :Biological conservation: + + HVG overlap :func:`~scib.metrics.hvg_overlap` + + Cell type ASW :func:`~scib.metrics.silhouette` + + Isolated label ASW :func:`~scib.metrics.isolated_labels` + + Isolated label F1 :func:`~scib.metrics.isolated_labels` + + NMI cluster/label :func:`~scib.metrics.nmi` + + ARI cluster/label :func:`~scib.metrics.ari` + + Cell cycle conservation :func:`~scib.metrics.cell_cycle` + + Trajectory conservation :func:`~scib.metrics.trajectory_conservation` + + :Batch correction: + + Graph connectivity :func:`~scib.metrics.graph_connectivity` + + Batch ASW :func:`~scib.metrics.silhouette_batch` + + Principal component regression :func:`~scib.metrics.pcr_comparison` + + :param adata: unintegrated, preprocessed anndata object + :param adata_int: integrated anndata object + :param batch_key: name of batch column in adata.obs and adata_int.obs + :param label_key: name of biological label (cell type) column in adata.obs and adata_int.obs + :param kwargs: + Parameters to pass on to :func:`~scib.metrics.metrics` function: + + + ``embed`` + + ``cluster_key`` + + ``cluster_nmi`` + + ``nmi_method`` + + ``nmi_dir`` + + ``si_metric`` + + ``organism`` + + ``n_isolated`` """ return metrics( adata, @@ -100,25 +128,43 @@ def metrics_all( label_key, **kwargs ): - """ - All metrics - - Biological conservation - HVG overlap - Cell type ASW - Isolated label ASW - Isolated label F1 - NMI cluster/label - ARI cluster/label - Cell cycle conservation - cLISI - - Batch conservation - Graph connectivity - Batch ASW - PC regression - kBET - iLISI + """All metrics + + :Biological conservation: + + HVG overlap :func:`~scib.metrics.hvg_overlap` + + Cell type ASW :func:`~scib.metrics.silhouette` + + Isolated label ASW :func:`~scib.metrics.isolated_labels` + + Isolated label F1 :func:`~scib.metrics.isolated_labels` + + NMI cluster/label :func:`~scib.metrics.nmi` + + ARI cluster/label :func:`~scib.metrics.ari` + + Cell cycle conservation :func:`~scib.metrics.cell_cycle` + + cLISI (cell type Local Inverse Simpson's Index) :func:`~scib.metrics.clisi_graph` + + Trajectory conservation :func:`~scib.metrics.trajectory_conservation` + + :Batch correction: + + Graph connectivity :func:`~scib.metrics.graph_connectivity` + + Batch ASW :func:`~scib.metrics.silhouette_batch` + + Principal component regression :func:`~scib.metrics.pcr_comparison` + + kBET (k-nearest neighbour batch effect test) :func:`~scib.metrics.kBET` + + iLISI (integration Local Inverse Simpson's Index) :func:`~scib.metrics.ilisi_graph` + + :param adata: unintegrated, preprocessed anndata object + :param adata_int: integrated anndata object + :param batch_key: name of batch column in adata.obs and adata_int.obs + :param label_key: name of biological label (cell type) column in adata.obs and adata_int.obs + :param kwargs: + Parameters to pass on to :func:`~scib.metrics.metrics` function: + + + ``embed`` + + ``cluster_key`` + + ``cluster_nmi`` + + ``nmi_method`` + + ``nmi_dir`` + + ``si_metric`` + + ``organism`` + + ``n_isolated`` + + ``subsample`` + + ``type_`` """ return metrics( adata, @@ -147,7 +193,7 @@ def metrics( adata_int, batch_key, label_key, - hvg_score_=False, + embed='X_pca', cluster_key='cluster', cluster_nmi=None, ari_=False, @@ -155,28 +201,100 @@ def metrics( nmi_method='arithmetic', nmi_dir=None, silhouette_=False, - embed='X_pca', si_metric='euclidean', pcr_=False, cell_cycle_=False, organism='mouse', + hvg_score_=False, isolated_labels_=False, # backwards compatibility isolated_labels_f1_=False, isolated_labels_asw_=False, n_isolated=None, graph_conn_=False, + trajectory_=False, kBET_=False, - subsample=0.5, lisi_graph_=False, ilisi_=False, clisi_=False, - trajectory_=False, + subsample=0.5, type_=None, verbose=False, ): - """ - Master metrics function: Wrapper for all metrics used in the study - Compute of all metrics given unintegrate and integrated anndata object + """Master metrics function + + Wrapper for all metrics used in the study. + Compute of all metrics given unintegrated and integrated anndata object + + :param adata: + unintegrated, preprocessed anndata object + :param adata_int: + integrated anndata object + :param batch_key: + name of batch column in adata.obs and adata_int.obs + :param label_key: + name of biological label (cell type) column in adata.obs and adata_int.obs + :param embed: + embedding representation of adata_int + + Used for: + + + silhouette scores (label ASW, batch ASW), + + PC regression, + + cell cycle conservation, + + isolated label scores, and + + kBET + :param cluster_key: + name of column to store cluster assignments. Will be overwritten if it exists + :param cluster_nmi: + Where to save cluster resolutions and NMI for optimal clustering + If None, these results will not be saved + :param `ari_`: + whether to compute ARI using :func:`~scib.metrics.ari` + :param `nmi_`: + whether to compute NMI using :func:`~scib.metrics.nmi` + :param nmi_method: + which implementation of NMI to use + :param nmi_dir: + directory of NMI code for some implementations of NMI + :param `silhouette_`: + whether to compute the average silhouette width scores for labels and batch + using :func:`~scib.metrics.silhouette` and :func:`~scib.metrics.silhouette_batch` + :param si_metric: + which distance metric to use for silhouette scores + :param `pcr_`: + whether to compute principal component regression using :func:`~scib.metrics.pc_comparison` + :param `cell_cycle_`: + whether to compute cell cycle score conservation using :func:`~scib.metrics.cell_cycle` + :param organism: + organism of the datasets, used for computing cell cycle scores on gene names + :param `hvg_score_`: + whether to compute highly variable gene conservation using :func:`~scib.metrics.hvg_overlap` + :param `isolated_labels_`: + whether to compute both isolated label scores using :func:`~scib.metrics.isolated_labels` + :param `isolated_labels_f1_`: + whether to compute isolated label score based on F1 score of clusters vs labels using + :func:`~scib.metrics.isolated_labels` + :param `isolated_labels_asw_`: + whether to compute isolated label score based on ASW (average silhouette width) using + :func:`~scib.metrics.isolated_labels` + :param `n_isolated`: + maximum number of batches per label for label to be considered as isolated + :param `graph_conn_`: + whether to compute graph connectivity score using :func:`~scib.metrics.graph_connectivity` + :param `trajectory_`: + whether to compute trajectory score using :func:`~scib.metrics.trajectory_conservation` + :param `kBET_`: + whether to compute kBET score using :func:`~scib.metrics.kBET` + :param `lisi_graph_`: + whether to compute both cLISI and iLISI using :func:`~scib.metrics.lisi_graph` + :param `clisi_`: + whether to compute cLISI using :func:`~scib.metrics.clisi_graph` + :param `ilisi_`: + whether to compute iLISI using :func:`~scib.metrics.ilisi_graph` + :param subsample: + subsample fraction for LISI scores + :param `type_`: + one of 'full', 'embed' or 'knn' (used for kBET and LISI scores) """ check_adata(adata) @@ -204,8 +322,6 @@ def metrics( nmi_all.to_csv(cluster_nmi, header=False) print(f'saved clustering NMI values to {cluster_nmi}') - results = {} - if nmi_: print('NMI...') nmi_score = nmi( @@ -396,16 +512,12 @@ def metrics( return pd.DataFrame.from_dict(results, orient='index') -# Deprecated - +@deprecated def measureTM(*args, **kwargs): """ - Deprecated - params: - *args: function to be tested for time and memory - **kwargs: list of function paramters - returns: - tuple : (memory (MB), time (s), list of *args function outputs) + :param *args: function to be tested for time and memory + :param **kwargs: list of function parameters + :returns: (memory (MB), time (s), list of *args function outputs) """ import cProfile from pstats import Stats diff --git a/scib/metrics/nmi.py b/scib/metrics/nmi.py index 6c4f7d66..d1c6e448 100644 --- a/scib/metrics/nmi.py +++ b/scib/metrics/nmi.py @@ -7,23 +7,23 @@ def nmi(adata, group1, group2, method="arithmetic", nmi_dir=None): - """ + """Normalized mutual information + Wrapper for normalized mutual information NMI between two different cluster assignments :param adata: Anndata object - :param group1: column name of `adata.obs` - :param group2: column name of `adata.obs` - :param method: NMI implementation - 'max': scikit method with `average_method='max'` - 'min': scikit method with `average_method='min'` - 'geometric': scikit method with `average_method='geometric'` - 'arithmetic': scikit method with `average_method='arithmetic'` - 'Lancichinetti': implementation by A. Lancichinetti 2009 et al. https://sites.google.com/site/andrealancichinetti/mutual + :param group1: column name of ``adata.obs`` + :param group2: column name of ``adata.obs`` + :param method: NMI implementation. + 'max': scikit method with ``average_method='max'``; + 'min': scikit method with ``average_method='min'``; + 'geometric': scikit method with ``average_method='geometric'``; + 'arithmetic': scikit method with ``average_method='arithmetic'``; + 'Lancichinetti': implementation by A. Lancichinetti 2009 et al. https://sites.google.com/site/andrealancichinetti/mutual; 'ONMI': implementation by Aaron F. McDaid et al. https://github.com/aaronmcdaid/Overlapping-NMI - :param nmi_dir: directory of compiled C code if 'Lancichinetti' or 'ONMI' are specified as `method`. + :param nmi_dir: directory of compiled C code if 'Lancichinetti' or 'ONMI' are specified as ``method``. These packages need to be compiled as specified in the corresponding READMEs. - :return: - Normalized mutual information NMI value + :return: Normalized mutual information NMI value """ check_adata(adata) @@ -55,8 +55,10 @@ def onmi(group1, group2, nmi_dir=None, verbose=True): """ Based on implementation https://github.com/aaronmcdaid/Overlapping-NMI publication: Aaron F. McDaid, Derek Greene, Neil Hurley 2011 - params: - nmi_dir: directory of compiled C code + + :param group1: list or series of cell assignments + :param group2: list or series of cell assignments + :param nmi_dir: directory of compiled C code """ if nmi_dir is None: @@ -122,8 +124,9 @@ def write_tmp_labels(group_assignments, to_int=False, delim='\n'): """ write the values of a specific obs column into a temporary file in text format needed for external C NMI implementations (onmi and nmi_Lanc functions), because they require files as input - params: - to_int: rename the unique column entries by integers in range(1,len(group_assignments)+1) + + :param group_assignments: list or series of cell assignments + :param to_int: rename the unique column entries by integers in range(1,len(group_assignments)+1) """ import tempfile diff --git a/scib/metrics/pcr.py b/scib/metrics/pcr.py index cb45657d..da7a0db3 100644 --- a/scib/metrics/pcr.py +++ b/scib/metrics/pcr.py @@ -16,11 +16,11 @@ def pcr_comparison( scale=True, verbose=False ): - """ - Compare the explained variance before and after integration + """Principal component regression score + Compare the explained variance before and after integration using :func:`~scib.metrics.pc_regression`. Return either the difference of variance contribution before and after integration - or a score between 0 and 1 (`scaled=True`) with 0 if the variance contribution hasn't + or a score between 0 and 1 (``scaled=True``) with 0 if the variance contribution hasn't changed. The larger the score, the more different the variance contributions are before and after integration. @@ -28,13 +28,13 @@ def pcr_comparison( :param adata_post: anndata object after integration :param covariate: Key for adata.obs column to regress against :param embed: Embedding to use for principal components. - If None, use the full expression matrix (`adata.X`), otherwise use the embedding - provided in `adata_post.obsm[embed]`. + If None, use the full expression matrix (``adata.X``), otherwise use the embedding + provided in ``adata_post.obsm[embed]``. :param n_comps: Number of principal components to compute :param scale: If True, scale score between 0 and 1 (default) :param verbose: :return: - Difference of R2Var value of PCR (scaled between 0 and 1 by default) + Difference of variance contribution of PCR (scaled between 0 and 1 by default) """ if embed == 'X_pca': @@ -78,22 +78,22 @@ def pcr( recompute_pca=True, verbose=False ): - """ - Principal component regression for anndata object + """Principal component regression for anndata object + + Wraps :func:`~scib.metrics.pc_regression` while checking whether to: - Checks whether to - + compute PCA on embedding or expression data (set `embed` to name of embedding matrix e.g. `embed='X_emb'`) + + compute PCA on embedding or expression data (set ``embed`` to name of embedding matrix e.g. ``embed='X_emb'``) + use existing PCA (only if PCA entry exists) + recompute PCA on expression matrix (default) :param adata: Anndata object :param covariate: Key for adata.obs column to regress against :param embed: Embedding to use for principal components. - If None, use the full expression matrix (`adata.X`), otherwise use the embedding - provided in `adata_post.obsm[embed]`. + If None, use the full expression matrix (``adata.X``), otherwise use the embedding + provided in ``adata_post.obsm[embed]``. :param n_comps: Number of PCs, if PCA is recomputed :return: - R2Var of regression + Variance contribution of regression """ check_adata(adata) @@ -130,18 +130,30 @@ def pc_regression( svd_solver='arpack', verbose=False ): - """ - :params data: Expression or PC matrix. Assumed to be PC, if pca_sd is given. + """Principal component regression + + Compute the overall variance contribution given a covariate according to the following formula: + + .. math:: + + Var(C|B) = \\sum^G_{i=1} Var(C|PC_i) \cdot R^2(PC_i|B) + + for :math:`G` principal components (:math:`PC_i`), where :math:`Var(C|PC_i)` is the variance of the data matrix + :math:`C` explained by the i-th principal component, and :math:`R^2(PC_i|B)` is the :math:`R^2` of the i-th + principal component regressed against a covariate :math:`B`. + + + :param data: Expression or PC matrix. Assumed to be PC, if pca_sd is given. :param covariate: series or list of batch assignments :param n_comps: number of PCA components for computing PCA, only when pca_sd is not given. If no pca_sd is not defined and n_comps=None, compute PCA and don't reduce data - :param pca_var: Iterable of variances for `n_comps` components. - If `pca_sd` is not `None`, it is assumed that the matrix contains PC, - otherwise PCA is computed on `data`. + :param pca_var: Iterable of variances for ``n_comps`` components. + If ``pca_sd`` is not ``None``, it is assumed that the matrix contains PC, + otherwise PCA is computed on ``data``. :param svd_solver: :param verbose: :return: - R2Var of regression + Variance contribution of regression """ if isinstance(data, (np.ndarray, sparse.csr_matrix, sparse.csc_matrix)): diff --git a/scib/metrics/silhouette.py b/scib/metrics/silhouette.py index d61623e2..37a313c3 100644 --- a/scib/metrics/silhouette.py +++ b/scib/metrics/silhouette.py @@ -9,12 +9,15 @@ def silhouette( metric='euclidean', scale=True ): - """ + """Average silhouette width (ASW) + Wrapper for sklearn silhouette function values range from [-1, 1] with - 1 being an ideal fit - 0 indicating overlapping clusters and - -1 indicating misclassified cells - By default, the score is scaled between 0 and 1. This is controlled `scale=True` + + * 1 indicates distinct, compact clusters + * 0 indicates overlapping clusters + * -1 indicates core-periphery (non-cluster) structure + + By default, the score is scaled between 0 and 1 (``scale=True``). :param group_key: key in adata.obs of cell labels :param embed: embedding key in adata.obsm, default: 'X_pca' @@ -43,21 +46,45 @@ def silhouette_batch( scale=True, verbose=True ): - """ - Absolute silhouette score of batch labels subsetted for each group. + """Batch ASW + + Modified average silhouette width (ASW) of batch + + This metric measures the silhouette of a given batch. + It assumes that a silhouette width close to 0 represents perfect overlap of the batches, thus the absolute value of + the silhouette width is used to measure how well batches are mixed. + For all cells :math:`i` of a cell type :math:`C_j`, the batch ASW of that cell type is: + + .. math:: + + batch \\, ASW_j = \\frac{1}{|C_j|} \\sum_{i \\in C_j} |silhouette(i)| + + The final score is the average of the absolute silhouette widths computed per cell type :math:`M`. + + .. math:: + + batch \\, ASW = \\frac{1}{|M|} \\sum_{i \\in M} batch \\, ASW_j + + For a scaled metric (which is the default), the absolute ASW per group is subtracted from 1 before averaging, so that + 0 indicates suboptimal label representation and 1 indicates optimal label representation. + + .. math:: + + batch \\, ASW_j = \\frac{1}{|C_j|} \\sum_{i \\in C_j} 1 - |silhouette(i)| + - :param batch_key: batches to be compared against - :param group_key: group labels to be subsetted by e.g. cell type + :param batch_key: batch labels to be compared against + :param group_key: group labels to be subset by e.g. cell type :param embed: name of column in adata.obsm :param metric: see sklearn silhouette score :param scale: if True, scale between 0 and 1 :param return_all: if True, return all silhouette scores and label means default False: return average width silhouette (ASW) - :param verbose: + :param verbose: print silhouette score per group :return: - average width silhouette ASW - mean silhouette per group in pd.DataFrame - Absolute silhouette scores per group label + Batch ASW (always) + Mean silhouette per group in pd.DataFrame (additionally, if return_all=True) + Absolute silhouette scores per group label (additionally, if return_all=True) """ if embed not in adata.obsm.keys(): print(adata.obsm.keys()) @@ -97,7 +124,7 @@ def silhouette_batch( asw = sil_means['silhouette_score'].mean() if verbose: - print(f'mean silhouette per cell: {sil_means}') + print(f'mean silhouette per group: {sil_means}') if return_all: return asw, sil_means, sil_all diff --git a/scib/metrics/trajectory.py b/scib/metrics/trajectory.py index ea0a531c..02ca92f9 100644 --- a/scib/metrics/trajectory.py +++ b/scib/metrics/trajectory.py @@ -7,76 +7,29 @@ from .utils import RootCellError -def get_root( +def trajectory_conservation( adata_pre, adata_post, - ct_key, + label_key, pseudotime_key="dpt_pseudotime", - dpt_dim=3 + batch_key=None ): - """ - Determine root cell for integrated adata based on unintegrated adata - - :param adata_pre: unintegrated adata - :param adata_post: integrated adata - :param label_key: column in `adata_pre.obs` of the groups used to precompute the trajectory - :param pseudotime_key: column in `adata_pre.obs` in which the pseudotime is saved in. - Column can contain empty entries, the dataset will be subset to the cells with scores. - :param dpt_dim: number of diffmap dimensions used to determine root - """ - n_components, adata_post.obs['neighborhood'] = connected_components( - csgraph=adata_post.obsp['connectivities'], - directed=False, - return_labels=True - ) - - start_clust = adata_pre.obs.groupby([ct_key]).mean()[pseudotime_key].idxmin() - min_dpt = adata_pre.obs[adata_pre.obs[ct_key] == start_clust].index - which_max_neigh = adata_post.obs['neighborhood'] == adata_post.obs['neighborhood'].value_counts().idxmax() - min_dpt = [value for value in min_dpt if value in adata_post.obs[which_max_neigh].index] - - adata_post_ti = adata_post[which_max_neigh] - - min_dpt = [adata_post_ti.obs_names.get_loc(i) for i in min_dpt] - - # compute Diffmap for adata_post - sc.tl.diffmap(adata_post_ti) - - # determine most extreme cell in adata_post Diffmap - min_dpt_cell = np.zeros(len(min_dpt)) - for dim in np.arange(dpt_dim): - - diffmap_mean = adata_post_ti.obsm["X_diffmap"][:, dim].mean() - diffmap_min_dpt = adata_post_ti.obsm["X_diffmap"][min_dpt, dim] + """Trajectory conservation score - # count opt cell - if len(diffmap_min_dpt) == 0: - raise RootCellError('No root cell in largest component') - - # choose optimum function - if len(diffmap_min_dpt) > 0 and diffmap_min_dpt.mean() < diffmap_mean: - opt = np.argmin - else: - opt = np.argmax + Trajectory conservation is measured by spearman’s rank correlation coefficient :math:`s`, between the pseudotime + values before and after integration. + The final score was scaled to a value between 0 and 1 using the equation - min_dpt_cell[opt(diffmap_min_dpt)] += 1 + .. math:: - # root cell is cell with max vote - return min_dpt[np.argmax(min_dpt_cell)], adata_post_ti + trajectory \\, conservation = \\frac {s + 1} {2} + This function Expects pseudotime values to be precomputed. -def trajectory_conservation( - adata_pre, - adata_post, - label_key, - pseudotime_key="dpt_pseudotime", - batch_key=None -): - """ :param adata_pre: unintegrated adata :param adata_post: integrated adata - :param label_key: column in `adata_pre.obs` of the groups used to precompute the trajectory - :param pseudotime_key: column in `adata_pre.obs` in which the pseudotime is saved in. + :param label_key: column in ``adata_pre.obs`` of the groups used to precompute the trajectory + :param pseudotime_key: column in ``adata_pre.obs`` in which the pseudotime is saved in. Column can contain empty entries, the dataset will be subset to the cells with scores. :param batch_key: set to batch key if if you want to compute the trajectory metric by batch """ @@ -122,3 +75,60 @@ def trajectory_conservation( corr[i] = pseudotime_before.corr(pseudotime_after, 'spearman') return (corr.mean() + 1) / 2 # scaled + + +def get_root( + adata_pre, + adata_post, + ct_key, + pseudotime_key="dpt_pseudotime", + dpt_dim=3 +): + """Determine root cell for integrated adata based on unintegrated adata + + :param adata_pre: unintegrated adata + :param adata_post: integrated adata + :param label_key: column in ``adata_pre.obs`` of the groups used to precompute the trajectory + :param pseudotime_key: column in ``adata_pre.obs`` in which the pseudotime is saved in. + Column can contain empty entries, the dataset will be subset to the cells with scores. + :param dpt_dim: number of diffmap dimensions used to determine root + """ + n_components, adata_post.obs['neighborhood'] = connected_components( + csgraph=adata_post.obsp['connectivities'], + directed=False, + return_labels=True + ) + + start_clust = adata_pre.obs.groupby([ct_key]).mean()[pseudotime_key].idxmin() + min_dpt = adata_pre.obs[adata_pre.obs[ct_key] == start_clust].index + which_max_neigh = adata_post.obs['neighborhood'] == adata_post.obs['neighborhood'].value_counts().idxmax() + min_dpt = [value for value in min_dpt if value in adata_post.obs[which_max_neigh].index] + + adata_post_ti = adata_post[which_max_neigh] + + min_dpt = [adata_post_ti.obs_names.get_loc(i) for i in min_dpt] + + # compute Diffmap for adata_post + sc.tl.diffmap(adata_post_ti) + + # determine most extreme cell in adata_post Diffmap + min_dpt_cell = np.zeros(len(min_dpt)) + for dim in np.arange(dpt_dim): + + diffmap_mean = adata_post_ti.obsm["X_diffmap"][:, dim].mean() + diffmap_min_dpt = adata_post_ti.obsm["X_diffmap"][min_dpt, dim] + + # count opt cell + if len(diffmap_min_dpt) == 0: + raise RootCellError('No root cell in largest component') + + # choose optimum function + if len(diffmap_min_dpt) > 0 and diffmap_min_dpt.mean() < diffmap_mean: + opt = np.argmin + else: + opt = np.argmax + + min_dpt_cell[opt(diffmap_min_dpt)] += 1 + + # root cell is cell with max vote + return min_dpt[np.argmax(min_dpt_cell)], adata_post_ti diff --git a/scib/preprocessing.py b/scib/preprocessing.py index 1ed1d7ab..ea68a6d3 100644 --- a/scib/preprocessing.py +++ b/scib/preprocessing.py @@ -1,3 +1,5 @@ +# TODO: move util functions e.g. reader functions elsewhere + import logging import tempfile @@ -21,6 +23,19 @@ def summarize_counts(adata, count_matrix=None, mt_gene_regex='^MT-'): + """Summarise counts of the given count matrix + + This function is useful for quality control. + Aggregates counts per cell and per gene as well as mitochondrial fraction. + + :param count_matrix: count matrix, by default uses ``adata.X`` + :param mt_gene_regex: regex string for identifying mitochondrial genes + :return: Include the following keys in ``adata.obs`` + 'n_counts': number of counts per cell (count depth) + 'log_counts': ``np.log`` of counts per cell + 'n_genes': number of counts per gene + 'mito_frac': percent of mitochondrial gene counts as an indicator of cell viability + """ utils.check_adata(adata) if count_matrix is None: @@ -38,7 +53,7 @@ def summarize_counts(adata, count_matrix=None, mt_gene_regex='^MT-'): if sparse.issparse(adata.X): mt_sum = mt_sum.A1 total_sum = total_sum.A1 - adata.obs['percent_mito'] = mt_sum / total_sum + adata.obs['mito_frac'] = mt_sum / total_sum # mt_gene_mask = [gene.startswith('mt-') for gene in adata.var_names] # mt_count = count_matrix[:, mt_gene_mask].sum(1) @@ -48,21 +63,50 @@ def summarize_counts(adata, count_matrix=None, mt_gene_regex='^MT-'): ### Quality Control -def plot_qc(adata, color=None, bins=60, legend_loc='right margin', histogram=True, - gene_threshold=(0, np.inf), - gene_filter_threshold=(0, np.inf), - count_threshold=(0, np.inf), - count_filter_threshold=(0, np.inf)): +def plot_qc( + adata, + color=None, + bins=60, + legend_loc='right margin', + histogram=True, + count_threshold=(0, np.inf), + count_filter_threshold=(0, np.inf), + gene_threshold=(0, np.inf), + gene_filter_threshold=(0, np.inf) +): + """Create QC Plots + + Create scatter plot for count depth vs. gene count, and histograms for count depth and gene count. + Filtering thresholds are included in all plots for making QC decisions. + + :param adata: ``anndata`` object containing summarised count keys 'n_counts' and 'n_genes' in ``adata.obs`` + :param color: column in ``adata.obs`` for the scatter plot + :param bins: number of bins for the histogram + :param legend_loc: location of legend of scatterplot + :param histogram: whether to include histograms + :param count_threshold: tuple of lower and upper count depth thresholds for zooming into histogram plots. + By default, this is unbounded and set to the values of ``count_filter_threshold``. + :param count_filter_threshold: tuple of lower and upper count depth thresholds visualised as cutoffs. + By default, this is unbounded, so no zoomed in histograms are plotted. + :param gene_threshold: tuple of lower and upper gene count thresholds for zooming into histogram plots. + By default, this is unbounded, so no zoomed in histograms are plotted. + :param gene_filter_threshold: tuple of lower and upper gene count thresholds visualised as cutoffs. + By default, this is unbounded and set to the values of ``gene_filter_threshold``. + """ if count_filter_threshold == (0, np.inf): count_filter_threshold = count_threshold if gene_filter_threshold == (0, np.inf): gene_filter_threshold = gene_threshold # 2D scatter plot - plot_scatter(adata, color=color, title=color, - gene_threshold=gene_filter_threshold[0], - count_threshold=count_filter_threshold[0], - legend_loc=legend_loc) + plot_scatter( + adata, + color=color, + title=color, + gene_threshold=gene_filter_threshold[0], + count_threshold=count_filter_threshold[0], + legend_loc=legend_loc + ) if not histogram: return @@ -70,25 +114,41 @@ def plot_qc(adata, color=None, bins=60, legend_loc='right margin', histogram=Tru if count_filter_threshold != (0, np.inf): print(f"Counts Threshold: {count_filter_threshold}") # count filtering - plot_count_filter(adata, obs_col='n_counts', bins=bins, - lower=count_threshold[0], - filter_lower=count_filter_threshold[0], - upper=count_threshold[1], - filter_upper=count_filter_threshold[1]) + plot_count_filter( + adata, + obs_col='n_counts', + bins=bins, + lower=count_threshold[0], + filter_lower=count_filter_threshold[0], + upper=count_threshold[1], + filter_upper=count_filter_threshold[1] + ) if gene_filter_threshold != (0, np.inf): print(f"Gene Threshold: {gene_filter_threshold}") # gene filtering - plot_count_filter(adata, obs_col='n_genes', bins=bins, - lower=gene_threshold[0], - filter_lower=gene_filter_threshold[0], - upper=gene_threshold[1], - filter_upper=gene_filter_threshold[1]) - - -def plot_scatter(adata, count_threshold=0, gene_threshold=0, - color=None, title='', lab_size=15, tick_size=11, legend_loc='right margin', - palette=None): + plot_count_filter( + adata, + obs_col='n_genes', + bins=bins, + lower=gene_threshold[0], + filter_lower=gene_filter_threshold[0], + upper=gene_threshold[1], + filter_upper=gene_filter_threshold[1] + ) + + +def plot_scatter( + adata, + count_threshold=0, + gene_threshold=0, + color=None, + title='Scatter plot of count depth vs gene counts', + lab_size=15, + tick_size=11, + legend_loc='right margin', + palette=None +): utils.check_adata(adata) if color: utils.check_batch(color, adata.obs) @@ -148,13 +208,20 @@ def plot_count_filter(adata, obs_col='n_counts', bins=60, lower=0, upper=np.inf, ### Normalisation -def normalize( - adata, - min_mean=0.1, - log=True, - precluster=True, - sparsify=True -): +def normalize(adata, sparsify=True, precluster=True, min_mean=0.1, log=True): + """Normalise counts using the ``scran`` normalisation method + + Using `computeSumFactors`_ function from `scran`_ package. + + .. _scran: https://rdrr.io/bioc/scran/ + .. _computeSumFactors: https://rdrr.io/bioc/scran/man/computeSumFactors.html + + :param adata: ``anndata`` object + :param sparsify: whether to convert the count matrix into a sparse matrix + :param precluster: whether to perform preliminary clustering for differentiated normalisation + :param min_mean: parameter of ``scran``'s ``computeSumFactors`` function + :param log: whether to performing log1p-transformation after normalisation + """ utils.check_adata(adata) # Check for 0 count cells @@ -205,9 +272,15 @@ def normalize( sc.tl.louvain(adata_pp, key_added='groups', resolution=0.5) ro.globalenv['input_groups'] = adata_pp.obs['groups'] - size_factors = ro.r('sizeFactors(computeSumFactors(SingleCellExperiment(' - 'list(counts=data_mat)), clusters = input_groups,' - f' min.mean = {min_mean}))') + size_factors = ro.r( + 'sizeFactors(' + ' computeSumFactors(' + ' SingleCellExperiment(list(counts=data_mat)),' + ' clusters = input_groups,' + f' min.mean = {min_mean}' + ' )' + ')' + ) del adata_pp @@ -241,8 +314,12 @@ def normalize( def scale_batch(adata, batch): - """ - Function to scale the gene expression values of each batch separately. + """Batch-aware scaling of count matrix + + Scaling counts to a mean of 0 and standard deviation of 1 using ``scanpy.pp.scale`` for each batch separately. + + :param adata: ``anndata`` object with normalised and log-transformed counts + :param batch: ``adata.obs`` column """ utils.check_adata(adata) @@ -275,18 +352,27 @@ def scale_batch(adata, batch): return adata_scaled -def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins=20, adataOut=False, n_stop=8000, - min_genes=500, step_size=1000): - ### Feature Selection - """ - params: - adata: - batch: adata.obs column - target_genes: maximum number of genes (intersection reduces the number of genes) - min_genes: minimum number of intersection HVGs targeted - step_size: step size to increase HVG selection per dataset - return: - list of highly variable genes less or equal to `target_genes` +def hvg_intersect( + adata, + batch, + target_genes=2000, + flavor='cell_ranger', + n_bins=20, + adataOut=False, + n_stop=8000, + min_genes=500, + step_size=1000 +): + """Highly variable gene selection + + Legacy approach to HVG selection only using HVG intersections between all batches + + :param adata: ``anndata`` object with preprocessed counts + :param batch: ``adata.obs`` column + :param target_genes: maximum number of genes (intersection reduces the number of genes) + :param min_genes: minimum number of intersection HVGs targeted + :param step_size: step size to increase HVG selection per dataset + :return: list of maximal ``target_genes`` number of highly variable genes """ utils.check_adata(adata) @@ -301,7 +387,15 @@ def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins= for i in split: sc.pp.filter_genes(i, min_cells=1) # remove genes unexpressed (otherwise hvg might break) - hvg_res.append(sc.pp.highly_variable_genes(i, flavor='cell_ranger', n_top_genes=n_hvg, inplace=False)) + hvg_res.append( + sc.pp.highly_variable_genes( + i, + flavor=flavor, + n_top_genes=n_hvg, + n_bins=n_bins, + inplace=False + ) + ) while not enough: genes = [] @@ -334,13 +428,20 @@ def hvg_intersect(adata, batch, target_genes=2000, flavor='cell_ranger', n_bins= def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_bins=20, adataOut=False): - """ + """Batch-aware highly variable gene selection - Method to select HVGs based on mean dispersions of genes that are highly + Method to select HVGs based on mean dispersions of genes that are highly variable genes in all batches. Using a the top target_genes per batch by average normalize dispersion. If target genes still hasn't been reached, then HVGs in all but one batches are used to fill up. This is continued until HVGs in a single batch are considered. + + :param adata: ``anndata`` object + :param batch: ``adata.obs`` column + :param target_genes: maximum number of genes (intersection reduces the number of genes) + :param flavor: parameter for ``scanpy.pp.highly_variable_genes`` + :param n_bins: parameter for ``scanpy.pp.highly_variable_genes`` + :param adataOut: whether to return an ``anndata`` object or a list of highly variable genes """ utils.check_adata(adata) @@ -352,14 +453,17 @@ def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_ n_batches = len(adata_hvg.obs[batch_key].cat.categories) # Calculate double target genes per dataset - sc.pp.highly_variable_genes(adata_hvg, - flavor=flavor, - n_top_genes=target_genes, - n_bins=n_bins, - batch_key=batch_key) - - nbatch1_dispersions = adata_hvg.var['dispersions_norm'][adata_hvg.var.highly_variable_nbatches > - len(adata_hvg.obs[batch_key].cat.categories) - 1] + sc.pp.highly_variable_genes( + adata_hvg, + flavor=flavor, + n_top_genes=target_genes, + n_bins=n_bins, + batch_key=batch_key + ) + + nbatch1_dispersions = adata_hvg.var['dispersions_norm'][ + adata_hvg.var.highly_variable_nbatches > len(adata_hvg.obs[batch_key].cat.categories) - 1 + ] nbatch1_dispersions.sort_values(ascending=False, inplace=True) @@ -399,17 +503,40 @@ def hvg_batch(adata, batch_key=None, target_genes=2000, flavor='cell_ranger', n_ ### Feature Reduction -def reduce_data(adata, batch_key=None, subset=False, - filter=True, flavor='cell_ranger', n_top_genes=2000, n_bins=20, - pca=True, pca_comps=50, overwrite_hvg=True, - neighbors=True, use_rep='X_pca', - umap=True): - """ - overwrite_hvg: - if True, ignores any pre-existing 'highly_variable' column in adata.var +def reduce_data( + adata, + batch_key=None, + flavor='cell_ranger', + n_top_genes=2000, + n_bins=20, + pca=True, + pca_comps=50, + overwrite_hvg=True, + neighbors=True, + use_rep='X_pca', + umap=True +): + """Apply feature selection and dimensionality reduction steps. + + Wrapper function of feature selection (highly variable genes), PCA, neighbours computation and dimensionality + reduction. + Highly variable gene selection is batch-aware, when a batch key is given. + + + :param adata: ``anndata`` object with normalised and log-transformed data in ``adata.X`` + :param batch_key: column in ``adata.obs`` containing batch assignment + :param flavor: parameter for ``scanpy.pp.highly_variable_genes`` + :param n_top_genes: parameter for ``scanpy.pp.highly_variable_genes`` + :param n_bins: parameter for ``scanpy.pp.highly_variable_genes`` + :param pca: whether to compute PCA + :param pca_comps: number of principal components + :param overwrite_hvg: if True, ignores any pre-existing 'highly_variable' column in adata.var and recomputes it if `n_top_genes` is specified else calls PCA on full features. if False, skips HVG computation even if `n_top_genes` is specified and uses pre-existing HVG column for PCA + :param neighbors: whether to compute neighbours graph + :param use_rep: embedding to use for neighbourhood graph + :param umap: whether to compute UMAP representation """ utils.check_adata(adata) @@ -431,10 +558,12 @@ def reduce_data(adata, batch_key=None, subset=False, else: print(f"Calculating {n_top_genes} HVGs for reduce_data.") - sc.pp.highly_variable_genes(adata, - n_top_genes=n_top_genes, - n_bins=n_bins, - flavor=flavor) + sc.pp.highly_variable_genes( + adata, + n_top_genes=n_top_genes, + n_bins=n_bins, + flavor=flavor + ) n_hvg = np.sum(adata.var["highly_variable"]) print(f'Computed {n_hvg} highly variable genes') @@ -459,12 +588,21 @@ def reduce_data(adata, batch_key=None, subset=False, ### Cell Cycle def score_cell_cycle(adata, organism='mouse'): - """ + """Score cell cycle score given an organism + + Wrapper function for `scanpy.tl.score_genes_cell_cycle`_ + + .. _scanpy.tl.score_genes_cell_cycle: https://scanpy.readthedocs.io/en/stable/generated/scanpy.tl.score_genes_cell_cycle.html + Tirosh et al. cell cycle marker genes downloaded from https://raw.githubusercontent.com/theislab/scanpy_usage/master/180209_cell_cycle/data/regev_lab_cell_cycle_genes.txt - return: (s_genes, g2m_genes) - s_genes: S-phase genes - g2m_genes: G2- and M-phase genes + + For human, mouse genes are capitalised and used directly. This is under the assumption that cell cycle genes are + well conserved across species. + + :param adata: anndata object containing + :param organism: organism of gene names to match cell cycle genes + :return: tuple of ``(s_genes, g2m_genes)`` of S-phase genes and G2- and M-phase genes scores """ import pathlib root = pathlib.Path(__file__).parent @@ -487,7 +625,19 @@ def score_cell_cycle(adata, organism='mouse'): sc.tl.score_genes_cell_cycle(adata, s_genes, g2m_genes) -def saveSeurat(adata, path, batch, hvgs=None): +def save_seurat(adata, path, batch, hvgs=None): + """Save an ``anndata`` object to file as a Seurat object + + Convert ``anndata`` object to Seurat object through ``rpy2`` and save to file as "RDS" file + + This function only accounts for batch assignments and highly variable genes. All other keys are not transferred to + the Seurat object. + + :param adata: anndata object to be saved + :param path: file path where the object should be saved + :param batch: key in ``adata.obs`` that holds batch assigments + :param hvgs: list of highly variable genes + """ import re try: @@ -526,6 +676,12 @@ def saveSeurat(adata, path, batch, hvgs=None): def read_seurat(path): + """Read ``Seurat`` object from file and convert to ``anndata`` object + + Using ``rpy2`` for reading an RDS object and converting it into an ``anndata`` object. + + :param path: file path to saved file + """ try: ro.r('library(Seurat)') @@ -551,10 +707,14 @@ def read_seurat(path): def read_conos(inPath, dir_path=None): - from os import mkdir, path - from shutil import rmtree - from time import time + """Read ``conos`` object + Using ``rpy2`` for reading an RDS object and converting it into an ``anndata`` object. + + :param inPath: + :param dir_path: + """ + from shutil import rmtree import pandas as pd from scipy.io import mmread diff --git a/setup.cfg b/setup.cfg index 6c5224d3..cac75d71 100644 --- a/setup.cfg +++ b/setup.cfg @@ -62,6 +62,7 @@ install_requires = pydot python-igraph llvmlite + deprecated zip_safe = False [options.package_data] @@ -72,6 +73,7 @@ scib = [options.extras_require] test = pytest; pytest-runner; pytest-icdiff dev = build; twine; isort; bump2version +docs = sphinx; sphinx_rtd_theme; myst_parser; sphinx-automodapi bbknn = bbknn ==1.3.9 scanorama = scanorama ==1.7.0 mnn = mnnpy ==0.1.9.5 diff --git a/tests/metrics/test_pcr_metrics.py b/tests/metrics/test_pcr_metrics.py index 13918032..85a5cea0 100644 --- a/tests/metrics/test_pcr_metrics.py +++ b/tests/metrics/test_pcr_metrics.py @@ -2,7 +2,7 @@ def test_pc_regression(adata): - scib.me.pcr.pc_regression(adata.X, adata.obs["batch"]) + scib.me.pc_regression(adata.X, adata.obs["batch"]) def test_pcr_batch(adata):