Skip to content

Commit de2e0c9

Browse files
committed
feat: more flexibility in loss coefs
We were before limited to the top level losses (`translations`, `contrastives`, ...). We can now define the loss by selecting any metrics directly from the coefs. For example: ```python {"translation_v_to_t": 5.0, "translation_t_to_v": 1.0} ``` will only use these two components for the total loss. I kept the LossCoefs and BroadcastLossCoefs classes to avoid breaking changes, but to use this new behavior, dicts can now be used directly.
1 parent 81be800 commit de2e0c9

File tree

3 files changed

+48
-45
lines changed

3 files changed

+48
-45
lines changed

shimmer/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
GWLosses2Domains,
3838
GWLossesBase,
3939
LossCoefs,
40+
combine_loss,
4041
)
4142
from shimmer.modules.selection import (
4243
RandomSelection,
@@ -84,6 +85,7 @@
8485
"ContrastiveLoss",
8586
"LossCoefs",
8687
"BroadcastLossCoefs",
88+
"combine_loss",
8789
"GWLossesBase",
8890
"GWLosses2Domains",
8991
"RepeatedDataset",

shimmer/modules/__init__.py

+2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
GWLosses2Domains,
3232
GWLossesBase,
3333
LossCoefs,
34+
combine_loss,
3435
)
3536
from shimmer.modules.selection import (
3637
RandomSelection,
@@ -63,6 +64,7 @@
6364
"ContrastiveLoss",
6465
"LossCoefs",
6566
"BroadcastLossCoefs",
67+
"combine_loss",
6668
"GWLossesBase",
6769
"GWLosses2Domains",
6870
"RepeatedDataset",

shimmer/modules/losses.py

+44-45
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,46 @@ class LossCoefs(TypedDict, total=False):
304304
"""Contrastive loss coefficient."""
305305

306306

307+
class BroadcastLossCoefs(TypedDict, total=False):
308+
"""
309+
Dict of loss coefficients used in the GWLossesFusion.
310+
311+
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
312+
If the loss is excplicitely set to 0, it will be logged, but not take part in
313+
the total loss.
314+
"""
315+
316+
contrastives: float
317+
"""Contrastive loss coefficient."""
318+
319+
fused: float
320+
"""fused loss coefficient (encode multiple domains and decode to one of them)."""
321+
322+
demi_cycles: float
323+
"""demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
324+
325+
cycles: float
326+
"""cycles loss coefficient. Cycles can be many-to-one"""
327+
328+
translations: float
329+
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""
330+
331+
332+
def combine_loss(
333+
metrics: dict[str, torch.Tensor],
334+
coefs: Mapping[str, float] | LossCoefs | BroadcastLossCoefs,
335+
) -> torch.Tensor:
336+
loss = torch.stack(
337+
[
338+
metrics[name] * coef
339+
for name, coef in coefs.items()
340+
if name in metrics and isinstance(coef, float) and coef > 0
341+
],
342+
dim=0,
343+
).mean()
344+
return loss
345+
346+
307347
class GWLosses2Domains(GWLossesBase):
308348
"""
309349
Implementation of `GWLossesBase` used for `GWModule`.
@@ -314,7 +354,7 @@ def __init__(
314354
gw_mod: GWModule,
315355
selection_mod: SelectionBase,
316356
domain_mods: dict[str, DomainModule],
317-
loss_coefs: LossCoefs,
357+
loss_coefs: LossCoefs | Mapping[str, float],
318358
contrastive_fn: ContrastiveLossType,
319359
):
320360
"""
@@ -440,16 +480,7 @@ def step(
440480
metrics.update(self.translation_loss(domain_latents, raw_data))
441481
metrics.update(self.contrastive_loss(domain_latents))
442482

443-
loss = torch.stack(
444-
[
445-
metrics[name] * coef
446-
for name, coef in self.loss_coefs.items()
447-
if isinstance(coef, float) and coef > 0
448-
],
449-
dim=0,
450-
).mean()
451-
452-
return LossOutput(loss, metrics)
483+
return LossOutput(combine_loss(metrics, self.loss_coefs), metrics)
453484

454485

455486
def generate_partitions(n: int) -> Generator[tuple[int, ...], None, None]:
@@ -616,31 +647,6 @@ def broadcast_loss(
616647
return metrics
617648

618649

619-
class BroadcastLossCoefs(TypedDict, total=False):
620-
"""
621-
Dict of loss coefficients used in the GWLossesFusion.
622-
623-
If one is not provided, the coefficient is assumed to be 0 and will not be logged.
624-
If the loss is excplicitely set to 0, it will be logged, but not take part in
625-
the total loss.
626-
"""
627-
628-
contrastives: float
629-
"""Contrastive loss coefficient."""
630-
631-
fused: float
632-
"""fused loss coefficient (encode multiple domains and decode to one of them)."""
633-
634-
demi_cycles: float
635-
"""demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
636-
637-
cycles: float
638-
"""cycles loss coefficient. Cycles can be many-to-one"""
639-
640-
translations: float
641-
"""translation loss coefficient. Translation, like cycles, can be many-to-one."""
642-
643-
644650
class GWLosses(GWLossesBase):
645651
"""
646652
Implementation of `GWLossesBase` for fusion-based models.
@@ -651,7 +657,7 @@ def __init__(
651657
gw_mod: GWModule,
652658
selection_mod: SelectionBase,
653659
domain_mods: dict[str, DomainModule],
654-
loss_coefs: BroadcastLossCoefs,
660+
loss_coefs: BroadcastLossCoefs | Mapping[str, float],
655661
contrastive_fn: ContrastiveLossType,
656662
):
657663
"""
@@ -716,14 +722,7 @@ def step(
716722
metrics.update(self.contrastive_loss(domain_latents))
717723
metrics.update(self.broadcast_loss(domain_latents, raw_data))
718724

719-
loss = torch.stack(
720-
[
721-
metrics[name] * coef
722-
for name, coef in self.loss_coefs.items()
723-
if isinstance(coef, float) and coef > 0
724-
],
725-
dim=0,
726-
).mean()
725+
loss = combine_loss(metrics, self.loss_coefs)
727726

728727
metrics["broadcast_loss"] = torch.stack(
729728
[

0 commit comments

Comments
 (0)