@@ -304,6 +304,46 @@ class LossCoefs(TypedDict, total=False):
304
304
"""Contrastive loss coefficient."""
305
305
306
306
307
+ class BroadcastLossCoefs (TypedDict , total = False ):
308
+ """
309
+ Dict of loss coefficients used in the GWLossesFusion.
310
+
311
+ If one is not provided, the coefficient is assumed to be 0 and will not be logged.
312
+ If the loss is excplicitely set to 0, it will be logged, but not take part in
313
+ the total loss.
314
+ """
315
+
316
+ contrastives : float
317
+ """Contrastive loss coefficient."""
318
+
319
+ fused : float
320
+ """fused loss coefficient (encode multiple domains and decode to one of them)."""
321
+
322
+ demi_cycles : float
323
+ """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
324
+
325
+ cycles : float
326
+ """cycles loss coefficient. Cycles can be many-to-one"""
327
+
328
+ translations : float
329
+ """translation loss coefficient. Translation, like cycles, can be many-to-one."""
330
+
331
+
332
+ def combine_loss (
333
+ metrics : dict [str , torch .Tensor ],
334
+ coefs : Mapping [str , float ] | LossCoefs | BroadcastLossCoefs ,
335
+ ) -> torch .Tensor :
336
+ loss = torch .stack (
337
+ [
338
+ metrics [name ] * coef
339
+ for name , coef in coefs .items ()
340
+ if name in metrics and isinstance (coef , float ) and coef > 0
341
+ ],
342
+ dim = 0 ,
343
+ ).mean ()
344
+ return loss
345
+
346
+
307
347
class GWLosses2Domains (GWLossesBase ):
308
348
"""
309
349
Implementation of `GWLossesBase` used for `GWModule`.
@@ -314,7 +354,7 @@ def __init__(
314
354
gw_mod : GWModule ,
315
355
selection_mod : SelectionBase ,
316
356
domain_mods : dict [str , DomainModule ],
317
- loss_coefs : LossCoefs ,
357
+ loss_coefs : LossCoefs | Mapping [ str , float ] ,
318
358
contrastive_fn : ContrastiveLossType ,
319
359
):
320
360
"""
@@ -440,16 +480,7 @@ def step(
440
480
metrics .update (self .translation_loss (domain_latents , raw_data ))
441
481
metrics .update (self .contrastive_loss (domain_latents ))
442
482
443
- loss = torch .stack (
444
- [
445
- metrics [name ] * coef
446
- for name , coef in self .loss_coefs .items ()
447
- if isinstance (coef , float ) and coef > 0
448
- ],
449
- dim = 0 ,
450
- ).mean ()
451
-
452
- return LossOutput (loss , metrics )
483
+ return LossOutput (combine_loss (metrics , self .loss_coefs ), metrics )
453
484
454
485
455
486
def generate_partitions (n : int ) -> Generator [tuple [int , ...], None , None ]:
@@ -616,31 +647,6 @@ def broadcast_loss(
616
647
return metrics
617
648
618
649
619
- class BroadcastLossCoefs (TypedDict , total = False ):
620
- """
621
- Dict of loss coefficients used in the GWLossesFusion.
622
-
623
- If one is not provided, the coefficient is assumed to be 0 and will not be logged.
624
- If the loss is excplicitely set to 0, it will be logged, but not take part in
625
- the total loss.
626
- """
627
-
628
- contrastives : float
629
- """Contrastive loss coefficient."""
630
-
631
- fused : float
632
- """fused loss coefficient (encode multiple domains and decode to one of them)."""
633
-
634
- demi_cycles : float
635
- """demi_cycles loss coefficient. Demi-cycles are always one-to-one"""
636
-
637
- cycles : float
638
- """cycles loss coefficient. Cycles can be many-to-one"""
639
-
640
- translations : float
641
- """translation loss coefficient. Translation, like cycles, can be many-to-one."""
642
-
643
-
644
650
class GWLosses (GWLossesBase ):
645
651
"""
646
652
Implementation of `GWLossesBase` for fusion-based models.
@@ -651,7 +657,7 @@ def __init__(
651
657
gw_mod : GWModule ,
652
658
selection_mod : SelectionBase ,
653
659
domain_mods : dict [str , DomainModule ],
654
- loss_coefs : BroadcastLossCoefs ,
660
+ loss_coefs : BroadcastLossCoefs | Mapping [ str , float ] ,
655
661
contrastive_fn : ContrastiveLossType ,
656
662
):
657
663
"""
@@ -716,14 +722,7 @@ def step(
716
722
metrics .update (self .contrastive_loss (domain_latents ))
717
723
metrics .update (self .broadcast_loss (domain_latents , raw_data ))
718
724
719
- loss = torch .stack (
720
- [
721
- metrics [name ] * coef
722
- for name , coef in self .loss_coefs .items ()
723
- if isinstance (coef , float ) and coef > 0
724
- ],
725
- dim = 0 ,
726
- ).mean ()
725
+ loss = combine_loss (metrics , self .loss_coefs )
727
726
728
727
metrics ["broadcast_loss" ] = torch .stack (
729
728
[
0 commit comments