Skip to content

Commit

Permalink
Merge pull request #474 from ivirshup/downsample_experiment
Browse files Browse the repository at this point in the history
Downsample total counts
  • Loading branch information
falexwolf committed Mar 21, 2019
2 parents 4a244c7 + 34096ce commit a9b0175
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 35 deletions.
105 changes: 75 additions & 30 deletions scanpy/preprocessing/_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .. import settings as sett
from .. import logging as logg
from ..utils import sanitize_anndata
from ..utils import sanitize_anndata, deprecated_arg_names
from ._distributed import materialize_as_ndarray
from ._utils import _get_mean_var

Expand Down Expand Up @@ -910,25 +910,37 @@ def subsample(data, fraction=None, n_obs=None, random_state=0, copy=False):
return X[obs_indices], obs_indices


def downsample_counts(adata, target_counts=20000, random_state=0,
replace=True, copy=False):
"""Downsample counts so that each cell has no more than `target_counts`.
@deprecated_arg_names({"target_counts": "counts_per_cell"})
def downsample_counts(
adata: AnnData,
counts_per_cell: Optional[int] = None,
total_counts: Optional[int] = None,
random_state: Optional[int] = 0,
replace: bool = False,
copy: bool = False,
) -> Optional[AnnData]:
"""
Downsample counts from count matrix.
Cells with fewer counts than `target_counts` are unaffected by this. This
has been implemented by M. D. Luecken.
If `counts_per_cell` in specified, each cell will downsampled. If
`total_counts` is specified, expression matrix will be downsampled to
contain at most `total_counts`.
Parameters
----------
adata : :class:`~anndata.AnnData`
adata
Annotated data matrix.
target_counts : `int` (default: 20,000)
Target number of counts for downsampling. Cells with more counts than
'target_counts' will be downsampled to have 'target_counts' counts.
random_state : `int` or `None`, optional (default: 0)
Random seed to change subsampling.
replace : `bool`, optional (default: `True`)
counts_per_cell
Target total counts per cell. If a cell has more than 'counts_per_cell',
it will be downsampled to this number.
total_counts
Target total counts. If the count matrix has more than `total_counts`
it will be downsampled to have this number.
random_state
Random seed for subsampling.
replace
Whether to sample the counts with replacement.
copy : `bool`, optional (default: `False`)
copy
If an :class:`~anndata.AnnData` is passed, determines whether a copy
is returned.
Expand All @@ -937,34 +949,67 @@ def downsample_counts(adata, target_counts=20000, random_state=0,
AnnData, None
Depending on `copy` returns or updates an `adata` with downsampled `.X`.
"""
if type(total_counts) == type(counts_per_cell):
raise ValueError("Must specify exactly one of `total_counts` or `counts_per_cell`.")
if copy:
adata = adata.copy()
adata.X = adata.X.astype(np.integer) # Numba doesn't want floats
if issparse(adata.X):
X = adata.X
if total_counts:
adata.X = _downsample_total_counts(adata.X, total_counts, random_state, replace)
elif counts_per_cell:
adata.X = _downsample_per_cell(adata.X, counts_per_cell, random_state, replace)
if copy:
return adata


def _downsample_per_cell(X, counts_per_cell, random_state, replace):
if issparse(X):
original_type = type(X)
if not isspmatrix_csr(X):
X = csr_matrix(X)
totals = np.ravel(X.sum(axis=1))
under_target = np.nonzero(totals > target_counts)[0]
under_target = np.nonzero(totals > counts_per_cell)[0]
cols = np.split(X.data.view(), X.indptr[1:-1])
for colidx in under_target:
col = cols[colidx]
downsample_cell(col, target_counts, random_state=random_state,
replace=replace, inplace=True)
if not isspmatrix_csr(adata.X): # Put it back
adata.X = type(adata.X)(X)
_downsample_array(col, counts_per_cell, random_state=random_state,
replace=replace, inplace=True)
X.eliminate_zeros()
if original_type is not csr_matrix: # Put it back
X = original_type(X)
else:
totals = np.ravel(X.sum(axis=1))
under_target = np.nonzero(totals > counts_per_cell)[0]
X[under_target, :] = \
np.apply_along_axis(_downsample_array, 1, X[under_target, :],
counts_per_cell, random_state=random_state,
replace=replace)
return X


def _downsample_total_counts(X, total_counts, random_state, replace):
total = X.sum()
if total < total_counts:
return X
if issparse(X):
original_type = type(X)
if not isspmatrix_csr(X):
X = csr_matrix(X)
_downsample_array(X.data, total_counts, random_state=random_state,
replace=replace, inplace=True)
X.eliminate_zeros()
if original_type is not csr_matrix:
X = original_type(X)
else:
totals = np.ravel(adata.X.sum(axis=1))
under_target = np.nonzero(totals > target_counts)[0]
adata.X[under_target, :] = \
np.apply_along_axis(downsample_cell, 1, adata.X[under_target, :],
target_counts, random_state=random_state, replace=replace)
if copy: return adata
v = X.view().reshape(np.multiply(*X.shape))
_downsample_array(v, total_counts, random_state, replace=replace,
inplace=True)
return X


@numba.njit
def downsample_cell(col: np.array, target: int, random_state: int=0,
replace: bool=True, inplace: bool=False):
@numba.njit(cache=True)
def _downsample_array(col: np.array, target: int, random_state: int=0,
replace: bool = True, inplace: bool=False):
"""
Evenly reduce counts in cell to target amount.
Expand Down
28 changes: 26 additions & 2 deletions scanpy/tests/test_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def test_regress_out_categorical():
multi = sc.pp.regress_out(adata, keys='batch', n_jobs=8, copy=True)
assert adata.X.shape == multi.X.shape

def test_downsample_counts():
def test_downsample_counts_per_cell():
TARGET = 1000
X = np.random.randint(0, 100, (1000, 100)) * \
np.random.binomial(1, .3, (1000, 100))
Expand All @@ -97,7 +97,7 @@ def test_downsample_counts():
adata_csc = AnnData(X=sp.csc_matrix(X))
for adata, replace in product((adata_dense, adata_csr, adata_csc), (True, False)):
initial_totals = np.ravel(adata.X.sum(axis=1))
adata = sc.pp.downsample_counts(adata, target_counts=TARGET, replace=replace, copy=True)
adata = sc.pp.downsample_counts(adata, counts_per_cell=TARGET, replace=replace, copy=True)
new_totals = np.ravel(adata.X.sum(axis=1))
if sp.issparse(adata.X):
assert all(adata.X.toarray()[X == 0] == 0)
Expand All @@ -109,3 +109,27 @@ def test_downsample_counts():
== new_totals[initial_totals <= TARGET])
if not replace:
assert np.all(X >= adata.X)

