Skip to content

Commit

Permalink
Add option to switch off bayesian cont loss in bayesian GW (#71)
Browse files Browse the repository at this point in the history
bdvllrs authored May 21, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent d2b67c0 commit c0eb83b
Showing 2 changed files with 17 additions and 3 deletions.
4 changes: 4 additions & 0 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
@@ -682,6 +682,7 @@ def __init__(
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
use_normalized_constrastive: bool = True,
contrastive_loss: ContrastiveLossType | None = None,
) -> None:
"""
@@ -706,6 +707,8 @@ def __init__(
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
learn_logit_scale (`bool`): whether to learn the contrastive learning
contrastive loss when using the default contrastive loss.
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
@@ -733,6 +736,7 @@ def __init__(
domain_mods,
loss_coefs,
contrastive_loss,
use_normalized_constrastive,
)

super().__init__(
16 changes: 13 additions & 3 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
@@ -316,10 +316,13 @@ def contrastive_loss_bayesian(
loss_output = contrastive_fn(
z1 * coef[0] * coef[1], z2 * coef[0] * coef[1]
)
loss_output_no_norm = contrastive_fn(z1, z2)

losses[loss_name] = loss_output.loss
metrics.update(
{f"{loss_name}_{k}": v for k, v in loss_output.metrics.items()}
)
metrics[f"unnorm_{loss_name}"] = loss_output_no_norm.loss

losses["contrastives"] = torch.stack(list(losses.values()), dim=0).mean()
losses.update(metrics)
@@ -763,6 +766,7 @@ def __init__(
domain_mods: dict[str, DomainModule],
loss_coefs: BroadcastLossCoefs,
contrastive_fn: ContrastiveLossType,
use_normalized_constrastive: bool = True,
):
"""
Loss module with uncertainty prediction to use with the GlobalWorkspaceBayesian
@@ -775,6 +779,8 @@ def __init__(
loss_coefs (`BroadcastLossCoefs`): loss coefficients
contrastive_fn (`ContrastiveLossType`): the contrastive function
to use in contrastive loss
use_normalized_constrastive (`bool`): whether to use the normalized cont
loss by the precision coefs
"""
super().__init__()

@@ -795,6 +801,8 @@ def __init__(
Contrastive loss to use.
"""

self.use_normalized_constrastive = use_normalized_constrastive

def contrastive_loss(
self, latent_domains: LatentsDomainGroupsT
) -> dict[str, torch.Tensor]:
@@ -807,9 +815,11 @@ def contrastive_loss(
Returns:
`dict[str, torch.Tensor]`: a dict of metrics.
"""
return contrastive_loss_bayesian(
self.gw_mod, latent_domains, self.contrastive_fn
)
if self.use_normalized_constrastive:
return contrastive_loss_bayesian(
self.gw_mod, latent_domains, self.contrastive_fn
)
return contrastive_loss(self.gw_mod, latent_domains, self.contrastive_fn)

def broadcast_loss(
self, latent_domains: LatentsDomainGroupsT

0 comments on commit c0eb83b

Please sign in to comment.