Skip to content

Commit

Permalink
fix: fixed a bug in KL loss aggregation (LVAE) (#277)
Browse files Browse the repository at this point in the history
### Description

Found a bug in the KL loss aggregation happening in the `LadderVAE`
model `training_step()`.
Specifically, the application of free-bits (`free_bits_kl()`, basically
clamping the values of KL entries to a certain lower threshold) was
happening after KL entries were rescaled. In this way, when free-bits
threshold was set to 1, all the KL entries were clamped to 1, as
normally way smaller than this.

- **What**: See above.
- **Why**: Clear bug in the code.
- **How**: Inverted the order of calls in the `get_kl_divergence_loss()`
function & adjusted some parts of the code to reflect the changes.
---

**Please ensure your PR meets the following requirements:**

- [x] Code builds and passes tests locally, including doctests
- [ ] New tests have been added (for bug fixes/features)
- [x] Pre-commit passes
- [ ] PR to the documentation exists (for bug fixes / features)

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Joran Deschamps <[email protected]>
  • Loading branch information
3 people authored Nov 23, 2024
1 parent 240e4d3 commit 0e0bc28
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 46 deletions.
18 changes: 9 additions & 9 deletions src/careamics/losses/lvae/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,30 +168,30 @@ def get_kl_divergence_loss(
dim=1,
) # shape: (B, n_layers)

# Apply free bits (& batch average)
kl = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)

# In 3D case, rescale by Z dim
# TODO If we have downsampling in Z dimension, then this needs to change.
if len(img_shape) == 3:
kl = kl / img_shape[0]

# Rescaling
if rescaling == "latent_dim":
for i in range(kl.shape[1]):
for i in range(len(kl)):
latent_dim = topdown_data["z"][i].shape[1:]
norm_factor = np.prod(latent_dim)
kl[:, i] = kl[:, i] / norm_factor
kl[i] = kl[i] / norm_factor
elif rescaling == "image_dim":
kl = kl / np.prod(img_shape[-2:])

# Apply free bits
kl_loss = free_bits_kl(kl, free_bits_coeff) # shape: (n_layers,)

# Aggregation
if aggregation == "mean":
kl_loss = kl_loss.mean() # shape: (1,)
kl = kl.mean() # shape: (1,)
elif aggregation == "sum":
kl_loss = kl_loss.sum() # shape: (1,)
kl = kl.sum() # shape: (1,)

return kl_loss
return kl


def _get_kl_divergence_loss_musplit(
Expand Down Expand Up @@ -220,7 +220,7 @@ def _get_kl_divergence_loss_musplit(
The KL divergence loss for the muSplit case. Shape is (1, ).
"""
return get_kl_divergence_loss(
kl_type=kl_type,
kl_type="kl", # TODO: hardcoded, deal in future PR
topdown_data=topdown_data,
rescaling="latent_dim",
aggregation="mean",
Expand Down
37 changes: 0 additions & 37 deletions src/careamics/models/lvae/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -402,40 +402,3 @@ def kl_normal_mc(z, p_mulv, q_mulv):
p_distrib = Normal(p_mu.get(), p_std)
q_distrib = Normal(q_mu.get(), q_std)
return q_distrib.log_prob(z) - p_distrib.log_prob(z)


def free_bits_kl(
kl: torch.Tensor, free_bits: float, batch_average: bool = False, eps: float = 1e-6
) -> torch.Tensor:
"""
Computes free-bits version of KL divergence.
Ensures that the KL doesn't go to zero for any latent dimension.
Hence, it contributes to use latent variables more efficiently,
leading to better representation learning.
NOTE:
Takes in the KL with shape (batch size, layers), returns the KL with
free bits (for optimization) with shape (layers,), which is the average
free-bits KL per layer in the current batch.
If batch_average is False (default), the free bits are per layer and
per batch element. Otherwise, the free bits are still per layer, but
are assigned on average to the whole batch. In both cases, the batch
average is returned, so it's simply a matter of doing mean(clamp(KL))
or clamp(mean(KL)).
Args:
kl (torch.Tensor)
free_bits (float)
batch_average (bool, optional))
eps (float, optional)
Returns
-------
The KL with free bits
"""
assert kl.dim() == 2
if free_bits < eps:
return kl.mean(0)
if batch_average:
return kl.mean(0).clamp(min=free_bits)
return kl.clamp(min=free_bits).mean(0)

0 comments on commit 0e0bc28

Please sign in to comment.