Skip to content

Commit

Permalink
debug stuff
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Sep 14, 2021
1 parent f407dbe commit 7218655
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 14 deletions.
24 changes: 24 additions & 0 deletions scvi/model/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,30 @@ def transform(self, adata: Optional[AnnData] = None) -> pd.DataFrame:
transformed_x = torch.cat(transformed_xs).numpy()
return pd.DataFrame(data=transformed_x, index=user_adata.obs_names)

def test_perplexity(self, adata: Optional[AnnData] = None) -> float:
if adata is not None:
self._check_var_equality(adata)
self._check_if_not_trained()

user_adata = adata or self.adata
dl = self._make_data_loader(
adata=user_adata, indices=np.arange(user_adata.n_obs)
)

perplexities = []
total_counts = 0
# batch_counts = []
for tensors in dl:
x = tensors[_CONSTANTS.X_KEY]
x_counts = x.sum().item()
total_counts += x_counts
perplexities.append(self.module.elbo(x))

return np.exp(np.mean(perplexities) / total_counts)

# normalized_batch_counts = np.array(batch_counts) / np.sum(batch_counts)
# return np.prod(np.power(perplexities, normalized_batch_counts))

def perplexity(self, adata: Optional[AnnData] = None) -> float:
"""
Computes approximate perplexity for `adata`.
Expand Down
21 changes: 9 additions & 12 deletions scvi/module/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from pyro.infer import Trace_ELBO
from pyro.nn import PyroModule
from scipy.special import gammaln, logsumexp, psi
from torch.distributions import constraints
Expand Down Expand Up @@ -261,6 +262,12 @@ def transform(self, x: torch.Tensor) -> torch.Tensor:
/ cell_component_unnormalized_dist.sum(axis=1)[:, np.newaxis]
)

def elbo(self, x: torch.Tensor) -> float:
elbo = Trace_ELBO().loss(self.model, self.guide, x, x.sum(dim=1))
print("elbo")
print(elbo)
return elbo

def perplexity(self, x: torch.Tensor) -> float:
"""
Computes the approximate perplexity of the for `x`.
Expand All @@ -273,18 +280,6 @@ def perplexity(self, x: torch.Tensor) -> float:
Perplexity as a float.
"""

def simple_elbo(model, guide, *args, **kwargs):
# run the guide and trace its execution
guide_trace = pyro.poutine.trace(guide).get_trace(*args, **kwargs)
# run the model and replay it against the samples from the guide
model_trace = pyro.poutine.trace(
pyro.poutine.replay(model, trace=guide_trace)
).get_trace(*args, **kwargs)
# construct the elbo loss function
return -1 * (model_trace.log_prob_sum() - guide_trace.log_prob_sum())

print(simple_elbo(self.model, self.guide, x, x.sum(dim=1)))

def dirichlet_log_coef(dirichlet_dist: np.ndarray) -> np.ndarray:
return (
psi(dirichlet_dist) - psi(np.sum(dirichlet_dist, axis=1))[:, np.newaxis]
Expand Down Expand Up @@ -326,5 +321,7 @@ def dirichlet_ll(prior: float, dist: torch.Tensor, size: int) -> float:
# E[log p(components | component_gene_prior) - log q(components | component_gene_posterior)]
score += dirichlet_ll(self.component_gene_prior, self.components, self.n_input)

print("score")
print(score)

return np.exp(-1.0 * score / total_count)
6 changes: 4 additions & 2 deletions tests/models/test_pyro.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,11 +395,13 @@ def test_lda_model():
assert adata_lda.shape == (adata.n_obs, n_topics) and np.all(
(adata_lda <= 1) & (adata_lda >= 0)
)
mod.perplexity()
print(mod.perplexity())
print(mod.test_perplexity())

adata2 = synthetic_iid()
adata2_lda = mod.transform(adata2).to_numpy()
assert adata2_lda.shape == (adata2.n_obs, n_topics) and np.all(
(adata2_lda <= 1) & (adata2_lda >= 0)
)
mod.perplexity(adata2)
print(mod.perplexity(adata2))
print(mod.test_perplexity(adata2))

0 comments on commit 7218655

Please sign in to comment.