Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated destVI #1457

Merged
merged 45 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 37 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d3334f3
Fix destVI cell type amortization, Vamp prior from overclustering, pa…
canergen Mar 18, 2022
9e5acd9
Registry wrong in condSCVI in last commit
canergen Mar 18, 2022
e1bab17
REGISTRY key in condSCVI
canergen Mar 18, 2022
4d47171
Updated training parameters of destVI
canergen Mar 18, 2022
fe93c0a
Learned rate adapted for destVI
canergen Mar 19, 2022
f3e4bcb
Merge branch 'scverse:master' into can_destvi
canergen Mar 22, 2022
b9dfef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
618a136
Update import of scanty
canergen Mar 22, 2022
3c8aca4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
f4ea841
Added leidenalg to test dependencies
canergen Mar 22, 2022
5af2ec2
Revert "Added leidenalg to test dependencies"
canergen Mar 22, 2022
6b6c957
Changed scanpy import to give error message
canergen Mar 22, 2022
5d7125b
Added leiden to scanpy requirements
canergen Mar 22, 2022
a78be90
Defaults to no l1_sparsity.
canergen Mar 23, 2022
2bafa45
Style changes to condSCVI.
canergen Mar 23, 2022
15d6bec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
4e0b4c7
Space in empty line removed
canergen Mar 23, 2022
22da8a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
2f098fe
Changed vamp overclustering to kmeans, minor comments, added test cases.
canergen Mar 28, 2022
16eefab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
b523918
Merge branch 'scverse:master' into can_destvi
canergen Mar 28, 2022
e3a5f66
Error in test function. Added check for custom clustering
canergen Mar 28, 2022
43ab4e5
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Mar 28, 2022
4000f21
Merge conflict
canergen Mar 28, 2022
2bd2a7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
5f98667
Edited dataset after setup of CondSCVI
canergen Mar 28, 2022
af8a341
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
1e18097
Error in test function.
canergen Mar 28, 2022
7e18f34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
9d0fc4f
Update test function
canergen Mar 28, 2022
e9637ff
Ran precommit
canergen Mar 28, 2022
fffd0b4
Merge pull request #1 from cane11/destvi_test
canergen Mar 28, 2022
8a0d587
Changed eta_prior_weight, exposed beta_prior_weight. Small adjustments.
canergen Mar 30, 2022
b7b11b3
Merge pull request #2 from cane11/destvi_test
canergen Mar 30, 2022
6969321
Doc changes, weight_vprior->mp_vprior
canergen Apr 4, 2022
0e9f24b
Updated documentation. Style changes. Renamed weighting parameters.
canergen Apr 4, 2022
b23ffe8
L1_reg added to test function
canergen Apr 4, 2022
9848765
Updated logsumexp. Small fixes
canergen Apr 4, 2022
2c53029
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2022
ced6122
Precommit edit.
canergen Apr 4, 2022
d948388
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Apr 4, 2022
605ea78
np.array -> torch.cat
canergen Apr 4, 2022
23c0dd6
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Apr 4, 2022
624325a
Updated references. Updated dropout_decoder
canergen Apr 6, 2022
4176ab3
Merge branch 'master' into can_destvi
adamgayoso Apr 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 20 additions & 8 deletions docs/user_guide/models/destvi.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ expression matrix of UMI counts $X$ with $N$ cells and $G$ genes, along with
a vector of cell type labels $\vec{c}$. Subsequently, the stLVM takes in the trained scLVM,
along a spatial gene expression matrix $Y$ with $S$ spots and $G$ genes.
Optionally, the user can specify the number of components used for the mixture model underlying the
emprical prior.
empirical prior.

## Generative process

Expand Down Expand Up @@ -95,7 +95,7 @@ and gene $g$, the observation is generated as a function of the latent variables
:nowrap: true

