@@ -479,6 +479,27 @@ def broadcast_loss(
479
479
"""
480
480
Computes broadcast loss including demi-cycle, cycle, and translation losses.
481
481
482
+ This return multiple metrics:
483
+ * `demi_cycles`
484
+ * `cycles`
485
+ * `translations`
486
+ * `fused`
487
+ * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form
488
+ "{domain1, domain2, domainN}" sorted in alphabetical order
489
+ (e.g. "from_{t, v}_to_{t}_loss").
490
+ * `from_{start_group}_to_{domain}_{metric}` with
491
+ additional metrics provided by the domain_mod's
492
+ `compute_broadcast_loss` output
493
+ * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss`
494
+ where `{start_group}`, `{target_group}` and `{case_group}` is of the form
495
+ "{domain1, domain2, domainN}" sorted in alphabetical order
496
+ (e.g. "from_{t, v}_to_{t}_loss"). `{start_group}` represents the input
497
+ domains, `{target_group}` the target domains used for the cycle and
498
+ `{case_group}` all available domains participating to the loss.
499
+ * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}`
500
+ additional metrics provided by the domain_mod's `compute_broadcast_loss`
501
+ output
502
+
482
503
Args:
483
504
gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
484
505
selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
@@ -488,7 +509,7 @@ def broadcast_loss(
488
509
489
510
Returns:
490
511
A dictionary with the total loss and additional metrics.
491
- """
512
+ """ # noqa: E501
492
513
losses : dict [str , torch .Tensor ] = {}
493
514
metrics : dict [str , torch .Tensor ] = {}
494
515
@@ -500,13 +521,12 @@ def broadcast_loss(
500
521
for group_domains , latents in latent_domains .items ():
501
522
encoded_latents = gw_mod .encode (latents )
502
523
partitions = generate_partitions (len (group_domains ))
503
- domain_names = list (latents )
504
- group_name = "-" .join (group_domains )
524
+ group_name = "{" + ", " .join (sorted (group_domains )) + "}"
505
525
506
526
for partition in partitions :
507
527
selected_latents = {
508
528
domain : latents [domain ]
509
- for domain , present in zip (domain_names , partition , strict = True )
529
+ for domain , present in zip (latents , partition , strict = True )
510
530
if present
511
531
}
512
532
selected_encoded_latents = {
0 commit comments