Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LDA implementation #1132

Merged
merged 56 commits into from
Sep 23, 2021
Merged
Show file tree
Hide file tree
Changes from 54 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
3b1f28d
first draft lda model and module
justjhong Sep 2, 2021
f40cf02
working sanity test
justjhong Sep 2, 2021
94a9fce
nits
justjhong Sep 2, 2021
47dda75
make model and guide PyroModules
justjhong Sep 2, 2021
72ac8b5
renaming, remove alpha posterior, add optional prior params, use FCLa…
justjhong Sep 3, 2021
12110c5
auto move data
justjhong Sep 7, 2021
cea593a
fix device logic
justjhong Sep 7, 2021
8b62a62
automove for encoder
justjhong Sep 7, 2021
e923510
add components, transform, and perplexity to mirror sklearn
justjhong Sep 8, 2021
98afa7e
scale encoder output
justjhong Sep 8, 2021
f20a3fa
debug statement
justjhong Sep 8, 2021
2a5ccfd
softplus in encoder
justjhong Sep 9, 2021
c8da603
log(1+x) before passing x into encoder
justjhong Sep 9, 2021
1ea3376
remove debug
justjhong Sep 9, 2021
da87395
fix for non np array X
justjhong Sep 9, 2021
fc4dc05
get componenets returns dataframe
justjhong Sep 9, 2021
ec46ea8
fix bugs in perplexity computation
justjhong Sep 9, 2021
72095db
add docstrings to lda
justjhong Sep 9, 2021
df01916
remove use of dataset
justjhong Sep 10, 2021
c83d36d
codacy
justjhong Sep 10, 2021
48d5c8b
codacy fix
justjhong Sep 10, 2021
9ddf9db
add plate subsampling for proper elbo computation
justjhong Sep 10, 2021
f0b8905
debug stuff
justjhong Sep 14, 2021
f53f478
change to PyroParam
justjhong Sep 14, 2021
53f7d48
add get_elbo function
justjhong Sep 14, 2021
8b05db6
change to torch Parameter
justjhong Sep 14, 2021
63c381d
try aggregating elbos
justjhong Sep 15, 2021
40a509a
fix prior assignment
justjhong Sep 15, 2021
372fe91
categorical bow and scale based on input size
justjhong Sep 15, 2021
1fb55e1
remove sklearn implementations
justjhong Sep 15, 2021
ad7b2d0
add docstrings
justjhong Sep 15, 2021
7f34fcb
fix test
justjhong Sep 15, 2021
68cdee0
address comments
justjhong Sep 16, 2021
0dafed0
save load test
justjhong Sep 16, 2021
6a0efdf
add layer to encoder
justjhong Sep 16, 2021
d6e99b4
clamp dirichlet parameters
justjhong Sep 16, 2021
0bd349e
add kl annealing
justjhong Sep 17, 2021
a449277
minor fixes
justjhong Sep 17, 2021
ba08237
fix tests
justjhong Sep 17, 2021
c016281
change topic gene dist to use MAP inference
justjhong Sep 17, 2021
7adb29b
fix event dim
justjhong Sep 17, 2021
2f869bb
fix softmax dim
justjhong Sep 17, 2021
a22c04d
fix dtypes
justjhong Sep 17, 2021
f339d92
switch to logistic normal
justjhong Sep 17, 2021
0a41189
softplus variances
justjhong Sep 17, 2021
cc5f496
remove unnecessary softplus
justjhong Sep 17, 2021
385c768
add note about progress bar elbo
justjhong Sep 20, 2021
a36cff7
factor out kl weight logic
justjhong Sep 20, 2021
f8efc39
do not anneal emprical sample statement
justjhong Sep 20, 2021
54f2fe7
add lda to api docs
justjhong Sep 20, 2021
055d6e1
add blei reference
justjhong Sep 20, 2021
a505a62
improve api docs
justjhong Sep 20, 2021
a563ecb
rename to amortized lda
justjhong Sep 20, 2021
c158cd8
move check if trained to base
justjhong Sep 21, 2021
9517ae7
missing param doc
justjhong Sep 22, 2021
3ccbabb
add docstring
justjhong Sep 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/api/developer.rst
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ Existing module classes with respective generative and inference procedures.
module.TOTALVAE
module.VAE
module.VAEC
module.AmortizedLDAPyroModule


