diff --git a/scvi/module/_amortizedlda.py b/scvi/module/_amortizedlda.py index 60859295d7..70ec6aea1d 100644 --- a/scvi/module/_amortizedlda.py +++ b/scvi/module/_amortizedlda.py @@ -204,7 +204,7 @@ def forward( pyro.sample( "log_cell_topic_dist", dist.Normal( - cell_topic_posterior_mu, F.softplus(cell_topic_posterior_sigma) + cell_topic_posterior_mu, cell_topic_posterior_sigma ).to_event(1), )