Skip to content

Commit

Permalink
Backport PR #1269: Min kl weight only for pyro (#1270)
Browse files Browse the repository at this point in the history
Co-authored-by: Justin Hong <[email protected]>
  • Loading branch information
meeseeksmachine and justjhong authored Nov 19, 2021
1 parent 28394bc commit d74be44
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 1 deletion.
20 changes: 20 additions & 0 deletions docs/release_notes/v0.14.5.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
New in 0.14.5
-------------

Bug fixes.

Changes
~~~~~~~
- Fix `kl_weight` floor for Pytorch-based models (`#1269`_).

Contributors
~~~~~~~~~~~~
- `@adamgayoso`_
- `@jjhong922`_
- `@watiss`_

.. _`@adamgayoso`: https://github.com/adamgayoso
.. _`@jjhong922`: https://github.com/jjhong922
.. _`@watiss`: https://github.com/watiss

.. _`#1269` : https://github.com/YosefLab/scvi-tools/pull/1269
6 changes: 5 additions & 1 deletion scvi/train/_trainingplans.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def _compute_kl_weight(
step: int,
n_epochs_kl_warmup: Optional[int],
n_steps_kl_warmup: Optional[int],
min_weight: Optional[float] = None,
) -> float:
epoch_criterion = n_epochs_kl_warmup is not None
step_criterion = n_steps_kl_warmup is not None
Expand All @@ -28,7 +29,9 @@ def _compute_kl_weight(
kl_weight = min(1.0, step / n_steps_kl_warmup)
else:
kl_weight = 1.0
return max(kl_weight, 1e-3)
if min_weight is not None:
kl_weight = max(kl_weight, min_weight)
return kl_weight


class TrainingPlan(pl.LightningModule):
Expand Down Expand Up @@ -734,6 +737,7 @@ def kl_weight(self):
self.global_step,
self.n_epochs_kl_warmup,
self.n_steps_kl_warmup,
min_weight=1e-3,
)


Expand Down

0 comments on commit d74be44

Please sign in to comment.