External module
Expand Down
1 change: 1 addition & 0 deletions docs/api/user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ Model
model.SCVI
model.TOTALVI
model.MULTIVI
model.AmortizedLDA



Expand Down
4 changes: 4 additions & 0 deletions docs/references.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,7 @@ References
.. [Lopez18] Romain Lopez, Jeffrey Regier, Michael Cole, Michael I. Jordan, Nir Yosef (2018),
*Deep generative modeling for single-cell transcriptomics*,
`Nature Methods <https://www.nature.com/articles/s41592-018-0229-2.epdf?author_access_token=5sMbnZl1iBFitATlpKkddtRgN0jAjWel9jnR3ZoTv0P1-tTjoP-mBfrGiMqpQx63aBtxToJssRfpqQ482otMbBw2GIGGeinWV4cULBLPg4L4DpCg92dEtoMaB1crCRDG7DgtNrM_1j17VfvHfoy1cQ%3D%3D>`__.

.. [Blei03] David M. Blei, Andrew Y. Ng, Michael I. Jordan (2003),
*Latent Dirichlet Allocation*,
`Journal of Machine Learning Research <https://www.jmlr.org/papers/volume3/blei03a/blei03a.pdf>`__.
3 changes: 1 addition & 2 deletions scvi/dataloaders/_anntorchdataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import h5py
import numpy as np
import pandas as pd
import torch
from anndata._core.sparse_dataset import SparseDataset
from torch.utils.data import Dataset

Expand Down Expand Up @@ -93,7 +92,7 @@ def setup_getitem(self):

self.attributes_and_types = keys_to_type

def __getitem__(self, idx: List[int]) -> Dict[str, torch.Tensor]:
def __getitem__(self, idx: List[int]) -> Dict[str, np.ndarray]:
"""Get tensors in dictionary from anndata at idx."""
data_numpy = {}
for key, dtype in self.attributes_and_types.items():
Expand Down
11 changes: 2 additions & 9 deletions scvi/external/stereoscope/_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from typing import Optional, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -220,10 +219,7 @@ def get_proportions(self, keep_noise=False) -> pd.DataFrame:
keep_noise
whether to account for the noise term as a standalone cell type in the proportion estimate.
"""
if self.is_trained_ is False:
warnings.warn(
"Trying to query inferred values from an untrained model. Please train the model first."
)
self._check_if_trained()

column_names = self.cell_type_mapping
if keep_noise:
Expand All @@ -249,10 +245,7 @@ def get_scale_for_ct(
-------
gene_expression
"""
if self.is_trained_ is False:
warnings.warn(
"Trying to query inferred values from an untrained model. Please train the model first."
)
self._check_if_trained()
ind_y = np.array([np.where(ct == self.cell_type_mapping)[0][0] for ct in y])
if ind_y.shape != y.shape:
raise ValueError(
Expand Down
2 changes: 2 additions & 0 deletions scvi/model/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from ._amortizedlda import AmortizedLDA
from ._autozi import AUTOZI
from ._condscvi import CondSCVI
from ._destvi import DestVI
Expand All @@ -18,4 +19,5 @@
"CondSCVI",
"DestVI",
"MULTIVI",
"AmortizedLDA",
]
240 changes: 240 additions & 0 deletions scvi/model/_amortizedlda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
import collections.abc
import logging
from typing import Optional, Sequence, Union

import numpy as np
import pandas as pd
import pyro
import torch
from anndata import AnnData

