From d5c8f5e3d680d7561cb98419881e770fcd03533b Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 15 Oct 2024 12:11:29 +0200 Subject: [PATCH] feat: more flexibility in loss coefs (#178) We were before limited to the top level losses (`translations`, `contrastives`, ...). We can now define the loss by selecting any metrics directly from the coefs. For example: ```python {"translation_v_to_t": 5.0, "translation_t_to_v": 1.0} ``` will only use these two components for the total loss. I kept the LossCoefs and BroadcastLossCoefs classes to avoid breaking changes, but to use this new behavior, dicts can now be used directly. --- shimmer/__init__.py | 2 + shimmer/modules/__init__.py | 2 + shimmer/modules/global_workspace.py | 13 ++-- shimmer/modules/losses.py | 104 ++++++++++++++++------------ 4 files changed, 70 insertions(+), 51 deletions(-) diff --git a/shimmer/__init__.py b/shimmer/__init__.py index db93b88e..53862080 100644 --- a/shimmer/__init__.py +++ b/shimmer/__init__.py @@ -37,6 +37,7 @@ GWLosses2Domains, GWLossesBase, LossCoefs, + combine_loss, ) from shimmer.modules.selection import ( RandomSelection, @@ -84,6 +85,7 @@ "ContrastiveLoss", "LossCoefs", "BroadcastLossCoefs", + "combine_loss", "GWLossesBase", "GWLosses2Domains", "RepeatedDataset", diff --git a/shimmer/modules/__init__.py b/shimmer/modules/__init__.py index 90f7171b..cd5957e2 100644 --- a/shimmer/modules/__init__.py +++ b/shimmer/modules/__init__.py @@ -31,6 +31,7 @@ GWLosses2Domains, GWLossesBase, LossCoefs, + combine_loss, ) from shimmer.modules.selection import ( RandomSelection, @@ -63,6 +64,7 @@ "ContrastiveLoss", "LossCoefs", "BroadcastLossCoefs", + "combine_loss", "GWLossesBase", "GWLosses2Domains", "RepeatedDataset", diff --git a/shimmer/modules/global_workspace.py b/shimmer/modules/global_workspace.py index 02a37b66..9ca55fe7 100644 --- a/shimmer/modules/global_workspace.py +++ b/shimmer/modules/global_workspace.py @@ -657,7 +657,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: LossCoefs, + loss_coefs: LossCoefs | Mapping[str, float], optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, scheduler_args: SchedulerArgs | None = None, @@ -682,7 +682,7 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients + loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients optim_lr (`float`): learning rate optim_weight_decay (`float`): weight decay scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments @@ -734,7 +734,7 @@ def __init__( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: BroadcastLossCoefs, + loss_coefs: BroadcastLossCoefs | Mapping[str, float], selection_temperature: float = 0.2, optim_lr: float = 1e-3, optim_weight_decay: float = 0.0, @@ -760,7 +760,8 @@ def __init__( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses. + loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the + losses. selection_temperature (`float`): temperature value for the RandomSelection module. optim_lr (`float`): learning rate @@ -808,7 +809,7 @@ def pretrained_global_workspace( gw_encoders: Mapping[str, Module], gw_decoders: Mapping[str, Module], workspace_dim: int, - loss_coefs: LossCoefs, + loss_coefs: LossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, scheduler: LRScheduler | None @@ -831,7 +832,7 @@ def pretrained_global_workspace( name to a `torch.nn.Module` class which role is to decode a GW representation into a unimodal latent representations. workspace_dim (`int`): dimension of the GW. - loss_coefs (`LossCoefs`): loss coefficients + loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients contrastive_loss (`ContrastiveLossType`): a contrastive loss function used for alignment. `learn_logit_scale` will not affect custom contrastive losses. diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index d9c95499..ef60f8dc 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -304,6 +304,61 @@ class LossCoefs(TypedDict, total=False): """Contrastive loss coefficient.""" +class BroadcastLossCoefs(TypedDict, total=False): + """ + Dict of loss coefficients used in the GWLossesFusion. + + If one is not provided, the coefficient is assumed to be 0 and will not be logged. + If the loss is excplicitely set to 0, it will be logged, but not take part in + the total loss. + """ + + contrastives: float + """Contrastive loss coefficient.""" + + fused: float + """fused loss coefficient (encode multiple domains and decode to one of them).""" + + demi_cycles: float + """demi_cycles loss coefficient. Demi-cycles are always one-to-one""" + + cycles: float + """cycles loss coefficient. Cycles can be many-to-one""" + + translations: float + """translation loss coefficient. Translation, like cycles, can be many-to-one.""" + + +def combine_loss( + metrics: dict[str, torch.Tensor], + coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs, +) -> torch.Tensor: + """ + Combines the metrics according to the ones selected in coefs + + Args: + metrics (`dict[str, torch.Tensor]`): all metrics to combine + coefs (`Mapping[str, float] | LossCoefs | BroadcastLossCoefs`): coefs for + selected metrics. Note, every metric does not need to be included here. + If not specified, the metric will not count in the final loss. + Also not that some metrics are redundant (e.g. `translations` contains + all of the `translation_{domain_1}_to_{domain_2}`). You can look at the + docs of the different losses for available values. + + Returns: + `torch.Tensor`: the combined loss. + """ + loss = torch.stack( + [ + metrics[name] * coef + for name, coef in coefs.items() + if name in metrics and isinstance(coef, float) and coef > 0 + ], + dim=0, + ).mean() + return loss + + class GWLosses2Domains(GWLossesBase): """ Implementation of `GWLossesBase` used for `GWModule`. @@ -314,7 +369,7 @@ def __init__( gw_mod: GWModule, selection_mod: SelectionBase, domain_mods: dict[str, DomainModule], - loss_coefs: LossCoefs, + loss_coefs: LossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, ): """ @@ -440,16 +495,7 @@ def step( metrics.update(self.translation_loss(domain_latents, raw_data)) metrics.update(self.contrastive_loss(domain_latents)) - loss = torch.stack( - [ - metrics[name] * coef - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and coef > 0 - ], - dim=0, - ).mean() - - return LossOutput(loss, metrics) + return LossOutput(combine_loss(metrics, self.loss_coefs), metrics) def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]: @@ -616,31 +662,6 @@ def broadcast_loss( return metrics -class BroadcastLossCoefs(TypedDict, total=False): - """ - Dict of loss coefficients used in the GWLossesFusion. - - If one is not provided, the coefficient is assumed to be 0 and will not be logged. - If the loss is excplicitely set to 0, it will be logged, but not take part in - the total loss. - """ - - contrastives: float - """Contrastive loss coefficient.""" - - fused: float - """fused loss coefficient (encode multiple domains and decode to one of them).""" - - demi_cycles: float - """demi_cycles loss coefficient. Demi-cycles are always one-to-one""" - - cycles: float - """cycles loss coefficient. Cycles can be many-to-one""" - - translations: float - """translation loss coefficient. Translation, like cycles, can be many-to-one.""" - - class GWLosses(GWLossesBase): """ Implementation of `GWLossesBase` for fusion-based models. @@ -651,7 +672,7 @@ def __init__( gw_mod: GWModule, selection_mod: SelectionBase, domain_mods: dict[str, DomainModule], - loss_coefs: BroadcastLossCoefs, + loss_coefs: BroadcastLossCoefs | Mapping[str, float], contrastive_fn: ContrastiveLossType, ): """ @@ -716,14 +737,7 @@ def step( metrics.update(self.contrastive_loss(domain_latents)) metrics.update(self.broadcast_loss(domain_latents, raw_data)) - loss = torch.stack( - [ - metrics[name] * coef - for name, coef in self.loss_coefs.items() - if isinstance(coef, float) and coef > 0 - ], - dim=0, - ).mean() + loss = combine_loss(metrics, self.loss_coefs) metrics["broadcast_loss"] = torch.stack( [