diff --git a/docs/release_notes/v0.14.5.rst b/docs/release_notes/v0.14.5.rst new file mode 100644 index 0000000000..b5cdade6d1 --- /dev/null +++ b/docs/release_notes/v0.14.5.rst @@ -0,0 +1,20 @@ +New in 0.14.5 +------------- + +Bug fixes. + +Changes +~~~~~~~ +- Fix `kl_weight` floor for Pytorch-based models (`#1269`_). + +Contributors +~~~~~~~~~~~~ +- `@adamgayoso`_ +- `@jjhong922`_ +- `@watiss`_ + +.. _`@adamgayoso`: https://github.com/adamgayoso +.. _`@jjhong922`: https://github.com/jjhong922 +.. _`@watiss`: https://github.com/watiss + +.. _`#1269` : https://github.com/YosefLab/scvi-tools/pull/1269 diff --git a/scvi/train/_trainingplans.py b/scvi/train/_trainingplans.py index 26530ecb14..26db824e61 100644 --- a/scvi/train/_trainingplans.py +++ b/scvi/train/_trainingplans.py @@ -19,6 +19,7 @@ def _compute_kl_weight( step: int, n_epochs_kl_warmup: Optional[int], n_steps_kl_warmup: Optional[int], + min_weight: Optional[float] = None, ) -> float: epoch_criterion = n_epochs_kl_warmup is not None step_criterion = n_steps_kl_warmup is not None @@ -28,7 +29,9 @@ def _compute_kl_weight( kl_weight = min(1.0, step / n_steps_kl_warmup) else: kl_weight = 1.0 - return max(kl_weight, 1e-3) + if min_weight is not None: + kl_weight = max(kl_weight, min_weight) + return kl_weight class TrainingPlan(pl.LightningModule): @@ -734,6 +737,7 @@ def kl_weight(self): self.global_step, self.n_epochs_kl_warmup, self.n_steps_kl_warmup, + min_weight=1e-3, )