Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
fix: fixed a bug in KL loss aggregation (LVAE) (#277)
### Description Found a bug in the KL loss aggregation happening in the `LadderVAE` model `training_step()`. Specifically, the application of free-bits (`free_bits_kl()`, basically clamping the values of KL entries to a certain lower threshold) was happening after KL entries were rescaled. In this way, when free-bits threshold was set to 1, all the KL entries were clamped to 1, as normally way smaller than this. - **What**: See above. - **Why**: Clear bug in the code. - **How**: Inverted the order of calls in the `get_kl_divergence_loss()` function & adjusted some parts of the code to reflect the changes. --- **Please ensure your PR meets the following requirements:** - [x] Code builds and passes tests locally, including doctests - [ ] New tests have been added (for bug fixes/features) - [x] Pre-commit passes - [ ] PR to the documentation exists (for bug fixes / features) --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Joran Deschamps <[email protected]>
- Loading branch information