Skip to content

Commit

Permalink
Remove all references to Variational
Browse files Browse the repository at this point in the history
  • Loading branch information
bdvllrs committed Mar 8, 2024
1 parent 7b20b20 commit c9eca21
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 42 deletions.
4 changes: 2 additions & 2 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
VarContrastiveLossType,
ContrastiveLossWithUncertaintyType,
contrastive_loss,
contrastive_loss_with_uncertainty,
)
Expand Down Expand Up @@ -71,7 +71,7 @@
"GWModule",
"GWModuleWithUncertainty",
"ContrastiveLossType",
"VarContrastiveLossType",
"ContrastiveLossWithUncertaintyType",
"contrastive_loss",
"ContrastiveLoss",
"contrastive_loss_with_uncertainty",
Expand Down
4 changes: 2 additions & 2 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
VarContrastiveLossType,
ContrastiveLossWithUncertaintyType,
contrastive_loss,
contrastive_loss_with_uncertainty,
)
Expand Down Expand Up @@ -57,7 +57,7 @@
"GWModule",
"GWModuleWithUncertainty",
"ContrastiveLossType",
"VarContrastiveLossType",
"ContrastiveLossWithUncertaintyType",
"contrastive_loss",
"ContrastiveLoss",
"contrastive_loss_with_uncertainty",
Expand Down
10 changes: 5 additions & 5 deletions shimmer/modules/contrastive_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
A function taking the prediction and targets and returning a LossOutput.
"""

VarContrastiveLossType = Callable[
ContrastiveLossWithUncertaintyType = Callable[
[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor], LossOutput
]
"""Contrastive loss function type for variational GlobalWorkspace.
"""Contrastive loss function type for GlobalWorkspaceWithUncertainty.
A function taking the prediction mean, prediction std, target mean and target std and
returns a LossOutput.
Expand Down Expand Up @@ -81,7 +81,7 @@ def contrastive_loss_with_uncertainty(
reduction: Literal["mean", "sum", "none"] = "mean",
) -> torch.Tensor:
"""CLIP-like contrastive loss with uncertainty.
This is used in Variational Global Workspaces.
This is used in Global Workspaces with uncertainty.
Args:
x (`torch.Tensor`): prediction
Expand Down Expand Up @@ -151,7 +151,7 @@ def forward(self, x: torch.Tensor, y: torch.Tensor) -> LossOutput:
class ContrastiveLossWithUncertainty(torch.nn.Module):
"""CLIP-like contrastive loss with uncertainty module.
This is used in Variational Global Workspaces.
This is used in Global Workspaces with uncertainty.
"""