\begin{align}
\gamma_x^c &\sim \frac{1}{K} \sum_{k=1}^K q_\Phi(\gamma^c \mid u_{kc}, c) \tag{4} \\
\gamma_x^c &\sim \sum_{k=1}^K m_{kc} q_\Phi(\gamma^c \mid u_{kc}, c) \tag{4} \\
x_{sg} &\sim \mathrm{NegativeBinomial}(l_s\alpha_g\sum_{c=1}^{C}\beta_{sc}f^g(c, \gamma_s^c), p_g) \tag{5} \\
\end{align}
```
Expand All @@ -106,7 +106,10 @@ $p_g$ is the rate parameter for the negative binomial distribution.

To avoid the latent variable $\gamma_s^c$ from incorporating variation attributed to experimental
assay differences, we assign an empirical prior informed by the scLVM and a corresponding set of
cells of the same cell type in the scRNA-seq dataset.
cells of the same cell type in the scRNA-seq dataset. To compute this function, we subcluster the latent space of the
scLVM for each cell type to k cell type specific clusters. For each cluster we compute an empirical mean and variance.
The loss is weighted by the probability of a random cell from this cell type to be in the respective cluster in the
scRNA-seq dataset (mixture probability, m_{kc}).
canergen marked this conversation as resolved.
Show resolved Hide resolved
Above, $\{u_{kc}\}_{k=1}^K$ designates a set of cells from cell type $c$ in the scRNA-seq dataset, and
$q_\Phi$ designates the variational distrbution from the scLVM.
In literature, the prior is referred to as a VampPrior ("variational aggregated mixture of posteriors" prior) [^ref2].
Expand All @@ -116,6 +119,9 @@ Lastly, an additional latent variable, $\eta_g$, is incorporated into the aggreg
as a dummy cell type to represent gene specific noise. The dummy cell type's expression profile is distributed
as $\epsilon_g := \mathrm{Softplus}(\eta_g)$ where $\eta_g \sim \mathrm{Normal}(0, 1)$.
Like the other cell types, there is an associated cell type abundance parameter $\beta_{sc}$ associated with $\eta$.
We suspect each spot to only contain a fraction of the different cell types. To increase sparsity of the cell type
proportions, the stLVM supports L1 regularization on the cell types proportions $\beta_{sc}$. By default this loss is
not used.

This generative process is also summarized in the following graphical model:

Expand Down Expand Up @@ -176,8 +182,8 @@ The loss is defined as:
:nowrap: true

\begin{align}
L(l, \alpha, \beta, f^g, \gamma, p, \eta) := &-\log p(X \mid l, \alpha, \beta, f^g, \gamma, p, \eta) - \log p(\eta) \\
&+ \mathrm{Var}(\alpha) - \log p(\gamma \mid \mathrm{VampPrior}) \tag{6} \\
L(l, \alpha, \beta, f^g, \gamma, p, \eta) := &-\log p(X \mid l, \alpha, \beta, f^g, \gamma, p, \eta) - eta_{reg} \log p(\eta) \\
&+ beta_{reg} \mathrm{Var}(\alpha) - \log p(\gamma \mid \mathrm{VampPrior}) + l1_{reg} sum_{c=1}^{C}\beta_{sc} \\
canergen marked this conversation as resolved.
Show resolved Hide resolved
\end{align}
```

Expand All @@ -203,7 +209,8 @@ Subsequently for a given cell type, users can plot a heatmap of the cell type pr

```
>>> import scanpy as sc
>>> sc.p1.embedding(st_adata, basis="location", color="B cells")
>>> st_adata.obs['B cells'] = st_adata.obsm['proportions']['B cells']
>>> sc.pl.spatial(st_adata, color="B cells", spot_size=130)
```

### Intra cell type variation
Expand All @@ -230,8 +237,13 @@ impute the spatial pattern of the cell-type-specific gene expression with:
### Comparative analysis between samples

To perform differential expression across samples, one can apply a frequentist test by taking samples
from the parameters of the generative distribution predicted for each spot in question. More details
can be found in the DestVI paper.
from the parameters of the generative distribution predicted for each spot in question.
### Utilities function

