Skip to content

Commit

Permalink
fix dtypes
Browse files Browse the repository at this point in the history
  • Loading branch information
justjhong committed Sep 17, 2021
1 parent 2f869bb commit a22c04d
Showing 1 changed file with 10 additions and 6 deletions.
16 changes: 10 additions & 6 deletions scvi/module/_lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ def __init__(
self,
n_input: int,
n_topics: int,
cell_topic_prior: Sequence[float],
topic_gene_prior: Sequence[float],
cell_topic_prior: torch.Tensor,
topic_gene_prior: torch.Tensor,
):
super().__init__(_LDA_PYRO_MODULE_NAME)

Expand All @@ -61,9 +61,9 @@ def __init__(

self.register_buffer(
"cell_topic_prior",
torch.FloatTensor(cell_topic_prior),
cell_topic_prior.clone().detach(),
)
self.register_buffer("topic_gene_prior", torch.FloatTensor(topic_gene_prior))
self.register_buffer("topic_gene_prior", topic_gene_prior.clone().detach())

# Hack: to allow auto_move_data to infer device.
self._dummy = torch.nn.Parameter(torch.zeros(1), requires_grad=False)
Expand Down Expand Up @@ -166,10 +166,14 @@ def __init__(self, n_input: int, n_topics: int, n_hidden: int):
self.n_obs = None

self.encoder = CellTopicDistPriorEncoder(n_input, n_topics, n_hidden)
self.topic_gene_posterior = torch.nn.Parameter(
self.unconstrained_topic_gene_posterior = torch.nn.Parameter(
torch.ones(self.n_topics, self.n_input),
)

@property
def topic_gene_posterior(self):
return F.softmax(self.unconstrained_topic_gene_posterior, dim=1)

@auto_move_data
def forward(
self,
Expand All @@ -182,7 +186,7 @@ def forward(
with pyro.plate("topics", self.n_topics), poutine.scale(None, kl_weight):
pyro.sample(
"topic_gene_dist",
dist.Delta(F.softmax(self.topic_gene_posterior, dim=1), event_dim=1),
dist.Delta(self.topic_gene_posterior, event_dim=1),
)

# Cell topic distributions guide.
Expand Down

0 comments on commit a22c04d

Please sign in to comment.