def __init__(
Expand All @@ -161,7 +161,7 @@ def __init__(
learn_logit_scale: bool = False,
) -> None:
"""
ContrastiveLoss used for VariationalGlobalWorkspace
ContrastiveLoss used for GlobalWorkspaceWithUncertainty
Args:
logit_scale (`torch.Tensor`): logit_scale tensor.
Expand Down
27 changes: 15 additions & 12 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
ContrastiveLoss,
ContrastiveLossType,
ContrastiveLossWithUncertainty,
VarContrastiveLossType,
ContrastiveLossWithUncertaintyType,
)
from shimmer.modules.domain import DomainModule
from shimmer.modules.gw_module import (
Expand Down Expand Up @@ -553,13 +553,13 @@ def __init__(
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
use_var_contrastive_loss: bool = False,
use_cont_loss_with_uncertainty: bool = False,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
learn_logit_scale: bool = False,
contrastive_loss: ContrastiveLossType | None = None,
var_contrastive_loss: VarContrastiveLossType | None = None,
cont_loss_with_uncertainty: ContrastiveLossWithUncertaintyType | None = None,
) -> None:
"""Initializes a Global Workspace
Expand All @@ -575,8 +575,8 @@ def __init__(
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
use_var_contrastive_loss (`bool`): whether to use the variational
contrastive loss which uses means and log variance for computations.
use_cont_loss_with_uncertainty (`bool`): whether to use the contrastive
loss with uncertainty which uses means and log variance for computations.
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand All @@ -585,23 +585,26 @@ def __init__(
contrastive_loss (`ContrastiveLossType | None`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
var_contrastive_loss (`VarContrastiveLossType | None`): a variational
contrastive loss. Only used if `use_var_contrastive_loss` is set to
`True`.
cont_loss_with_uncertainty (`ContrastiveLossWithUncertaintyType | None`): a
contrastive loss with uncertainty.
Only used if `use_cont_loss_with_uncertainty` is set to `True`.
"""
domain_mods = freeze_domain_modules(domain_mods)

gw_mod = GWModuleWithUncertainty(
domain_mods, workspace_dim, gw_encoders, gw_decoders
)

if use_var_contrastive_loss:
if var_contrastive_loss is None:
var_contrastive_loss = ContrastiveLossWithUncertainty(
if use_cont_loss_with_uncertainty:
if cont_loss_with_uncertainty is None:
cont_loss_with_uncertainty = ContrastiveLossWithUncertainty(
torch.tensor([1]).log(), "mean", learn_logit_scale
)
loss_mod = GWLossesWithUncertainty(
gw_mod, domain_mods, loss_coefs, var_contrastive_fn=var_contrastive_loss
gw_mod,
domain_mods,
loss_coefs,
cont_fn_with_uncertainty=cont_loss_with_uncertainty,
)
else:
if contrastive_loss is None:
Expand Down
46 changes: 25 additions & 21 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import torch
import torch.nn.functional as F

from shimmer.modules.contrastive_loss import ContrastiveLossType, VarContrastiveLossType
from shimmer.modules.contrastive_loss import (
ContrastiveLossType,
ContrastiveLossWithUncertaintyType,
)
from shimmer.modules.domain import DomainModule, LossOutput
from shimmer.modules.gw_module import GWModule, GWModuleBase, GWModuleWithUncertainty
from shimmer.types import LatentsDomainGroupsT, ModelModeT
Expand Down Expand Up @@ -253,9 +256,9 @@ def contrastive_loss(
def contrastive_loss_with_uncertainty(
gw_mod: GWModuleWithUncertainty,
latent_domains: LatentsDomainGroupsT,
contrastive_fn: VarContrastiveLossType,
contrastive_fn: ContrastiveLossWithUncertaintyType,
) -> dict[str, torch.Tensor]:
"""Computes the variational contrastive loss with uncertainty.
"""Computes the contrastive loss with uncertainty.
This return multiple metrics:
* `contrastive_{domain_1}_and_{domain_2}` with the contrastive
Expand All @@ -267,9 +270,9 @@ def contrastive_loss_with_uncertainty(
`contrastive_{domain_1}_and_{domain_2}` values.
Args:
gw_mod (`VariationalGWModule`): The GWModule to use
gw_mod (`GWModuleWithUncertainty`): The GWModule to use
latent_domains (`LatentsDomainGroupsT`): the latent unimodal groups
contrastive_fn (`VarContrastiveLossType`): the variational contrastive function
contrastive_fn (`ContrastiveLossWithUncertaintyType`): the contrastive function
to apply
Returns:
Expand Down Expand Up @@ -474,19 +477,19 @@ def __init__(
domain_mods: dict[str, DomainModule],
loss_coefs: LossCoefs,
contrastive_fn: ContrastiveLossType | None = None,
var_contrastive_fn: VarContrastiveLossType | None = None,
cont_fn_with_uncertainty: ContrastiveLossWithUncertaintyType | None = None,
):
"""
Loss module with uncertainty to use with the GlobalWorkspaceWithUncertainty
Args:
gw_mod (`VariationalGWModule`): the GWModule
gw_mod (`GWModuleWithUncertainty`): the GWModule
domain_mods (`dict[str, DomainModule]`): a dict where the key is the
domain name and value is the DomainModule
loss_coefs (`VariationalLossCoefs`): loss coefficients
loss_coefs (`LossCoefsWithUncertainty`): loss coefficients
contrastive_fn (`ContrastiveLossType | None`): the contrastive function
to use in contrastive loss
var_contrastive_fn (`VarContrastiveLossType | None`): a contrastive
cont_fn_with_uncertainty (`ContrastiveLossWithUncertaintyType | None`): a contrastive
function that uses uncertainty
"""

Expand All @@ -502,15 +505,16 @@ def __init__(
"""The loss coefficients."""

assert (contrastive_fn is not None) != (
var_contrastive_fn is not None
), "Should either have contrastive_fn or var_contrastive_fn"
cont_fn_with_uncertainty is not None
), "Should either have contrastive_fn or cont_fn_with_uncertainty"

self.contrastive_fn = contrastive_fn
"""Contrastive loss to use without the use of uncertainty. This is only
used in `VariationalGWLosses.step` if `VariationalGWLosses.var_contrastive_fn`
is not set."""
used in `GWLossesWithUncertainty.step` if
`GWLossesWithUncertainty.cont_fn_with_uncertainty` is not set.
"""

self.var_contrastive_fn = var_contrastive_fn
self.cont_fn_with_uncertainty = cont_fn_with_uncertainty
"""Contrastive loss to use with the use of uncertainty."""

def demi_cycle_loss(
Expand Down Expand Up @@ -557,7 +561,7 @@ def contrastive_loss(
) -> dict[str, torch.Tensor]:
"""Contrastive loss.
If `VariationalGWLosses.var_contrastive_fn` is set, will use the
If `GWLossesWithUncertainty.cont_fn_with_uncertainty` is set, will use the
contrastive loss with uncertainty. Otherwise, use the traditional
contrastive loss (see `GWLosses.contrastive_loss`).
Expand All @@ -567,9 +571,9 @@ def contrastive_loss(
Returns:
`dict[str, torch.Tensor]`: a dict of metrics.
"""
if self.var_contrastive_fn is not None:
if self.cont_fn_with_uncertainty is not None:
return contrastive_loss_with_uncertainty(
self.gw_mod, latent_domains, self.var_contrastive_fn
self.gw_mod, latent_domains, self.cont_fn_with_uncertainty
)

assert self.contrastive_fn is not None
Expand All @@ -582,10 +586,10 @@ def step(
Computes and returns the losses
Contains:
- Demi-cycle metrics (see `VariationalGWLosses.demi_cycle_loss`)
- Cycle metrics (see `VariationalGWLosses.cycle_loss`)
- Translation metrics (see `VariationalGWLosses.translation_loss`)
- Contrastive metrics (see `VariationalGWLosses.contrastive_loss`)
- Demi-cycle metrics (see `GWLossesWithUncertainty.demi_cycle_loss`)
- Cycle metrics (see `GWLossesWithUncertainty.cycle_loss`)
- Translation metrics (see `GWLossesWithUncertainty.translation_loss`)
- Contrastive metrics (see `GWLossesWithUncertainty.contrastive_loss`)
Args:
domain_latents (`LatentsDomainGroupsT`): All latent groups
Expand Down

0 comments on commit c9eca21

Please sign in to comment.