To explore the results of the output of the stLVM, we published a utilities function covering functions
for automatic thresholding of cell type proportions, a spatial PCA analysis to find main axis of variation
in spatial gene expression and the described frequentist test for differential expression. Further information
can be found on [destvi_utils](https://destvi-utils.readthedocs.io/en/latest/installation.html)

[^ref1]: Romain Lopez, Baoguo Li, Hadas Keren-Shaul, Pierre Boyeau, Merav Kedmi, David Pilzer, Adam Jelinski, Eyal David, Allon Wagner, Yoseph Addad, Michael I. Jordan, Ido Amit, Nir Yosef (2021),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah, we can change the title here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Romain Lopez, Baoguo Li, Hadas Keren-Shaul, Pierre Boyeau, Merav Kedmi, David Pilzer, Adam Jelinski, Ido Yofe, Eyal David, Allon Wagner, Can Ergen, Yoseph Addadi, Ofra Golani, Franca Ronchese, Michael I Jordan, Ido Amit, Nir Yosef. DestVI identifies continuums of cell types in spatial transcriptomics data. Nature Biotechnology (in press), 2022.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made it consistent with other packages and also added link to VAMP prior paper

*Multi-resolution deconvolution of spatial transcriptomics data reveals continuous patterns of inflammation*,
Expand Down
102 changes: 69 additions & 33 deletions scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import torch
from anndata import AnnData
from sklearn.cluster import KMeans

from scvi import REGISTRY_KEYS
from scvi.data import AnnDataManager
Expand Down Expand Up @@ -35,8 +36,6 @@ class CondSCVI(RNASeqMixin, VAEMixin, UnsupervisedTrainingMixin, BaseModelClass)
Dimensionality of the latent space.
n_layers
Number of hidden layers used for encoder and decoder NNs.
dropout_rate
Dropout rate for the encoder neural networks.
weight_obs
Whether to reweight observations by their inverse proportion (useful for lowly abundant cell types)
**module_kwargs
Expand All @@ -57,7 +56,6 @@ def __init__(
n_hidden: int = 128,
n_latent: int = 5,
n_layers: int = 2,
dropout_rate: float = 0.1,
weight_obs: bool = False,
**module_kwargs,
):
Expand All @@ -82,19 +80,16 @@ def __init__(
n_hidden=n_hidden,
n_latent=n_latent,
n_layers=n_layers,
dropout_rate=dropout_rate,
**module_kwargs,
)
self._model_summary_string = (
"Conditional SCVI Model with the following params: \nn_hidden: {}, n_latent: {}, n_layers: {}, dropout_rate: {}, weight_obs: {}"
).format(n_hidden, n_latent, n_layers, dropout_rate, weight_obs)
).format(n_hidden, n_latent, n_layers, 0.05, weight_obs)
self.init_params_ = self._get_init_params(locals())