from scvi._constants import _CONSTANTS
from scvi.module import AmortizedLDAPyroModule

from .base import BaseModelClass, PyroSviTrainMixin

logger = logging.getLogger(__name__)


class AmortizedLDA(PyroSviTrainMixin, BaseModelClass):
"""
Amortized Latent Dirichlet Allocation [Blei03]_.

Parameters
----------
adata
AnnData object that has been registered via :func:`~scvi.data.setup_anndata`.
n_topics
Number of topics to model.
n_hidden
Number of nodes in the hidden layer of the encoder.
cell_topic_prior
Prior of cell topic distribution. If `None`, defaults to `1 / n_topics`.
topic_gene_prior
Prior of topic gene distribution. If `None`, defaults to `1 / n_topics`.

Examples
--------
>>> adata = anndata.read_h5ad(path_to_anndata)
>>> scvi.data.setup_anndata(adata)
>>> model = scvi.model.AmortizedLDA(adata)
>>> model.train()
>>> gene_by_topic = model.get_gene_by_topic()
>>> adata.obsm["X_LDA"] = model.get_latent_representation()
"""

def __init__(
self,
adata: AnnData,
n_topics: int = 20,
n_hidden: int = 128,
cell_topic_prior: Optional[Union[float, Sequence[float]]] = None,
topic_gene_prior: Optional[Union[float, Sequence[float]]] = None,
):
# in case any other model was created before that shares the same parameter names.
pyro.clear_param_store()

super().__init__(adata)

n_input = self.summary_stats["n_vars"]

if (
cell_topic_prior is not None
and not isinstance(cell_topic_prior, float)
and (
not isinstance(cell_topic_prior, collections.abc.Sequence)
or len(cell_topic_prior) != n_topics
)
):
raise ValueError(
f"cell_topic_prior, {cell_topic_prior}, must be None, "
f"a float or a Sequence of length n_topics."
)
if (
topic_gene_prior is not None
and not isinstance(topic_gene_prior, float)
and (
not isinstance(topic_gene_prior, collections.abc.Sequence)
or len(topic_gene_prior) != n_input
)
):
raise ValueError(
f"topic_gene_prior, {topic_gene_prior}, must be None, "
f"a float or a Sequence of length n_input."
)

self.module = AmortizedLDAPyroModule(
n_input=n_input,
n_topics=n_topics,
n_hidden=n_hidden,
cell_topic_prior=cell_topic_prior,
topic_gene_prior=topic_gene_prior,
)

self.init_params_ = self._get_init_params(locals())

def get_gene_by_topic(self, give_mean=True) -> pd.DataFrame:
"""
Gets the gene by topic matrix.

Parameters
----------
adata
AnnData to transform. If None, returns the gene by topic matrix for
the source AnnData.
give_mean
Give mean of distribution if True or sample from it.

Returns
-------
A `n_var x n_topics` Pandas DataFrame containing the gene by topic matrix.
"""
self._check_if_trained(warn=False)

topic_by_gene = self.module.topic_by_gene(give_mean=give_mean)

return pd.DataFrame(
data=topic_by_gene.numpy().T,
index=self.adata.var_names,
columns=[f"topic_{i}" for i in range(topic_by_gene.shape[0])],
)

def get_latent_representation(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
give_mean: bool = True,
) -> pd.DataFrame:
"""
Converts a count matrix to an inferred topic distribution.

Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
give_mean
Give mean of distribution or sample from it.

Returns
-------
A `n_obs x n_topics` Pandas DataFrame containing the normalized estimate
of the topic distribution for each observation.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

dl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)

transformed_xs = []
for tensors in dl:
x = tensors[_CONSTANTS.X_KEY]
transformed_xs.append(
self.module.get_topic_distribution(x, give_mean=give_mean)
)
transformed_x = torch.cat(transformed_xs).numpy()

return pd.DataFrame(
data=transformed_x,
index=adata.obs_names,
columns=[f"topic_{i}" for i in range(transformed_x.shape[1])],
)

def get_elbo(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> float:
"""
Return the ELBO for the data.