def test_downsample_total_counts():
X = np.random.randint(0, 100, (1000, 100)) * \
np.random.binomial(1, .3, (1000, 100))
total = X.sum()
target = np.floor_divide(total, 10)
adata_dense = AnnData(X=X.copy())
adata_csr = AnnData(X=sp.csr_matrix(X))
for adata, replace in product((adata_dense, adata_csr), (True, False)):
initial_totals = np.ravel(adata.X.sum(axis=1))
adata = sc.pp.downsample_counts(adata, total_counts=target, replace=replace, copy=True)
new_totals = np.ravel(adata.X.sum(axis=1))
if sp.issparse(adata.X):
assert all(adata.X.toarray()[X == 0] == 0)
else:
assert all(adata.X[X == 0] == 0)
assert adata.X.sum() == target
assert all(initial_totals >= new_totals)
if not replace:
assert np.all(X >= adata.X)
for adata in (adata_dense, adata_csr): # When specified total is greater than current total
adata = sc.pp.downsample_counts(adata, total_counts=total + 10, replace=False, copy=True)
assert (adata.X == X).all()

39 changes: 36 additions & 3 deletions scanpy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import inspect
from weakref import WeakSet
from collections import namedtuple
from functools import partial
from functools import partial, wraps
from types import ModuleType
from typing import Union, Callable, Optional

Expand All @@ -16,12 +16,12 @@
from pandas.api.types import CategoricalDtype

from . import settings, logging as logg
import warnings

EPS = 1e-15


def check_versions():
import warnings
from distutils.version import LooseVersion

if sys.version_info < (3, 0):
Expand Down Expand Up @@ -61,6 +61,40 @@ def type_doc(name: str):
)


def deprecated_arg_names(arg_mapping):
"""
Decorator which marks a functions keyword arguments as deprecated. It will
result in a warning being emitted when the deprecated keyword argument is
used, and the function being called with the new argument.
Parameters
----------
arg_mapping : dict[str, str]
Mapping from deprecated argument name to current argument name.
"""
def decorator(func):
@wraps(func)
def func_wrapper(*args, **kwargs):
warnings.simplefilter(
'always', DeprecationWarning) # turn off filter
for old, new in arg_mapping.items():
if old in kwargs:
warnings.warn(
"Keyword argument '{0}' has been deprecated in favour "
"of '{1}'. '{0}' will be removed in a future version."
.format(old, new),
category=DeprecationWarning,
stacklevel=2,
)
val = kwargs.pop(old)
kwargs[new] = val
warnings.simplefilter(
'default', DeprecationWarning) # reset filter
return func(*args, **kwargs)
return func_wrapper
return decorator


def descend_classes_and_funcs(mod: ModuleType, root: str, encountered=None):
if encountered is None:
encountered = WeakSet()
Expand Down Expand Up @@ -774,7 +808,6 @@ def warn_with_traceback(message, category, filename, lineno, file=None, line=Non
--------
http://stackoverflow.com/questions/22373927/get-traceback-of-warnings
"""
import warnings
import traceback
traceback.print_stack()
log = file if hasattr(file, 'write') else sys.stderr
Expand Down

0 comments on commit a9b0175

Please sign in to comment.