diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index d9c9549..4c66ccf 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -479,6 +479,27 @@ def broadcast_loss( """ Computes broadcast loss including demi-cycle, cycle, and translation losses. + This return multiple metrics: + * `demi_cycles` + * `cycles` + * `translations` + * `fused` + * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t,v}_to_t_loss"). + * `from_{start_group}_to_{domain}_{metric}` with + additional metrics provided by the domain_mod's + `compute_broadcast_loss` output + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` + where `{start_group}`, `{target_group}` and `{case_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t}_through_{v}_to_t_case_{t,v}_loss"). `{start_group}` represents the input + domains, `{target_group}` the target domains used for the cycle and + `{case_group}` all available domains participating to the loss. + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` + additional metrics provided by the domain_mod's `compute_broadcast_loss` + output + Args: gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use @@ -488,7 +509,7 @@ def broadcast_loss( Returns: A dictionary with the total loss and additional metrics. - """ + """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -500,19 +521,18 @@ def broadcast_loss( for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) partitions = generate_partitions(len(group_domains)) - domain_names = list(latents) - group_name = "-".join(group_domains) + group_name = "{" + ",".join(sorted(group_domains)) + "}" for partition in partitions: selected_latents = { domain: latents[domain] - for domain, present in zip(domain_names, partition, strict=True) + for domain, present in zip(latents, partition, strict=True) if present } selected_encoded_latents = { domain: encoded_latents[domain] for domain in selected_latents } - selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}" + selected_group_label = "{" + ",".join(sorted(selected_latents)) + "}" selection_scores = selection_mod(selected_latents, selected_encoded_latents) fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores) @@ -560,7 +580,7 @@ def broadcast_loss( } inverse_selected_group_label = ( - "{" + ", ".join(sorted(inverse_selected_latents)) + "}" + "{" + ",".join(sorted(inverse_selected_latents)) + "}" ) re_encoded_latents = gw_mod.encode(inverse_selected_latents)