The ELBO is a lower bound on the log likelihood of the data used for optimization
of VAEs. Note, this is not the negative ELBO, higher is better.

Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

Returns
-------
The positive ELBO.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

dl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)

elbos = []
for tensors in dl:
x = tensors[_CONSTANTS.X_KEY]
library = x.sum(dim=1)
elbos.append(self.module.get_elbo(x, library, len(dl.indices)))
return np.mean(elbos)

def get_perplexity(
self,
adata: Optional[AnnData] = None,
indices: Optional[Sequence[int]] = None,
batch_size: Optional[int] = None,
) -> float:
"""
Computes approximate perplexity for `adata`.

Perplexity is defined as exp(-1 * log-likelihood per count).

Parameters
----------
adata
AnnData object with equivalent structure to initial AnnData. If `None`, defaults to the
AnnData object used to initialize the model.
indices
Indices of cells in adata to use. If `None`, all cells are used.
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.

Returns
-------
Perplexity.
"""
self._check_if_trained(warn=False)
adata = self._validate_anndata(adata)

dl = self._make_data_loader(adata=adata, indices=indices, batch_size=batch_size)
total_counts = sum(tensors[_CONSTANTS.X_KEY].sum().item() for tensors in dl)

return np.exp(
self.get_elbo(adata=adata, indices=indices, batch_size=batch_size)
/ total_counts
)
16 changes: 3 additions & 13 deletions scvi/model/_destvi.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import logging
import warnings
from collections import OrderedDict
from typing import Dict, Optional, Sequence, Union

Expand Down Expand Up @@ -161,10 +160,7 @@ def get_proportions(
batch_size
Minibatch size for data loading into model. Only used if amortization. Defaults to `scvi.settings.batch_size`.
"""
if self.is_trained_ is False:
warnings.warn(
"Trying to query inferred values from an untrained model. Please train the model first."
)
self._check_if_trained()

column_names = self.cell_type_mapping
index_names = self.adata.obs.index
Expand Down Expand Up @@ -216,10 +212,7 @@ def get_gamma(
return_numpy
if activated, will return a numpy array of shape is n_spots x n_latent x n_labels.
"""
if self.is_trained_ is False:
warnings.warn(
"Trying to query inferred values from an untrained model. Please train the model first."
)
self._check_if_trained()

column_names = np.arange(self.module.n_latent)
index_names = self.adata.obs.index
Expand Down Expand Up @@ -276,10 +269,7 @@ def get_scale_for_ct(
-------
Pandas dataframe of gene_expression
"""
if self.is_trained_ is False:
warnings.warn(
"Trying to query inferred values from an untrained model. Please train the model first."
)
self._check_if_trained()

if label not in self.cell_type_mapping:
raise ValueError("Unknown cell type")
Expand Down
5 changes: 3 additions & 2 deletions scvi/model/_scanvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,9 @@ def from_scvi_model(
scanvi_kwargs
kwargs for scanVI model
"""
if scvi_model.is_trained_ is False:
warnings.warn("Passed in scvi model hasn't been trained yet.")
scvi_model._check_if_trained(
message="Passed in scvi model hasn't been trained yet."
)

scanvi_kwargs = dict(scanvi_kwargs)
init_params = scvi_model.init_params_
Expand Down
3 changes: 1 addition & 2 deletions scvi/model/_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,8 +295,7 @@ def get_latent_library_size(
batch_size
Minibatch size for data loading into model. Defaults to `scvi.settings.batch_size`.
"""
if not self.is_trained_:
raise RuntimeError("Please train the model first.")
self._check_if_trained(warn=False)

adata = self._validate_anndata(adata)
post = self._make_data_loader(
Expand Down
Loading