@torch.no_grad()
def get_vamp_prior(
self,
adata: Optional[AnnData] = None,
p: int = 50,
self, adata: Optional[AnnData] = None, p: int = 10
) -> np.ndarray:
r"""
Return an empirical prior over the cell-type specific latent space (vamp prior) that may be used for deconvolution.
Expand All @@ -105,7 +100,7 @@ def get_vamp_prior(
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
p
canergen marked this conversation as resolved.
Show resolved Hide resolved
number of components in the mixture model underlying the empirical prior
number of clusters in kmeans clustering for cell-type sub-clustering for empirical prior

Returns
-------
Expand All @@ -121,41 +116,82 @@ def get_vamp_prior(

adata = self._validate_anndata(adata)

# Extracting latent representation of adata including variances.
mean_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent))
var_vprior = np.zeros((self.summary_stats.n_labels, p, self.module.n_latent))
var_vprior = np.ones((self.summary_stats.n_labels, p, self.module.n_latent))
mp_vprior = np.zeros((self.summary_stats.n_labels, p))

labels_state_registry = self.adata_manager.get_state_registry(
REGISTRY_KEYS.LABELS_KEY
)
key = labels_state_registry.original_key
mapping = labels_state_registry.categorical_mapping
for ct in range(self.summary_stats.n_labels):
# pick p cells
local_indices = np.random.choice(
np.where(adata.obs[key] == mapping[ct])[0], p
)
# get mean and variance from posterior
scdl = self._make_data_loader(
adata=adata, indices=local_indices, batch_size=p
)
mean = []
var = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz_m"], out["qz_v"]
mean += [mean_.cpu()]
var += [var_.cpu()]

mean_vprior[ct], var_vprior[ct] = np.array(torch.cat(mean)), np.array(
torch.cat(var)

scdl = self._make_data_loader(adata=adata, batch_size=p)

mean = []
canergen marked this conversation as resolved.
Show resolved Hide resolved
var = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz_m"], out["qz_v"]
mean += [mean_.cpu()]
var += [var_.cpu()]

mean_cat, var_cat = np.array(torch.cat(mean)), np.array(torch.cat(var))
canergen marked this conversation as resolved.
Show resolved Hide resolved

for ct in range(self.summary_stats["n_labels"]):
local_indices = np.where(adata.obs[key] == mapping[ct])[0]
n_local_indices = len(local_indices)
if "overclustering_vamp" not in adata.obs.columns:
if p < n_local_indices and p > 0:
overclustering_vamp = KMeans(n_clusters=p, n_init=30).fit_predict(
mean_cat[local_indices]
)
else:
# Every cell is its own cluster
overclustering_vamp = np.arange(n_local_indices)
else:
overclustering_vamp = adata[local_indices, :].obs["overclustering_vamp"]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it explained somewhere how user would know to set this a priori?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I put a paragraph about it in the tutorial.


keys, counts = np.unique(overclustering_vamp, return_counts=True)
canergen marked this conversation as resolved.
Show resolved Hide resolved

n_labels_overclustering = len(keys)
if n_labels_overclustering > p:
error_mess = """
Given cell type specific clustering contains more clusters than vamp_prior_p.
Increase value of vamp_prior_p to largest number of cell type specific clusters."""

raise ValueError(error_mess)

var_cluster = np.ones(
[
n_labels_overclustering,
self.module.n_latent,
]
)
mean_cluster = np.zeros_like(var_cluster)

for index, cluster in enumerate(keys):
indices_curr = local_indices[
np.where(overclustering_vamp == cluster)[0]
]
var_cluster[index, :] = np.mean(var_cat[indices_curr], axis=0) + np.var(
mean_cat[indices_curr], axis=0
)
mean_cluster[index, :] = np.mean(mean_cat[indices_curr], axis=0)

slicing = slice(n_labels_overclustering)
mean_vprior[ct, slicing, :] = mean_cluster
var_vprior[ct, slicing, :] = var_cluster
mp_vprior[ct, slicing] = counts / sum(counts)

return mean_vprior, var_vprior
return mean_vprior, var_vprior, mp_vprior

def train(
self,
max_epochs: int = 400,
max_epochs: int = 300,
lr: float = 0.001,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 1,
Expand Down
18 changes: 13 additions & 5 deletions scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
n_hidden: int,
n_latent: int,
n_layers: int,
l1_reg: float,
**module_kwargs,
):
super(DestVI, self).__init__(st_adata)
Expand All @@ -85,6 +86,7 @@ def __init__(
n_latent=n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
l1_reg=l1_reg,
**module_kwargs,
)
self.cell_type_mapping = cell_type_mapping
Expand All @@ -96,7 +98,8 @@ def from_rna_model(
cls,
st_adata: AnnData,
sc_model: CondSCVI,
vamp_prior_p: int = 50,
vamp_prior_p: int = 15,
l1_reg: float = 0.0,
**module_kwargs,
):
"""
Expand All @@ -110,6 +113,9 @@ def from_rna_model(
trained CondSCVI model
canergen marked this conversation as resolved.
Show resolved Hide resolved
vamp_prior_p
canergen marked this conversation as resolved.
Show resolved Hide resolved
number of mixture parameter for VampPrior calculations
l1_reg
Scalar parameter indicating the strength of L1 regularization on cell type proportions.
A value of 50 leads to sparser results.
**model_kwargs
Keyword args for :class:`~scvi.model.DestVI`
"""
Expand All @@ -123,7 +129,7 @@ def from_rna_model(
mean_vprior = None
var_vprior = None
else:
mean_vprior, var_vprior = sc_model.get_vamp_prior(
mean_vprior, var_vprior, mp_vprior = sc_model.get_vamp_prior(
sc_model.adata, p=vamp_prior_p
)

Expand All @@ -138,6 +144,8 @@ def from_rna_model(
sc_model.module.n_layers,
mean_vprior=mean_vprior,
var_vprior=var_vprior,
mp_vprior=mp_vprior,
l1_reg=l1_reg,
**module_kwargs,
)

Expand Down Expand Up @@ -298,13 +306,13 @@ def get_scale_for_ct(

def train(
self,
max_epochs: int = 400,
lr: float = 0.005,
max_epochs: int = 2000,
lr: float = 0.003,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 1.0,
validation_size: Optional[float] = None,
batch_size: int = 128,
n_epochs_kl_warmup: int = 50,
n_epochs_kl_warmup: int = 200,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
Expand Down
Loading