diff --git a/scvi/utils/_differential.py b/scvi/utils/_differential.py index 5442bd8498..f8870a5cdf 100644 --- a/scvi/utils/_differential.py +++ b/scvi/utils/_differential.py @@ -6,8 +6,12 @@ import numpy as np import pandas as pd import torch +from scipy.sparse import issparse +from sklearn.mixture import GaussianMixture +from scvi import _CONSTANTS from scvi._compat import Literal +from scvi.data import get_from_registry logger = logging.getLogger(__name__) @@ -48,6 +52,7 @@ def get_bayes_factors( change_fn: Optional[Union[str, Callable]] = None, m1_domain_fn: Optional[Callable] = None, delta: Optional[float] = 0.5, + pseudocounts: Union[float, None] = 0.0, cred_interval_lvls: Optional[Union[List[float], np.ndarray]] = None, ) -> Dict[str, np.ndarray]: r""" @@ -150,7 +155,11 @@ def get_bayes_factors( delta specific case of region inducing differential expression. In this case, we suppose that :math:`R \setminus [-\delta, \delta]` does not induce differential expression - (LFC case) + (LFC case). If the provided value is `None`, then a proper threshold is determined + from the distribution of LFCs accross genes. + pseudocounts + pseudocount offset used for the mode `change`. + When None, observations from non-expressed genes are used to estimate its value. cred_interval_lvls List of credible interval levels to compute for the posterior LFC distribution @@ -164,8 +173,7 @@ def get_bayes_factors( # warnings.warn( # "Differential expression requires a Posterior object created with all indices." # ) - - eps = 1e-8 # used for numerical stability + eps = 1e-8 # Normalized means sampling for both populations scales_batches_1 = self.scale_sampler( selection=idx1, @@ -236,6 +244,19 @@ def get_bayes_factors( m_permutation=m_permutation, ) + # Adding pseudocounts to the scales + if pseudocounts is None: + logger.debug("Estimating pseudocounts offet from the data") + x = get_from_registry(self.adata, _CONSTANTS.X_KEY) + where_zero_a = densify(np.max(x[idx1], 0)) == 0 + where_zero_b = densify(np.max(x[idx2], 0)) == 0 + pseudocounts = estimate_pseudocounts_offset( + scales_a=scales_1, + scales_b=scales_2, + where_zero_a=where_zero_a, + where_zero_b=where_zero_b, + ) + logger.debug("Using pseudocounts ~ {}".format(pseudocounts)) # Core of function: hypotheses testing based on the posterior samples we obtained above if mode == "vanilla": logger.debug("Differential expression using vanilla mode") @@ -254,7 +275,7 @@ def get_bayes_factors( # step 1: Construct the change function def lfc(x, y): - return np.log2(x) - np.log2(y) + return np.log2(x + pseudocounts) - np.log2(y + pseudocounts) if change_fn == "log-fold" or change_fn is None: change_fn = lfc @@ -263,10 +284,15 @@ def lfc(x, y): # step2: Construct the DE area function if m1_domain_fn is None: - delta = delta if delta is not None else 0.5 def m1_domain_fn(samples): - return np.abs(samples) >= delta + delta_ = ( + delta + if delta is not None + else estimate_delta(lfc_means=samples.mean(0)) + ) + logger.debug("Using delta ~ {:.2f}".format(delta_)) + return np.abs(samples) >= delta_ change_fn_specs = inspect.getfullargspec(change_fn) domain_fn_specs = inspect.getfullargspec(m1_domain_fn) @@ -277,6 +303,11 @@ def m1_domain_fn(samples): try: change_distribution = change_fn(scales_1, scales_2) is_de = m1_domain_fn(change_distribution) + delta_ = ( + estimate_delta(lfc_means=change_distribution.mean(0)) + if delta is None + else delta + ) except TypeError: raise TypeError( "change_fn or m1_domain_fn have has wrong properties." @@ -298,6 +329,8 @@ def m1_domain_fn(samples): bayes_factor=np.log(proba_m1 + eps) - np.log(1.0 - proba_m1 + eps), scale1=px_scale_mean1, scale2=px_scale_mean2, + pseudocounts=pseudocounts, + delta=delta_, **change_distribution_props, ) else: @@ -403,6 +436,78 @@ def scale_sampler( return dict(scale=px_scales, batch=batch_ids) +def estimate_delta(lfc_means: List[np.ndarray], coef=0.6, min_thres=0.3): + """ + Computes a threshold LFC value based on means of LFCs. + + Parameters + ---------- + lfc_means + LFC means for each gene, should be 1d. + coef + Tunable hyperparameter to choose the threshold based on estimated modes, defaults to 0.6 + min_thres + Minimum returned threshold value, defaults to 0.3 + """ + logger.debug("Estimating delta from effect size samples") + if lfc_means.ndim >= 2: + raise ValueError("lfc_means should be 1-dimensional of shape: (n_genes,).") + gmm = GaussianMixture(n_components=3) + gmm.fit(lfc_means[:, None]) + vals = np.sort(gmm.means_.squeeze()) + res = coef * np.abs(vals[[0, -1]]).mean() + res = np.maximum(min_thres, res) + return res + + +def estimate_pseudocounts_offset( + scales_a: List[np.ndarray], + scales_b: List[np.ndarray], + where_zero_a: List[np.ndarray], + where_zero_b: List[np.ndarray], + percentile: Optional[float] = 0.9, +): + """ + Determines pseudocount offset. + + This shrinks LFCs asssociated with non-expressed genes to zero. + + Parameters + ---------- + scales_a + Scales in first population + scales_b + Scales in second population + where_zero_a + mask where no observed counts + where_zero_b + mask where no observed counts + """ + max_scales_a = np.max(scales_a, 0) + max_scales_b = np.max(scales_b, 0) + asserts = ( + (max_scales_a.shape == where_zero_a.shape) + and (max_scales_b.shape == where_zero_b.shape) + ) and (where_zero_a.shape == where_zero_b.shape) + if not asserts: + raise ValueError( + "Dimension mismatch between scales and/or masks to compute the pseudocounts offset." + ) + if where_zero_a.sum() >= 1: + artefact_scales_a = max_scales_a[where_zero_a] + eps_a = np.percentile(artefact_scales_a, q=percentile) + else: + eps_a = 1e-10 + + if where_zero_b.sum() >= 1: + artefact_scales_b = max_scales_b[where_zero_b] + eps_b = np.percentile(artefact_scales_b, q=percentile) + else: + eps_b = 1e-10 + res = np.maximum(eps_a, eps_b) + return res + + def pairs_sampler( arr1: Union[List[float], np.ndarray, torch.Tensor], arr2: Union[List[float], np.ndarray, torch.Tensor], @@ -577,3 +682,9 @@ def save_cluster_xlsx( for i, x in enumerate(cluster_names): de_results[i].to_excel(writer, sheet_name=str(x)) writer.close() + + +def densify(arr): + if issparse(arr): + return np.asarray(arr.todense()).squeeze() + return arr diff --git a/tests/core/test_differential.py b/tests/core/test_differential.py index 1df8027758..7f4a52c015 100644 --- a/tests/core/test_differential.py +++ b/tests/core/test_differential.py @@ -7,6 +7,43 @@ from scvi.model import SCVI from scvi.model.base._utils import _prepare_obs from scvi.utils import DifferentialComputation +from scvi.utils._differential import estimate_delta, estimate_pseudocounts_offset + + +def test_features(): + a = np.random.randn( + 100, + ) + b = 3 + np.random.randn( + 100, + ) + c = -3 + np.random.randn( + 100, + ) + alls = np.concatenate([a, b, c]) + delta = estimate_delta(alls) + expected_range = (delta >= 0.4 * 3) and (delta <= 6) + if not expected_range: + raise ValueError("The effect-size threshold was not properly estimated.") + + scales_a = np.random.rand(100, 50) + where_zero_a = np.zeros(50, dtype=bool) + where_zero_a[:10] = True + scales_a[:, :10] = 1e-6 + + scales_b = np.random.rand(100, 50) + where_zero_b = np.zeros(50, dtype=bool) + where_zero_b[-10:] = True + scales_b[:, -10:] = 1e-7 + offset = estimate_pseudocounts_offset( + scales_a=scales_a, + scales_b=scales_b, + where_zero_a=where_zero_a, + where_zero_b=where_zero_b, + ) + expected_off_range = offset <= 1e-6 + if not expected_off_range: + raise ValueError("The pseudocount offset was not properly estimated.") def test_differential_computation(save_path): @@ -23,7 +60,21 @@ def test_differential_computation(save_path): cell_idx2 = ~cell_idx1 dc.get_bayes_factors(cell_idx1, cell_idx2, mode="vanilla", use_permutation=True) - dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", use_permutation=False) + res = dc.get_bayes_factors( + cell_idx1, cell_idx2, mode="change", use_permutation=False + ) + assert (res["delta"] == 0.5) and (res["pseudocounts"] == 0.0) + res = dc.get_bayes_factors( + cell_idx1, cell_idx2, mode="change", use_permutation=False, delta=None + ) + dc.get_bayes_factors( + cell_idx1, + cell_idx2, + mode="change", + use_permutation=False, + delta=None, + pseudocounts=None, + ) dc.get_bayes_factors(cell_idx1, cell_idx2, mode="change", cred_interval_lvls=[0.75]) delta = 0.5