Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updated destVI #1457

Merged
merged 45 commits into from
Apr 6, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
d3334f3
Fix destVI cell type amortization, Vamp prior from overclustering, pa…
canergen Mar 18, 2022
9e5acd9
Registry wrong in condSCVI in last commit
canergen Mar 18, 2022
e1bab17
REGISTRY key in condSCVI
canergen Mar 18, 2022
4d47171
Updated training parameters of destVI
canergen Mar 18, 2022
fe93c0a
Learned rate adapted for destVI
canergen Mar 19, 2022
f3e4bcb
Merge branch 'scverse:master' into can_destvi
canergen Mar 22, 2022
b9dfef1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
618a136
Update import of scanty
canergen Mar 22, 2022
3c8aca4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2022
f4ea841
Added leidenalg to test dependencies
canergen Mar 22, 2022
5af2ec2
Revert "Added leidenalg to test dependencies"
canergen Mar 22, 2022
6b6c957
Changed scanpy import to give error message
canergen Mar 22, 2022
5d7125b
Added leiden to scanpy requirements
canergen Mar 22, 2022
a78be90
Defaults to no l1_sparsity.
canergen Mar 23, 2022
2bafa45
Style changes to condSCVI.
canergen Mar 23, 2022
15d6bec
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
4e0b4c7
Space in empty line removed
canergen Mar 23, 2022
22da8a5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 23, 2022
2f098fe
Changed vamp overclustering to kmeans, minor comments, added test cases.
canergen Mar 28, 2022
16eefab
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
b523918
Merge branch 'scverse:master' into can_destvi
canergen Mar 28, 2022
e3a5f66
Error in test function. Added check for custom clustering
canergen Mar 28, 2022
43ab4e5
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Mar 28, 2022
4000f21
Merge conflict
canergen Mar 28, 2022
2bd2a7f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
5f98667
Edited dataset after setup of CondSCVI
canergen Mar 28, 2022
af8a341
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
1e18097
Error in test function.
canergen Mar 28, 2022
7e18f34
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 28, 2022
9d0fc4f
Update test function
canergen Mar 28, 2022
e9637ff
Ran precommit
canergen Mar 28, 2022
fffd0b4
Merge pull request #1 from cane11/destvi_test
canergen Mar 28, 2022
8a0d587
Changed eta_prior_weight, exposed beta_prior_weight. Small adjustments.
canergen Mar 30, 2022
b7b11b3
Merge pull request #2 from cane11/destvi_test
canergen Mar 30, 2022
6969321
Doc changes, weight_vprior->mp_vprior
canergen Apr 4, 2022
0e9f24b
Updated documentation. Style changes. Renamed weighting parameters.
canergen Apr 4, 2022
b23ffe8
L1_reg added to test function
canergen Apr 4, 2022
9848765
Updated logsumexp. Small fixes
canergen Apr 4, 2022
2c53029
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 4, 2022
ced6122
Precommit edit.
canergen Apr 4, 2022
d948388
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Apr 4, 2022
605ea78
np.array -> torch.cat
canergen Apr 4, 2022
23c0dd6
Merge branch 'can_destvi' of https://github.com/cane11/scvi-tools int…
canergen Apr 4, 2022
624325a
Updated references. Updated dropout_decoder
canergen Apr 6, 2022
4176ab3
Merge branch 'master' into can_destvi
adamgayoso Apr 6, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 60 additions & 24 deletions scvi/model/_condscvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Optional, Union

import numpy as np
import scanpy
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
import torch
from anndata import AnnData

