Skip to content

Commit c0bea5e

Browse files
committed
docs: add metrics returned by broadcast_loss in the docs
This is useful for the more flexible coefs now. Note, this also renames the metrics of the broadcast loss (group case is now formatted the same way as the other groups).
1 parent 81be800 commit c0bea5e

File tree

1 file changed

+24
-4
lines changed

1 file changed

+24
-4
lines changed

shimmer/modules/losses.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,27 @@ def broadcast_loss(
479479
"""
480480
Computes broadcast loss including demi-cycle, cycle, and translation losses.
481481
482+
This return multiple metrics:
483+
* `demi_cycles`
484+
* `cycles`
485+
* `translations`
486+
* `fused`
487+
* `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form
488+
"{domain1, domain2, domainN}" sorted in alphabetical order
489+
(e.g. "from_{t, v}_to_{t}_loss").
490+
* `from_{start_group}_to_{domain}_{metric}` with
491+
additional metrics provided by the domain_mod's
492+
`compute_broadcast_loss` output
493+
* `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss`
494+
where `{start_group}`, `{target_group}` and `{case_group}` is of the form
495+
"{domain1, domain2, domainN}" sorted in alphabetical order
496+
(e.g. "from_{t, v}_to_{t}_loss"). `{start_group}` represents the input
497+
domains, `{target_group}` the target domains used for the cycle and
498+
`{case_group}` all available domains participating to the loss.
499+
* `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}`
500+
additional metrics provided by the domain_mod's `compute_broadcast_loss`
501+
output
502+
482503
Args:
483504
gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use
484505
selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use
@@ -488,7 +509,7 @@ def broadcast_loss(
488509
489510
Returns:
490511
A dictionary with the total loss and additional metrics.
491-
"""
512+
""" # noqa: E501
492513
losses: dict[str, torch.Tensor] = {}
493514
metrics: dict[str, torch.Tensor] = {}
494515

@@ -500,13 +521,12 @@ def broadcast_loss(
500521
for group_domains, latents in latent_domains.items():
501522
encoded_latents = gw_mod.encode(latents)
502523
partitions = generate_partitions(len(group_domains))
503-
domain_names = list(latents)
504-
group_name = "-".join(group_domains)
524+
group_name = "{" + ", ".join(sorted(group_domains)) + "}"
505525

506526
for partition in partitions:
507527
selected_latents = {
508528
domain: latents[domain]
509-
for domain, present in zip(domain_names, partition, strict=True)
529+
for domain, present in zip(latents, partition, strict=True)
510530
if present
511531
}
512532
selected_encoded_latents = {

0 commit comments

Comments
 (0)