Skip to content

Commit

Permalink
Change outlier detection for MRVI to ball admissibility calculation (#…
Browse files Browse the repository at this point in the history
…3007)

The outlier detection was mistakenly set to the "ap" (aggregated
posterior) thresholding setting when it should have been the "ball"
thresholding calculation.

The only changes in this pr are in the `get_outlier_cell_sample_pairs`,
everything else is auto-linted.

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
justjhong and pre-commit-ci[bot] authored Oct 2, 2024
1 parent b8bc970 commit b5d81d4
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 12 deletions.
14 changes: 14 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,20 @@ to [Semantic Versioning]. Full commit history is available in the

## Version 1.2

### 1.2.1 (2024-XX-XX)

#### Added

#### Fixed

- Breaking Change: Fix `get_outlier_cell_sample_pairs` function in {class}`scvi.external.MRVI`
to correctly compute the maxmimum log-density across in-sample cells rather than the
aggregated posterior log-density {pr}`3007`.

#### Changed

#### Removed

### 1.2.0 (2024-09-26)

#### Added
Expand Down
42 changes: 30 additions & 12 deletions src/scvi/external/mrvi/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from scvi.data import AnnDataManager, fields
from scvi.external.mrvi._module import MRVAE
from scvi.external.mrvi._types import MRVIReduction
from scvi.external.mrvi._utils import rowwise_max_excluding_diagonal
from scvi.model.base import BaseModelClass, JaxTrainingMixin
from scvi.utils import setup_anndata_dsp
from scvi.utils._docstrings import devices_dsp
Expand Down Expand Up @@ -745,7 +746,10 @@ def get_aggregated_posterior(
indices: npt.ArrayLike | None = None,
batch_size: int = 256,
) -> Distribution:
"""Compute the aggregated posterior over the ``u`` latent representations.
"""Computes the aggregated posterior over the ``u`` latent representations.
For the specified samples, it computes the aggregated posterior over the ``u`` latent
representations. Returns a NumPyro MixtureSameFamily distribution.
Parameters
----------
Expand Down Expand Up @@ -959,12 +963,13 @@ def get_outlier_cell_sample_pairs(
admissibility_threshold: float = 0.0,
batch_size: int = 256,
) -> xr.Dataset:
"""Compute outlier cell-sample pairs.
"""Compute admissibility scores for cell-sample pairs.
This function fits a GMM for each sample based on the latent representation of the cells in
the sample or computes an approximate aggregated posterior for each sample. Then, for every
cell, it computes the log-probability of the cell under the approximated posterior of each
sample as a measure of admissibility.
This function computes the posterior distribution for u for each cell. Then, for every
cell, it computes the log-probability of the cell under the posterior of each cell
each sample and takes the maximum value for a given sample as a measure of admissibility
for that sample. Additionally, it computes a threshold that determines if
a cell-sample pair is admissible based on the within-sample admissibility scores.
Parameters
----------
Expand Down Expand Up @@ -995,21 +1000,34 @@ def get_outlier_cell_sample_pairs(
adata_s = adata[sample_idxs]

ap = self.get_aggregated_posterior(adata=adata, indices=sample_idxs)
log_probs_s = jnp.quantile(
ap.log_prob(adata_s.obsm["U"]).sum(axis=1), q=quantile_threshold
)
n_splits = adata.n_obs // batch_size
in_max_comp_log_probs = ap.component_distribution.log_prob(
np.expand_dims(adata_s.obsm["U"], ap.mixture_dim)
).sum(axis=1)
log_probs_s = rowwise_max_excluding_diagonal(in_max_comp_log_probs)

log_probs_ = []
n_splits = adata.n_obs // batch_size
for u_rep in np.array_split(adata.obsm["U"], n_splits):
log_probs_.append(jax.device_get(ap.log_prob(u_rep).sum(-1, keepdims=True)))
log_probs_.append(
jax.device_get(
ap.component_distribution.log_prob(
np.expand_dims(u_rep, ap.mixture_dim)
) # (n_cells_batch, n_cells_ap, n_latent_dim)
.sum(axis=1) # (n_cells_batch, n_latent_dim)
.max(axis=1, keepdims=True) # (n_cells_batch, 1)
)
)

log_probs_ = np.concatenate(log_probs_, axis=0) # (n_cells, 1)

threshs.append(np.array(log_probs_s))
log_probs.append(np.array(log_probs_))

threshs_all = np.concatenate(threshs)
global_thresh = np.quantile(threshs_all, q=quantile_threshold)
threshs = np.array(len(log_probs) * [global_thresh])

log_probs = np.concatenate(log_probs, 1)
threshs = np.array(threshs)
log_ratios = log_probs - threshs

coords = {
Expand Down

0 comments on commit b5d81d4

Please sign in to comment.