Expand Down Expand Up @@ -57,7 +58,7 @@ def __init__(
n_hidden: int = 128,
n_latent: int = 5,
n_layers: int = 2,
dropout_rate: float = 0.1,
dropout_rate: float = 0.05,
canergen marked this conversation as resolved.
Show resolved Hide resolved
weight_obs: bool = False,
**module_kwargs,
):
Expand Down Expand Up @@ -92,9 +93,7 @@ def __init__(

@torch.no_grad()
def get_vamp_prior(
self,
adata: Optional[AnnData] = None,
p: int = 50,
self, adata: Optional[AnnData] = None, p: int = 50, resolution: float = 10.0
canergen marked this conversation as resolved.
Show resolved Hide resolved
) -> np.ndarray:
r"""
Return an empirical prior over the cell-type specific latent space (vamp prior) that may be used for deconvolution.
Expand All @@ -106,6 +105,8 @@ def get_vamp_prior(
AnnData object used to initialize the model.
p
canergen marked this conversation as resolved.
Show resolved Hide resolved
number of components in the mixture model underlying the empirical prior
resolution
resolution of overclustering in leiden used for metapoints in empirical prior

Returns
-------
Expand All @@ -128,34 +129,69 @@ def get_vamp_prior(
)
key = labels_state_registry.original_key
mapping = labels_state_registry.categorical_mapping
for ct in range(self.summary_stats.n_labels):
# pick p cells
local_indices = np.random.choice(
np.where(adata.obs[key] == mapping[ct])[0], p

scdl = self._make_data_loader(adata=adata, batch_size=p)

mean = []
canergen marked this conversation as resolved.
Show resolved Hide resolved
var = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz_m"], out["qz_v"]
mean += [mean_.cpu()]
var += [var_.cpu()]

mean_cat, var_cat = np.array(torch.cat(mean)), np.array(torch.cat(var))
canergen marked this conversation as resolved.
Show resolved Hide resolved
adata.obsm["X_CondSCVI"] = mean_cat
canergen marked this conversation as resolved.
Show resolved Hide resolved

for ct in range(self.summary_stats["n_labels"]):
local_indices = np.where(adata.obs[key] == mapping[ct])[0]
sub_adata = adata[local_indices, :].copy()
scanpy.pp.neighbors(sub_adata, use_rep="X_CondSCVI")
adamgayoso marked this conversation as resolved.
Show resolved Hide resolved
if "overclustering_vamp" not in sub_adata.obs.columns:
scanpy.tl.leiden(
sub_adata, resolution=resolution, key_added="overclustering_vamp"
)

var_cluster = np.zeros(
[
len(np.unique(sub_adata.obs.overclustering_vamp)),
self.module.n_latent,
]
)
mean_cluster = np.zeros(
[
len(np.unique(sub_adata.obs.overclustering_vamp)),
canergen marked this conversation as resolved.
Show resolved Hide resolved
self.module.n_latent,
]
)
# get mean and variance from posterior
scdl = self._make_data_loader(
adata=adata, indices=local_indices, batch_size=p

for j in np.unique(sub_adata.obs.overclustering_vamp):
indices_curr = local_indices[
np.where(sub_adata.obs["overclustering_vamp"] == j)[0]
]
var_cluster[int(j), :] = (
np.mean(var_cat[indices_curr], axis=0)
+ np.mean(mean_cat[indices_curr] ** 2, axis=0)
- (np.mean(mean_cat[indices_curr], axis=0)) ** 2
canergen marked this conversation as resolved.
Show resolved Hide resolved
)
mean_cluster[int(j), :] = np.mean(mean_cat[indices_curr], axis=0)

unique, counts = np.unique(
sub_adata.obs.overclustering_vamp, return_counts=True
)
mean = []
var = []
for tensors in scdl:
x = tensors[REGISTRY_KEYS.X_KEY]
y = tensors[REGISTRY_KEYS.LABELS_KEY]
out = self.module.inference(x, y)
mean_, var_ = out["qz_m"], out["qz_v"]
mean += [mean_.cpu()]
var += [var_.cpu()]

mean_vprior[ct], var_vprior[ct] = np.array(torch.cat(mean)), np.array(
torch.cat(var)
selection = np.random.choice(
np.arange(len(unique)), size=p, p=counts / sum(counts)
)
mean_vprior[ct, :, :] = mean_cluster[selection, :]
var_vprior[ct, :, :] = var_cluster[selection, :]

return mean_vprior, var_vprior

def train(
self,
max_epochs: int = 400,
max_epochs: int = 300,
lr: float = 0.001,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 1,
Expand Down
19 changes: 14 additions & 5 deletions scvi/model/_destvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def __init__(
n_hidden: int,
n_latent: int,
n_layers: int,
l1_sparsity: float,
**module_kwargs,
):
super(DestVI, self).__init__(st_adata)
Expand All @@ -85,6 +86,7 @@ def __init__(
n_latent=n_latent,
n_layers=n_layers,
n_hidden=n_hidden,
l1_sparsity=l1_sparsity,
**module_kwargs,
)
self.cell_type_mapping = cell_type_mapping
Expand All @@ -96,7 +98,9 @@ def from_rna_model(
cls,
st_adata: AnnData,
sc_model: CondSCVI,
vamp_prior_p: int = 50,
vamp_prior_p: int = 500,
canergen marked this conversation as resolved.
Show resolved Hide resolved
vamp_prior_resolution: float = 10.0,
l1_sparsity: float = 60.0,
canergen marked this conversation as resolved.
Show resolved Hide resolved
**module_kwargs,
):
"""
Expand All @@ -110,6 +114,10 @@ def from_rna_model(
trained CondSCVI model
canergen marked this conversation as resolved.
Show resolved Hide resolved
vamp_prior_p
canergen marked this conversation as resolved.
Show resolved Hide resolved
number of mixture parameter for VampPrior calculations
vamp_prior_resolution
cluster resolution for VampPrior calculations
l1_sparsity
sparsity constraint for cell type proportions
**model_kwargs
Keyword args for :class:`~scvi.model.DestVI`
"""
Expand All @@ -124,7 +132,7 @@ def from_rna_model(
var_vprior = None
else:
mean_vprior, var_vprior = sc_model.get_vamp_prior(
sc_model.adata, p=vamp_prior_p
sc_model.adata, p=vamp_prior_p, resolution=vamp_prior_resolution
)

return cls(
Expand All @@ -138,6 +146,7 @@ def from_rna_model(
sc_model.module.n_layers,
mean_vprior=mean_vprior,
var_vprior=var_vprior,
l1_sparsity=l1_sparsity,
**module_kwargs,
)

Expand Down Expand Up @@ -298,13 +307,13 @@ def get_scale_for_ct(

def train(
self,
max_epochs: int = 400,
lr: float = 0.005,
max_epochs: int = 2000,
lr: float = 0.003,
use_gpu: Optional[Union[str, int, bool]] = None,
train_size: float = 1.0,
validation_size: Optional[float] = None,
batch_size: int = 128,
n_epochs_kl_warmup: int = 50,
n_epochs_kl_warmup: int = 200,
plan_kwargs: Optional[dict] = None,
**kwargs,
):
Expand Down
52 changes: 36 additions & 16 deletions scvi/module/_mrdeconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ def __init__(
px_r: np.ndarray,
mean_vprior: np.ndarray = None,
var_vprior: np.ndarray = None,
amortization: Literal["none", "latent", "proportion", "both"] = "latent",
amortization: Literal["none", "latent", "proportion", "both"] = "both",
l1_sparsity: float = 60.0,
):
super().__init__()
self.n_spots = n_spots
Expand All @@ -69,14 +70,15 @@ def __init__(
self.n_latent = n_latent
self.n_genes = n_genes
self.amortization = amortization
self.l1_sparsity = l1_sparsity
# unpack and copy parameters
self.decoder = FCLayers(
n_in=n_latent,
n_out=n_hidden,
n_cat_list=[n_labels],
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
dropout_rate=0.05,
canergen marked this conversation as resolved.
Show resolved Hide resolved
canergen marked this conversation as resolved.
Show resolved Hide resolved
use_layer_norm=True,
use_batch_norm=False,
)
Expand Down Expand Up @@ -121,17 +123,24 @@ def __init__(
n_cat_list=None,
n_layers=2,
n_hidden=n_hidden,
dropout_rate=0.1,
dropout_rate=0.05,
canergen marked this conversation as resolved.
Show resolved Hide resolved
use_layer_norm=True,
use_batch_norm=False,
),
torch.nn.Linear(n_hidden, n_latent * n_labels),
)
# cell type loadings
self.V_encoder = FCLayers(
n_in=self.n_genes,
n_out=self.n_labels + 1,
n_layers=2,
n_hidden=n_hidden,
dropout_rate=0.1,
self.V_encoder = torch.nn.Sequential(
FCLayers(
n_in=self.n_genes,
n_out=n_hidden,
n_layers=2,
n_hidden=n_hidden,
dropout_rate=0.05,
use_layer_norm=True,
use_batch_norm=False,
),
torch.nn.Linear(n_hidden, n_labels + 1),
canergen marked this conversation as resolved.
Show resolved Hide resolved
)

def _get_inference_input(self, tensors):
Expand All @@ -155,7 +164,7 @@ def generative(self, x, ind_x):
m = x.shape[0]
library = torch.sum(x, dim=1, keepdim=True)
# setup all non-linearities
beta = torch.nn.functional.softplus(self.beta) # n_genes
beta = torch.exp(self.beta) # n_genes
canergen marked this conversation as resolved.
Show resolved Hide resolved
eps = torch.nn.functional.softplus(self.eta) # n_genes
x_ = torch.log(1 + x)
# subsample parameters
Expand Down Expand Up @@ -217,14 +226,21 @@ def loss(
px_rate = generative_outputs["px_rate"]
px_o = generative_outputs["px_o"]
gamma = generative_outputs["gamma"]
v = generative_outputs["v"]

reconst_loss = -NegativeBinomial(px_rate, logits=px_o).log_prob(x).sum(-1)

# eta prior likelihood
mean = torch.zeros_like(self.eta)
scale = torch.ones_like(self.eta)
glo_neg_log_likelihood_prior = -Normal(mean, scale).log_prob(self.eta).sum()
glo_neg_log_likelihood_prior += torch.var(self.beta)
glo_neg_log_likelihood_prior = (
-1e-10 * Normal(mean, scale).log_prob(self.eta).sum()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that seems extreme, what's supporting this choice?

This comment was marked as outdated.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1e-4 also works (everything above leads to wrong cell type proportions). I don't have a strong reason that it helps with the three datasets to have it. 1e-4*n_obs is close to the old term. Beta_weighting is much bigger now and added as an input parameter.

)
glo_neg_log_likelihood_prior += 5.0 * torch.var(self.beta)

v_sparsity_loss = (
self.l1_sparsity * torch.mean(torch.exp(self.beta)) * torch.abs(v).mean(-1)
canergen marked this conversation as resolved.
Show resolved Hide resolved
)

# gamma prior likelihood
if self.mean_vprior is None:
Expand All @@ -245,16 +261,20 @@ def loss(
0
) # 1, p, n_labels, n_latent
pre_lse = (
Normal(mean_vprior, torch.sqrt(var_vprior)).log_prob(gamma).sum(-1)
Normal(mean_vprior, torch.sqrt(var_vprior) + 1e-4)
.log_prob(gamma)
.sum(-1)
) # minibatch, p, n_labels
log_likelihood_prior = torch.logsumexp(pre_lse, 1) - np.log(
self.p
) # minibatch, n_labels
neg_log_likelihood_prior = -log_likelihood_prior.sum(1) # minibatch
# mean_vprior is of shape n_labels, p, n_latent

loss = (
n_obs * torch.mean(reconst_loss + kl_weight * neg_log_likelihood_prior)
loss = n_obs * (
torch.mean(
reconst_loss + kl_weight * (neg_log_likelihood_prior + v_sparsity_loss)
canergen marked this conversation as resolved.
Show resolved Hide resolved
)
+ glo_neg_log_likelihood_prior
)

Expand Down Expand Up @@ -329,7 +349,7 @@ def get_ct_specific_expression(
integer for cell types
"""
# cell-type specific gene expression, shape (minibatch, celltype, gene).
beta = torch.nn.functional.softplus(self.beta) # n_genes
beta = torch.exp(self.beta) # n_genes
canergen marked this conversation as resolved.
Show resolved Hide resolved
y_torch = (y * torch.ones_like(ind_x)).ravel()
# obtain the relevant gammas
if self.amortization in ["both", "latent"]:
Expand Down
6 changes: 3 additions & 3 deletions scvi/module/_vaec.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def __init__(
n_hidden: int = 128,
n_latent: int = 5,
n_layers: int = 2,
dropout_rate: float = 0.1,
dropout_rate: float = 0.05,
log_variational: bool = True,
ct_weight: np.ndarray = None,
**module_kwargs,
Expand Down Expand Up @@ -84,7 +84,7 @@ def __init__(
n_cat_list=[n_labels],
n_layers=n_layers,
n_hidden=n_hidden,
dropout_rate=0,
dropout_rate=dropout_rate,
canergen marked this conversation as resolved.
Show resolved Hide resolved
inject_covariates=True,
use_batch_norm=False,
use_layer_norm=True,
Expand Down Expand Up @@ -129,7 +129,7 @@ def inference(self, x, y, n_samples=1):
Runs the inference (encoder) model.
"""
x_ = x
library = torch.log(x.sum(1)).unsqueeze(1)
library = x.sum(1).unsqueeze(1)
if self.log_variational:
x_ = torch.log(1 + x_)

Expand Down