Skip to content

Commit

Permalink
feat: more flexibility in loss coefs (#178)
Browse files Browse the repository at this point in the history
We were before limited to the top level losses (`translations`,
`contrastives`, ...).
We can now define the loss by selecting any metrics directly from the
coefs.

For example:
```python
{"translation_v_to_t": 5.0, "translation_t_to_v": 1.0}
```
will only use these two components for the total loss.

I kept the LossCoefs and BroadcastLossCoefs classes to avoid breaking
changes, but to
use this new behavior, dicts can now be used directly.
  • Loading branch information
bdvllrs authored Oct 15, 2024
1 parent 81be800 commit d5c8f5e
Show file tree
Hide file tree
Showing 4 changed files with 70 additions and 51 deletions.
2 changes: 2 additions & 0 deletions shimmer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
GWLosses2Domains,
GWLossesBase,
LossCoefs,
combine_loss,
)
from shimmer.modules.selection import (
RandomSelection,
Expand Down Expand Up @@ -84,6 +85,7 @@
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"combine_loss",
"GWLossesBase",
"GWLosses2Domains",
"RepeatedDataset",
Expand Down
2 changes: 2 additions & 0 deletions shimmer/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
GWLosses2Domains,
GWLossesBase,
LossCoefs,
combine_loss,
)
from shimmer.modules.selection import (
RandomSelection,
Expand Down Expand Up @@ -63,6 +64,7 @@
"ContrastiveLoss",
"LossCoefs",
"BroadcastLossCoefs",
"combine_loss",
"GWLossesBase",
"GWLosses2Domains",
"RepeatedDataset",
Expand Down
13 changes: 7 additions & 6 deletions shimmer/modules/global_workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,7 +657,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
loss_coefs: LossCoefs | Mapping[str, float],
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
scheduler_args: SchedulerArgs | None = None,
Expand All @@ -682,7 +682,7 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients
optim_lr (`float`): learning rate
optim_weight_decay (`float`): weight decay
scheduler_args (`SchedulerArgs | None`): optimization scheduler's arguments
Expand Down Expand Up @@ -734,7 +734,7 @@ def __init__(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: BroadcastLossCoefs,
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
selection_temperature: float = 0.2,
optim_lr: float = 1e-3,
optim_weight_decay: float = 0.0,
Expand All @@ -760,7 +760,8 @@ def __init__(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`BroadcastLossCoefs`): loss coefs for the losses.
loss_coefs (`BroadcastLossCoefs | Mapping[str, float]`): loss coefs for the
losses.
selection_temperature (`float`): temperature value for the RandomSelection
module.
optim_lr (`float`): learning rate
Expand Down Expand Up @@ -808,7 +809,7 @@ def pretrained_global_workspace(
gw_encoders: Mapping[str, Module],
gw_decoders: Mapping[str, Module],
workspace_dim: int,
loss_coefs: LossCoefs,
loss_coefs: LossCoefs | Mapping[str, float],
contrastive_fn: ContrastiveLossType,
scheduler: LRScheduler
| None
Expand All @@ -831,7 +832,7 @@ def pretrained_global_workspace(
name to a `torch.nn.Module` class which role is to decode a
GW representation into a unimodal latent representations.
workspace_dim (`int`): dimension of the GW.
loss_coefs (`LossCoefs`): loss coefficients
loss_coefs (`LossCoefs | Mapping[str, float]`): loss coefficients
contrastive_loss (`ContrastiveLossType`): a contrastive loss
function used for alignment. `learn_logit_scale` will not affect custom
contrastive losses.
Expand Down
104 changes: 59 additions & 45 deletions shimmer/modules/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,61 @@ class LossCoefs(TypedDict, total=False):
"""Contrastive loss coefficient."""


class BroadcastLossCoefs(TypedDict, total=False):
"""
Dict of loss coefficients used in the GWLossesFusion.
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
If the loss is excplicitely set to 0, it will be logged, but not take part in
the total loss.
"""

contrastives: float
"""Contrastive loss coefficient."""

fused: float
"""fused loss coefficient (encode multiple domains and decode to one of them)."""

demi_cycles: float
"""demi_cycles loss coefficient. Demi-cycles are always one-to-one"""

cycles: float
"""cycles loss coefficient. Cycles can be many-to-one"""

translations: float
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""


def combine_loss(
metrics: dict[str, torch.Tensor],
coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs,
) -> torch.Tensor:
"""
Combines the metrics according to the ones selected in coefs
Args:
metrics (`dict[str, torch.Tensor]`): all metrics to combine
coefs (`Mapping[str, float] | LossCoefs | BroadcastLossCoefs`): coefs for
selected metrics. Note, every metric does not need to be included here.
If not specified, the metric will not count in the final loss.
Also not that some metrics are redundant (e.g. `translations` contains
all of the `translation_{domain_1}_to_{domain_2}`). You can look at the
docs of the different losses for available values.
Returns:
`torch.Tensor`: the combined loss.
"""
loss = torch.stack(
[
metrics[name] * coef
for name, coef in coefs.items()
if name in metrics and isinstance(coef, float) and coef > 0
],
dim=0,
).mean()
return loss


class GWLosses2Domains(GWLossesBase):
"""
Implementation of `GWLossesBase` used for `GWModule`.
Expand All @@ -314,7 +369,7 @@ def __init__(
gw_mod: GWModule,
selection_mod: SelectionBase,
domain_mods: dict[str, DomainModule],
loss_coefs: LossCoefs,
loss_coefs: LossCoefs | Mapping[str, float],
contrastive_fn: ContrastiveLossType,
):
"""
Expand Down Expand Up @@ -440,16 +495,7 @@ def step(
metrics.update(self.translation_loss(domain_latents, raw_data))
metrics.update(self.contrastive_loss(domain_latents))

loss = torch.stack(
[
metrics[name] * coef
for name, coef in self.loss_coefs.items()
if isinstance(coef, float) and coef > 0
],
dim=0,
).mean()

return LossOutput(loss, metrics)
return LossOutput(combine_loss(metrics, self.loss_coefs), metrics)


def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]:
Expand Down Expand Up @@ -616,31 +662,6 @@ def broadcast_loss(
return metrics


class BroadcastLossCoefs(TypedDict, total=False):
"""
Dict of loss coefficients used in the GWLossesFusion.
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
If the loss is excplicitely set to 0, it will be logged, but not take part in
the total loss.
"""

contrastives: float
"""Contrastive loss coefficient."""

fused: float
"""fused loss coefficient (encode multiple domains and decode to one of them)."""

demi_cycles: float
"""demi_cycles loss coefficient. Demi-cycles are always one-to-one"""

cycles: float
"""cycles loss coefficient. Cycles can be many-to-one"""

translations: float
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""


class GWLosses(GWLossesBase):
"""
Implementation of `GWLossesBase` for fusion-based models.
Expand All @@ -651,7 +672,7 @@ def __init__(
gw_mod: GWModule,
selection_mod: SelectionBase,
domain_mods: dict[str, DomainModule],
loss_coefs: BroadcastLossCoefs,
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
contrastive_fn: ContrastiveLossType,
):
"""
Expand Down Expand Up @@ -716,14 +737,7 @@ def step(
metrics.update(self.contrastive_loss(domain_latents))
metrics.update(self.broadcast_loss(domain_latents, raw_data))

loss = torch.stack(
[
metrics[name] * coef
for name, coef in self.loss_coefs.items()
if isinstance(coef, float) and coef > 0
],
dim=0,
).mean()
loss = combine_loss(metrics, self.loss_coefs)

metrics["broadcast_loss"] = torch.stack(
[
Expand Down

0 comments on commit d5c8f5e

Please sign in to comment.