From 80db7ce492a1ce7c7121404a0b67ed5c8ac7fee2 Mon Sep 17 00:00:00 2001 From: bdvllrs Date: Tue, 15 Oct 2024 11:52:21 +0000 Subject: [PATCH] docs: add metrics returned by `broadcast_loss` in the docs This is useful for the more flexible coefs now. Note, this also renames the metrics of the broadcast loss (spaces and group case is now formatted the same way as the other groups). --- shimmer/modules/losses.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/shimmer/modules/losses.py b/shimmer/modules/losses.py index d9c95499..4c66ccf4 100644 --- a/shimmer/modules/losses.py +++ b/shimmer/modules/losses.py @@ -479,6 +479,27 @@ def broadcast_loss( """ Computes broadcast loss including demi-cycle, cycle, and translation losses. + This return multiple metrics: + * `demi_cycles` + * `cycles` + * `translations` + * `fused` + * `from_{start_group}_to_{domain}_loss` where `{start_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t,v}_to_t_loss"). + * `from_{start_group}_to_{domain}_{metric}` with + additional metrics provided by the domain_mod's + `compute_broadcast_loss` output + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_loss` + where `{start_group}`, `{target_group}` and `{case_group}` is of the form + "{domain1,domain2,domainN}" sorted in alphabetical order + (e.g. "from_{t}_through_{v}_to_t_case_{t,v}_loss"). `{start_group}` represents the input + domains, `{target_group}` the target domains used for the cycle and + `{case_group}` all available domains participating to the loss. + * `from_{start_group}_through_{target_group}_to_{domain}_case_{case_group}_{metric}` + additional metrics provided by the domain_mod's `compute_broadcast_loss` + output + Args: gw_mod (`shimmer.modules.gw_module.GWModuleBase`): The GWModule to use selection_mod (`shimmer.modules.selection.SelectionBase`): Selection mod to use @@ -488,7 +509,7 @@ def broadcast_loss( Returns: A dictionary with the total loss and additional metrics. - """ + """ # noqa: E501 losses: dict[str, torch.Tensor] = {} metrics: dict[str, torch.Tensor] = {} @@ -500,19 +521,18 @@ def broadcast_loss( for group_domains, latents in latent_domains.items(): encoded_latents = gw_mod.encode(latents) partitions = generate_partitions(len(group_domains)) - domain_names = list(latents) - group_name = "-".join(group_domains) + group_name = "{" + ",".join(sorted(group_domains)) + "}" for partition in partitions: selected_latents = { domain: latents[domain] - for domain, present in zip(domain_names, partition, strict=True) + for domain, present in zip(latents, partition, strict=True) if present } selected_encoded_latents = { domain: encoded_latents[domain] for domain in selected_latents } - selected_group_label = "{" + ", ".join(sorted(selected_latents)) + "}" + selected_group_label = "{" + ",".join(sorted(selected_latents)) + "}" selection_scores = selection_mod(selected_latents, selected_encoded_latents) fused_latents = gw_mod.fuse(selected_encoded_latents, selection_scores) @@ -560,7 +580,7 @@ def broadcast_loss( } inverse_selected_group_label = ( - "{" + ", ".join(sorted(inverse_selected_latents)) + "}" + "{" + ",".join(sorted(inverse_selected_latents)) + "}" ) re_encoded_latents = gw_mod.encode(inverse_selected_latents)